debugging.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675
  1. """
  2. This module provides debugging backends for TorchDynamo to help diagnose and troubleshoot
  3. compilation and execution issues. It includes:
  4. Key Debugging Backends:
  5. - eager: Simple pass-through backend that runs models in eager mode
  6. - eager_noexcept: Similar to eager but with additional exception handling
  7. - eager_debug: Adds schema validation checks for custom operators
  8. - aot_eager: Uses AOT Autograd with nop compiler for debugging
  9. - aot_eager_decomp_partition: Uses TorchInductor decompositions for debugging
  10. - torchscript: Compiles using TorchScript for debugging JIT-related issues
  11. Testing and Development Tools:
  12. - Backends for inducing specific errors (compile/runtime/accuracy)
  13. - ExplainOutput class for detailed graph compilation analysis
  14. - Utilities for cross-referencing and mode management
  15. - Tools for graph detail inspection and break reason analysis
  16. These backends are primarily used for:
  17. 1. Debugging graph breaks and compilation failures
  18. 2. Testing error handling and recovery mechanisms
  19. 3. Analyzing performance bottlenecks
  20. 4. Validating operator schemas and decompositions
  21. """
  22. import dataclasses
  23. import functools
  24. import logging
  25. from collections.abc import Callable, Iterable
  26. from importlib import import_module
  27. from typing import Any, Optional, TYPE_CHECKING, Union
  28. import torch
  29. from functorch.compile import min_cut_rematerialization_partition
  30. from torch import _guards
  31. from torch._dynamo.output_graph import GraphCompileReason
  32. from torch._functorch import config as functorch_config
  33. from torch._functorch.compilers import ts_compile
  34. from .common import aot_autograd
  35. from .registry import CompiledFn, CompilerFn, register_debug_backend as register_backend
  36. if TYPE_CHECKING:
  37. from torch.fx.node import Target
  38. log = logging.getLogger(__name__)
  39. @register_backend
  40. def eager(
  41. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  42. ) -> Callable[..., Any]:
  43. if kwargs:
  44. log.warning("eager backend ignoring extra kwargs %s", kwargs)
  45. if torch._functorch.config.force_autograd_cache:
  46. from torch._dynamo.aot_compile_types import GraphModuleSerializableCallable
  47. return GraphModuleSerializableCallable(gm)
  48. return gm.forward
  49. def make_eager_backend_with_torch_function_mode(
  50. mode: torch.overrides.TorchFunctionMode,
  51. ) -> Callable[..., Any]:
  52. return make_eager_backend_with_torch_function_modes([mode])
  53. def make_eager_backend_with_torch_function_modes(
  54. modes: Iterable[torch.overrides.TorchFunctionMode],
  55. ) -> Callable[..., Any]:
  56. """Used to trace HOPs (cond and while) for eager execution, the metadata
  57. TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
  58. in the HOP, so we need to externally run this mode and not trace it."""
  59. from contextlib import ExitStack
  60. def fn(
  61. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  62. ) -> Callable[..., Any]:
  63. def wrapper(*args: Any, **kwargs: Any) -> Any:
  64. with ExitStack() as stack:
  65. for mode in modes:
  66. stack.enter_context(mode)
  67. return gm.forward(*args, **kwargs)
  68. return wrapper
  69. return fn
  70. @register_backend
  71. def eager_noexcept(
  72. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  73. ) -> Callable[..., Any]:
  74. if kwargs:
  75. log.warning("eager_noexcept backend ignoring extra kwargs %s", kwargs)
  76. # This backend is intended to check that dynamo-generated GraphModules
  77. # do not cause errors.
  78. def inner(*args: Any) -> Any:
  79. try:
  80. return gm(*args)
  81. except Exception as e:
  82. raise torch._dynamo.exc.TorchDynamoException(
  83. "Unexpected exception when running generated GraphModule"
  84. ) from e
  85. return inner
  86. @register_backend
  87. def pre_dispatch_eager(
  88. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  89. ) -> torch.fx.GraphModule:
  90. if kwargs:
  91. log.warning("pre_dispatch_eager backend ignoring extra kwargs %s", kwargs)
  92. from torch.fx.experimental.proxy_tensor import make_fx
  93. def runnable_gm(*args: Any) -> Any:
  94. return torch.fx.Interpreter(gm).run(*args)
  95. pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs)
  96. pre_dispatch_gm.print_readable()
  97. return pre_dispatch_gm
  98. @register_backend
  99. def eager_debug(
  100. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  101. ) -> Callable[..., Any]:
  102. if kwargs:
  103. log.warning("eager_debug backend ignoring extra kwargs %s", kwargs)
  104. from torch._subclasses.schema_check_mode import SchemaCheckMode
  105. # We could add more debugging bits here.
  106. # Right now, this backend can be used to check for and error on
  107. # custom dispatcher ops that have incorrect schemas.
  108. def inner(*args: Any) -> Any:
  109. with SchemaCheckMode():
  110. return torch.fx.Interpreter(gm).run(*args)
  111. return inner
  112. @register_backend(name="ts") # type: ignore[misc]
  113. def torchscript(
  114. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor]
  115. ) -> torch.jit.ScriptModule:
  116. return torch.jit.script(gm)
  117. def invoke_subgraph_inner_compiler(
  118. subgraph: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  119. ) -> Callable[..., Any]:
  120. """Inner compiler that wraps forward/backward graphs in invoke_subgraph HOP.
  121. This is used as the fw_compiler/bw_compiler for aot_autograd. When the resulting
  122. function is traced by make_fx, it emits an invoke_subgraph HOP instead of inlining.
  123. """
  124. from torch._dynamo import disable
  125. from torch._higher_order_ops.invoke_subgraph import invoke_subgraph_infer
  126. @disable
  127. @torch._dynamo.allow_in_graph
  128. def invoke_subgraph_wrapper_unboxed(*operands: Any) -> Any:
  129. return invoke_subgraph_infer(subgraph, *operands)
  130. # NB: The direct to unboxed path is broken, you MUST DO THIS
  131. def invoke_subgraph_wrapper(args: list[Any]) -> Any:
  132. return invoke_subgraph_wrapper_unboxed(*args)
  133. invoke_subgraph_wrapper._boxed_call = True # type: ignore[attr-defined]
  134. return invoke_subgraph_wrapper
  135. # I cannot say how many times I had to revert to this vibe coded version of
  136. # the code, which worked, and the cleaner versions of the code did not work,
  137. # so I'm leaving this here until we fix the rest of the bugs.
  138. '''
  139. # Counter for unique subgraph names in invoke_subgraph backend
  140. _invoke_subgraph_counter = 0
  141. def invoke_subgraph_inner_compiler_good(
  142. fx_g: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  143. ) -> Callable[..., Any]:
  144. """Inner compiler that wraps forward/backward graphs in invoke_subgraph HOP.
  145. This is used as the fw_compiler/bw_compiler for aot_autograd. When the resulting
  146. function is traced by make_fx, it emits an invoke_subgraph HOP instead of inlining.
  147. """
  148. from torch._higher_order_ops.invoke_subgraph import (
  149. invoke_subgraph as invoke_subgraph_hop,
  150. )
  151. from torch.fx.experimental.proxy_tensor import get_proxy_mode
  152. global _invoke_subgraph_counter
  153. _invoke_subgraph_counter += 1
  154. name = f"invoke_subgraph_{_invoke_subgraph_counter}"
  155. from torch._dynamo import disable
  156. # Check if fx_g uses boxed calling convention
  157. fx_g_is_boxed = getattr(fx_g, "_boxed_call", False)
  158. @disable
  159. @torch._dynamo.allow_in_graph
  160. def invoke_subgraph_wrapper_unboxed(*args: Any) -> Any:
  161. proxy_mode = get_proxy_mode()
  162. if proxy_mode is not None:
  163. # When being traced by make_fx, emit invoke_subgraph HOP
  164. return invoke_subgraph_hop(fx_g, name, *args) # type: ignore[arg-type]
  165. else:
  166. # Normal execution path - call fx_g with proper calling convention
  167. if fx_g_is_boxed:
  168. return fx_g(list(args))
  169. else:
  170. return fx_g(*args)
  171. # Wrap to handle boxed arguments (list of args) as expected by AOTAutograd
  172. def invoke_subgraph_wrapper(args: list[Any]) -> Any:
  173. return invoke_subgraph_wrapper_unboxed(*args)
  174. invoke_subgraph_wrapper._boxed_call = True # type: ignore[attr-defined]
  175. return invoke_subgraph_wrapper
  176. '''
  177. @register_backend
  178. def invoke_subgraph(
  179. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  180. ) -> Callable[..., Any]:
  181. """Backend that wraps forward/backward graphs in invoke_subgraph HOP when traced by make_fx.
  182. This backend uses AOTAutograd to partition into forward/backward graphs, then wraps
  183. each in an invoke_subgraph HOP. This is useful for recursive Dynamo tracing scenarios
  184. where you want the compiled subgraph to appear as invoke_subgraph HOPs in the outer
  185. trace rather than being inlined.
  186. Requires:
  187. - torch._dynamo.config.force_compile_during_fx_trace = True
  188. (this implicitly overrides error_on_nested_fx_trace)
  189. """
  190. if kwargs:
  191. log.warning("invoke_subgraph backend ignoring extra kwargs %s", kwargs)
  192. # Use AOTAutograd to partition into forward/backward
  193. return aot_autograd(
  194. fw_compiler=invoke_subgraph_inner_compiler,
  195. bw_compiler=invoke_subgraph_inner_compiler,
  196. partition_fn=min_cut_rematerialization_partition,
  197. keep_inference_input_mutations=True,
  198. )(gm, fake_tensor_inputs)
  199. # used boxed call to discard inputs when they are no longer needed
  200. def boxed_nop(
  201. fx_g: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  202. ) -> Callable[..., Any]:
  203. from torch.fx.graph import _BoxedCodeGen
  204. # Set the graph to use boxed codegen
  205. fx_g.graph.set_codegen(_BoxedCodeGen())
  206. fx_g.recompile()
  207. # Wrap the forward method in a function so we can set _boxed_call attribute
  208. forward_fn = fx_g.forward
  209. def run(args: Any) -> Any:
  210. from torch.utils._debug_mode import DebugInterpreter, get_active_debug_mode
  211. if (
  212. debug_mode := get_active_debug_mode()
  213. ) is not None and debug_mode.run_compile_with_interpreter:
  214. return DebugInterpreter(fx_g, backend="aot_eager").run(*args)
  215. return forward_fn(args)
  216. run._boxed_call = True # type: ignore[attr-defined]
  217. return run
  218. def boxed_nop_with_mode(
  219. fx_g: torch.fx.GraphModule,
  220. example_inputs: list[torch.Tensor],
  221. *,
  222. mode: torch.overrides.TorchFunctionMode,
  223. ) -> Callable[..., Any]:
  224. from torch.fx.graph import _BoxedCodeGen
  225. # Set the graph to use boxed codegen
  226. fx_g.graph.set_codegen(_BoxedCodeGen())
  227. fx_g.recompile()
  228. # Create a wrapper that runs with the mode
  229. forward_fn = fx_g.forward
  230. def run(args: Any) -> Any:
  231. with mode:
  232. return forward_fn(args)
  233. run._boxed_call = True # type: ignore[attr-defined]
  234. return run
  235. def fake_crossref_boxed_nop(
  236. fx_g: torch.fx.GraphModule,
  237. example_inputs: list[torch.Tensor],
  238. ignore_op_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None,
  239. ) -> Callable[..., Any]:
  240. from torch.fx.graph import _BoxedCodeGen
  241. # Set the graph to use boxed codegen
  242. fx_g.graph.set_codegen(_BoxedCodeGen())
  243. fx_g.recompile()
  244. # Create a wrapper that runs with the mode
  245. forward_fn = fx_g.forward
  246. def run(args: Any) -> Any:
  247. with torch._subclasses.CrossRefFakeMode(ignore_op_fn):
  248. return forward_fn(args)
  249. run._boxed_call = True # type: ignore[attr-defined]
  250. return run
  251. def ignore_builtins(op: torch._ops.OpOverload) -> bool:
  252. return op.namespace in ("aten", "prims", "prim")
  253. def get_nop_func() -> Callable[
  254. [torch.fx.GraphModule, list[torch.Tensor]], Callable[..., Any]
  255. ]:
  256. if not torch._functorch.config.fake_tensor_crossref:
  257. return boxed_nop
  258. elif torch._functorch.config.fake_tensor_crossref == "all":
  259. return fake_crossref_boxed_nop
  260. else:
  261. assert torch._functorch.config.fake_tensor_crossref == "custom_ops"
  262. return functools.partial(fake_crossref_boxed_nop, ignore_op_fn=ignore_builtins)
  263. # Useful for debugging purpose
  264. # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
  265. def aot_eager(
  266. gm: torch.fx.GraphModule,
  267. fake_tensor_inputs: list[torch.Tensor],
  268. fw_compiler: Optional[Callable[..., Any]] = None,
  269. bw_compiler: Optional[Callable[..., Any]] = None,
  270. **kwargs: Any,
  271. ) -> Callable[..., Any]:
  272. return aot_autograd(
  273. fw_compiler=fw_compiler or boxed_nop,
  274. bw_compiler=bw_compiler or boxed_nop,
  275. partition_fn=min_cut_rematerialization_partition,
  276. keep_inference_input_mutations=True,
  277. )(gm, fake_tensor_inputs, **kwargs)
  278. register_backend(name="aot_eager", compiler_fn=aot_eager)
  279. aot_eager_default_partitioner = aot_autograd(
  280. fw_compiler=boxed_nop, keep_inference_input_mutations=True
  281. )
  282. register_backend(
  283. name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner
  284. )
  285. # Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
  286. # inductor problems.
  287. # aot_eager_decomp_partition just replaces the inductor compiler with nop to help
  288. # isolate inductor vs aot_eager errors
  289. def aot_eager_decomp_partition(
  290. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  291. ) -> Callable[..., Any]:
  292. if kwargs:
  293. log.warning(
  294. "aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs
  295. )
  296. from torch._inductor.compiler_bisector import CompilerBisector
  297. config_patches = {"unlift_effect_tokens": True}
  298. if bisect_changes := CompilerBisector.get_config_change(
  299. "aot_eager_decomp_partition"
  300. ):
  301. config_patches.update(bisect_changes) # type: ignore[arg-type]
  302. with functorch_config.patch(config_patches):
  303. return aot_autograd(
  304. # these are taken from memory_efficient_fusion()
  305. fw_compiler=get_nop_func(),
  306. bw_compiler=get_nop_func(),
  307. # NB: lambda here is to delay import of inductor
  308. decompositions=lambda: import_module(
  309. "torch._inductor.compile_fx"
  310. ).select_decomp_table(),
  311. partition_fn=functools.partial(
  312. min_cut_rematerialization_partition, compiler="inductor"
  313. ),
  314. )(gm, fake_tensor_inputs)
  315. register_backend(
  316. name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition
  317. )
  318. # aot_eager_decomp_partition_with_mode is similar as aot_eager_decomp_partition,
  319. # except that it takes a TorchDispatchMode mode and run the fw/bw in the mode
  320. def aot_eager_decomp_partition_with_mode(
  321. gm: torch.fx.GraphModule,
  322. fake_tensor_inputs: list[torch.Tensor],
  323. mode: Any,
  324. **kwarg: Any,
  325. ) -> Callable[..., Any]:
  326. return aot_autograd(
  327. # these are taken from memory_efficient_fusion()
  328. fw_compiler=functools.partial(boxed_nop_with_mode, mode=mode),
  329. bw_compiler=functools.partial(boxed_nop_with_mode, mode=mode),
  330. # NB: lambda here is to delay import of inductor
  331. decompositions=lambda: import_module(
  332. "torch._inductor.compile_fx"
  333. ).select_decomp_table(),
  334. partition_fn=functools.partial(
  335. min_cut_rematerialization_partition, compiler="inductor"
  336. ),
  337. )(gm, fake_tensor_inputs)
  338. register_backend(
  339. name="aot_eager_decomp_partition_with_mode",
  340. compiler_fn=aot_eager_decomp_partition_with_mode, # type: ignore[arg-type]
  341. )
  342. def aot_eager_decomp_partition_crossref(
  343. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  344. ) -> Callable[..., Any]:
  345. # if the config is set, respect it, otherwise only test custom_ops.
  346. # custom_op bad metas always manifest as an error whereas aten will only sometimes.
  347. # by default, use the less noisy option
  348. config_val = (
  349. "custom_ops"
  350. if not functorch_config.fake_tensor_crossref
  351. else functorch_config.fake_tensor_crossref
  352. )
  353. with functorch_config.patch(fake_tensor_crossref=config_val):
  354. return aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs)
  355. register_backend(
  356. name="aot_eager_decomp_partition_crossref",
  357. compiler_fn=aot_eager_decomp_partition_crossref,
  358. )
  359. # AOT Autograd with torchscript backend. Default partitioner.
  360. # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
  361. # by using the relevant fuser with torch.jit.fuser(...)
  362. aot_ts = aot_autograd(fw_compiler=ts_compile)
  363. register_backend(name="aot_ts", compiler_fn=aot_ts)
  364. # These buggy backends are used for inducing bugs so that we can test
  365. # our repro extraction / minifier scripts
  366. class ReluCompileError(Exception):
  367. pass
  368. class TestingOnlyCompileError(Exception):
  369. pass
  370. @register_backend
  371. def relu_compile_error_TESTING_ONLY(
  372. gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  373. ) -> torch.fx.GraphModule:
  374. for node in gm.graph.nodes:
  375. if node.target is torch.relu:
  376. raise ReluCompileError
  377. return gm
  378. @register_backend
  379. def relu_runtime_error_TESTING_ONLY(
  380. gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  381. ) -> torch.fx.GraphModule:
  382. for node in gm.graph.nodes:
  383. if node.target is torch.relu:
  384. node.target = torch._assert
  385. node.args = (False, "ReluRuntimeError")
  386. gm.recompile()
  387. return gm
  388. @register_backend
  389. def relu_accuracy_error_TESTING_ONLY(
  390. gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  391. ) -> torch.fx.GraphModule:
  392. for node in gm.graph.nodes:
  393. if node.target is torch.relu:
  394. node.target = torch.add
  395. node.args = (node.args[0], 1)
  396. gm.recompile()
  397. return gm
  398. @register_backend
  399. def non_leaf_compile_error_TESTING_ONLY(
  400. gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  401. ) -> torch.fx.GraphModule:
  402. # Require at least one non-trivial thing in the graph,
  403. # see https://github.com/pytorch/pytorch/issues/102898
  404. for node in gm.graph.nodes:
  405. if node.op == "call_function":
  406. break
  407. else:
  408. return gm
  409. for t in example_inputs:
  410. if not t.is_leaf:
  411. raise TestingOnlyCompileError
  412. return gm
  413. @dataclasses.dataclass
  414. class ExplainOutput:
  415. """
  416. This is the output of :func:`torch._dynamo.explain()`
  417. There is no reason to create this class directly.
  418. """
  419. graphs: list[torch.fx.GraphModule]
  420. graph_count: int
  421. graph_break_count: int
  422. break_reasons: list[GraphCompileReason]
  423. op_count: int
  424. ops_per_graph: Optional[list[list["Target"]]] = None
  425. out_guards: Optional[list[_guards.Guard]] = None
  426. compile_times: Optional[str] = None
  427. def __str__(self) -> str:
  428. output = f"Graph Count: {self.graph_count}\n"
  429. output += f"Graph Break Count: {self.graph_break_count}\n"
  430. output += f"Op Count: {self.op_count}\n"
  431. output += "Break Reasons:\n"
  432. for idx, break_reason in enumerate(self.break_reasons):
  433. output += f" Break Reason {idx + 1}:\n"
  434. output += f" Reason: {break_reason.reason}\n"
  435. output += " User Stack:\n"
  436. for frame_summary in break_reason.user_stack:
  437. output += f" {frame_summary}\n"
  438. if self.ops_per_graph is not None:
  439. output += "Ops per Graph:\n"
  440. for idx, ops in enumerate(self.ops_per_graph):
  441. output += f" Ops {idx + 1}:\n"
  442. for op in ops:
  443. output += f" {op}\n"
  444. if self.out_guards is not None:
  445. output += "Out Guards:\n"
  446. for i, guard in enumerate(self.out_guards):
  447. output += f" Guard {i + 1}:\n"
  448. output += f" {str(guard)}"
  449. if self.compile_times is not None:
  450. output += f"Compile Times: {self.compile_times}\n"
  451. return output
  452. def _explain_graph_detail(
  453. gm: torch.fx.GraphModule,
  454. graphs: list[torch.fx.GraphModule],
  455. op_count: int,
  456. ops_per_graph: list[list["Target"]],
  457. break_reasons: list[GraphCompileReason],
  458. ) -> tuple[
  459. torch.fx.GraphModule,
  460. list[torch.fx.GraphModule],
  461. int,
  462. list[list["Target"]],
  463. list[GraphCompileReason],
  464. ]:
  465. """
  466. This function is a utility which processes a torch.fx.GraphModule and
  467. accumulates information about its ops, graph breaks, and other details. It
  468. is intended to be used by the ExplainWithBackend class and
  469. `torch._dynamo.explain()` to provide details from Dynamo's graph capture.
  470. Parameters:
  471. gm (torch.fx.GraphModule): The GraphModule to be processed.
  472. graphs (list): A list that accumulates all the GraphModules processed.
  473. op_count (int): The total count of operations in all GraphModules processed so far.
  474. ops_per_graph (list): A list that accumulates the operations of each GraphModule.
  475. break_reasons (list): A list that accumulates the reasons for breaks in each GraphModule.
  476. Returns:
  477. tuple: A tuple containing the processed GraphModule, the updated lists of graphs,
  478. operations per graph, and break reasons, and the updated operation count.
  479. """
  480. graphs.append(gm)
  481. ops = [node.target for node in gm.graph.nodes if node.op == "call_function"]
  482. op_count += len(ops)
  483. ops_per_graph.append(ops)
  484. if gm.compile_subgraph_reason.graph_break: # type: ignore[union-attr]
  485. break_reasons.append(gm.compile_subgraph_reason) # type: ignore[arg-type]
  486. return gm, graphs, op_count, ops_per_graph, break_reasons
  487. class ExplainWithBackend:
  488. """
  489. This class is intended to be used as a backend for `torch.compile`. It is
  490. composable with other backends. When used in this way, it accumulates
  491. information about graph breaks, ops, and other info and provides a string
  492. representation summarizing this information.
  493. Attributes:
  494. backend (str): The name of the backend to use for optimization.
  495. graphs (list): A list of the graphs captured by TorchDynamo.
  496. op_count (int): The total number of operations in all optimized graphs.
  497. break_reasons (list): A list of graph break reasons with stack traces.
  498. Example Usage:
  499. def fn(x):
  500. x = torch.sigmoid(x)
  501. return x
  502. torch._dynamo.reset()
  503. eb = ExplainWithBackend("inductor")
  504. optimized_fn = torch.compile(fn, backend=eb)
  505. result = optimized_fn(torch.randn(5))
  506. print(eb.output())
  507. """
  508. def __init__(self, backend: Union[CompilerFn, str]) -> None:
  509. from .registry import lookup_backend
  510. self.backend = lookup_backend(backend)
  511. self.graphs: list[torch.fx.GraphModule] = []
  512. self.op_count = 0
  513. self.break_reasons: list[GraphCompileReason] = []
  514. def __call__(
  515. self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  516. ) -> CompiledFn:
  517. ops_per_graph: list[list[Target]] = []
  518. gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail(
  519. gm, self.graphs, self.op_count, ops_per_graph, self.break_reasons
  520. )
  521. return self.backend(gm, example_inputs)
  522. def output(self) -> ExplainOutput:
  523. graph_count = len(self.graphs)
  524. output = ExplainOutput(
  525. self.graphs,
  526. graph_count,
  527. graph_count - 1,
  528. self.break_reasons,
  529. self.op_count,
  530. )
  531. return output