graphs.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. from __future__ import annotations
  2. import typing
  3. from collections.abc import Callable
  4. from typing import Optional, overload, TYPE_CHECKING, TypeAlias, Union
  5. from typing_extensions import ParamSpec, Self, TypeVar
  6. import torch
  7. from torch import Tensor
  8. if TYPE_CHECKING:
  9. from torch.xpu import _POOL_HANDLE
  10. from .._utils import _dummy_type
  11. __all__ = [
  12. "is_current_stream_capturing",
  13. "graph_pool_handle",
  14. "XPUGraph",
  15. "graph",
  16. "make_graphed_callables",
  17. ]
  18. _R = TypeVar("_R")
  19. _P = ParamSpec("_P")
  20. if not hasattr(torch._C, "_XpuStreamBase"):
  21. # Define dummy base classes
  22. torch._C.__dict__["_XPUGraph"] = _dummy_type("_XPUGraph")
  23. torch._C.__dict__["_xpu_graph_pool_handle"] = _dummy_type("_xpu_graph_pool_handle")
  24. torch._C.__dict__["_xpu_isCurrentStreamCapturing"] = _dummy_type(
  25. "_xpu_isCurrentStreamCapturing"
  26. )
  27. # pyrefly: ignore [missing-module-attribute]
  28. from torch._C import _xpu_graph_pool_handle, _xpu_isCurrentStreamCapturing, _XPUGraph
  29. def is_current_stream_capturing() -> bool:
  30. r"""Return True if XPU graph capture is underway on the current XPU stream, False otherwise.
  31. If a XPU context does not exist on the current device, returns False without initializing the context.
  32. """
  33. return _xpu_isCurrentStreamCapturing()
  34. def graph_pool_handle() -> _POOL_HANDLE:
  35. r"""Return an opaque token representing the id of a graph memory pool."""
  36. return torch.xpu._POOL_HANDLE(_xpu_graph_pool_handle())
  37. class XPUGraph(_XPUGraph):
  38. r"""Wrapper around a XPU graph.
  39. Arguments:
  40. keep_graph (bool, optional): If ``keep_graph=False``, the
  41. executable command graph will be instantiated on GPU at the end of
  42. ``capture_end`` and the underlying modifiable command graph will be
  43. destroyed. Note that the executable command graph will not be
  44. instantiated at the end of ``capture_end`` in this
  45. case. Instead, it will be instantiated via an explicit called
  46. to ``instantiate`` or automatically on the first call to
  47. ``replay`` if ``instantiate`` was not already called. Calling
  48. ``instantiate`` manually before ``replay`` is recommended to
  49. prevent increased latency on the first call to ``replay``.
  50. """
  51. def __new__(cls, keep_graph: bool = False) -> Self:
  52. return super().__new__(cls, keep_graph)
  53. def capture_begin(self, pool: Optional[_POOL_HANDLE] = None) -> None:
  54. r"""Begin capturing XPU work on the current xpu stream.
  55. Typically, you shouldn't call ``capture_begin`` yourself.
  56. Use :class:`~torch.xpu.graph`, which call ``capture_begin`` internally.
  57. Arguments:
  58. pool (optional): Token (returned by :func:`~torch.xpu.graph_pool_handle` or
  59. :meth:`other_Graph_instance.pool()<torch.xpu.XPUGraph.pool>`) that hints this graph may share memory
  60. with the indicated pool.
  61. """
  62. super().capture_begin(pool=pool)
  63. def capture_end(self) -> None:
  64. r"""End XPU graph capture on the current stream.
  65. After ``capture_end``, ``replay`` may be called on this instance.
  66. Typically, you shouldn't call ``capture_end`` yourself.
  67. Use :class:`~torch.xpu.graph`, which call ``capture_end`` internally.
  68. """
  69. super().capture_end()
  70. def instantiate(self) -> None:
  71. r"""Instantiate the XPU graph. Will be called by
  72. ``capture_end`` if ``keep_graph=False``, or by ``replay`` if
  73. ``keep_graph=True`` and ``instantiate`` has not already been
  74. explicitly called. Does not destroy the xpu modify command graph returned
  75. by ``raw_xpu_graph``.
  76. """
  77. super().instantiate()
  78. def replay(self) -> None:
  79. r"""Replay the XPU work captured by this graph."""
  80. super().replay()
  81. def reset(self) -> None:
  82. r"""Delete the graph currently held by this instance."""
  83. super().reset()
  84. def pool(self) -> _POOL_HANDLE:
  85. r"""Return an opaque token representing the id of this graph's memory pool.
  86. This id can optionally be passed to another graph's ``capture_begin``,
  87. which hints the other graph may share the same memory pool.
  88. """
  89. return super().pool()
  90. def enable_debug_mode(self) -> None:
  91. r"""Enable debugging mode for XPUGraph.debug_dump."""
  92. return super().enable_debug_mode()
  93. def debug_dump(self, debug_path: str) -> None:
  94. r"""
  95. Arguments:
  96. debug_path (required): Path to dump the graph to.
  97. Calls a debugging function to dump the graph if the debugging is
  98. enabled via XPUGraph.enable_debug_mode()
  99. """
  100. return super().debug_dump(debug_path)
  101. def raw_xpu_graph(self) -> int:
  102. r"""Returns the underlying xpuGraph_t. ``keep_graph`` must be True.
  103. XPU doesn't provide APIs to manipulate this object.
  104. """ # noqa: B950
  105. return super().raw_xpu_graph()
  106. def raw_xpu_graph_exec(self) -> int:
  107. r"""Returns the underlying xpuGraphExec_t. ``instantiate`` must have been called if ``keep_graph`` is True, or ``capture_end`` must have been called if ``keep_graph`` is False. If you call ``instantiate()`` after ``raw_xpu_graph_exec()``, the previously returned xpuGraphExec_t will be destroyed. It is your responsibility not to use this object after destruction.
  108. XPU doesn't provide APIs to manipulate this object.
  109. """ # noqa: B950
  110. return super().raw_xpu_graph_exec()
  111. class graph:
  112. r"""Context-manager that captures XPU work into a :class:`torch.xpu.XPUGraph` object for later replay.
  113. Arguments:
  114. xpu_graph (torch.xpu.XPUGraph): Graph object used for capture.
  115. pool (optional): Opaque token (returned by a call to :func:`~torch.xpu.graph_pool_handle()` or
  116. :meth:`other_Graph_instance.pool()<torch.xpu.XPUGraph.pool>`) hinting this graph's capture
  117. may share memory from the specified pool.
  118. stream (torch.xpu.Stream, optional): If supplied, will be set as the current stream in the context.
  119. If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
  120. .. note::
  121. For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
  122. used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
  123. """ # noqa: B950
  124. default_capture_stream: Optional[torch.xpu.Stream] = None
  125. def __init__(
  126. self,
  127. xpu_graph: XPUGraph,
  128. pool: Optional[_POOL_HANDLE] = None,
  129. stream: Optional[torch.xpu.Stream] = None,
  130. ):
  131. # Lazy-init of default_capture_stream helps avoid circular-import errors.
  132. # Not thread safe, but graphs already have the general (explicitly documented)
  133. # restriction that only one capture may be underway at a time in the process.
  134. if self.__class__.default_capture_stream is None:
  135. self.__class__.default_capture_stream = torch.xpu.Stream()
  136. self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = (
  137. () if pool is None else (pool,)
  138. )
  139. self.capture_stream = (
  140. stream if stream is not None else self.__class__.default_capture_stream
  141. )
  142. if self.capture_stream is None:
  143. raise AssertionError("capture_stream must not be None")
  144. self.stream_ctx = self.capture_stream
  145. self.xpu_graph = xpu_graph
  146. def __enter__(self) -> None:
  147. # Free as much memory as we can for the graph
  148. torch.xpu.synchronize()
  149. torch.xpu.empty_cache()
  150. self.stream_ctx.__enter__()
  151. self.xpu_graph.capture_begin(*self.pool)
  152. def __exit__(self, *args: object) -> None:
  153. self.xpu_graph.capture_end()
  154. self.stream_ctx.__exit__(*args)
  155. _ModuleOrCallable: TypeAlias = Union["torch.nn.Module", Callable[..., object]]
  156. @overload
  157. def make_graphed_callables(
  158. callables: _ModuleOrCallable,
  159. sample_args: tuple[Tensor, ...],
  160. num_warmup_iters: int = 3,
  161. allow_unused_input: bool = False,
  162. pool: Optional[_POOL_HANDLE] = None,
  163. ) -> _ModuleOrCallable: ...
  164. @overload
  165. def make_graphed_callables(
  166. callables: tuple[_ModuleOrCallable, ...],
  167. sample_args: tuple[tuple[Tensor, ...], ...],
  168. num_warmup_iters: int = 3,
  169. allow_unused_input: bool = False,
  170. pool: Optional[_POOL_HANDLE] = None,
  171. ) -> tuple[_ModuleOrCallable, ...]: ...
  172. def make_graphed_callables(
  173. callables: Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]],
  174. sample_args: Union[tuple[Tensor, ...], tuple[tuple[Tensor, ...], ...]],
  175. num_warmup_iters: int = 3,
  176. allow_unused_input: bool = False,
  177. pool: Optional[_POOL_HANDLE] = None,
  178. ) -> Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]]:
  179. r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
  180. Each graphed callable's forward pass runs its source callable's
  181. forward XPU work as a XPU graph inside a single autograd node.
  182. The graphed callable's forward pass also appends
  183. a backward node to the autograd graph. During backward, this node runs the
  184. callable's backward work as a XPU graph.
  185. Therefore, each graphed callable should be a drop-in replacement for its source callable
  186. in an autograd-enabled training loop.
  187. See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.
  188. If you pass a tuple of several callables, their captures will use the same memory pool.
  189. Arguments:
  190. callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
  191. If you pass a tuple of callables, their order in the tuple must be the same order they'll run
  192. in the live workload.
  193. sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
  194. If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
  195. If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
  196. num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
  197. 11 iterations for warm up. Default: ``3``.
  198. allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
  199. (and therefore their grad is always zero) is an error. Defaults to False.
  200. pool (optional): Token (returned by :func:`~torch.xpu.graph_pool_handle` or
  201. :meth:`other_Graph_instance.pool()<torch.xpu.XPUGraph.pool>`) that hints this graph may share memory
  202. with the indicated pool.
  203. .. note::
  204. The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
  205. that's expected for the corresponding real input in the training loop.
  206. .. warning::
  207. This API is in beta and may change in future releases.
  208. .. warning::
  209. ``sample_args`` for each callable must contain only Tensors. Other types are not allowed.
  210. .. warning::
  211. Returned callables do not support higher order differentiation (e.g., double backward).
  212. .. warning::
  213. In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
  214. may be trainable. Buffers must have ``requires_grad=False``.
  215. .. warning::
  216. After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
  217. you may not add or remove any of that Module's parameters or buffers.
  218. .. warning::
  219. :class:`torch.nn.Module`\s passed to :func:`~torch.xpu.make_graphed_callables` must not have module hooks
  220. registered on them at the time they are passed. However, registering hooks on modules *after* passing them
  221. through :func:`~torch.xpu.make_graphed_callables` is allowed.
  222. .. warning::
  223. When running a graphed callable, you must pass its arguments in the same order and format
  224. they appeared in that callable's ``sample_args``.
  225. .. warning::
  226. The automatic mixed precision is supported in :func:`~torch.xpu.make_graphed_callables` only with disabled
  227. caching. The context manager `torch.amp.autocast()` must have `cache_enabled=False`.
  228. """
  229. if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
  230. raise RuntimeError(
  231. "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
  232. )
  233. just_one_callable = False
  234. _sample_args: tuple[tuple[Tensor, ...], ...]
  235. if not isinstance(callables, tuple):
  236. just_one_callable = True
  237. callables = (callables,)
  238. _sample_args = (typing.cast(tuple[Tensor, ...], sample_args),)
  239. else:
  240. _sample_args = typing.cast(tuple[tuple[Tensor, ...], ...], sample_args)
  241. flatten_sample_args = []
  242. for c, args in zip(callables, _sample_args):
  243. if isinstance(c, torch.nn.Module):
  244. if not (
  245. len(c._backward_hooks) == 0
  246. and len(c._forward_hooks) == 0
  247. and len(c._forward_pre_hooks) == 0
  248. ):
  249. raise RuntimeError(
  250. "Modules must not have hooks registered at the time they are passed. However, registering hooks "
  251. + "on modules after passing them through make_graphed_callables is allowed."
  252. )
  253. if not all(b.requires_grad is False for b in c.buffers()):
  254. raise RuntimeError(
  255. "In any :class:`~torch.nn.Module` passed to "
  256. + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have "
  257. + "``requires_grad=False``."
  258. )
  259. flatten_arg = torch.utils._pytree.arg_tree_leaves(*args)
  260. flatten_sample_args.append(tuple(flatten_arg))
  261. if not all(isinstance(arg, torch.Tensor) for arg in flatten_arg):
  262. raise TypeError(
  263. "In the beta API, sample_args "
  264. + "for each callable must contain only Tensors. Other types are not allowed."
  265. )
  266. # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
  267. # passes to forward (ie, its sample_args) AND the module's parameter attributes.
  268. per_callable_len_user_args = [len(args) for args in flatten_sample_args]
  269. per_callable_module_params = [
  270. tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
  271. for c in callables
  272. ]
  273. per_callable_static_input_surfaces = [
  274. flatten_sample_args[i] + per_callable_module_params[i]
  275. for i in range(len(callables))
  276. ]
  277. fwd_graphs = [torch.xpu.XPUGraph() for _ in range(len(callables))]
  278. bwd_graphs = [torch.xpu.XPUGraph() for _ in range(len(callables))]
  279. mempool = graph_pool_handle() if pool is None else pool
  280. # Warmup
  281. torch.xpu.synchronize()
  282. with torch.xpu.stream(torch.xpu.Stream()):
  283. for func, args, static_input_surface in zip(
  284. callables, _sample_args, per_callable_static_input_surfaces
  285. ):
  286. grad_inputs, outputs, outputs_grad = None, None, None
  287. for _ in range(num_warmup_iters):
  288. outputs = torch.utils._pytree.tree_leaves(func(*args))
  289. outputs_grad = tuple(o for o in outputs if o.requires_grad)
  290. if len(outputs_grad) > 0:
  291. grad_inputs = torch.autograd.grad(
  292. outputs=outputs_grad,
  293. inputs=tuple(
  294. i for i in static_input_surface if i.requires_grad
  295. ),
  296. grad_outputs=tuple(
  297. torch.empty_like(o) for o in outputs if o.requires_grad
  298. ),
  299. only_inputs=True,
  300. allow_unused=allow_unused_input,
  301. )
  302. for v in [outputs, outputs_grad, grad_inputs]:
  303. del v
  304. torch.xpu.synchronize()
  305. # Capture forward graphs
  306. per_callable_static_outputs = []
  307. per_callable_output_unflatten_spec = []
  308. for func, args, fwd_graph in zip(callables, _sample_args, fwd_graphs):
  309. # each graph uses the same mempool
  310. with torch.xpu.graph(fwd_graph, pool=mempool):
  311. func_outputs = func(*args)
  312. flatten_outputs, spec = torch.utils._pytree.tree_flatten(func_outputs)
  313. per_callable_static_outputs.append(tuple(flatten_outputs))
  314. per_callable_output_unflatten_spec.append(spec)
  315. # Capture backward graphs in reverse order
  316. per_callable_static_grad_outputs = []
  317. per_callable_static_grad_inputs = []
  318. for static_input_surface, static_outputs, bwd_graph in zip(
  319. reversed(per_callable_static_input_surfaces),
  320. reversed(per_callable_static_outputs),
  321. reversed(bwd_graphs),
  322. ):
  323. static_grad_outputs = tuple(
  324. torch.empty_like(o) if o.requires_grad else None for o in static_outputs
  325. )
  326. outputs_grad = tuple(o for o in static_outputs if o.requires_grad)
  327. grad_inputs = None
  328. if len(outputs_grad) > 0:
  329. with torch.xpu.graph(bwd_graph, pool=mempool):
  330. grad_inputs = torch.autograd.grad(
  331. outputs=outputs_grad,
  332. inputs=tuple(i for i in static_input_surface if i.requires_grad),
  333. grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
  334. only_inputs=True,
  335. allow_unused=allow_unused_input,
  336. )
  337. static_grad_inputs = []
  338. grad_idx = 0
  339. for arg in static_input_surface:
  340. if arg.requires_grad and grad_inputs is not None:
  341. static_grad_inputs.append(grad_inputs[grad_idx])
  342. grad_idx += 1
  343. else:
  344. static_grad_inputs.append(None) # type: ignore[arg-type]
  345. static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
  346. per_callable_static_grad_outputs.append(static_grad_outputs)
  347. per_callable_static_grad_inputs.append(static_grad_inputs)
  348. # Reverses the most recent two lists
  349. per_callable_static_grad_outputs.reverse()
  350. per_callable_static_grad_inputs.reverse()
  351. def make_graphed_autograd_function(
  352. fwd_graph: XPUGraph,
  353. bwd_graph: XPUGraph,
  354. module_params: tuple[torch.nn.Parameter, ...],
  355. len_user_args: int,
  356. output_unflatten_spec: torch.utils._pytree.TreeSpec,
  357. static_input_surface: tuple[Tensor, ...],
  358. static_outputs: tuple[Tensor, ...],
  359. static_grad_outputs: tuple[Optional[Tensor], ...],
  360. static_grad_inputs: tuple[Tensor, ...],
  361. ) -> Callable[..., object]:
  362. class Graphed(torch.autograd.Function):
  363. @staticmethod
  364. # pyrefly: ignore [bad-override]
  365. def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]:
  366. # At this stage, only the user args may (potentially) be new tensors.
  367. for i in range(len_user_args):
  368. if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
  369. static_input_surface[i].copy_(inputs[i])
  370. fwd_graph.replay()
  371. if not isinstance(static_outputs, tuple):
  372. raise RuntimeError("static_outputs must be a tuple")
  373. return tuple(o.detach() for o in static_outputs)
  374. @staticmethod
  375. @torch.autograd.function.once_differentiable
  376. # pyrefly: ignore [bad-override]
  377. def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]:
  378. if len(grads) != len(static_grad_outputs):
  379. raise RuntimeError(
  380. f"Expected {len(static_grad_outputs)} gradients but got {len(grads)}"
  381. )
  382. for g, grad in zip(static_grad_outputs, grads):
  383. if g is not None:
  384. if g.data_ptr() != grad.data_ptr():
  385. g.copy_(grad)
  386. bwd_graph.replay()
  387. if not isinstance(static_grad_inputs, tuple):
  388. raise RuntimeError("static_grad_inputs must be a tuple")
  389. return tuple(
  390. # pyrefly: ignore [bad-argument-type]
  391. b.detach() if b is not None else b
  392. for b in static_grad_inputs
  393. )
  394. def functionalized(*user_args: object) -> object:
  395. # Runs the new autograd function which replays the XPU graphs
  396. flatten_user_args = torch.utils._pytree.arg_tree_leaves(*user_args)
  397. out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
  398. return torch.utils._pytree.tree_unflatten(out, output_unflatten_spec)
  399. return functionalized
  400. ret: list[_ModuleOrCallable] = []
  401. for i, func in enumerate(callables):
  402. graphed = make_graphed_autograd_function(
  403. fwd_graphs[i],
  404. bwd_graphs[i],
  405. per_callable_module_params[i],
  406. per_callable_len_user_args[i],
  407. per_callable_output_unflatten_spec[i],
  408. per_callable_static_input_surfaces[i],
  409. per_callable_static_outputs[i],
  410. per_callable_static_grad_outputs[i],
  411. per_callable_static_grad_inputs[i],
  412. )
  413. if isinstance(func, torch.nn.Module):
  414. def make_graphed_forward(
  415. func: torch.nn.Module,
  416. graph_training_state: bool,
  417. graphed: Callable[_P, _R],
  418. orig_fwd: Callable[_P, _R],
  419. ) -> Callable[_P, _R]:
  420. def new_fwd(*user_args: _P.args, **user_kwargs: _P.kwargs) -> _R:
  421. if func.training == graph_training_state:
  422. return graphed(*user_args, **user_kwargs)
  423. else:
  424. return orig_fwd(*user_args, **user_kwargs)
  425. return new_fwd
  426. func.forward = make_graphed_forward(
  427. func, func.training, graphed, func.forward
  428. )
  429. ret.append(func)
  430. else:
  431. ret.append(graphed)
  432. if just_one_callable:
  433. return ret[0]
  434. return tuple(ret)