graph.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874
  1. import abc
  2. import contextlib
  3. import functools
  4. import logging
  5. import threading
  6. from collections import defaultdict, deque
  7. from collections.abc import (
  8. Callable,
  9. Generator,
  10. Iterable,
  11. Iterator,
  12. MutableMapping,
  13. Sequence,
  14. )
  15. from typing import (
  16. Any,
  17. cast,
  18. Literal,
  19. NamedTuple,
  20. Optional,
  21. TYPE_CHECKING,
  22. TypeAlias,
  23. Union,
  24. )
  25. from weakref import WeakKeyDictionary, WeakValueDictionary
  26. import torch
  27. from torch.autograd.variable import Variable
  28. from torch.utils._python_dispatch import TorchDispatchMode
  29. from torch.utils.hooks import RemovableHandle
  30. if TYPE_CHECKING:
  31. from torch._ops import OpOverload
  32. __all__ = [
  33. "saved_tensors_hooks",
  34. "save_on_cpu",
  35. "disable_saved_tensors_hooks",
  36. "register_multi_grad_hook",
  37. "allow_mutation_on_saved_tensors",
  38. "Node",
  39. "GradientEdge",
  40. "get_gradient_edge",
  41. "increment_version",
  42. "set_warn_on_accumulate_grad_stream_mismatch",
  43. ]
  44. log = logging.getLogger(__name__)
  45. class Node(abc.ABC):
  46. @abc.abstractmethod
  47. def name(self) -> str:
  48. r"""Return the name.
  49. Example::
  50. >>> import torch
  51. >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
  52. >>> b = a.clone()
  53. >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
  54. >>> print(b.grad_fn.name())
  55. CloneBackward0
  56. """
  57. raise NotImplementedError
  58. @property
  59. @abc.abstractmethod
  60. def next_functions(self) -> tuple[tuple[Optional["Node"], int], ...]:
  61. raise NotImplementedError
  62. @abc.abstractmethod
  63. def metadata(self) -> dict:
  64. r"""Return the metadata."""
  65. raise NotImplementedError
  66. @abc.abstractmethod
  67. def _sequence_nr(self) -> int:
  68. raise NotImplementedError
  69. @property
  70. @abc.abstractmethod
  71. def _input_metadata(self) -> list[Any]:
  72. raise NotImplementedError
  73. @abc.abstractmethod
  74. def _register_hook_dict(self, tensor: torch.Tensor) -> None:
  75. raise NotImplementedError
  76. @abc.abstractmethod
  77. def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle:
  78. r"""Register a backward hook.
  79. The hook will be called every time a gradient with respect to the
  80. Node is computed. The hook should have the following signature::
  81. hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
  82. The hook should not modify its argument, but it can optionally return
  83. a new gradient which will be used in place of :attr:`grad_inputs`.
  84. This function returns a handle with a method ``handle.remove()``
  85. that removes the hook from the module.
  86. .. note::
  87. See :ref:`backward-hooks-execution` for more information on how when this hook
  88. is executed, and how its execution is ordered relative to other hooks.
  89. .. note::
  90. In the rare case where the hook is registered while the Node has already
  91. begun execution, there is no longer any guarantee on :attr:`grad_outputs`
  92. content (it might be as usual or empty depending on other factors). The
  93. hook can still optionally return a new gradient to be used in place of
  94. :attr:`grad_inputs` independent of :attr:`grad_outputs`.
  95. Example::
  96. >>> import torch
  97. >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
  98. >>> b = a.clone()
  99. >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
  100. >>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
  101. >>> b.sum().backward(retain_graph=True)
  102. >>> print(a.grad)
  103. tensor([2., 2., 2.])
  104. >>> handle.remove() # Removes the hook
  105. >>> a.grad = None
  106. >>> b.sum().backward(retain_graph=True)
  107. >>> print(a.grad)
  108. tensor([1., 1., 1.])
  109. """
  110. raise NotImplementedError
  111. @abc.abstractmethod
  112. def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle:
  113. r"""Register a backward pre-hook.
  114. The hook will be called every time a gradient with respect to the
  115. Node is computed. The hook should have the following signature::
  116. hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
  117. The hook should not modify its argument, but it can optionally return
  118. a new gradient which will be used in place of :attr:`grad_outputs`.
  119. This function returns a handle with a method ``handle.remove()``
  120. that removes the hook from the module.
  121. .. note::
  122. See :ref:`backward-hooks-execution` for more information on how when this hook
  123. is executed, and how its execution is ordered relative to other hooks.
  124. Example::
  125. >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
  126. >>> b = a.clone()
  127. >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
  128. >>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,))
  129. >>> b.sum().backward(retain_graph=True)
  130. >>> print(a.grad)
  131. tensor([2., 2., 2.])
  132. >>> handle.remove()
  133. >>> a.grad = None
  134. >>> b.sum().backward(retain_graph=True)
  135. >>> print(a.grad)
  136. tensor([1., 1., 1.])
  137. """
  138. raise NotImplementedError
  139. @classmethod
  140. def __subclasshook__(cls, subclass: type) -> bool:
  141. if cls is Node and (
  142. (
  143. subclass is not None
  144. and subclass is getattr(torch._C._functions, subclass.__name__, None)
  145. )
  146. or issubclass(subclass, torch.autograd.function.BackwardCFunction)
  147. ):
  148. return True
  149. return NotImplemented
  150. def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, "GradientEdge"]) -> Node:
  151. if isinstance(t, GradientEdge):
  152. return t.node
  153. if t.requires_grad and t.grad_fn is None:
  154. with torch.enable_grad():
  155. node = t.view_as(t).grad_fn.next_functions[0][0] # type: ignore[union-attr]
  156. else:
  157. node = t.grad_fn
  158. if node is None:
  159. raise AssertionError("Expected gradient function to be set")
  160. return node
  161. class GradientEdge(NamedTuple):
  162. """Object representing a given gradient edge within the autograd graph.
  163. To get the gradient edge where a given Tensor gradient will be computed,
  164. you can do ``edge = autograd.graph.get_gradient_edge(tensor)``.
  165. """
  166. node: Node
  167. output_nr: int
  168. # This token can be used to ensure the graph stays alive when it cannot be
  169. # done via the node field
  170. ownership_token: Optional[Node] = None
  171. def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge:
  172. """Get the gradient edge for computing the gradient of the given Tensor.
  173. In particular, it is equivalent to call
  174. ``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``.
  175. """
  176. if not tensor.requires_grad:
  177. raise RuntimeError(
  178. "It is not possible to get the gradient edge for a Tensor "
  179. "that does not require gradients",
  180. )
  181. grad_fn = _get_grad_fn_or_grad_acc(tensor)
  182. # Python-based Node are owned by the C++ side meaning the python grad_fn
  183. # object we hold here does NOT keep the C++ graph alive.
  184. # Create an ownership token by creating a new C++ node that own the graph
  185. # we care about here.
  186. token = None
  187. if isinstance(grad_fn, torch._C._FunctionBase):
  188. with torch.enable_grad():
  189. token = tensor.view_as(tensor).grad_fn
  190. # Note that output_nr default to 0 which is the right value
  191. # for the AccumulateGrad node.
  192. # pyrefly: ignore [bad-argument-type]
  193. return GradientEdge(grad_fn, tensor.output_nr, ownership_token=token)
  194. def increment_version(tensor: Union[torch.Tensor, Iterable[torch.Tensor]]) -> None:
  195. """Update autograd metadata tracking whether the given Tensor was modified in place.
  196. This is to enable more accurate error checking within the autograd engine.
  197. It is already done automatically by PyTorch functions and within custom Function
  198. when mark_dirty() is called appropriately so you only need to call this explicitly
  199. if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't
  200. know about. For example a custom kernel that reads the Tensor data_ptr and modifies
  201. the memory inplace based on this pointer. Can accept either a tensor, or a list of tensors.
  202. Note that incrementing the version counter multiple times for a single inplace operation
  203. is not problematic.
  204. Note that if you pass in tensor constructed under torch.inference_mode(),
  205. we will not bump its version counter (because your tensor does not have one).
  206. """
  207. if isinstance(tensor, torch.Tensor):
  208. tensor = (tensor,)
  209. torch._C._increment_version(tensor)
  210. class saved_tensors_hooks:
  211. """Context-manager that sets a pair of pack / unpack hooks for saved tensors.
  212. Use this context-manager to define how intermediary results of an operation
  213. should be packed before saving, and unpacked on retrieval.
  214. In that context, the ``pack_hook`` function will be called every time an
  215. operation saves a tensor for backward (this includes intermediary results
  216. saved using
  217. :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
  218. also those recorded by a PyTorch-defined operation). The output of
  219. ``pack_hook`` is then stored in the computation graph instead of the
  220. original tensor.
  221. The ``unpack_hook`` is called when the saved tensor needs to be accessed,
  222. namely when executing :func:`torch.Tensor.backward()` or
  223. :func:`torch.autograd.grad()`. It takes as argument the *packed* object
  224. returned by ``pack_hook`` and should return a tensor which has the same
  225. content as the original tensor (passed as input to the corresponding
  226. ``pack_hook``).
  227. The hooks should have the following signatures:
  228. pack_hook(tensor: Tensor) -> Any
  229. unpack_hook(Any) -> Tensor
  230. where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
  231. In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
  232. of value, size, dtype and device.
  233. Example::
  234. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  235. >>> def pack_hook(x):
  236. ... print("Packing", x)
  237. ... return x.detach()
  238. >>>
  239. >>> def unpack_hook(x):
  240. ... print("Unpacking", x)
  241. ... return x
  242. >>>
  243. >>> a = torch.ones(5, requires_grad=True)
  244. >>> b = torch.ones(5, requires_grad=True) * 2
  245. >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
  246. ... y = a * b
  247. Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
  248. Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
  249. >>> y.sum().backward()
  250. Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
  251. Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
  252. .. warning ::
  253. Performing an inplace operation on the input to either hooks may lead
  254. to undefined behavior.
  255. .. warning ::
  256. Only one pair of hooks is allowed at a time. When recursively nesting this
  257. context-manager, only the inner-most pair of hooks will be applied.
  258. .. warning ::
  259. To avoid reference cycle, the return value of ``pack_hook`` cannot hold a
  260. reference to the input tensor. For example, use `lambda x: x.detach()`
  261. instead of `lambda x: x` as the pack hook.
  262. """
  263. def __init__(
  264. self,
  265. pack_hook: Callable[[torch.Tensor], Any],
  266. unpack_hook: Callable[[Any], torch.Tensor],
  267. ) -> None:
  268. self.pack_hook = pack_hook
  269. self.unpack_hook = unpack_hook
  270. def __enter__(self) -> None:
  271. torch._C._autograd._push_saved_tensors_default_hooks(
  272. self.pack_hook, self.unpack_hook
  273. )
  274. def __exit__(self, *args: object) -> None:
  275. torch._C._autograd._pop_saved_tensors_default_hooks()
  276. class save_on_cpu(saved_tensors_hooks):
  277. """Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward.
  278. When performing operations within this context manager, intermediary
  279. results saved in the graph during the forward pass will be moved to CPU,
  280. then copied back to the original device when needed for the backward pass.
  281. If the graph was already on CPU, no tensor copy is performed.
  282. Use this context-manager to trade compute for GPU memory usage (e.g.
  283. when your model doesn't fit in GPU memory during training).
  284. Args:
  285. pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
  286. during packing and copied to GPU asynchronously during unpacking.
  287. Defaults to ``False``.
  288. Also see :ref:`cuda-memory-pinning`.
  289. Example::
  290. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  291. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  292. >>> a = torch.randn(5, requires_grad=True, device="cuda")
  293. >>> b = torch.randn(5, requires_grad=True, device="cuda")
  294. >>> c = torch.randn(5, requires_grad=True, device="cuda")
  295. >>>
  296. >>> def f(a, b, c):
  297. ... prod_1 = a * b # a and b are saved on GPU
  298. ... with torch.autograd.graph.save_on_cpu():
  299. ... prod_2 = prod_1 * c # prod_1 and c are saved on CPU
  300. ... y = prod_2 * a # prod_2 and a are saved on GPU
  301. ... return y
  302. >>>
  303. >>> y = f(a, b, c)
  304. >>> del a, b, c # for illustration only
  305. >>> # the content of a, b, and prod_2 are still alive on GPU
  306. >>> # the content of prod_1 and c only live on CPU
  307. >>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward
  308. >>> # all intermediary tensors are released (deleted) after the call to backward
  309. """
  310. def __init__(self, pin_memory: bool = False, device_type: str = "cuda") -> None:
  311. device_module = getattr(torch, device_type, torch.cuda)
  312. def pack_to_cpu(tensor: torch.Tensor) -> tuple[torch.device, torch.Tensor]:
  313. if not pin_memory:
  314. return (tensor.device, tensor.cpu())
  315. packed = torch.empty(
  316. tensor.size(),
  317. dtype=tensor.dtype,
  318. layout=tensor.layout,
  319. pin_memory=(device_module.is_available() and not tensor.is_sparse),
  320. )
  321. packed.copy_(tensor)
  322. return (tensor.device, packed)
  323. def unpack_from_cpu(packed: tuple[torch.device, torch.Tensor]) -> torch.Tensor:
  324. device, tensor = packed
  325. return tensor.to(device, non_blocking=pin_memory)
  326. super().__init__(pack_to_cpu, unpack_from_cpu)
  327. @contextlib.contextmanager
  328. def disable_saved_tensors_hooks(error_message: str) -> Generator[None, None, None]:
  329. """Context-manager that disables the saved tensors default hooks feature.
  330. Useful for if you are creating a feature that does not work with saved
  331. tensors default hooks.
  332. Args:
  333. error_message (str): When saved tensors default hooks are used when they
  334. have been are disabled, a RuntimeError with this
  335. error message gets raised.
  336. Example::
  337. >>> # xdoctest: +SKIP(failing)
  338. >>> message = "saved tensors default hooks are disabled"
  339. >>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
  340. ... # Raises RuntimeError: saved tensors default hooks are disabled
  341. ... with torch.autograd.graph.save_on_cpu():
  342. ... pass
  343. """
  344. maybe_prev_message = None
  345. try:
  346. maybe_prev_message = (
  347. torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
  348. )
  349. torch._C._autograd._saved_tensors_hooks_disable(error_message)
  350. yield
  351. finally:
  352. # See NOTE: [disabled_error_message invariant]
  353. if maybe_prev_message is None:
  354. torch._C._autograd._saved_tensors_hooks_enable()
  355. else:
  356. torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
  357. def set_warn_on_accumulate_grad_stream_mismatch(enabled: bool) -> None:
  358. """Whether to warn when the AccumulateGrad node's stream does not match the stream
  359. of the node that produced the incoming gradient.
  360. """
  361. return torch._C._set_warn_on_accumulate_grad_stream_mismatch(enabled)
  362. class _MultiHandle(RemovableHandle):
  363. handles: tuple[RemovableHandle, ...]
  364. def __init__(self, handles: tuple[RemovableHandle, ...]) -> None:
  365. self.handles = handles
  366. def remove(self) -> None:
  367. for handle in self.handles:
  368. handle.remove()
  369. def __getstate__(self) -> tuple[RemovableHandle, ...]:
  370. return self.handles
  371. def __setstate__(self, state: tuple[RemovableHandle, ...]) -> None:
  372. self.handles = state
  373. def register_multi_grad_hook(
  374. tensors: Sequence[torch.Tensor],
  375. fn: Union[
  376. Callable[[Sequence[Optional[torch.Tensor]]], None],
  377. Callable[[torch.Tensor], None],
  378. ],
  379. *,
  380. mode: Literal["all", "any"] = "all",
  381. ) -> RemovableHandle:
  382. r"""Register a multi-grad backward hook.
  383. There are two supported modes: ``"all"`` and ``"any"``.
  384. Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in
  385. :attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but
  386. is not part of the graph, or if a tensor is not needed to compute the gradients
  387. for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call,
  388. this tensor will be ignored and the hook will not wait for its gradient to be
  389. computed.
  390. After every non-ignored tensor's gradient has been computed, :attr:`fn` will be
  391. called with those gradients. ``None`` will be passed for tensors that did not
  392. have their gradients computed.
  393. Under the ``"any"`` mode, the hook will be called after the first gradient
  394. with respect to a tensor in :attr:`tensors` has been computed. The hook
  395. will be called with that gradient as its argument.
  396. The hook should not modify its arguments.
  397. This function returns a handle with a method ``handle.remove()`` that removes the hook.
  398. .. note::
  399. See :ref:`backward-hooks-execution` for more information on how when this hook
  400. is executed, and how its execution is ordered relative to other hooks.
  401. Example::
  402. >>> import torch
  403. >>>
  404. >>> a = torch.rand(2, 3, requires_grad=True)
  405. >>> b = torch.rand(2, 3, requires_grad=True)
  406. >>> c = a * b
  407. >>> d = a * b
  408. >>>
  409. >>> def fn(grads):
  410. ... print([g is not None for g in grads])
  411. ...
  412. >>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
  413. >>>
  414. >>> c.sum().backward(retain_graph=True)
  415. [True, True, True, False]
  416. >>> c.sum().backward(inputs=(a,), retain_graph=True)
  417. [True, False, True, False]
  418. >>>
  419. """
  420. supported_modes = ("all", "any")
  421. lock = threading.Lock()
  422. if mode not in supported_modes:
  423. raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}")
  424. if mode == "all":
  425. count: dict[int, int] = {}
  426. nb_calls = None
  427. buffer: dict[int, list[Optional[torch.Tensor]]] = {}
  428. grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
  429. len_tensors = len(tensors)
  430. def get_inner_hook(idx: int) -> Callable[[torch.Tensor], None]:
  431. def inner_hook(grad: torch.Tensor) -> None:
  432. nonlocal count, nb_calls, buffer, fn
  433. id = torch._C._current_graph_task_id()
  434. if id == -1:
  435. raise AssertionError(
  436. "expected this hook to be called inside a backward call"
  437. )
  438. count[id] = count.get(id, 0)
  439. # pyrefly: ignore [unsupported-operation]
  440. buffer[id] = buffer.get(id, [None] * len_tensors)
  441. with lock:
  442. curr_count, count[id] = count[id], count[id] + 1
  443. if curr_count == 0:
  444. # On the first call, compute the actual nb_calls and buffer
  445. nb_calls = sum(
  446. map(torch._C._will_engine_execute_node, grad_fns)
  447. )
  448. buffer[id][idx] = grad
  449. if nb_calls is None:
  450. raise AssertionError("Expected nb_calls to be set")
  451. if curr_count == nb_calls - 1:
  452. fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn)
  453. fn(buffer[id])
  454. del count[id]
  455. del buffer[id]
  456. return inner_hook
  457. handles = tuple(
  458. t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors)
  459. )
  460. elif mode == "any":
  461. fn = cast(Callable[[torch.Tensor], None], fn)
  462. ran_hook: dict[int, bool] = defaultdict(bool)
  463. @functools.wraps(fn)
  464. def wrapped_fn(grad: torch.Tensor) -> None:
  465. nonlocal ran_hook
  466. id = torch._C._current_graph_task_id()
  467. if id == -1:
  468. raise AssertionError(
  469. "expected this hook to be called inside a backward call"
  470. )
  471. with lock:
  472. prev, ran_hook[id] = ran_hook[id], True
  473. if prev:
  474. return
  475. fn(grad)
  476. handles = tuple(
  477. tensor.register_hook(wrapped_fn)
  478. for tensor in tensors
  479. if tensor.requires_grad
  480. )
  481. return _MultiHandle(handles) # type: ignore[possibly-undefined]
  482. # NOTE [Allow mutation on tensors saved for backward]
  483. #
  484. # 1. Tensor gets saved for backward
  485. # - remember the python object id and the version of the tensor
  486. # - remember aliasing information (data_ptr of base + version)
  487. # - save the original so we control its lifetime
  488. # 2. Any time a tensor gets in-placed
  489. # - for each tensor aliased to it:
  490. # - check using its object id and version to see if it has been saved
  491. # - if it has been saved, clone it
  492. # - delete the reference to the original
  493. # 3. during backward
  494. # - if the clone exists, the tensor must've been modified in-place
  495. _allow_mutation_on_saved_tensors_enabled: bool = False
  496. _TID: TypeAlias = tuple[int, int, int]
  497. _SID: TypeAlias = tuple[int, int]
  498. def _get_tid(tensor: torch.Tensor) -> _TID:
  499. # FIXME: This is almost definitely a bug.
  500. if isinstance(
  501. tensor,
  502. (
  503. torch._subclasses.fake_tensor.FakeTensor,
  504. torch._subclasses.functional_tensor.FunctionalTensor,
  505. ),
  506. ):
  507. data_ptr = 0
  508. else:
  509. data_ptr = tensor.data_ptr()
  510. return (id(tensor), data_ptr, tensor._version)
  511. def _get_sid(tensor: torch.Tensor) -> _SID:
  512. # FIXME: This is almost definitely a bug.
  513. if isinstance(
  514. tensor,
  515. (
  516. torch._subclasses.fake_tensor.FakeTensor,
  517. torch._subclasses.functional_tensor.FunctionalTensor,
  518. ),
  519. ):
  520. data_ptr = 0
  521. else:
  522. data_ptr = tensor.data_ptr()
  523. return (data_ptr, tensor._version)
  524. class _Handle:
  525. pass
  526. class _swap_with_cloned(saved_tensors_hooks):
  527. def __init__(self, ctx: "_AllowMutationOnSavedContext") -> None:
  528. def pack_hook(tensor: torch.Tensor) -> _Handle:
  529. tid = _get_tid(tensor)
  530. sid = _get_sid(tensor)
  531. # Tensors saved for backward have an entry in _tid_to_weakhandle
  532. handle: Optional[_Handle] = None
  533. # Save aliasing information
  534. ctx.sid_to_tid[sid].add(tid)
  535. # NB: The same tensor (of the same version) can be saved multiple times
  536. if tid not in ctx.tid_to_weakhandle:
  537. handle = _Handle()
  538. ctx.tid_to_weakhandle[tid] = handle
  539. ctx.original[handle] = tensor
  540. else:
  541. # Store an additional strong reference to the handle
  542. handle = ctx.tid_to_weakhandle[tid]
  543. return handle
  544. def unpack_hook(handle: _Handle) -> torch.Tensor:
  545. error_msg = (
  546. "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
  547. "in which the graph was originally recorded."
  548. )
  549. if not _allow_mutation_on_saved_tensors_enabled:
  550. raise AssertionError(error_msg)
  551. if handle in ctx.cloned:
  552. res = ctx.cloned[handle]
  553. else:
  554. if handle not in ctx.original:
  555. raise AssertionError(error_msg)
  556. res = ctx.original[handle]
  557. return res
  558. super().__init__(pack_hook, unpack_hook)
  559. class _CloneArgBeforeMutateMode(TorchDispatchMode):
  560. def __init__(self, ctx: "_AllowMutationOnSavedContext") -> None:
  561. self.ctx = ctx
  562. def __torch_dispatch__(
  563. self,
  564. func: "OpOverload",
  565. types: Iterable[type],
  566. args: tuple[Any, ...] = (),
  567. kwargs: Optional[dict[Any, Any]] = None,
  568. ) -> Any:
  569. kwargs = kwargs or {}
  570. def maybe_clone(t: torch.Tensor) -> None:
  571. tid = _get_tid(t)
  572. sid = _get_sid(t)
  573. ctx = self.ctx
  574. if sid in ctx.sid_to_tid:
  575. for tid in ctx.sid_to_tid[sid]:
  576. if tid not in ctx.tid_to_weakhandle:
  577. # We know that if tid is in sid_to_tid, then it must also be in
  578. # tid_to_weakhandle. However, it is possible for the tensor to be
  579. # saved at one point, but cleared by backward before it is modified
  580. # in-place. Consider the following example:
  581. #
  582. # >>> a = torch.randn(2, 3, requires_grad=True).clone()
  583. # >>> out = (a**2).sum()
  584. # >>> out.backward()
  585. # >>> a.sin_()
  586. continue
  587. handle = ctx.tid_to_weakhandle[tid]
  588. if handle in ctx.cloned:
  589. # The same exact tensor has been cloned already
  590. continue
  591. ctx.cloned[handle] = ctx.original[handle].clone()
  592. del ctx.original[handle]
  593. for idx, arg in enumerate(func._schema.arguments):
  594. if arg.alias_info is not None and arg.alias_info.is_write:
  595. if arg.is_out:
  596. maybe_clone(kwargs["out"])
  597. elif isinstance(args[idx], list):
  598. # Foreach case. (Possible optimization: if most of the
  599. # tensors need to be cloned, use a for each clone?)
  600. for t in args[idx]:
  601. maybe_clone(t)
  602. else:
  603. maybe_clone(args[idx])
  604. return func(*args, **kwargs)
  605. class _AllowMutationOnSavedContext:
  606. def __init__(self) -> None:
  607. self.cloned: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary()
  608. self.original: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary()
  609. self.tid_to_weakhandle: MutableMapping[_TID, _Handle] = WeakValueDictionary()
  610. self.sid_to_tid: dict[_SID, set[_TID]] = defaultdict(set)
  611. def clear(self) -> None:
  612. self.cloned.clear()
  613. self.original.clear()
  614. self.tid_to_weakhandle.clear()
  615. self.sid_to_tid.clear()
  616. @contextlib.contextmanager
  617. def allow_mutation_on_saved_tensors() -> Generator[
  618. _AllowMutationOnSavedContext, None, None
  619. ]:
  620. """Context manager under which mutating tensors saved for backward is allowed.
  621. Under this context manager, tensors saved for backward are cloned on mutation,
  622. so the original version can still be used during backward. Normally, mutating a tensor
  623. saved for backward will result in an error raised when it's used during backward.
  624. To ensure the correct behavior, both the forward and backward should be run under
  625. the same context manager.
  626. Returns:
  627. An _AllowMutationOnSavedContext object storing the state managed by this
  628. context manager. This object can be useful for debugging purposes. The state
  629. managed by the context manager is automatically cleared upon exiting.
  630. Example::
  631. >>> import torch
  632. >>> with torch.autograd.graph.allow_mutation_on_saved_tensors():
  633. ... # forward
  634. ... a = torch.ones(2, 3, requires_grad=True)
  635. ... b = a.clone()
  636. ... out = (b**2).sum()
  637. ... b.sin_()
  638. ... # backward
  639. ... out.sum().backward()
  640. ...
  641. tensor([[0.8415, 0.8415, 0.8415],
  642. [0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>)
  643. """
  644. global _allow_mutation_on_saved_tensors_enabled
  645. ctx = _AllowMutationOnSavedContext()
  646. with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx):
  647. try:
  648. if _allow_mutation_on_saved_tensors_enabled:
  649. raise RuntimeError(
  650. "allow_mutation_on_saved_tensors contexts cannot be nested"
  651. )
  652. _allow_mutation_on_saved_tensors_enabled = True
  653. yield ctx
  654. finally:
  655. ctx.clear()
  656. _allow_mutation_on_saved_tensors_enabled = False
  657. def _register_logging_hooks_on_whole_graph(
  658. t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
  659. ) -> Callable[[], None]:
  660. grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
  661. def iter_graph(roots: list[Node]) -> Iterator[Node]:
  662. if not roots:
  663. return
  664. seen: set[Node] = set()
  665. q: deque[Node] = deque()
  666. for node in roots:
  667. if node is not None:
  668. seen.add(node)
  669. q.append(node)
  670. while q:
  671. node = q.popleft()
  672. for fn, _ in node.next_functions:
  673. if fn in seen or fn is None:
  674. continue
  675. seen.add(fn)
  676. q.append(fn)
  677. yield node
  678. def fmt(t: Optional[torch.Tensor]) -> str:
  679. # Avoid circular import
  680. from torch.utils._dtype_abbrs import dtype_abbrs
  681. if t is None:
  682. return "None"
  683. return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]"
  684. def prehook(grad_outputs: Sequence[Optional[torch.Tensor]]) -> None:
  685. node = torch._C._current_autograd_node()
  686. grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]"
  687. log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}"
  688. log.debug(log_str)
  689. handles = [node.register_prehook(prehook) for node in iter_graph(grad_fns)]
  690. def unregister_hooks() -> None:
  691. for handle in handles:
  692. handle.remove()
  693. return unregister_hooks
  694. def _engine_run_backward(
  695. t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
  696. *args: Any,
  697. **kwargs: Any,
  698. ) -> tuple[torch.Tensor, ...]:
  699. attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
  700. if attach_logging_hooks:
  701. unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
  702. try:
  703. return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
  704. t_outputs, *args, **kwargs
  705. ) # Calls into the C++ engine to run the backward pass
  706. finally:
  707. if attach_logging_hooks:
  708. unregister_hooks() # type: ignore[possibly-undefined]