autograd.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. from collections.abc import Callable
  4. from dataclasses import dataclass
  5. from typing import Any, Optional, Protocol
  6. from torch import _C, _ops, autograd, Tensor
  7. from torch.utils import _pytree
  8. from . import utils
  9. class InfoProtocol(Protocol):
  10. _backward_fn: Optional[Callable]
  11. _setup_context_fn: Optional[Callable]
  12. @dataclasses.dataclass
  13. class Info:
  14. _backward_fn: Optional[Callable]
  15. _setup_context_fn: Optional[Callable]
  16. def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable:
  17. name: str = f"GeneratedBackwardFor_{op._namespace}_{op._opname}_{op._overloadname}"
  18. has_kwarg_only_args = utils.has_kwarg_only_args(op._schema)
  19. @dataclass
  20. class Metadata:
  21. keyset: _C.DispatchKeySet
  22. keyword_only_args: dict[str, Any]
  23. def forward_no_grad(*args):
  24. metadata = args[-1]
  25. args = args[:-1]
  26. with _C._AutoDispatchBelowAutograd():
  27. keyset = metadata.keyset
  28. kwargs = metadata.keyword_only_args
  29. result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
  30. return result
  31. def forward(ctx, *args):
  32. metadata = args[-1]
  33. args = args[:-1]
  34. with _C._AutoDispatchBelowAutograd():
  35. keyset = metadata.keyset
  36. kwargs = metadata.keyword_only_args
  37. result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
  38. if info._setup_context_fn:
  39. # The Dispatcher will remove args that are equal to their default
  40. # values from (args, kwargs). We're going to add it back so that
  41. # the user can access them.
  42. #
  43. # This is OK to do: The Dispatcher removed the args for serialization
  44. # FC/BC reasons (that is, a graph will not store args that are equal
  45. # to their default values), but that doesn't matter here. If the user
  46. # adds a new default arg, then they must update
  47. # their setup_context (along with the rest of their operator
  48. # registrations)
  49. args, kwargs = utils.fill_defaults(op._schema, args, kwargs)
  50. if has_kwarg_only_args:
  51. info._setup_context_fn(
  52. ctx=ctx, inputs=args, keyword_only_inputs=kwargs, output=result
  53. )
  54. else:
  55. info._setup_context_fn(ctx=ctx, inputs=args, output=result)
  56. return result
  57. def backward(ctx, *grads):
  58. if info._backward_fn:
  59. try:
  60. prev_needs_input_grad = ctx.needs_input_grad
  61. ctx.needs_input_grad = ctx.needs_input_grad[:-1]
  62. result = info._backward_fn(ctx, *grads)
  63. finally:
  64. ctx.needs_input_grad = prev_needs_input_grad
  65. if isinstance(result, tuple):
  66. return (*result, None)
  67. return result, None
  68. raise RuntimeError(
  69. f"Trying to backward through {op} but no autograd "
  70. f"formula was registered. "
  71. f"Please use register_autograd to add one."
  72. )
  73. Generated = type(
  74. name,
  75. (autograd.Function,),
  76. {
  77. "forward": staticmethod(forward),
  78. "backward": staticmethod(backward),
  79. },
  80. )
  81. schema = op._schema
  82. if any(
  83. utils.is_tensorlist_like_type(a.type)
  84. for a in (*schema.arguments, *schema.returns)
  85. ):
  86. Generated = supports_tensorlist(Generated)
  87. # The dispatcher passes any keyword-only-args as kwargs and the
  88. # rest of the args (even if specified as kwargs) as args.
  89. def autograd_impl(keyset, *args, **keyword_only_args):
  90. if _C.is_grad_enabled() and _C._any_requires_grad(*args):
  91. result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined]
  92. else:
  93. result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
  94. return result
  95. return autograd_impl
  96. def supports_tensorlist(cls: Any) -> Any:
  97. """Allows a given autograd.Function class to support List[Tensor] inputs/outputs.
  98. Regular autograd.Function has a constraint that it only directly supports autograd for
  99. Tensors. Applying @supports_tensorlist enables an autograd.Function to support
  100. autograd for List[Tensor] inputs and outputs.
  101. """
  102. orig_forward = cls.forward
  103. orig_backward = cls.backward
  104. orig_apply = cls.apply
  105. @dataclass
  106. class Metadata:
  107. input_spec: _pytree.TreeSpec
  108. output_spec: Optional[_pytree.TreeSpec] = None
  109. result_is_tuple: Optional[bool] = None
  110. def new_forward(ctx, *args):
  111. metadata = args[-1]
  112. args = args[:-1]
  113. if not isinstance(metadata, Metadata):
  114. raise NotImplementedError(
  115. "NYI: calling supports_tensorlist autograd.Function.forward directly. "
  116. "You should probably be calling .apply instead. "
  117. "Please file an issue if not."
  118. )
  119. args = _pytree.tree_unflatten(list(args), metadata.input_spec)
  120. result = orig_forward(ctx, *args)
  121. metadata.result_is_tuple = isinstance(result, tuple)
  122. if not metadata.result_is_tuple:
  123. result = (result,)
  124. flat_result, output_spec = _pytree.tree_flatten(result, not_list_of_tensor)
  125. metadata.output_spec = output_spec
  126. if hasattr(ctx, "_pt_metadata"):
  127. raise RuntimeError(
  128. "Please don't set ctx._pt_metadata; PyTorch uses it to store info"
  129. )
  130. ctx._pt_metadata = metadata
  131. return tuple(flat_result)
  132. def new_backward(ctx, *grads):
  133. if not hasattr(ctx, "_pt_metadata"):
  134. raise NotImplementedError(
  135. "NYI: calling supports_tensorlist autograd.Function.backward directly. "
  136. "This will automatically get called by PyTorch autograd. "
  137. "Please file an issue if you need this."
  138. )
  139. metadata = ctx._pt_metadata
  140. grads = _pytree.tree_unflatten(list(grads), metadata.output_spec)
  141. # If the user's input is ([x, y, z], w),
  142. # then needs_input_grad is (bool, bool, bool, bool, bool).
  143. # We need to
  144. # 1. get rid of the additional bool (which comes from the extra
  145. # `metadata input`)
  146. # 2. _pytree.tree_unflatten to get the right structure.
  147. prev_needs_input_grad = ctx.needs_input_grad
  148. try:
  149. ctx.needs_input_grad = _pytree.tree_unflatten(
  150. list(ctx.needs_input_grad[:-1]), metadata.input_spec
  151. )
  152. grad_inputs = orig_backward(ctx, *grads)
  153. finally:
  154. ctx.needs_input_grad = prev_needs_input_grad
  155. if not isinstance(grad_inputs, tuple):
  156. grad_inputs = (grad_inputs,)
  157. # Assume that any Nones in the backward are Tensors.
  158. # If the forward has an arg that is [1, 2, 3], the backward should
  159. # return None as the grad.
  160. # If the forward has an arg that is [tensor, tensor], the backward
  161. # may return [None, None], [grad, None], [None, grad], or [grad, grad].
  162. flat_grad_inputs, grad_inputs_spec = _pytree.tree_flatten(
  163. grad_inputs, not_list_of_optional_tensor
  164. )
  165. if grad_inputs_spec != metadata.input_spec:
  166. raise RuntimeError(
  167. f"Expected the return from backward to be of the same structure "
  168. f"as the inputs. Got: {grad_inputs_spec} (return from backward), "
  169. f"{metadata.input_spec} (inputs)"
  170. )
  171. return tuple(flat_grad_inputs + [None])
  172. def new_apply(*args):
  173. flat_args, input_spec = _pytree.tree_flatten(args, is_leaf=not_list_of_tensor)
  174. metadata = Metadata(input_spec)
  175. result = orig_apply(*flat_args, metadata) # type: ignore[misc]
  176. if metadata.output_spec is None:
  177. raise AssertionError("metadata.output_spec must not be None")
  178. result = _pytree.tree_unflatten(list(result), metadata.output_spec)
  179. if not metadata.result_is_tuple:
  180. if not isinstance(result, tuple):
  181. raise AssertionError(f"result must be tuple, got {type(result)}")
  182. if len(result) != 1:
  183. raise AssertionError(
  184. f"result tuple must have length 1, got {len(result)}"
  185. )
  186. return result[0]
  187. return result
  188. cls.forward = new_forward
  189. cls.backward = new_backward
  190. cls.apply = new_apply
  191. return cls
  192. def not_list_of_tensor(tree):
  193. if isinstance(tree, tuple):
  194. return False
  195. if isinstance(tree, list):
  196. return any(not isinstance(l, Tensor) for l in tree)
  197. return True
  198. def not_list_of_optional_tensor(tree):
  199. if isinstance(tree, tuple):
  200. return False
  201. if isinstance(tree, list):
  202. return any(l is not None and not isinstance(l, Tensor) for l in tree)
  203. return True