cudagraphs.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. """
  2. This module implements CUDA graphs support for TorchDynamo backends.
  3. CUDA graphs allow for capturing and replaying GPU operations, which can significantly
  4. reduce CPU overhead in GPU-accelerated PyTorch models. This module provides:
  5. - CUDA graph creation and management for both forward and backward passes
  6. - Input mutation detection and handling
  7. - Device compatibility checking
  8. - Stack trace management for debugging
  9. - Integration with TorchInductor's cudagraph trees
  10. The backend supports two main modes:
  11. 1. cudagraphs: Full CUDA graph support with both forward and backward pass optimization
  12. 2. cudagraphs_inner: Lower-level CUDA graph implementation used for benchmarking
  13. Key components:
  14. - CudagraphsBackend: Main backend class for CUDA graph integration
  15. - Mutation detection utilities to ensure graph safety
  16. - Device mapping and compatibility checks
  17. - Stack trace collection for debugging
  18. """
  19. import functools
  20. from collections import defaultdict
  21. from collections.abc import Callable, Sequence
  22. from typing import Any, Optional
  23. import torch
  24. import torch.fx
  25. from torch._dynamo import config
  26. from torch._dynamo.backends.common import aot_autograd
  27. from torch._dynamo.backends.debugging import boxed_nop
  28. from torch._inductor.cudagraph_utils import (
  29. BoxedDeviceIndex,
  30. check_multiple_devices_or_any_cpu_nodes,
  31. format_default_skip_message,
  32. get_mutation_stack_trace,
  33. get_placeholder_info,
  34. log_cudagraph_skip_and_bump_counter,
  35. )
  36. from torch._inductor.utils import (
  37. BoxedBool,
  38. count_tangents,
  39. get_first_incompatible_cudagraph_node,
  40. num_fw_fixed_arguments,
  41. output_node,
  42. )
  43. from torch.multiprocessing.reductions import StorageWeakRef
  44. from .registry import register_backend
  45. def find_input_mutations(g: torch.fx.Graph) -> set[int]:
  46. def meta_fk(meta: dict[str, Any]) -> Any:
  47. return meta["val"] if "val" in meta else meta["fake_result"]
  48. inputs = defaultdict(set)
  49. input_idx = 0
  50. mutated_inputs = set()
  51. for n in g.nodes:
  52. if n.op == "placeholder":
  53. if isinstance(meta_fk(n.meta), torch.Tensor):
  54. inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
  55. input_idx += 1
  56. elif n.op == "call_function":
  57. if not hasattr(n.target, "_schema"):
  58. continue
  59. schema = n.target._schema
  60. for i, arg in enumerate(schema.arguments):
  61. if i < len(n.args):
  62. argument = n.args[i]
  63. else:
  64. if arg.name not in n.kwargs:
  65. continue
  66. argument = n.kwargs[arg.name]
  67. mut_arg = False
  68. if arg.alias_info:
  69. if arg.alias_info.is_write:
  70. mut_arg = True
  71. if mut_arg:
  72. # TODO: not correct for args that contain tensors in a struct
  73. # like list
  74. mutated_inputs |= inputs[
  75. StorageWeakRef(meta_fk(argument.meta)._typed_storage())
  76. ]
  77. # TODO: error on unrecognized nodes
  78. return mutated_inputs
  79. def get_device_node_mapping(
  80. gm: torch.fx.GraphModule,
  81. ) -> dict[torch.device, torch.fx.Node]:
  82. device_node_mapping: dict[torch.device, torch.fx.Node] = {}
  83. for n in gm.graph.nodes:
  84. t = n.meta.get("val", None)
  85. if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
  86. device_node_mapping[t.device] = n
  87. return device_node_mapping
  88. def check_for_mutation_ignore_cuda_graph_managed_tensor(
  89. aot_model: torch.fx.GraphModule, num_fixed: int
  90. ) -> Optional[str]:
  91. mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
  92. if not mutation_indices:
  93. return None
  94. placeholders = get_placeholder_info(aot_model.graph)
  95. return get_mutation_stack_trace(placeholders, mutation_indices)
  96. def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed: int) -> Optional[str]:
  97. if not config.cudagraph_backend_support_input_mutation:
  98. if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor(
  99. aot_model, num_fixed
  100. ):
  101. return mut_skip
  102. if skip := check_multiple_devices_or_any_cpu_nodes(
  103. get_device_node_mapping(aot_model)
  104. ):
  105. return skip
  106. if node := get_first_incompatible_cudagraph_node(aot_model):
  107. return format_default_skip_message(f"incompatible op ({node.name})")
  108. return None
  109. def get_device_index(gm: torch.fx.GraphModule) -> int:
  110. device = next(iter(get_device_node_mapping(gm)))
  111. assert device.type == "cuda"
  112. return device.index
  113. def get_stack_traces(gm: torch.fx.GraphModule) -> list[Optional[str]]:
  114. output = output_node(gm)
  115. assert len(output.args) == 1
  116. args = output.args[0]
  117. if not hasattr(args, "__iter__"):
  118. return []
  119. return [
  120. (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
  121. for arg in args # type: ignore[union-attr]
  122. ]
  123. def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any]) -> Any:
  124. from torch._inductor.cudagraph_trees import cudagraphify_impl
  125. do_cudagraphs = BoxedBool(True)
  126. boxed_device_index = BoxedDeviceIndex(None)
  127. def forward_cudagraphs(
  128. aot_model: torch.fx.GraphModule,
  129. aot_inputs: list[Any],
  130. is_inference: bool = False,
  131. ) -> Any:
  132. interp = boxed_nop(aot_model, aot_inputs)
  133. fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
  134. if skip_msg := check_for_skip(aot_model, fixed):
  135. BoxedBool.disable(do_cudagraphs)
  136. log_cudagraph_skip_and_bump_counter(
  137. f"skipping cudagraphs due to {skip_msg}"
  138. )
  139. return interp
  140. boxed_device_index.set(get_device_index(aot_model))
  141. out = cudagraphify_impl(
  142. interp,
  143. aot_inputs,
  144. range(fixed),
  145. device_index=boxed_device_index.value,
  146. is_backward=False,
  147. is_inference=is_inference,
  148. stack_traces=get_stack_traces(aot_model),
  149. placeholders=get_placeholder_info(aot_model.graph),
  150. mutated_input_idxs=find_input_mutations(aot_model.graph),
  151. )
  152. out._boxed_call = True # type: ignore[attr-defined]
  153. return out
  154. def backward_cudagraphs(
  155. aot_model: torch.fx.GraphModule, aot_inputs: list[Any]
  156. ) -> Any:
  157. interp = boxed_nop(aot_model, aot_inputs)
  158. if not do_cudagraphs:
  159. return aot_model
  160. fixed = count_tangents(aot_model)
  161. if skip_msg := check_for_skip(aot_model, fixed):
  162. log_cudagraph_skip_and_bump_counter(
  163. f"skipping cudagraphs due to {skip_msg}"
  164. )
  165. # See [Backward Generation Handling]
  166. device_idx = boxed_device_index.value
  167. if device_idx is None:
  168. device_idx = 0 # Default to device 0 if not set
  169. manager = torch._inductor.cudagraph_trees.get_manager(
  170. device_idx, create_if_none_exists=False
  171. )
  172. assert manager is not None
  173. def fn(inputs: list[Any]) -> Any:
  174. # pyrefly: ignore [missing-attribute]
  175. manager.set_to_running_backward()
  176. return aot_model(inputs)
  177. fn._boxed_call = True # type: ignore[attr-defined]
  178. return fn
  179. out = cudagraphify_impl(
  180. interp,
  181. aot_inputs,
  182. range(fixed),
  183. device_index=get_device_index(aot_model),
  184. is_backward=True,
  185. is_inference=False,
  186. stack_traces=get_stack_traces(aot_model),
  187. placeholders=get_placeholder_info(aot_model.graph),
  188. mutated_input_idxs=find_input_mutations(aot_model.graph),
  189. )
  190. out._boxed_call = True # type: ignore[attr-defined]
  191. return out
  192. aot_cudagraphs = aot_autograd(
  193. fw_compiler=forward_cudagraphs,
  194. bw_compiler=backward_cudagraphs,
  195. inference_compiler=functools.partial(forward_cudagraphs, is_inference=True),
  196. keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
  197. )
  198. return aot_cudagraphs(dynamo_model, dynamo_inputs)
  199. class CudagraphsBackend:
  200. compiler_name = "cudagraphs"
  201. @staticmethod
  202. def reset() -> None:
  203. from torch._inductor.cudagraph_trees import reset_cudagraph_trees
  204. reset_cudagraph_trees()
  205. @staticmethod
  206. def __call__(model: torch.fx.GraphModule, inputs: Sequence[Any]) -> Any:
  207. return cudagraphs(model, inputs)
  208. # aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
  209. # for debugging and can serve as a perf baseline.
  210. register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
  211. def cudagraphs_inner(
  212. model: Callable[..., Any],
  213. inputs: Sequence[Any],
  214. copy_outputs: bool = True,
  215. copy_inputs: bool = True,
  216. ) -> Callable[..., Sequence[Any]]:
  217. """This isn't registered as a backend, but is used in some benchmarks"""
  218. assert isinstance(inputs, (list, tuple))
  219. if copy_inputs:
  220. # pyrefly: ignore [bad-argument-type]
  221. static_inputs = [torch.zeros_like(x) for x in inputs]
  222. else:
  223. static_inputs = list(inputs)
  224. # warmup
  225. torch.cuda.synchronize()
  226. stream = torch.cuda.Stream()
  227. stream.wait_stream(torch.cuda.current_stream())
  228. with torch.cuda.stream(stream):
  229. model(*inputs)
  230. stream.synchronize()
  231. torch.cuda.current_stream().wait_stream(stream)
  232. torch.cuda.synchronize()
  233. # record
  234. graph = torch.cuda.CUDAGraph()
  235. with torch.cuda.graph(graph, stream=stream):
  236. static_outputs = model(*static_inputs)
  237. if not isinstance(static_outputs, (list, tuple)):
  238. static_outputs = (static_outputs,)
  239. def run(*new_inputs: Any) -> Sequence[Any]:
  240. assert len(static_inputs) == len(new_inputs)
  241. if copy_inputs:
  242. for dst, src in zip(static_inputs, new_inputs):
  243. dst.copy_(src)
  244. graph.replay()
  245. if copy_outputs:
  246. return [x.clone() for x in static_outputs]
  247. else:
  248. return static_outputs
  249. return run