| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314 |
- # mypy: allow-untyped-defs
- import functools
- from collections import namedtuple
- import torch
- import torch.utils._pytree as pytree
- # NOTE [CustomOp autograd kernel indirection]
- # We register `inner` as the autograd kernel for this custom_op.
- # `inner` either calls the autograd formula registered by the user,
- # or goes into an `autograd_not_implemented` kernel.
- #
- # The reason why this indirection exists is
- # so that we can swap out the autograd kernel (the PyTorch dispatcher
- # doesn't actually allow us to do this). By default, we want
- # the `autograd_not_implemented` behavior, but then the user may come
- # and register something that is actually a backward formula
- def autograd_kernel_indirection(custom_op):
- autograd_fallback = autograd_not_implemented(custom_op)
- def inner(*args, **kwargs):
- if custom_op._has_impl("autograd"):
- kernel = custom_op._get_impl("autograd").func
- return kernel(*args, **kwargs)
- # As explained in NOTE ["backward", "save_for_backward", and "autograd"],
- # after the user gives us "backward" and "save_for_backward", we generate
- # the "autograd" impl. If the user only provided one, then we tell
- # the user they've done something wrong.
- if custom_op._has_impl("save_for_backward") or custom_op._has_impl("backward"):
- missing = (
- "save_for_backward" if custom_op._has_impl("backward") else "backward"
- )
- found = "save_for_backward" if missing == "backward" else "backward"
- loc = custom_op._get_impl(found).location
- raise RuntimeError(
- f"We found a '{found}' registration for {custom_op} at "
- f"{loc} but were unable to find a '{missing}' registration. "
- f"To use the CustomOp API to register a backward formula, "
- f"please provide us both a backward function and a "
- f"'save for backward' function via `impl_backward` and "
- f"`impl_save_for_backward` respectively."
- )
- return autograd_fallback(*args, **kwargs)
- return inner
- # TODO(#101191): Use the actual C++ autograd not implemented fallback,
- # or change the default autograd fallback to the autograd not implemented fallback.
- def autograd_not_implemented(custom_op):
- def kernel(*args, **kwargs):
- if torch.is_grad_enabled() and pytree.tree_any(
- lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
- ):
- raise RuntimeError("Autograd has not been implemented for operator")
- with torch._C._AutoDispatchBelowAutograd():
- return custom_op(*args, **kwargs)
- return kernel
- def mark_non_differentiable(ctx, output, output_differentiability):
- # Output types are restricted to be:
- # - Tensor
- # - Tensor[]
- # - int, bool, Scalar, float
- # See _check_can_register_backward
- if output_differentiability is not None:
- if not isinstance(output, tuple):
- tuple_output = (output,)
- else:
- tuple_output = output # type: ignore[assignment]
- if len(output_differentiability) != len(tuple_output):
- raise AssertionError(
- f"output_differentiability length {len(output_differentiability)} "
- f"!= output length {len(tuple_output)}"
- )
- non_differentiable_tensors = []
- for idx, (differentiable, out) in enumerate(
- zip(output_differentiability, tuple_output)
- ):
- if isinstance(out, torch.Tensor):
- if not differentiable:
- non_differentiable_tensors.append(out)
- continue
- if isinstance(out, list):
- if not differentiable:
- non_differentiable_tensors.extend(out)
- continue
- if differentiable:
- raise RuntimeError(
- f"With output_differentiability={output_differentiability}. "
- f"At idx {idx}, we received an object of type {type(out)} that "
- f"is not a Tensor, so it cannot have be marked as differentiable in "
- f"output_differentiability."
- )
- if non_differentiable_tensors:
- ctx.mark_non_differentiable(*non_differentiable_tensors)
- def construct_autograd_kernel(
- schema,
- output_differentiability,
- custom_op,
- op_overload,
- save_for_backward_fn,
- backward_fn,
- ):
- def apply(*args):
- flat_args, spec = pytree.tree_flatten(args)
- out_spec = None
- def forward(ctx, *flat_args):
- ctx.set_materialize_grads(True)
- args = pytree.tree_unflatten(list(flat_args), spec)
- with torch._C._AutoDispatchBelowAutograd():
- output = op_overload(*args)
- # We use the info about args to give better error messages in backward
- args_info = namedtuple_args(schema, pytree.tree_map(type, args))
- save_for_backward_fn_inputs = namedtuple_args(schema, args)
- to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
- save_pytree_for_backward(ctx, (to_save, args_info))
- mark_non_differentiable(ctx, output, output_differentiability)
- nonlocal out_spec
- flat_output, out_spec = pytree.tree_flatten(output)
- return tuple(flat_output)
- def backward(ctx, *flat_grad_output):
- if out_spec is None:
- raise AssertionError("out_spec is unexpectedly None")
- grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
- saved, args_info = unpack_saved(ctx)
- # There is nothing on the ctx object for now, it is just there so
- # that we can add additional things in the future.
- inner_ctx = object()
- if not isinstance(grads, tuple):
- grads = (grads,)
- grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
- # Massage the grad_inputs_dict to a form acceptable by
- # autograd.Function.
- validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
- return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
- generated_cls = gen_autograd_function(
- custom_op._opname + "_customop", forward, backward
- )
- flat_output = generated_cls.apply(*flat_args)
- if out_spec is None:
- raise AssertionError("out_spec is unexpectedly None")
- return pytree.tree_unflatten(list(flat_output), out_spec)
- return apply
- def gen_autograd_function(name, forward, backward):
- generated_cls = type(
- name,
- (torch.autograd.Function,),
- {
- "forward": staticmethod(forward),
- "backward": staticmethod(backward),
- },
- )
- return generated_cls
- @functools.lru_cache
- def namedtuple_args_cls(schema):
- attribs = [arg.name for arg in schema.arguments.flat_all]
- name = str(schema.name) + "_args"
- # mypy doesn't support dynamic namedtuple name
- tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
- return tuple_cls
- def namedtuple_args(schema, args):
- if not isinstance(args, tuple):
- raise AssertionError(f"expected tuple, got {type(args)}")
- tuple_cls = namedtuple_args_cls(schema)
- return tuple_cls(*args)
- def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
- def error(what):
- backward = forward_op._get_impl("backward")
- raise RuntimeError(
- f"In the backward function defined for {forward_op} at "
- f"{backward.location} using the CustomOp API, {what}"
- )
- if not isinstance(grad_inputs_dict, dict):
- error(
- f"expected the output of the backward function to be a dict but "
- f"got {type(grad_inputs_dict)}"
- )
- expected_keys = {
- arg.name
- for arg in forward_op._schema.arguments.flat_all
- if arg.type.is_tensor_like()
- }
- actual_keys = grad_inputs_dict.keys()
- if expected_keys != actual_keys:
- error(
- f"expected the returned grad_input dict to have keys "
- f"{expected_keys} but got {actual_keys}. The backward "
- f"function must return a gradient (can be None) for each arg "
- f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
- f"Args declared to be non-Tensor-like types should not appear "
- f"in the grad_input dict"
- )
- for name, grad in grad_inputs_dict.items():
- arg_info = getattr(args_info, name)
- if isinstance(arg_info, list):
- if not isinstance(grad, (tuple, list)):
- error(
- f"for input '{name}' expected the grad_input dict to "
- f"hold a list of gradients but got object of type "
- f"{type(grad)}."
- )
- if len(grad) != len(arg_info):
- error(
- f"for input '{name}' expected the grad_input dict to "
- f"hold a list of {len(arg_info)} gradients but got "
- f"{len(grad)}"
- )
- for idx, (g, info) in enumerate(zip(grad, arg_info)):
- if g is None:
- continue
- if not isinstance(g, torch.Tensor):
- error(
- f"for input '{name}' expected the grad_input dict to "
- f"hold a list of None or Tensor gradients but got "
- f"object of {type(g)} at index {idx}"
- )
- if not issubclass(info, torch.Tensor):
- error(
- f"for input '{name}', got a Tensor as the gradient "
- f"for the {idx}-th value but expected None because "
- f"the {idx}-th value was not a Tensor (it was "
- f"type {arg_info}"
- )
- continue
- if grad is None:
- continue
- if not isinstance(grad, torch.Tensor):
- error(
- f"got object of type {type(grad)} as the gradient for input "
- f"'{name}', "
- f"but expected the gradient to be either None or a Tensor"
- )
- if not issubclass(arg_info, torch.Tensor):
- error(
- f"got a Tensor as the gradient for input '{name}' but "
- f"expected None as the gradient because input '{name}' "
- f"was not a Tensor (it was type {arg_info})."
- )
- def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
- result = []
- for name, arg_info in args_info._asdict().items():
- if name not in grad_inputs_dict:
- result.append(pytree.tree_map(lambda x: None, arg_info))
- continue
- result.append(grad_inputs_dict[name])
- return tuple(pytree.tree_leaves(result))
- # Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
- # autograd.Function prefers that users use ctx.save_for_backward to
- # save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
- # ctx object.
- def save_pytree_for_backward(ctx, stuff):
- flat_stuff, spec = pytree.tree_flatten(stuff)
- num_elts = len(flat_stuff)
- tensor_idxs = [
- idx for idx, thing in enumerate(flat_stuff) if isinstance(thing, torch.Tensor)
- ]
- non_tensor_idxs = [
- idx
- for idx, thing in enumerate(flat_stuff)
- if not isinstance(thing, torch.Tensor)
- ]
- tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
- non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
- ctx.spec = spec
- ctx.num_elts = num_elts
- ctx.save_for_backward(*tensors)
- ctx.tensor_idxs = tensor_idxs
- ctx.saved_non_tensors = non_tensors
- ctx.non_tensor_idxs = non_tensor_idxs
- # Inverse operation to save_pytree_for_backward
- def unpack_saved(ctx):
- flat_stuff = [None] * ctx.num_elts
- for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
- flat_stuff[idx] = tensor
- for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
- flat_stuff[idx] = non_tensor
- stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
- return stuff
|