autograd.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # mypy: allow-untyped-defs
  2. import functools
  3. from collections import namedtuple
  4. import torch
  5. import torch.utils._pytree as pytree
  6. # NOTE [CustomOp autograd kernel indirection]
  7. # We register `inner` as the autograd kernel for this custom_op.
  8. # `inner` either calls the autograd formula registered by the user,
  9. # or goes into an `autograd_not_implemented` kernel.
  10. #
  11. # The reason why this indirection exists is
  12. # so that we can swap out the autograd kernel (the PyTorch dispatcher
  13. # doesn't actually allow us to do this). By default, we want
  14. # the `autograd_not_implemented` behavior, but then the user may come
  15. # and register something that is actually a backward formula
  16. def autograd_kernel_indirection(custom_op):
  17. autograd_fallback = autograd_not_implemented(custom_op)
  18. def inner(*args, **kwargs):
  19. if custom_op._has_impl("autograd"):
  20. kernel = custom_op._get_impl("autograd").func
  21. return kernel(*args, **kwargs)
  22. # As explained in NOTE ["backward", "save_for_backward", and "autograd"],
  23. # after the user gives us "backward" and "save_for_backward", we generate
  24. # the "autograd" impl. If the user only provided one, then we tell
  25. # the user they've done something wrong.
  26. if custom_op._has_impl("save_for_backward") or custom_op._has_impl("backward"):
  27. missing = (
  28. "save_for_backward" if custom_op._has_impl("backward") else "backward"
  29. )
  30. found = "save_for_backward" if missing == "backward" else "backward"
  31. loc = custom_op._get_impl(found).location
  32. raise RuntimeError(
  33. f"We found a '{found}' registration for {custom_op} at "
  34. f"{loc} but were unable to find a '{missing}' registration. "
  35. f"To use the CustomOp API to register a backward formula, "
  36. f"please provide us both a backward function and a "
  37. f"'save for backward' function via `impl_backward` and "
  38. f"`impl_save_for_backward` respectively."
  39. )
  40. return autograd_fallback(*args, **kwargs)
  41. return inner
  42. # TODO(#101191): Use the actual C++ autograd not implemented fallback,
  43. # or change the default autograd fallback to the autograd not implemented fallback.
  44. def autograd_not_implemented(custom_op):
  45. def kernel(*args, **kwargs):
  46. if torch.is_grad_enabled() and pytree.tree_any(
  47. lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
  48. ):
  49. raise RuntimeError("Autograd has not been implemented for operator")
  50. with torch._C._AutoDispatchBelowAutograd():
  51. return custom_op(*args, **kwargs)
  52. return kernel
  53. def mark_non_differentiable(ctx, output, output_differentiability):
  54. # Output types are restricted to be:
  55. # - Tensor
  56. # - Tensor[]
  57. # - int, bool, Scalar, float
  58. # See _check_can_register_backward
  59. if output_differentiability is not None:
  60. if not isinstance(output, tuple):
  61. tuple_output = (output,)
  62. else:
  63. tuple_output = output # type: ignore[assignment]
  64. if len(output_differentiability) != len(tuple_output):
  65. raise AssertionError(
  66. f"output_differentiability length {len(output_differentiability)} "
  67. f"!= output length {len(tuple_output)}"
  68. )
  69. non_differentiable_tensors = []
  70. for idx, (differentiable, out) in enumerate(
  71. zip(output_differentiability, tuple_output)
  72. ):
  73. if isinstance(out, torch.Tensor):
  74. if not differentiable:
  75. non_differentiable_tensors.append(out)
  76. continue
  77. if isinstance(out, list):
  78. if not differentiable:
  79. non_differentiable_tensors.extend(out)
  80. continue
  81. if differentiable:
  82. raise RuntimeError(
  83. f"With output_differentiability={output_differentiability}. "
  84. f"At idx {idx}, we received an object of type {type(out)} that "
  85. f"is not a Tensor, so it cannot have be marked as differentiable in "
  86. f"output_differentiability."
  87. )
  88. if non_differentiable_tensors:
  89. ctx.mark_non_differentiable(*non_differentiable_tensors)
  90. def construct_autograd_kernel(
  91. schema,
  92. output_differentiability,
  93. custom_op,
  94. op_overload,
  95. save_for_backward_fn,
  96. backward_fn,
  97. ):
  98. def apply(*args):
  99. flat_args, spec = pytree.tree_flatten(args)
  100. out_spec = None
  101. def forward(ctx, *flat_args):
  102. ctx.set_materialize_grads(True)
  103. args = pytree.tree_unflatten(list(flat_args), spec)
  104. with torch._C._AutoDispatchBelowAutograd():
  105. output = op_overload(*args)
  106. # We use the info about args to give better error messages in backward
  107. args_info = namedtuple_args(schema, pytree.tree_map(type, args))
  108. save_for_backward_fn_inputs = namedtuple_args(schema, args)
  109. to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
  110. save_pytree_for_backward(ctx, (to_save, args_info))
  111. mark_non_differentiable(ctx, output, output_differentiability)
  112. nonlocal out_spec
  113. flat_output, out_spec = pytree.tree_flatten(output)
  114. return tuple(flat_output)
  115. def backward(ctx, *flat_grad_output):
  116. if out_spec is None:
  117. raise AssertionError("out_spec is unexpectedly None")
  118. grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
  119. saved, args_info = unpack_saved(ctx)
  120. # There is nothing on the ctx object for now, it is just there so
  121. # that we can add additional things in the future.
  122. inner_ctx = object()
  123. if not isinstance(grads, tuple):
  124. grads = (grads,)
  125. grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
  126. # Massage the grad_inputs_dict to a form acceptable by
  127. # autograd.Function.
  128. validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
  129. return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
  130. generated_cls = gen_autograd_function(
  131. custom_op._opname + "_customop", forward, backward
  132. )
  133. flat_output = generated_cls.apply(*flat_args)
  134. if out_spec is None:
  135. raise AssertionError("out_spec is unexpectedly None")
  136. return pytree.tree_unflatten(list(flat_output), out_spec)
  137. return apply
  138. def gen_autograd_function(name, forward, backward):
  139. generated_cls = type(
  140. name,
  141. (torch.autograd.Function,),
  142. {
  143. "forward": staticmethod(forward),
  144. "backward": staticmethod(backward),
  145. },
  146. )
  147. return generated_cls
  148. @functools.lru_cache
  149. def namedtuple_args_cls(schema):
  150. attribs = [arg.name for arg in schema.arguments.flat_all]
  151. name = str(schema.name) + "_args"
  152. # mypy doesn't support dynamic namedtuple name
  153. tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
  154. return tuple_cls
  155. def namedtuple_args(schema, args):
  156. if not isinstance(args, tuple):
  157. raise AssertionError(f"expected tuple, got {type(args)}")
  158. tuple_cls = namedtuple_args_cls(schema)
  159. return tuple_cls(*args)
  160. def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
  161. def error(what):
  162. backward = forward_op._get_impl("backward")
  163. raise RuntimeError(
  164. f"In the backward function defined for {forward_op} at "
  165. f"{backward.location} using the CustomOp API, {what}"
  166. )
  167. if not isinstance(grad_inputs_dict, dict):
  168. error(
  169. f"expected the output of the backward function to be a dict but "
  170. f"got {type(grad_inputs_dict)}"
  171. )
  172. expected_keys = {
  173. arg.name
  174. for arg in forward_op._schema.arguments.flat_all
  175. if arg.type.is_tensor_like()
  176. }
  177. actual_keys = grad_inputs_dict.keys()
  178. if expected_keys != actual_keys:
  179. error(
  180. f"expected the returned grad_input dict to have keys "
  181. f"{expected_keys} but got {actual_keys}. The backward "
  182. f"function must return a gradient (can be None) for each arg "
  183. f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
  184. f"Args declared to be non-Tensor-like types should not appear "
  185. f"in the grad_input dict"
  186. )
  187. for name, grad in grad_inputs_dict.items():
  188. arg_info = getattr(args_info, name)
  189. if isinstance(arg_info, list):
  190. if not isinstance(grad, (tuple, list)):
  191. error(
  192. f"for input '{name}' expected the grad_input dict to "
  193. f"hold a list of gradients but got object of type "
  194. f"{type(grad)}."
  195. )
  196. if len(grad) != len(arg_info):
  197. error(
  198. f"for input '{name}' expected the grad_input dict to "
  199. f"hold a list of {len(arg_info)} gradients but got "
  200. f"{len(grad)}"
  201. )
  202. for idx, (g, info) in enumerate(zip(grad, arg_info)):
  203. if g is None:
  204. continue
  205. if not isinstance(g, torch.Tensor):
  206. error(
  207. f"for input '{name}' expected the grad_input dict to "
  208. f"hold a list of None or Tensor gradients but got "
  209. f"object of {type(g)} at index {idx}"
  210. )
  211. if not issubclass(info, torch.Tensor):
  212. error(
  213. f"for input '{name}', got a Tensor as the gradient "
  214. f"for the {idx}-th value but expected None because "
  215. f"the {idx}-th value was not a Tensor (it was "
  216. f"type {arg_info}"
  217. )
  218. continue
  219. if grad is None:
  220. continue
  221. if not isinstance(grad, torch.Tensor):
  222. error(
  223. f"got object of type {type(grad)} as the gradient for input "
  224. f"'{name}', "
  225. f"but expected the gradient to be either None or a Tensor"
  226. )
  227. if not issubclass(arg_info, torch.Tensor):
  228. error(
  229. f"got a Tensor as the gradient for input '{name}' but "
  230. f"expected None as the gradient because input '{name}' "
  231. f"was not a Tensor (it was type {arg_info})."
  232. )
  233. def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
  234. result = []
  235. for name, arg_info in args_info._asdict().items():
  236. if name not in grad_inputs_dict:
  237. result.append(pytree.tree_map(lambda x: None, arg_info))
  238. continue
  239. result.append(grad_inputs_dict[name])
  240. return tuple(pytree.tree_leaves(result))
  241. # Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
  242. # autograd.Function prefers that users use ctx.save_for_backward to
  243. # save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
  244. # ctx object.
  245. def save_pytree_for_backward(ctx, stuff):
  246. flat_stuff, spec = pytree.tree_flatten(stuff)
  247. num_elts = len(flat_stuff)
  248. tensor_idxs = [
  249. idx for idx, thing in enumerate(flat_stuff) if isinstance(thing, torch.Tensor)
  250. ]
  251. non_tensor_idxs = [
  252. idx
  253. for idx, thing in enumerate(flat_stuff)
  254. if not isinstance(thing, torch.Tensor)
  255. ]
  256. tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
  257. non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
  258. ctx.spec = spec
  259. ctx.num_elts = num_elts
  260. ctx.save_for_backward(*tensors)
  261. ctx.tensor_idxs = tensor_idxs
  262. ctx.saved_non_tensors = non_tensors
  263. ctx.non_tensor_idxs = non_tensor_idxs
  264. # Inverse operation to save_pytree_for_backward
  265. def unpack_saved(ctx):
  266. flat_stuff = [None] * ctx.num_elts
  267. for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
  268. flat_stuff[idx] = tensor
  269. for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
  270. flat_stuff[idx] = non_tensor
  271. stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
  272. return stuff