graphs.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. # pylint: disable=useless-parent-delegation
  2. from __future__ import annotations
  3. import gc
  4. import typing
  5. from collections.abc import Callable
  6. from typing import overload, TYPE_CHECKING, TypeAlias, Union
  7. from typing_extensions import ParamSpec, Self, TypeVar
  8. import torch
  9. from torch import Tensor
  10. if TYPE_CHECKING:
  11. # importing _POOL_HANDLE at runtime toplevel causes an import cycle
  12. from torch.cuda import _POOL_HANDLE
  13. from .._utils import _dummy_type
  14. __all__ = [
  15. "is_current_stream_capturing",
  16. "graph_pool_handle",
  17. "CUDAGraph",
  18. "graph",
  19. "make_graphed_callables",
  20. ]
  21. _R = TypeVar("_R")
  22. _P = ParamSpec("_P")
  23. if not hasattr(torch._C, "_CudaStreamBase"):
  24. # Define dummy base classes
  25. torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
  26. torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle")
  27. torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type(
  28. "_cuda_isCurrentStreamCapturing"
  29. )
  30. from torch._C import _cuda_isCurrentStreamCapturing, _CUDAGraph, _graph_pool_handle
  31. def is_current_stream_capturing() -> bool:
  32. r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
  33. If a CUDA context does not exist on the current device, returns False without initializing the context.
  34. """
  35. return _cuda_isCurrentStreamCapturing()
  36. # Python shim helps Sphinx process docstrings more reliably.
  37. def graph_pool_handle() -> _POOL_HANDLE:
  38. r"""Return an opaque token representing the id of a graph memory pool.
  39. See :ref:`Graph memory management<graph-memory-management>`.
  40. .. warning::
  41. This API is in beta and may change in future releases.
  42. """
  43. return torch.cuda._POOL_HANDLE(_graph_pool_handle())
  44. # Python shim helps Sphinx process docstrings more reliably.
  45. class CUDAGraph(_CUDAGraph):
  46. r"""Wrapper around a CUDA graph.
  47. Arguments:
  48. keep_graph (bool, optional): If ``keep_graph=False``, the
  49. cudaGraphExec_t will be instantiated on GPU at the end of
  50. ``capture_end`` and the underlying cudaGraph_t will be
  51. destroyed. Users who want to query or otherwise modify the
  52. underlying cudaGraph_t before instantiation can set
  53. ``keep_graph=True`` and access it via ``raw_cuda_graph`` after
  54. ``capture_end``. Note that the cudaGraphExec_t will not be
  55. instantiated at the end of ``capture_end`` in this
  56. case. Instead, it will be instantiated via an explicit called
  57. to ``instantiate`` or automatically on the first call to
  58. ``replay`` if ``instantiate`` was not already called. Calling
  59. ``instantiate`` manually before ``replay`` is recommended to
  60. prevent increased latency on the first call to ``replay``. It
  61. is allowed to modify the raw cudaGraph_t after first calling
  62. ``instantiate``, but the user must call ``instantiate`` again
  63. manually to make sure the instantiated graph has these
  64. changes. Pytorch has no means of tracking these changes.
  65. .. warning::
  66. This API is in beta and may change in future releases.
  67. """
  68. def __new__(cls, keep_graph: bool = False) -> Self:
  69. return super().__new__(cls, keep_graph)
  70. def capture_begin(
  71. self, pool: _POOL_HANDLE | None = None, capture_error_mode: str = "global"
  72. ) -> None:
  73. r"""Begin capturing CUDA work on the current stream.
  74. Typically, you shouldn't call ``capture_begin`` yourself.
  75. Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
  76. which call ``capture_begin`` internally.
  77. Arguments:
  78. pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
  79. :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
  80. with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
  81. capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
  82. Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
  83. may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
  84. actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
  85. unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
  86. """ # noqa: B950
  87. super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
  88. def capture_end(self) -> None:
  89. r"""End CUDA graph capture on the current stream.
  90. After ``capture_end``, ``replay`` may be called on this instance.
  91. Typically, you shouldn't call ``capture_end`` yourself.
  92. Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
  93. which call ``capture_end`` internally.
  94. """
  95. super().capture_end()
  96. def instantiate(self) -> None:
  97. r"""Instantiate the CUDA graph. Will be called by
  98. ``capture_end`` if ``keep_graph=False``, or by ``replay`` if
  99. ``keep_graph=True`` and ``instantiate`` has not already been
  100. explicitly called. Does not destroy the cudaGraph_t returned
  101. by ``raw_cuda_graph``.
  102. """
  103. super().instantiate()
  104. def replay(self) -> None:
  105. r"""Replay the CUDA work captured by this graph."""
  106. super().replay()
  107. def reset(self) -> None:
  108. r"""Delete the graph currently held by this instance."""
  109. super().reset()
  110. def pool(self) -> _POOL_HANDLE:
  111. r"""Return an opaque token representing the id of this graph's memory pool.
  112. This id can optionally be passed to another graph's ``capture_begin``,
  113. which hints the other graph may share the same memory pool.
  114. """
  115. return super().pool()
  116. def enable_debug_mode(self) -> None:
  117. r"""Enable debugging mode for CUDAGraph.debug_dump."""
  118. return super().enable_debug_mode()
  119. def debug_dump(self, debug_path: str) -> None:
  120. r"""
  121. Arguments:
  122. debug_path (required): Path to dump the graph to.
  123. Calls a debugging function to dump the graph if the debugging is
  124. enabled via CUDAGraph.enable_debug_mode()
  125. """
  126. return super().debug_dump(debug_path)
  127. def raw_cuda_graph(self) -> int:
  128. r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True.
  129. See the following for APIs for how to manipulate this object: `Graph Managmement <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html>`_ and `cuda-python Graph Management bindings <https://nvidia.github.io/cuda-python/cuda-bindings/latest/module/runtime.html#graph-management>`_
  130. """ # noqa: B950
  131. return super().raw_cuda_graph()
  132. def raw_cuda_graph_exec(self) -> int:
  133. r"""Returns the underlying cudaGraphExec_t. ``instantiate`` must have been called if ``keep_graph`` is True, or ``capture_end`` must have been called if ``keep_graph`` is False. If you call ``instantiate()`` after ``raw_cuda_graph_exec()``, the previously returned cudaGraphExec_t will be destroyed. It is your responsibility not to use this object after destruction.
  134. See the following for APIs for how to manipulate this object: `Graph Execution <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH__EXEC.html>`_ and `cuda-python Graph Execution bindings <https://nvidia.github.io/cuda-python/cuda-bindings/latest/module/runtime.html#graph-execution>`_
  135. """ # noqa: B950
  136. return super().raw_cuda_graph_exec()
  137. class graph:
  138. r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay.
  139. See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
  140. detailed use, and constraints.
  141. Arguments:
  142. cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture.
  143. pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or
  144. :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture
  145. may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
  146. stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
  147. If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
  148. capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
  149. Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
  150. may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
  151. actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
  152. unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
  153. .. note::
  154. For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
  155. used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
  156. .. warning::
  157. This API is in beta and may change in future releases.
  158. .. _cudaStreamCaptureMode:
  159. https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
  160. """ # noqa: B950
  161. default_capture_stream: torch.cuda.Stream | None = None
  162. def __init__(
  163. self,
  164. cuda_graph: CUDAGraph,
  165. pool: _POOL_HANDLE | None = None,
  166. stream: torch.cuda.Stream | None = None,
  167. capture_error_mode: str = "global",
  168. ):
  169. # Lazy-init of default_capture_stream helps avoid circular-import errors.
  170. # Not thread safe, but graphs already have the general (explicitly documented)
  171. # restriction that only one capture may be underway at a time in the process.
  172. if stream is None and self.__class__.default_capture_stream is None:
  173. self.__class__.default_capture_stream = torch.cuda.Stream()
  174. self.pool: tuple[()] | tuple[_POOL_HANDLE] = () if pool is None else (pool,)
  175. self.capture_stream = (
  176. stream if stream is not None else self.__class__.default_capture_stream
  177. )
  178. if self.capture_stream is None:
  179. raise AssertionError("capture_stream must not be None")
  180. self.stream_ctx = torch.cuda.stream(self.capture_stream)
  181. self.cuda_graph = cuda_graph
  182. self.capture_error_mode = capture_error_mode
  183. def __enter__(self) -> None:
  184. # Free as much memory as we can for the graph
  185. torch.cuda.synchronize()
  186. if torch.compiler.config.force_cudagraph_gc:
  187. # Originally we unconditionally garbage collected here. On one hand
  188. # that's nice because we have a chance to collect more memory, but
  189. # on the other hand it is REALLY expensive, especially for doing
  190. # multiple cudagraph captures in a row. In theory it will only help
  191. # when a dead python cycle is holding onto CUDA memory.
  192. gc.collect()
  193. torch.cuda.empty_cache()
  194. # pyrefly: ignore [missing-attribute]
  195. torch._C._host_emptyCache()
  196. # Stackoverflow seems comfortable with this pattern
  197. # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
  198. self.stream_ctx.__enter__()
  199. self.cuda_graph.capture_begin(
  200. # type: ignore[misc]
  201. *self.pool,
  202. # pyrefly: ignore [bad-keyword-argument]
  203. capture_error_mode=self.capture_error_mode,
  204. )
  205. def __exit__(self, *args: object) -> None:
  206. self.cuda_graph.capture_end()
  207. self.stream_ctx.__exit__(*args)
  208. # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
  209. _ModuleOrCallable: TypeAlias = Union["torch.nn.Module", Callable[..., object]]
  210. @overload
  211. def make_graphed_callables(
  212. callables: _ModuleOrCallable,
  213. sample_args: tuple[Tensor, ...],
  214. num_warmup_iters: int = 3,
  215. allow_unused_input: bool = False,
  216. pool: _POOL_HANDLE | None = None,
  217. ) -> _ModuleOrCallable: ...
  218. @overload
  219. def make_graphed_callables(
  220. callables: tuple[_ModuleOrCallable, ...],
  221. sample_args: tuple[tuple[Tensor, ...], ...],
  222. num_warmup_iters: int = 3,
  223. allow_unused_input: bool = False,
  224. pool: _POOL_HANDLE | None = None,
  225. ) -> tuple[_ModuleOrCallable, ...]: ...
  226. def make_graphed_callables(
  227. callables: _ModuleOrCallable | tuple[_ModuleOrCallable, ...],
  228. sample_args: tuple[Tensor, ...] | tuple[tuple[Tensor, ...], ...],
  229. num_warmup_iters: int = 3,
  230. allow_unused_input: bool = False,
  231. pool: _POOL_HANDLE | None = None,
  232. ) -> _ModuleOrCallable | tuple[_ModuleOrCallable, ...]:
  233. r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
  234. Each graphed callable's forward pass runs its source callable's
  235. forward CUDA work as a CUDA graph inside a single autograd node.
  236. The graphed callable's forward pass also appends
  237. a backward node to the autograd graph. During backward, this node runs the
  238. callable's backward work as a CUDA graph.
  239. Therefore, each graphed callable should be a drop-in replacement for its source callable
  240. in an autograd-enabled training loop.
  241. See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.
  242. If you pass a tuple of several callables, their captures will use the same memory pool.
  243. See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.
  244. Arguments:
  245. callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
  246. See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
  247. is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order
  248. they'll run in the live workload.
  249. sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
  250. If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
  251. If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
  252. num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
  253. 11 iterations for warm up. Default: ``3``.
  254. allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
  255. (and therefore their grad is always zero) is an error. Defaults to False.
  256. pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
  257. :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
  258. with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
  259. .. note::
  260. The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
  261. that's expected for the corresponding real input in the training loop.
  262. .. warning::
  263. This API is in beta and may change in future releases.
  264. .. warning::
  265. ``sample_args`` for each callable must contain only Tensors. Other types are not allowed.
  266. .. warning::
  267. Returned callables do not support higher order differentiation (e.g., double backward).
  268. .. warning::
  269. In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
  270. may be trainable. Buffers must have ``requires_grad=False``.
  271. .. warning::
  272. After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
  273. you may not add or remove any of that Module's parameters or buffers.
  274. .. warning::
  275. :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
  276. registered on them at the time they are passed. However, registering hooks on modules *after* passing them
  277. through :func:`~torch.cuda.make_graphed_callables` is allowed.
  278. .. warning::
  279. When running a graphed callable, you must pass its arguments in the same order and format
  280. they appeared in that callable's ``sample_args``.
  281. .. warning::
  282. The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
  283. caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
  284. """
  285. if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
  286. raise RuntimeError(
  287. "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
  288. )
  289. just_one_callable = False
  290. _sample_args: tuple[tuple[Tensor, ...], ...]
  291. if not isinstance(callables, tuple):
  292. just_one_callable = True
  293. callables = (callables,)
  294. _sample_args = (typing.cast(tuple[Tensor, ...], sample_args),)
  295. else:
  296. _sample_args = typing.cast(tuple[tuple[Tensor, ...], ...], sample_args)
  297. flatten_sample_args = []
  298. for c, args in zip(callables, _sample_args):
  299. if isinstance(c, torch.nn.Module):
  300. if not (
  301. len(c._backward_hooks) == 0
  302. and len(c._forward_hooks) == 0
  303. and len(c._forward_pre_hooks) == 0
  304. ):
  305. raise AssertionError(
  306. "Modules must not have hooks registered at the time they are passed. However, registering hooks "
  307. + "on modules after passing them through make_graphed_callables is allowed."
  308. )
  309. if not all(b.requires_grad is False for b in c.buffers()):
  310. raise AssertionError(
  311. "In any :class:`~torch.nn.Module` passed to "
  312. + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have "
  313. + "``requires_grad=False``."
  314. )
  315. flatten_arg = torch.utils._pytree.arg_tree_leaves(*args)
  316. flatten_sample_args.append(tuple(flatten_arg))
  317. if not all(isinstance(arg, torch.Tensor) for arg in flatten_arg):
  318. raise AssertionError(
  319. "In the beta API, sample_args "
  320. + "for each callable must contain only Tensors. Other types are not allowed."
  321. )
  322. # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
  323. # passes to forward (ie, its sample_args) AND the module's parameter attributes.
  324. per_callable_len_user_args = [len(args) for args in flatten_sample_args]
  325. per_callable_module_params = [
  326. tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
  327. for c in callables
  328. ]
  329. per_callable_static_input_surfaces = [
  330. flatten_sample_args[i] + per_callable_module_params[i]
  331. for i in range(len(callables))
  332. ]
  333. fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
  334. bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
  335. mempool = graph_pool_handle() if pool is None else pool
  336. # Warmup
  337. # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
  338. # from ending up in any captures.
  339. torch.cuda.synchronize()
  340. with torch.cuda.stream(torch.cuda.Stream()):
  341. for func, args, static_input_surface in zip(
  342. callables, _sample_args, per_callable_static_input_surfaces
  343. ):
  344. grad_inputs, outputs, outputs_grad = None, None, None
  345. for _ in range(num_warmup_iters):
  346. outputs = torch.utils._pytree.tree_leaves(func(*args))
  347. outputs_grad = tuple(o for o in outputs if o.requires_grad)
  348. if len(outputs_grad) > 0:
  349. grad_inputs = torch.autograd.grad(
  350. outputs=outputs_grad,
  351. inputs=tuple(
  352. i for i in static_input_surface if i.requires_grad
  353. ),
  354. grad_outputs=tuple(
  355. torch.empty_like(o) for o in outputs if o.requires_grad
  356. ),
  357. only_inputs=True,
  358. allow_unused=allow_unused_input,
  359. )
  360. for v in [outputs, outputs_grad, grad_inputs]:
  361. del v
  362. torch.cuda.synchronize()
  363. # All captures here share a mempool. To avoid replays corrupting each other's memory,
  364. # the safest approach is to capture all passes in the same order they'll run:
  365. # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
  366. # Capture forward graphs
  367. per_callable_static_outputs = []
  368. per_callable_output_unflatten_spec = []
  369. for func, args, fwd_graph in zip(callables, _sample_args, fwd_graphs):
  370. with torch.cuda.graph(fwd_graph, pool=mempool):
  371. func_outputs = func(*args)
  372. flatten_outputs, spec = torch.utils._pytree.tree_flatten(func_outputs)
  373. per_callable_static_outputs.append(tuple(flatten_outputs))
  374. per_callable_output_unflatten_spec.append(spec)
  375. # Capture backward graphs in reverse order
  376. per_callable_static_grad_outputs = []
  377. per_callable_static_grad_inputs = []
  378. for static_input_surface, static_outputs, bwd_graph in zip(
  379. reversed(per_callable_static_input_surfaces),
  380. reversed(per_callable_static_outputs),
  381. reversed(bwd_graphs),
  382. ):
  383. # For now, assumes all static_outputs require grad
  384. # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
  385. static_grad_outputs = tuple(
  386. torch.empty_like(o) if o.requires_grad else None for o in static_outputs
  387. )
  388. outputs_grad = tuple(o for o in static_outputs if o.requires_grad)
  389. grad_inputs = None
  390. if len(outputs_grad) > 0:
  391. with torch.cuda.graph(bwd_graph, pool=mempool):
  392. grad_inputs = torch.autograd.grad(
  393. outputs=outputs_grad,
  394. inputs=tuple(i for i in static_input_surface if i.requires_grad),
  395. grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
  396. only_inputs=True,
  397. allow_unused=allow_unused_input,
  398. )
  399. # Constructs a tuple suitable for returning from Graphed.backward:
  400. # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
  401. # I couldn't think of a slick one-liner for this pattern.
  402. static_grad_inputs = []
  403. grad_idx = 0
  404. for arg in static_input_surface:
  405. if arg.requires_grad and grad_inputs is not None:
  406. static_grad_inputs.append(grad_inputs[grad_idx])
  407. grad_idx += 1
  408. else:
  409. static_grad_inputs.append(None) # type: ignore[arg-type]
  410. static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
  411. per_callable_static_grad_outputs.append(static_grad_outputs)
  412. per_callable_static_grad_inputs.append(static_grad_inputs)
  413. # Reverses the most recent two lists
  414. per_callable_static_grad_outputs.reverse()
  415. per_callable_static_grad_inputs.reverse()
  416. # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
  417. def make_graphed_autograd_function(
  418. fwd_graph: CUDAGraph,
  419. bwd_graph: CUDAGraph,
  420. module_params: tuple[torch.nn.Parameter, ...],
  421. len_user_args: int,
  422. output_unflatten_spec: torch.utils._pytree.TreeSpec,
  423. static_input_surface: tuple[Tensor, ...],
  424. static_outputs: tuple[Tensor, ...],
  425. static_grad_outputs: tuple[Tensor | None, ...],
  426. static_grad_inputs: tuple[Tensor, ...],
  427. ) -> Callable[..., object]:
  428. class Graphed(torch.autograd.Function):
  429. @staticmethod
  430. # pyrefly: ignore [bad-override]
  431. def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]:
  432. # At this stage, only the user args may (potentially) be new tensors.
  433. for i in range(len_user_args):
  434. if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
  435. static_input_surface[i].copy_(inputs[i])
  436. fwd_graph.replay()
  437. if not isinstance(static_outputs, tuple):
  438. raise AssertionError(
  439. f"static_outputs must be tuple, got {type(static_outputs)}"
  440. )
  441. return tuple(o.detach() for o in static_outputs)
  442. @staticmethod
  443. @torch.autograd.function.once_differentiable
  444. # pyrefly: ignore [bad-override]
  445. def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]:
  446. if len(grads) != len(static_grad_outputs):
  447. raise AssertionError(
  448. f"len(grads)={len(grads)} != len(static_grad_outputs)={len(static_grad_outputs)}"
  449. )
  450. for g, grad in zip(static_grad_outputs, grads):
  451. if g is not None:
  452. # don't copy if autograd gods have been kind and the
  453. # incoming grad is already in the right place
  454. if g.data_ptr() != grad.data_ptr():
  455. g.copy_(grad)
  456. bwd_graph.replay()
  457. # Input args that didn't require grad expect a None gradient.
  458. if not isinstance(static_grad_inputs, tuple):
  459. raise AssertionError(
  460. f"static_grad_inputs must be tuple, got {type(static_grad_inputs)}"
  461. )
  462. return tuple(
  463. # pyrefly: ignore [bad-argument-type]
  464. b.detach() if b is not None else b
  465. for b in static_grad_inputs
  466. )
  467. def functionalized(*user_args: object) -> object:
  468. # Runs the autograd function with inputs == all inputs to the graph that might require grad
  469. # (explicit user args + module parameters)
  470. # Assumes module params didn't change since capture.
  471. flatten_user_args = torch.utils._pytree.arg_tree_leaves(*user_args)
  472. out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
  473. return torch.utils._pytree.tree_unflatten(out, output_unflatten_spec)
  474. return functionalized
  475. # Put together the final graphed callables
  476. ret: list[_ModuleOrCallable] = []
  477. for i, func in enumerate(callables):
  478. graphed = make_graphed_autograd_function(
  479. fwd_graphs[i],
  480. bwd_graphs[i],
  481. per_callable_module_params[i],
  482. per_callable_len_user_args[i],
  483. per_callable_output_unflatten_spec[i],
  484. per_callable_static_input_surfaces[i],
  485. per_callable_static_outputs[i],
  486. per_callable_static_grad_outputs[i],
  487. per_callable_static_grad_inputs[i],
  488. )
  489. if isinstance(func, torch.nn.Module):
  490. def make_graphed_forward(
  491. func: torch.nn.Module,
  492. graph_training_state: bool,
  493. graphed: Callable[_P, _R],
  494. orig_fwd: Callable[_P, _R],
  495. ) -> Callable[_P, _R]:
  496. def new_fwd(*user_args: _P.args, **user_kwargs: _P.kwargs) -> _R:
  497. # If the module's training-or-eval state matches what we graphed,
  498. # run the graph, otherwise run the original forward method
  499. if func.training == graph_training_state:
  500. return graphed(*user_args, **user_kwargs)
  501. else:
  502. return orig_fwd(*user_args, **user_kwargs)
  503. return new_fwd
  504. func.forward = make_graphed_forward(
  505. func, func.training, graphed, func.forward
  506. )
  507. ret.append(func)
  508. else:
  509. ret.append(graphed)
  510. if just_one_callable:
  511. return ret[0]
  512. return tuple(ret)