function.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import inspect
  4. import itertools
  5. import warnings
  6. from collections import OrderedDict
  7. from collections.abc import Callable
  8. from typing import Any, Concatenate, Optional, TypeVar
  9. from typing_extensions import deprecated, ParamSpec
  10. import torch
  11. import torch._C as _C
  12. import torch._functorch as _functorch
  13. import torch.utils.hooks as hooks
  14. from torch._C import _functions
  15. from torch._functorch.autograd_function import custom_function_call
  16. __all__ = [
  17. "FunctionCtx",
  18. "BackwardCFunction",
  19. "FunctionMeta",
  20. "Function",
  21. "once_differentiable",
  22. "InplaceFunction",
  23. "NestedIOFunction",
  24. ]
  25. # Unique id provider for each class inheriting from Function
  26. # This is incremented in FunctionMeta during class definition
  27. AUTOGRAD_FUNCTION_COUNTER = itertools.count()
  28. _T = TypeVar("_T")
  29. _R = TypeVar("_R")
  30. _P = ParamSpec("_P")
  31. # Formerly known as: _ContextMethodMixin
  32. class FunctionCtx:
  33. def save_for_backward(self, *tensors: torch.Tensor):
  34. r"""Save given tensors for a future call to :func:`~Function.backward`.
  35. ``save_for_backward`` should be called at most once, in either the
  36. :func:`setup_context` or :func:`forward` methods, and only with tensors.
  37. All tensors intended to be used in the backward pass should be saved
  38. with ``save_for_backward`` (as opposed to directly on ``ctx``) to prevent
  39. incorrect gradients and memory leaks, and enable the application of saved
  40. tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`.
  41. See :ref:`extending-autograd` for more details.
  42. Note that if intermediary tensors, tensors that are neither inputs
  43. nor outputs of :func:`forward`, are saved for backward, your custom Function
  44. may not support double backward.
  45. Custom Functions that do not support double backward should decorate their
  46. :func:`backward` method with ``@once_differentiable`` so that performing
  47. double backward raises an error. If you'd like to support double backward,
  48. you can either recompute intermediaries based on the inputs during backward
  49. or return the intermediaries as the outputs of the custom Function. See the
  50. `double backward tutorial <https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html>`_
  51. for more details.
  52. In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors`
  53. attribute. Before returning them to the user, a check is made to ensure
  54. they weren't used in any in-place operation that modified their content.
  55. Arguments can also be ``None``. This is a no-op.
  56. See :ref:`extending-autograd` for more details on how to use this method.
  57. Example::
  58. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  59. >>> class Func(Function):
  60. >>> @staticmethod
  61. >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
  62. >>> w = x * z
  63. >>> out = x * y + y * z + w * y
  64. >>> ctx.save_for_backward(x, y, w, out)
  65. >>> ctx.z = z # z is not a tensor
  66. >>> return out
  67. >>>
  68. >>> @staticmethod
  69. >>> @once_differentiable
  70. >>> def backward(ctx, grad_out):
  71. >>> x, y, w, out = ctx.saved_tensors
  72. >>> z = ctx.z
  73. >>> gx = grad_out * (y + y * z)
  74. >>> gy = grad_out * (x + z + w)
  75. >>> gz = None
  76. >>> return gx, gy, gz
  77. >>>
  78. >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
  79. >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
  80. >>> c = 4
  81. >>> d = Func.apply(a, b, c)
  82. """
  83. self.to_save = tensors
  84. def save_for_forward(self, *tensors: torch.Tensor):
  85. r"""Save given tensors for a future call to :func:`~Function.jvp`.
  86. ``save_for_forward`` should be called at most once, in either the
  87. :func:`setup_context` or :func:`forward` methods, and all arguments
  88. should be tensors.
  89. In :func:`jvp`, saved objects can be accessed through the :attr:`saved_tensors`
  90. attribute.
  91. Arguments can also be ``None``. This is a no-op.
  92. See :ref:`extending-autograd` for more details on how to use this method.
  93. Example::
  94. >>> # xdoctest: +SKIP
  95. >>> class Func(torch.autograd.Function):
  96. >>> @staticmethod
  97. >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
  98. >>> ctx.save_for_backward(x, y)
  99. >>> ctx.save_for_forward(x, y)
  100. >>> ctx.z = z
  101. >>> return x * y * z
  102. >>>
  103. >>> @staticmethod
  104. >>> def jvp(ctx, x_t, y_t, _):
  105. >>> x, y = ctx.saved_tensors
  106. >>> z = ctx.z
  107. >>> return z * (y * x_t + x * y_t)
  108. >>>
  109. >>> @staticmethod
  110. >>> def vjp(ctx, grad_out):
  111. >>> x, y = ctx.saved_tensors
  112. >>> z = ctx.z
  113. >>> return z * grad_out * y, z * grad_out * x, None
  114. >>>
  115. >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
  116. >>> t = torch.tensor(1., dtype=torch.double)
  117. >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
  118. >>> c = 4
  119. >>>
  120. >>> with fwAD.dual_level():
  121. >>> a_dual = fwAD.make_dual(a, t)
  122. >>> d = Func.apply(a_dual, b, c)
  123. """
  124. for tensor in tensors:
  125. if not (isinstance(tensor, torch.Tensor) or tensor is None):
  126. raise AssertionError(
  127. "save_for_forward expects all arguments to be tensors; you should "
  128. "save non-tensors as attributes on ctx."
  129. )
  130. self.saved_for_forward = tensors
  131. def mark_dirty(self, *args: torch.Tensor):
  132. r"""Mark given tensors as modified in an in-place operation.
  133. This should be called at most once, in either the :func:`setup_context`
  134. or :func:`forward` methods, and all arguments should be inputs.
  135. Every tensor that's been modified in-place in a call to :func:`forward`
  136. should be given to this function, to ensure correctness of our checks.
  137. It doesn't matter whether the function is called before or after
  138. modification.
  139. Examples::
  140. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  141. >>> class Inplace(Function):
  142. >>> @staticmethod
  143. >>> def forward(ctx, x):
  144. >>> x_npy = x.numpy() # x_npy shares storage with x
  145. >>> x_npy += 1
  146. >>> ctx.mark_dirty(x)
  147. >>> return x
  148. >>>
  149. >>> @staticmethod
  150. >>> @once_differentiable
  151. >>> def backward(ctx, grad_output):
  152. >>> return grad_output
  153. >>>
  154. >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
  155. >>> b = a * a
  156. >>> Inplace.apply(a) # This would lead to wrong gradients!
  157. >>> # but the engine would not know unless we mark_dirty
  158. >>> # xdoctest: +SKIP
  159. >>> b.backward() # RuntimeError: one of the variables needed for gradient
  160. >>> # computation has been modified by an inplace operation
  161. """
  162. self.dirty_tensors = args
  163. @deprecated(
  164. "`mark_shared_storage` is deprecated. "
  165. "Tensors with shared storages are automatically tracked. "
  166. "Note that calls to `set_()` are not tracked",
  167. category=FutureWarning,
  168. )
  169. def mark_shared_storage(self, *pairs):
  170. pass
  171. def mark_non_differentiable(self, *args: torch.Tensor):
  172. r"""Mark outputs as non-differentiable.
  173. This should be called at most once, in either the :func:`setup_context`
  174. or :func:`forward` methods, and all arguments should be tensor outputs.
  175. This will mark outputs as not requiring gradients, increasing the
  176. efficiency of backward computation. You still need to accept a gradient
  177. for each output in :meth:`~Function.backward`, but it's always going to
  178. be a zero tensor with the same shape as the shape of a corresponding
  179. output.
  180. This is used e.g. for indices returned from a sort. See example::
  181. >>> class Func(Function):
  182. >>> @staticmethod
  183. >>> def forward(ctx, x):
  184. >>> sorted, idx = x.sort()
  185. >>> ctx.mark_non_differentiable(idx)
  186. >>> ctx.save_for_backward(x, idx)
  187. >>> return sorted, idx
  188. >>>
  189. >>> @staticmethod
  190. >>> @once_differentiable
  191. >>> def backward(ctx, g1, g2): # still need to accept g2
  192. >>> x, idx = ctx.saved_tensors
  193. >>> grad_input = torch.zeros_like(x)
  194. >>> grad_input.index_add_(0, idx, g1)
  195. >>> return grad_input
  196. """
  197. self.non_differentiable = args
  198. def set_materialize_grads(self, value: bool):
  199. r"""Set whether to materialize grad tensors. Default is ``True``.
  200. This should be called only from either the :func:`setup_context` or
  201. :func:`forward` methods.
  202. If ``True``, undefined grad tensors will be expanded to tensors full of zeros
  203. prior to calling the :func:`backward` and :func:`jvp` methods.
  204. Example::
  205. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  206. >>> class SimpleFunc(Function):
  207. >>> @staticmethod
  208. >>> def forward(ctx, x):
  209. >>> return x.clone(), x.clone()
  210. >>>
  211. >>> @staticmethod
  212. >>> @once_differentiable
  213. >>> def backward(ctx, g1, g2):
  214. >>> return g1 + g2 # No check for None necessary
  215. >>>
  216. >>> # We modify SimpleFunc to handle non-materialized grad outputs
  217. >>> class Func(Function):
  218. >>> @staticmethod
  219. >>> def forward(ctx, x):
  220. >>> ctx.set_materialize_grads(False)
  221. >>> ctx.save_for_backward(x)
  222. >>> return x.clone(), x.clone()
  223. >>>
  224. >>> @staticmethod
  225. >>> @once_differentiable
  226. >>> def backward(ctx, g1, g2):
  227. >>> x, = ctx.saved_tensors
  228. >>> grad_input = torch.zeros_like(x)
  229. >>> if g1 is not None: # We must check for None now
  230. >>> grad_input += g1
  231. >>> if g2 is not None:
  232. >>> grad_input += g2
  233. >>> return grad_input
  234. >>>
  235. >>> a = torch.tensor(1., requires_grad=True)
  236. >>> b, _ = Func.apply(a) # induces g2 to be undefined
  237. """
  238. self.materialize_grads = value
  239. # DO NOT USE: This is only defined to be able to load old serialized models
  240. _ContextMethodMixin = FunctionCtx
  241. class _HookMixin:
  242. @staticmethod
  243. def _register_hook(backward_hooks, hook):
  244. if backward_hooks is None:
  245. backward_hooks = OrderedDict()
  246. handle = hooks.RemovableHandle(backward_hooks)
  247. backward_hooks[handle.id] = hook
  248. return backward_hooks, handle
  249. class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
  250. r"""
  251. This class is used for internal autograd work. Do not use.
  252. """
  253. def apply(self, *args):
  254. r"""
  255. Apply method used when executing this Node during the backward
  256. """
  257. # _forward_cls is defined by derived class
  258. # The user should define either backward or vjp but never both.
  259. backward_fn = self._forward_cls.backward # type: ignore[attr-defined]
  260. vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined]
  261. if backward_fn is not Function.backward and vjp_fn is not Function.vjp:
  262. raise RuntimeError(
  263. "Implementing both 'backward' and 'vjp' for a custom "
  264. "Function is not allowed. You should only implement one "
  265. "of them."
  266. )
  267. user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
  268. return user_fn(self, *args)
  269. def apply_jvp(self, *args):
  270. r"""
  271. Apply method used when executing forward mode AD during the forward
  272. """
  273. # _forward_cls is defined by derived class
  274. return self._forward_cls.jvp(self, *args) # type: ignore[attr-defined]
  275. def _compiled_autograd_key(self):
  276. return self._forward_cls._compiled_autograd_key(self) # type: ignore[attr-defined]
  277. class FunctionMeta(type):
  278. """Function metaclass.
  279. This metaclass sets up the following properties:
  280. _backward_cls: The Function class corresponding to the differentiated
  281. version of this function (which is generated on the fly by this
  282. metaclass).
  283. """
  284. def __init__(cls, name, bases, attrs):
  285. backward_fn = type(
  286. name + "Backward", (BackwardCFunction,), {"_forward_cls": cls}
  287. )
  288. backward_fn._autograd_function_id = next(AUTOGRAD_FUNCTION_COUNTER) # type: ignore[attr-defined]
  289. cls._backward_cls = backward_fn
  290. super().__init__(name, bases, attrs)
  291. class _SingleLevelFunction(
  292. _C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta
  293. ):
  294. @staticmethod
  295. def forward(*args: Any, **kwargs: Any) -> Any:
  296. r"""Define the forward of the custom autograd Function.
  297. This function is to be overridden by all subclasses.
  298. There are two ways to define forward:
  299. Usage 1 (Combined forward and ctx)::
  300. @staticmethod
  301. def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
  302. pass
  303. - It must accept a context ctx as the first argument, followed by any
  304. number of arguments (tensors or other types).
  305. - See :ref:`combining-forward-context` for more details
  306. Usage 2 (Separate forward and ctx)::
  307. @staticmethod
  308. def forward(*args: Any, **kwargs: Any) -> Any:
  309. pass
  310. @staticmethod
  311. def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
  312. pass
  313. - The forward no longer accepts a ctx argument.
  314. - Instead, you must also override the :meth:`torch.autograd.Function.setup_context`
  315. staticmethod to handle setting up the ``ctx`` object.
  316. ``output`` is the output of the forward, ``inputs`` are a Tuple of inputs
  317. to the forward.
  318. - See :ref:`extending-autograd` for more details
  319. The context can be used to store arbitrary data that can be then
  320. retrieved during the backward pass. Tensors should not be stored
  321. directly on `ctx` (though this is not currently enforced for
  322. backward compatibility). Instead, tensors should be saved either with
  323. :func:`ctx.save_for_backward` if they are intended to be used in
  324. ``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward`
  325. if they are intended to be used for in ``jvp``.
  326. """
  327. raise NotImplementedError(
  328. "You must implement the forward function for custom autograd.Function."
  329. )
  330. @staticmethod
  331. def setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> Any:
  332. r"""There are two ways to define the forward pass of an autograd.Function.
  333. Either:
  334. 1. Override forward with the signature ``forward(ctx, *args, **kwargs)``.
  335. ``setup_context`` is not overridden. Setting up the ctx for backward
  336. happens inside the ``forward``.
  337. 2. Override forward with the signature ``forward(*args, **kwargs)`` and
  338. override ``setup_context``. Setting up the ctx for backward happens
  339. inside ``setup_context`` (as opposed to inside the ``forward``)
  340. See :meth:`torch.autograd.Function.forward` and :ref:`extending-autograd` for more details.
  341. """
  342. raise NotImplementedError("setup_context is not implemented.")
  343. @staticmethod
  344. def backward(ctx: Any, *grad_outputs: Any) -> Any:
  345. r"""Define a formula for differentiating the operation with backward mode automatic differentiation.
  346. This function is to be overridden by all subclasses.
  347. (Defining this function is equivalent to defining the ``vjp`` function.)
  348. It must accept a context :attr:`ctx` as the first argument, followed by
  349. as many outputs as the :func:`forward` returned (None will be passed in
  350. for non tensor outputs of the forward function),
  351. and it should return as many tensors, as there were inputs to
  352. :func:`forward`. Each argument is the gradient w.r.t the given output,
  353. and each returned value should be the gradient w.r.t. the
  354. corresponding input. If an input is not a Tensor or is a Tensor not
  355. requiring grads, you can just pass None as a gradient for that input.
  356. The context can be used to retrieve tensors saved during the forward
  357. pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
  358. of booleans representing whether each input needs gradient. E.g.,
  359. :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
  360. first input to :func:`forward` needs gradient computed w.r.t. the
  361. output.
  362. """
  363. raise NotImplementedError(
  364. "You must implement either the backward or vjp method for "
  365. "your custom autograd.Function to use it with backward "
  366. "mode AD."
  367. )
  368. # vjp and backward are alias of each other
  369. vjp = backward
  370. """
  371. Bool that specifies if PyTorch should clear saved tensors after the first
  372. access to ``ctx.saved_tensors``. When set to True, accessing saved_tensors
  373. clears the internal references, allowing the tensors to be cleared as soon
  374. as the Tensor returned by saved_tensors is deleted.
  375. This is useful for reducing memory pressure in backward passes when you
  376. only need to access saved tensors once.
  377. Default is False.
  378. """
  379. clear_saved_tensors_on_access = False
  380. @staticmethod
  381. def jvp(ctx: Any, *grad_inputs: Any) -> Any:
  382. r"""Define a formula for differentiating the operation with forward mode automatic differentiation.
  383. This function is to be overridden by all subclasses.
  384. It must accept a context :attr:`ctx` as the first argument, followed by
  385. as many inputs as the :func:`forward` got (None will be passed in
  386. for non tensor inputs of the forward function),
  387. and it should return as many tensors as there were outputs to
  388. :func:`forward`. Each argument is the gradient w.r.t the given input,
  389. and each returned value should be the gradient w.r.t. the
  390. corresponding output. If an output is not a Tensor or the function is not
  391. differentiable with respect to that output, you can just pass None as a
  392. gradient for that input.
  393. You can use the :attr:`ctx` object to pass any value from the forward to this
  394. functions.
  395. """
  396. raise NotImplementedError(
  397. "You must implement the jvp function for custom "
  398. "autograd.Function to use it with forward mode AD."
  399. )
  400. class Function(_SingleLevelFunction):
  401. r"""Base class to create custom `autograd.Function`.
  402. To create a custom `autograd.Function`, subclass this class and implement
  403. the :meth:`forward` and :meth:`backward` static methods. Then, to use your custom
  404. op in the forward pass, call the class method ``apply``. Do not call
  405. :meth:`forward` directly.
  406. To ensure correctness and best performance, make sure you are calling the
  407. correct methods on ``ctx`` and validating your backward function using
  408. :func:`torch.autograd.gradcheck`.
  409. See :ref:`extending-autograd` for more details on how to use this class.
  410. Examples::
  411. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  412. >>> class Exp(Function):
  413. >>> @staticmethod
  414. >>> def forward(ctx, i):
  415. >>> result = i.exp()
  416. >>> ctx.save_for_backward(result)
  417. >>> return result
  418. >>>
  419. >>> @staticmethod
  420. >>> def backward(ctx, grad_output):
  421. >>> result, = ctx.saved_tensors
  422. >>> return grad_output * result
  423. >>>
  424. >>> # Use it by calling the apply method:
  425. >>> # xdoctest: +SKIP
  426. >>> output = Exp.apply(input)
  427. """
  428. def __init__(self, *args, **kwargs):
  429. warnings.warn(
  430. f"{self.__class__} should not be instantiated. Methods on autograd functions "
  431. "are all static, so you should invoke them on the class itself. "
  432. "Instantiating an autograd function will raise an "
  433. "error in a future version of PyTorch.",
  434. DeprecationWarning,
  435. stacklevel=2,
  436. )
  437. def __call__(self, *args, **kwargs):
  438. raise RuntimeError(
  439. "Legacy autograd function with non-static forward method is deprecated. "
  440. "Please use new-style autograd function with static forward method. "
  441. "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)"
  442. )
  443. """
  444. Bool that specifies if PyTorch should attempt to autogenerate
  445. :func:`torch.vmap` support for this autograd.Function. You may set this to
  446. True only if this autograd.Function's forward, backward, and jvp (if they
  447. exist) are written using PyTorch operations; otherwise, please override
  448. :meth:`torch.autograd.Function.vmap` to add support for :func:`torch.vmap`.
  449. Please see :ref:`func-autograd-function` for more details.
  450. """
  451. generate_vmap_rule = False
  452. @staticmethod
  453. def vmap(info, in_dims, *args):
  454. r"""Define the behavior for this autograd.Function underneath :func:`torch.vmap`.
  455. For a :func:`torch.autograd.Function` to support
  456. :func:`torch.vmap`, you must either override this static method, or set
  457. ``generate_vmap_rule`` to ``True`` (you may not do both).
  458. If you choose to override this staticmethod: it must accept
  459. - an ``info`` object as the first argument. ``info.batch_size``
  460. specifies the size of the dimension being vmapped over,
  461. while ``info.randomness`` is the randomness option passed to
  462. :func:`torch.vmap`.
  463. - an ``in_dims`` tuple as the second argument.
  464. For each arg in ``args``, ``in_dims`` has a corresponding
  465. ``Optional[int]``. It is ``None`` if the arg is not a Tensor or if
  466. the arg is not being vmapped over, otherwise, it is an integer
  467. specifying what dimension of the Tensor is being vmapped over.
  468. - ``*args``, which is the same as the args to :meth:`~Function.forward`.
  469. The return of the vmap staticmethod is a tuple of ``(output, out_dims)``.
  470. Similar to ``in_dims``, ``out_dims`` should be of the same structure as
  471. ``output`` and contain one ``out_dim`` per output that specifies if the
  472. output has the vmapped dimension and what index it is in.
  473. Please see :ref:`func-autograd-function` for more details.
  474. """
  475. raise NotImplementedError(
  476. "To use autograd.Function with vmap, you must either override the "
  477. "vmap staticmethod or set generate_vmap_rule=True."
  478. )
  479. @classmethod
  480. def apply(cls, *args, **kwargs):
  481. def bind_default_args(func, *args, **kwargs):
  482. signature = inspect.signature(func)
  483. bound_args = signature.bind(*args, **kwargs)
  484. bound_args.apply_defaults()
  485. return bound_args.args
  486. is_setup_ctx_defined = _is_setup_context_defined(cls.setup_context)
  487. if is_setup_ctx_defined:
  488. args = bind_default_args(cls.forward, *args, **kwargs)
  489. if not torch._C._are_functorch_transforms_active():
  490. # See NOTE: [functorch vjp and autograd interaction]
  491. args = _functorch.utils.unwrap_dead_wrappers(args)
  492. return super().apply(*args, **kwargs) # type: ignore[misc]
  493. if not is_setup_ctx_defined:
  494. raise RuntimeError(
  495. "In order to use an autograd.Function with functorch transforms "
  496. "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
  497. "staticmethod. For more details, please see "
  498. "https://pytorch.org/docs/main/notes/extending.func.html"
  499. )
  500. return custom_function_call(cls, *args, **kwargs)
  501. @staticmethod
  502. def _compiled_autograd_key(ctx):
  503. return (ctx._autograd_function_id,)
  504. def _is_setup_context_defined(fn):
  505. return fn != _SingleLevelFunction.setup_context
  506. def once_differentiable(
  507. fn: Callable[Concatenate[_T, _P], _R],
  508. ) -> Callable[Concatenate[_T, _P], _R]:
  509. @functools.wraps(fn)
  510. def wrapper(ctx: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R:
  511. with torch.no_grad():
  512. outputs = fn(ctx, *args, **kwargs)
  513. if not torch.is_grad_enabled():
  514. return outputs
  515. # If any of the inputs have requires_grad=True, we force the outputs
  516. # to have requires_grad=True but point to a grad_fn which throws an
  517. # error message during (double) back-propagation.
  518. # XXX: this is only an approximation of requires_grad - there's no way
  519. # to figure out if fn didn't use ctx.saved_tensors and as a result
  520. # some Tensors might require grad, even if no args do.
  521. # Unfortunately, this leads to unexpected error messages ("no nodes
  522. # require computing gradients"), but I don't have a better idea.
  523. # These functions would raise an error in backward anyway.
  524. requires_grad = any(
  525. isinstance(arg, torch.Tensor) and arg.requires_grad for arg in args
  526. )
  527. if not requires_grad:
  528. return outputs
  529. if not isinstance(outputs, tuple):
  530. outputs_ = (outputs,)
  531. else:
  532. outputs_ = outputs
  533. err_fn = _functions.DelayedError(
  534. b"trying to differentiate twice a function that was marked "
  535. b"with @once_differentiable",
  536. len(outputs_),
  537. )
  538. # Create aliases of each output that has requires_grad=True. We need
  539. # at least one of the inputs to err_fn to require grad so that the
  540. # output will have a grad_fn.
  541. def fake_requires_grad(var):
  542. if var is not None:
  543. var = var.detach()
  544. var.requires_grad = True
  545. return var
  546. return err_fn(*[fake_requires_grad(v) for v in outputs_]) # type: ignore[return-value]
  547. return wrapper
  548. class InplaceFunction(Function):
  549. r"""
  550. This class is here only for backward compatibility reasons.
  551. Use :class:`Function` instead of this for any new use case.
  552. """
  553. def __init__(self, inplace=False):
  554. super().__init__()
  555. self.inplace = inplace
  556. def _nested_map(condition, fn, condition_msg=None):
  557. def _map(obj):
  558. if condition(obj):
  559. return fn(obj)
  560. elif obj is None:
  561. return None
  562. elif isinstance(obj, (list, tuple)):
  563. mapped = (_map(x) for x in obj)
  564. if hasattr(obj, "_fields"):
  565. # obj is namedtuple
  566. return type(obj)(*mapped)
  567. return type(obj)(mapped)
  568. elif isinstance(obj, dict):
  569. return {x: _map(obj[x]) for x in obj}
  570. else:
  571. raise ValueError(
  572. "Auto nesting doesn't know how to process "
  573. "an input object of type "
  574. + torch.typename(obj)
  575. + (
  576. ". Accepted types: " + condition_msg + ", or lists/tuples of them"
  577. if condition_msg
  578. else ""
  579. )
  580. )
  581. return _map
  582. def _jit_unwrap_structured(obj):
  583. if hasattr(obj, "_jit_unwrap"):
  584. return obj._jit_unwrap()
  585. return obj
  586. def _iter_filter(condition, allow_unknown=False, condition_msg=None, conversion=None):
  587. def _iter(obj):
  588. if conversion is not None:
  589. obj = conversion(obj)
  590. if condition(obj):
  591. yield obj
  592. elif obj is None:
  593. return
  594. elif isinstance(obj, (list, tuple)):
  595. for o in obj:
  596. yield from _iter(o)
  597. elif isinstance(obj, dict):
  598. # We only accept primitive key types, so we needn't inspect them
  599. for o in obj.values():
  600. yield from _iter(o)
  601. elif allow_unknown:
  602. yield obj
  603. else:
  604. raise ValueError(
  605. "Auto nesting doesn't know how to process "
  606. "an input object of type "
  607. + torch.typename(obj)
  608. + (
  609. ". Accepted types: " + condition_msg + ", or lists/tuples of them"
  610. if condition_msg
  611. else ""
  612. )
  613. )
  614. return _iter
  615. def _unflatten(input, proto):
  616. # unflatten a list or tuple input into a nested list/tuple structure
  617. # specified by proto
  618. def unflatten_helper(input, proto):
  619. res: list[Optional[torch.Tensor]] = []
  620. if hasattr(proto, "_jit_wrap"):
  621. return proto._jit_wrap(input)
  622. if not isinstance(proto, (list, tuple)):
  623. return input[0], input[1:]
  624. for e in proto:
  625. if e is None:
  626. res.append(e)
  627. else:
  628. res_e, input = unflatten_helper(input, e)
  629. res.append(res_e)
  630. return type(proto)(res), input
  631. return unflatten_helper(input, proto)[0]
  632. _iter_jit_values = _iter_filter(
  633. lambda o: o is None or isinstance(o, torch._C.Value),
  634. condition_msg="jit's Values or None",
  635. )
  636. _iter_tensors = _iter_filter(
  637. lambda x: isinstance(x, torch.Tensor),
  638. condition_msg="Tensors",
  639. conversion=_jit_unwrap_structured,
  640. )
  641. _iter_tensors_permissive = _iter_filter(
  642. lambda x: isinstance(x, torch.Tensor),
  643. allow_unknown=True,
  644. condition_msg="Tensors (permissive)",
  645. )
  646. _iter_None_tensors = _iter_filter(
  647. lambda o: o is None or isinstance(o, torch.Tensor), condition_msg="Tensors or None"
  648. )
  649. _map_tensor_data = _nested_map(
  650. lambda x: isinstance(x, torch.Tensor), lambda o: o.data, condition_msg="Tensors"
  651. )
  652. class NestedIOFunction(Function):
  653. r"""
  654. This class is here only for backward compatibility reasons.
  655. Use :class:`Function` instead of this for any new use case.
  656. """
  657. # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
  658. # superclass (Function) but are instance methods here, which mypy reports as incompatible.
  659. def _do_forward(self, *input):
  660. self._nested_input = input
  661. flat_input = tuple(_iter_tensors(input))
  662. flat_output = super()._do_forward(*flat_input) # type: ignore[misc]
  663. nested_tensors = _unflatten(flat_output, self._nested_output)
  664. return nested_tensors
  665. def _do_backward(self, gradients, retain_variables):
  666. self.retain_variables = retain_variables
  667. result = super()._do_backward(gradients, retain_variables) # type: ignore[misc]
  668. if not retain_variables:
  669. del self._nested_output
  670. del self._to_save_nested
  671. return result
  672. def backward(self, *gradients: Any) -> Any: # type: ignore[override]
  673. r"""
  674. Shared backward utility.
  675. """
  676. nested_gradients = _unflatten(gradients, self._nested_output)
  677. result = self.backward_extended(*nested_gradients) # type: ignore[func-returns-value]
  678. return tuple(_iter_None_tensors(result))
  679. __call__ = _do_forward
  680. def forward(self, *args: Any) -> Any: # type: ignore[override]
  681. r"""
  682. Shared forward utility.
  683. """
  684. nested_tensors = _map_tensor_data(self._nested_input)
  685. result = self.forward_extended(*nested_tensors) # type: ignore[func-returns-value]
  686. del self._nested_input
  687. self._nested_output = result
  688. return tuple(_iter_tensors(result))
  689. def save_for_backward(self, *args: Any) -> None:
  690. r"""
  691. See :meth:`Function.save_for_backward`.
  692. """
  693. self.to_save = tuple(_iter_tensors(args))
  694. self._to_save_nested = args
  695. @property
  696. def saved_tensors(self): # type: ignore[override]
  697. r"""
  698. See :meth:`Function.saved_tensors`.
  699. """
  700. flat_tensors = super().saved_tensors # type: ignore[misc]
  701. return _unflatten(flat_tensors, self._to_save_nested)
  702. def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
  703. r"""
  704. See :meth:`Function.mark_dirty`.
  705. """
  706. self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
  707. def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None:
  708. r"""
  709. See :meth:`Function.mark_non_differentiable`.
  710. """
  711. self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
  712. def forward_extended(self, *input: Any) -> None:
  713. r"""
  714. User defined forward.
  715. """
  716. raise NotImplementedError
  717. def backward_extended(self, *grad_output: Any) -> None:
  718. r"""
  719. User defined backward.
  720. """
  721. raise NotImplementedError