debug.py 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334
  1. import collections
  2. import contextlib
  3. import copy
  4. import dataclasses
  5. import functools
  6. import io
  7. import itertools
  8. import json
  9. import logging
  10. import os
  11. import os.path
  12. import pickle
  13. import pstats
  14. import shutil
  15. import tempfile
  16. import traceback
  17. from collections.abc import Callable, Iterator, Sequence
  18. from typing import Any, IO, Optional, Union
  19. from unittest.mock import patch
  20. import torch
  21. from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled
  22. from torch import fx
  23. from torch._dynamo.repro.after_aot import save_graph_repro
  24. from torch._dynamo.utils import get_debug_dir
  25. from torch._inductor import utils
  26. from torch._logging import getArtifactLogger
  27. from torch._logging._internal import trace_structured
  28. from torch._utils_internal import signpost_event
  29. from torch.fx.graph_module import GraphModule
  30. from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
  31. from torch.fx.passes.tools_common import legalize_graph
  32. from torch.types import FileLike
  33. from torch.utils._ordered_set import OrderedSet
  34. from torch.utils._pytree import tree_map
  35. from . import config, ir # noqa: F811, this is needed
  36. from .ir import ExternKernel
  37. from .scheduler import (
  38. BaseSchedulerNode,
  39. FusedSchedulerNode,
  40. NopKernelSchedulerNode,
  41. OutputNode,
  42. SchedulerNode,
  43. )
  44. from .virtualized import V
  45. log = logging.getLogger(__name__)
  46. # Graph execution tracking for debugging
  47. GRAPH_EXECUTION_ORDER: Optional[list[dict[str, object]]] = None
  48. RECORD_GRAPH_EXECUTION: bool = False
  49. GRAPH_COMPILE_IDS: Optional[dict[int, Optional[str]]] = None
  50. ir_pre_fusion_log = getArtifactLogger(__name__, "ir_pre_fusion")
  51. ir_post_fusion_log = getArtifactLogger(__name__, "ir_post_fusion")
  52. SchedulerNodeList = list[Any]
  53. BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
  54. GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
  55. @functools.cache
  56. def has_dot() -> bool:
  57. return shutil.which("dot") is not None
  58. def draw_buffers(
  59. nodes: list[BaseSchedulerNode],
  60. print_graph: bool = False,
  61. fname: Optional[str] = None,
  62. ) -> None:
  63. """
  64. Draw a graph in fname.svg.
  65. """
  66. if not has_dot():
  67. log.warning("draw_buffers() requires `graphviz` package")
  68. return
  69. if fname is None:
  70. fname = get_graph_being_compiled()
  71. graph = create_fx_from_snodes(nodes)
  72. for node in graph.nodes:
  73. if "fusion_meta" not in node.meta:
  74. continue
  75. group = node.meta["fusion_meta"].group
  76. if isinstance(group, tuple):
  77. if isinstance(group[1], int):
  78. group = (group[1],)
  79. else:
  80. group = group[1]
  81. # gather meta data
  82. dtype = None
  83. if isinstance(node, ir.ComputedBuffer):
  84. dtype = node.data.dtype
  85. metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type]
  86. node.meta["tensor_meta"] = metadata
  87. if print_graph:
  88. print(graph)
  89. gm = GraphModule({}, graph)
  90. legalize_graph(gm)
  91. gm.graph.lint()
  92. draw_graph(
  93. gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape
  94. )
  95. def create_fx_from_snodes(snodes: list[BaseSchedulerNode]) -> fx.Graph:
  96. """
  97. Creates a FX Graph from a list of SchedulerNode objects.
  98. """
  99. def get_fake_func(name: str) -> Callable[..., int]:
  100. def func1(*args: Any) -> int:
  101. return 0
  102. func1.__name__ = name
  103. return func1
  104. FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"])
  105. buf_to_fx_node = {}
  106. node_to_fx_node = {}
  107. graph = torch.fx.Graph()
  108. first_node = None
  109. outputs = []
  110. group: Any = None
  111. # create call_function node for each Buffer and Kernel
  112. # pyrefly: ignore [bad-assignment]
  113. for snode in snodes:
  114. if snode.is_extern():
  115. node_type = "extern"
  116. group = node_type
  117. elif snode.is_template():
  118. node_type = "template"
  119. group = node_type
  120. elif isinstance(snode, NopKernelSchedulerNode):
  121. node_type = "nop"
  122. group = node_type
  123. elif isinstance(snode, SchedulerNode):
  124. node_type = "compute"
  125. group = snode.group
  126. elif isinstance(snode, FusedSchedulerNode):
  127. node_type = "fused"
  128. group = snode.group
  129. else:
  130. raise RuntimeError("Unknown node type")
  131. fused_name = torch._inductor.utils.get_fused_kernel_name(
  132. snode.get_nodes(), "original_aten"
  133. )
  134. func_name = f"{node_type}: {fused_name}"
  135. node_func = get_fake_func(func_name)
  136. kwargs = {}
  137. if hasattr(snode, "get_device"):
  138. kwargs = {"device": snode.get_device()}
  139. fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) # type: ignore[arg-type]
  140. def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool:
  141. if isinstance(snode, FusedSchedulerNode):
  142. return any(in_output(x) for x in snode.snodes)
  143. return any(
  144. isinstance(user.node, OutputNode)
  145. for buf in snode.get_outputs()
  146. for user in buf.users
  147. )
  148. if in_output(snode):
  149. outputs.append(fx_node)
  150. name = snode.get_name()
  151. fx_node.name = name
  152. fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type)
  153. node_to_fx_node[name] = fx_node
  154. for buf in snode.get_outputs():
  155. buf_to_fx_node[buf.get_name()] = fx_node
  156. if first_node is None:
  157. first_node = fx_node
  158. # create edges between nodes
  159. for snode in snodes:
  160. name = snode.get_name()
  161. deps = snode.read_writes.reads
  162. fx_node = node_to_fx_node[name]
  163. new_args = []
  164. for dep in deps:
  165. if dep.name in buf_to_fx_node:
  166. dep_node = buf_to_fx_node[dep.name]
  167. else:
  168. with graph.inserting_before(first_node):
  169. dep_node = graph.placeholder(dep.name)
  170. buf_to_fx_node[dep.name] = dep_node
  171. if dep_node == fx_node: # to avoid cycles
  172. continue
  173. new_args.append(dep_node)
  174. fx_node.args = tuple(new_args)
  175. graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
  176. return graph
  177. def update_orig_fx_node_name_to_buf_name(
  178. nodes: Optional[SchedulerNodeList],
  179. node_name_to_buf_name: dict[str, str],
  180. parent_buf_name: Optional[str] = None,
  181. n_origins: int = 0,
  182. ) -> None:
  183. if nodes is None:
  184. return
  185. for node in nodes:
  186. # for FusedSchedulerNode, traverse recursively into get_nodes()
  187. buf_name = node.get_name()
  188. children_nodes = node.get_nodes()
  189. if children_nodes is not None and len(children_nodes) > 1:
  190. update_orig_fx_node_name_to_buf_name(
  191. children_nodes,
  192. node_name_to_buf_name,
  193. buf_name if parent_buf_name is None else parent_buf_name,
  194. )
  195. continue
  196. else:
  197. # pyrefly: ignore [bad-argument-type, unsupported-operation]
  198. assert len(children_nodes) == 1 and children_nodes[0] == node
  199. ir_node = node.node
  200. if ir_node is None or ir_node.origins is None:
  201. continue
  202. for origin in ir_node.origins:
  203. node_name = origin.name
  204. # when buf1 and buf2 both have origin=node1
  205. # we draw node1 according to buf1
  206. if node_name not in node_name_to_buf_name:
  207. node_name_to_buf_name[node_name] = (
  208. buf_name if parent_buf_name is None else parent_buf_name
  209. )
  210. def get_node_name_to_buf_meta(
  211. node_name_to_buf_name: dict[str, str],
  212. ) -> dict[str, BufMeta]:
  213. buf_name_to_n_node = {}
  214. for node_name, buf_name in node_name_to_buf_name.items():
  215. if buf_name not in buf_name_to_n_node:
  216. buf_name_to_n_node[buf_name] = OrderedSet([node_name])
  217. else:
  218. # pyrefly: ignore [missing-attribute]
  219. buf_name_to_n_node[buf_name].add(node_name)
  220. node_name_to_buf_meta = {}
  221. for node_name, buf_name in node_name_to_buf_name.items():
  222. n_node = len(buf_name_to_n_node[buf_name])
  223. node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node)
  224. return node_name_to_buf_meta
  225. def annotate_orig_fx_with_snodes(
  226. gm: torch.fx.GraphModule,
  227. snodes: SchedulerNodeList,
  228. ) -> None:
  229. """
  230. Creates a FX Graph from a list of SchedulerNode objects.
  231. """
  232. node_name_to_buf_name: dict[str, str] = {}
  233. update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name)
  234. if node_name_to_buf_name is None:
  235. return
  236. node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name)
  237. for node in gm.graph.nodes:
  238. if node.name in node_name_to_buf_meta:
  239. node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name)
  240. @contextlib.contextmanager
  241. def enable_aot_logging() -> Iterator[None]:
  242. compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
  243. import torch._functorch.aot_autograd
  244. log = logging.getLogger(torch._functorch.aot_autograd.__name__)
  245. stack = contextlib.ExitStack()
  246. if not compile_debug:
  247. try:
  248. yield
  249. finally:
  250. stack.close()
  251. return
  252. # Enable all graphs to be logged to a file by setting the flags to True
  253. # and the log level of the file logger to DEBUG
  254. stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
  255. path = os.path.join(get_debug_dir(), "torchinductor")
  256. os.makedirs(path, exist_ok=True)
  257. fh = logging.FileHandler(
  258. os.path.join(
  259. path,
  260. f"aot_{get_aot_graph_name()}_debug.log",
  261. )
  262. )
  263. fh.setLevel(logging.DEBUG)
  264. fh.setFormatter(
  265. logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
  266. )
  267. log.addHandler(fh)
  268. try:
  269. yield
  270. finally:
  271. log.removeHandler(fh)
  272. stack.close()
  273. # Used for provenance tracking
  274. # They are not stored in DebugContext because they are not set in
  275. # _inductor_triton_kernel_to_post_grad_node_info's Debug Context
  276. _inductor_post_to_pre_grad_nodes: dict[str, dict[str, list[str]]] = {}
  277. _inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {}
  278. _pre_grad_graph_id: Optional[int] = None
  279. _inductor_pre_grad_node_stack_trace: dict[str, str] = {}
  280. _inductor_kernel_stack_trace: dict[str, list[str]] = {}
  281. _inductor_kernel_provenance_debug_handle: int = 0
  282. def reset_inductor_kernel_provenance_debug_handle() -> None:
  283. global _inductor_kernel_provenance_debug_handle
  284. _inductor_kernel_provenance_debug_handle = 0
  285. @contextlib.contextmanager
  286. def reset_provenance_globals() -> Iterator[None]:
  287. """Context manager that resets provenance tracking globals upon entering
  288. and restores their original values when exiting."""
  289. global _pre_grad_graph_id
  290. global _inductor_post_to_pre_grad_nodes
  291. global _inductor_triton_kernel_to_post_grad_node_info
  292. global _inductor_pre_grad_node_stack_trace
  293. global _inductor_kernel_stack_trace
  294. global _inductor_kernel_provenance_debug_handle
  295. # Store original values
  296. original_pre_grad_graph_id = _pre_grad_graph_id
  297. original_post_to_pre_grad_nodes = _inductor_post_to_pre_grad_nodes.copy()
  298. original_triton_kernel_to_post_grad_node_info = (
  299. _inductor_triton_kernel_to_post_grad_node_info.copy()
  300. )
  301. original_inductor_pre_grad_node_stack_trace = (
  302. _inductor_pre_grad_node_stack_trace.copy()
  303. )
  304. original_inductor_kernel_stack_trace = _inductor_kernel_stack_trace.copy()
  305. original_inductor_kernel_provenance_debug_handle = (
  306. _inductor_kernel_provenance_debug_handle
  307. )
  308. # Reset to default values
  309. _pre_grad_graph_id = -1
  310. _inductor_post_to_pre_grad_nodes = {}
  311. _inductor_triton_kernel_to_post_grad_node_info = {}
  312. _inductor_pre_grad_node_stack_trace = {}
  313. _inductor_kernel_stack_trace = {}
  314. _inductor_kernel_provenance_debug_handle = 0
  315. try:
  316. yield
  317. finally:
  318. # Restore original values
  319. _pre_grad_graph_id = original_pre_grad_graph_id
  320. _inductor_post_to_pre_grad_nodes = original_post_to_pre_grad_nodes
  321. _inductor_triton_kernel_to_post_grad_node_info = (
  322. original_triton_kernel_to_post_grad_node_info
  323. )
  324. _inductor_kernel_stack_trace = original_inductor_kernel_stack_trace
  325. _inductor_pre_grad_node_stack_trace = (
  326. original_inductor_pre_grad_node_stack_trace
  327. )
  328. _inductor_kernel_provenance_debug_handle = (
  329. original_inductor_kernel_provenance_debug_handle
  330. )
  331. class DebugContext:
  332. _counter = itertools.count()
  333. @staticmethod
  334. def create_debug_dir(folder_name: str) -> Optional[str]:
  335. debug_dir = config.trace.debug_dir or get_debug_dir()
  336. for n in DebugContext._counter:
  337. dirname = os.path.join(
  338. debug_dir,
  339. "torchinductor",
  340. f"{folder_name}.{n}",
  341. )
  342. if not os.path.exists(dirname):
  343. os.makedirs(dirname)
  344. return dirname
  345. return None
  346. def __init__(self) -> None:
  347. self._prof = None
  348. self._path = None
  349. self._stack = contextlib.ExitStack()
  350. def copy(self, new_path: str) -> None:
  351. if not self._path:
  352. return
  353. assert new_path.endswith(".debug"), new_path
  354. from filelock import FileLock
  355. try:
  356. with FileLock(f"{new_path}.lock"):
  357. if os.path.exists(new_path):
  358. shutil.rmtree(new_path)
  359. shutil.copytree(self._path, new_path)
  360. except OSError:
  361. log.warning(
  362. "Failed to copy debug files from %s to %s", self._path, new_path
  363. )
  364. def fopen(
  365. self,
  366. filename: str,
  367. write_mode: str = "w",
  368. *args: Any,
  369. **kwargs: Any,
  370. ) -> IO[Any]:
  371. assert self._path
  372. return open(os.path.join(self._path, filename), write_mode, *args, **kwargs)
  373. @contextlib.contextmanager
  374. def fopen_context(
  375. self,
  376. filename: str,
  377. write_mode: str = "w",
  378. *args: Any,
  379. **kwargs: Any,
  380. ) -> Iterator[IO[Any]]:
  381. assert self._path
  382. with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f:
  383. yield f
  384. def filename(self, suffix: str) -> str:
  385. assert self._path
  386. return os.path.join(self._path, suffix)
  387. def upload_tar(self) -> None:
  388. if config.trace.upload_tar is not None:
  389. import tarfile
  390. assert self._path
  391. tar_file = os.path.join(
  392. self._path, f"{os.path.basename(self._path)}.tar.gz"
  393. )
  394. with tarfile.open(tar_file, "w:gz") as tar:
  395. tar.add(self._path, arcname=os.path.basename(self._path))
  396. config.trace.upload_tar(tar_file)
  397. def __enter__(self) -> None:
  398. if config.debug:
  399. log = logging.getLogger("torch._dynamo")
  400. prev_level = log.level
  401. log.setLevel(logging.DEBUG)
  402. def reset_log_level(level: Any) -> None:
  403. log.setLevel(level)
  404. self._stack.callback(reset_log_level, prev_level)
  405. self._stack.enter_context(V.set_debug_handler(self))
  406. if not config.trace.enabled:
  407. return
  408. self._path = self.create_debug_dir(get_aot_graph_name()) # type: ignore[assignment]
  409. if config.trace.debug_log:
  410. self._setup_log_capture("debug.log", logging.DEBUG)
  411. if config.trace.info_log:
  412. self._setup_log_capture("info.log", logging.INFO)
  413. def _setup_log_capture(
  414. self,
  415. filename: str,
  416. level: int,
  417. ) -> None:
  418. log = logging.getLogger("torch._inductor")
  419. fd = self._stack.enter_context(self.fopen(filename))
  420. ch = logging.StreamHandler(fd)
  421. ch.setLevel(level)
  422. ch.setFormatter(
  423. logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
  424. )
  425. log.addHandler(ch)
  426. log.setLevel(min(log.level, level))
  427. self._stack.callback(log.removeHandler, ch)
  428. def __exit__(
  429. self,
  430. exc_type: Optional[type[BaseException]],
  431. exc_val: Optional[BaseException],
  432. exc_tb: Optional[Any],
  433. ) -> None:
  434. if self._prof:
  435. self._prof.disable()
  436. self._save_profile_data()
  437. if self._path:
  438. self.upload_tar()
  439. log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
  440. self._stack.close()
  441. def _save_profile_data(self) -> None:
  442. assert self._prof
  443. self._prof.dump_stats(self.filename("compile.prof"))
  444. with self.fopen("compile.stats") as fd:
  445. stats = pstats.Stats(self._prof, stream=fd)
  446. stats.strip_dirs()
  447. stats.sort_stats("cumtime")
  448. stats.print_stats(100)
  449. stats.sort_stats("tottime")
  450. stats.print_stats(100)
  451. def __getattr__(self, name: str) -> Optional[Callable[..., None]]:
  452. if config.trace.enabled and getattr(config.trace, name):
  453. try:
  454. return getattr(DebugFormatter(self), name)
  455. except Exception:
  456. log.warning("Ignoring exception in debug code", exc_info=True)
  457. return None
  458. else:
  459. def ignored(*args: Any, **kwargs: Any) -> None:
  460. pass
  461. return ignored
  462. class DebugFormatter:
  463. def __init__(self, handler: DebugContext) -> None:
  464. self.fopen = handler.fopen
  465. self.fopen_context = handler.fopen_context
  466. self.filename = handler.filename
  467. self.handler = handler
  468. def fx_graph(
  469. self,
  470. gm: torch.fx.GraphModule,
  471. inputs: list[torch.Tensor],
  472. ) -> None:
  473. with self.fopen("fx_graph_runnable.py") as fd:
  474. save_dir = None
  475. if torch._inductor.config.trace.save_real_tensors:
  476. inputs = torch._subclasses.fake_utils.try_convert_fake_to_real(inputs)
  477. save_dir = os.path.dirname(fd.name)
  478. # dont try to use stable hash torchinductor compilation if saving real tensors
  479. # and avoid recursively trying to save real tensors inside of the inductor compilation
  480. # regardless
  481. stable_hash = torch._inductor.config.trace.save_real_tensors
  482. with torch._inductor.config.patch(
  483. {"trace.enabled": False, "trace.save_real_tensors": False}
  484. ):
  485. save_graph_repro(
  486. fd,
  487. gm,
  488. inputs,
  489. "inductor",
  490. save_dir=save_dir,
  491. stable_hash=stable_hash,
  492. )
  493. with self.fopen("fx_graph_readable.py") as fd:
  494. fd.write(gm.print_readable(print_output=False))
  495. def fx_graph_transformed(
  496. self,
  497. gm: torch.fx.GraphModule,
  498. inputs: list[torch.Tensor],
  499. ) -> None:
  500. with self.fopen("fx_graph_transformed.py") as fd:
  501. fd.write(gm.print_readable(print_output=False))
  502. def ir_pre_fusion(self, nodes: SchedulerNodeList) -> None:
  503. with self.fopen("ir_pre_fusion.txt") as fd:
  504. fd.write(self._write_ir(nodes))
  505. def ir_post_fusion(self, nodes: SchedulerNodeList) -> None:
  506. with self.fopen("ir_post_fusion.txt") as fd:
  507. fd.write(self._write_ir(nodes))
  508. @staticmethod
  509. def _write_ir(nodes: SchedulerNodeList) -> str:
  510. buf = io.StringIO()
  511. for node in nodes:
  512. buf.write(node.debug_str())
  513. buf.write("\n\n\n")
  514. return buf.getvalue()
  515. def graph_diagram(self, nodes: SchedulerNodeList) -> None:
  516. draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
  517. def draw_orig_fx_graph(
  518. self,
  519. gm: torch.fx.GraphModule,
  520. nodes: SchedulerNodeList,
  521. ) -> None:
  522. annotate_orig_fx_with_snodes(gm, nodes)
  523. draw_graph(
  524. gm,
  525. fname=self.filename("orig_fx_graph_diagram.svg"),
  526. clear_meta=False,
  527. prog=GRAPHVIZ_COMMAND_SCALABLE,
  528. parse_stack_trace=True,
  529. dot_graph_shape=config.trace.dot_graph_shape,
  530. )
  531. def output_code(self, filename: str, extension: str = "py") -> None:
  532. shutil.copy(filename, self.filename(f"output_code.{extension}"))
  533. def log_autotuning_results(
  534. self,
  535. name: str,
  536. input_nodes: list[ir.IRNode],
  537. timings: dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821
  538. elapse: float,
  539. precompile_elapse: float,
  540. prescreening_elapse: Optional[float],
  541. ) -> None:
  542. from .ir import FixedLayout
  543. def build_node_info(node: ir.IRNode) -> dict[str, str]:
  544. if hasattr(node, "name"):
  545. node_name = node.name
  546. else:
  547. node_name = ""
  548. node_info = {
  549. "name": node_name,
  550. "type": type(node).__name__,
  551. }
  552. try:
  553. layout = node.get_output_spec()
  554. if isinstance(layout, FixedLayout):
  555. static_layout = FixedLayout(
  556. layout.device,
  557. dtype=layout.dtype,
  558. size=V.graph.sizevars.optimization_hints(layout.size),
  559. stride=V.graph.sizevars.optimization_hints(layout.stride),
  560. offset=V.graph.sizevars.optimization_hint(
  561. layout.offset, fallback=0
  562. ),
  563. )
  564. node_info["layout"] = str(static_layout)
  565. else:
  566. node_info["layout"] = str(layout)
  567. except Exception:
  568. pass
  569. try:
  570. node_info["dtype"] = str(node.get_dtype())
  571. except Exception:
  572. pass
  573. try:
  574. node_info["device"] = str(node.get_device())
  575. except Exception:
  576. pass
  577. try:
  578. node_info["stride"] = str(
  579. V.graph.sizevars.optimization_hints(node.get_stride())
  580. )
  581. except Exception:
  582. pass
  583. try:
  584. node_info["size"] = str(
  585. V.graph.sizevars.optimization_hints(node.get_size())
  586. ) # type: ignore[arg-type]
  587. except Exception:
  588. pass
  589. try:
  590. node_info["numel"] = str(
  591. V.graph.sizevars.optimization_hint(node.get_numel())
  592. )
  593. except Exception:
  594. pass
  595. if hasattr(node, "data") and isinstance(node.data, ir.IRNode):
  596. node_info["data"] = build_node_info(node.data)
  597. return node_info
  598. general_properties = {
  599. "op_name": name,
  600. "cuda_device_name": torch.cuda.get_device_name(),
  601. "cuda_device_count": torch.cuda.device_count(),
  602. "input_nodes": [build_node_info(node) for node in input_nodes],
  603. "autotuning_time": elapse,
  604. "precompile_time": precompile_elapse,
  605. "prescreening_time": prescreening_elapse,
  606. }
  607. with self.fopen_context(
  608. "autotuning_result_json_list.txt", "at", encoding="utf-8"
  609. ) as fd:
  610. for caller, time in timings.items():
  611. info_dict = dict(caller.info_dict())
  612. info_dict.update(general_properties)
  613. info_dict["benchmark_result"] = time
  614. json.dump(info_dict, fd)
  615. fd.write("\n")
  616. def log_ir_pre_fusion(nodes: SchedulerNodeList) -> None:
  617. if ir_pre_fusion_log.isEnabledFor(logging.INFO):
  618. ir_pre_fusion_log.info("BEFORE FUSION\n%s", DebugFormatter._write_ir(nodes))
  619. V.debug.ir_pre_fusion(nodes)
  620. def log_ir_post_fusion(nodes: SchedulerNodeList) -> None:
  621. if ir_post_fusion_log.isEnabledFor(logging.INFO):
  622. ir_post_fusion_log.info("AFTER FUSION\n%s", DebugFormatter._write_ir(nodes))
  623. V.debug.ir_post_fusion(nodes)
  624. def _dump_collective_schedule(schedule: list[Union[str, None]]) -> None:
  625. try:
  626. trace_structured(
  627. "artifact",
  628. metadata_fn=lambda: {
  629. "name": "inductor_collective_schedule",
  630. "encoding": "json",
  631. },
  632. payload_fn=lambda: schedule,
  633. )
  634. except Exception:
  635. log.debug(
  636. "Failed to log inductor_collective_schedule via structured logging",
  637. exc_info=True,
  638. )
  639. def log_collective_schedule(nodes: Sequence[BaseSchedulerNode]) -> None:
  640. schedule = [
  641. getattr(op, "python_kernel_name", None)
  642. for node in nodes
  643. if isinstance(op := getattr(node, "node", None), ir._CollectiveKernel)
  644. ]
  645. # Only log when there is at least one collective op
  646. if schedule:
  647. _dump_collective_schedule(schedule)
  648. def log_runtime_and_tensor_meta(node_runtimes: Sequence[tuple[Any, float]]) -> None:
  649. """Log per-op runtime estimates and output tensor metadata for TLParse."""
  650. try:
  651. to_optimization_hints = V.graph.sizevars.optimization_hints
  652. def to_list(x: Optional[Sequence[Any]]) -> list[Any]:
  653. return list(to_optimization_hints(x)) if x is not None else []
  654. def dtype_to_str(dtype: Any) -> Optional[str]:
  655. if dtype is None:
  656. return None
  657. s = str(dtype)
  658. s = s.removeprefix("torch.")
  659. return s
  660. ops: list[dict[str, Any]] = []
  661. for s, runtime_ns in node_runtimes:
  662. name = getattr(s.node, "python_kernel_name", s.get_name())
  663. op_type = "collective" if utils.is_collective(s.node) else "compute"
  664. # Build outputs metadata if available
  665. outputs: list[dict[str, Any]] = []
  666. try:
  667. for buf in s.get_outputs():
  668. irnode = buf.node
  669. shape = irnode.maybe_get_size()
  670. stride = (
  671. irnode.get_stride()
  672. if isinstance(irnode.layout, ir.Layout)
  673. else None
  674. )
  675. dtype = irnode.maybe_get_dtype()
  676. outputs.append(
  677. {
  678. "shape": to_list(shape),
  679. "stride": to_list(stride),
  680. "dtype": dtype_to_str(dtype),
  681. }
  682. )
  683. except Exception:
  684. pass
  685. ops.append(
  686. {
  687. "name": name,
  688. "type": op_type,
  689. "estimated_runtime_ns": runtime_ns,
  690. "outputs": outputs,
  691. }
  692. )
  693. trace_structured(
  694. "artifact",
  695. metadata_fn=lambda: {
  696. "name": "inductor_runtime_and_tensor_meta",
  697. "encoding": "json",
  698. },
  699. payload_fn=lambda: {"ops": ops},
  700. )
  701. except Exception:
  702. log.debug("Failed to log inductor_runtime_and_tensor_meta", exc_info=True)
  703. def log_graph_execution() -> None:
  704. """Emit a structured artifact with the graph execution order."""
  705. if not GRAPH_EXECUTION_ORDER:
  706. return
  707. try:
  708. trace_structured(
  709. "artifact",
  710. metadata_fn=lambda: {
  711. "name": "graph_execution",
  712. "encoding": "json",
  713. },
  714. payload_fn=lambda: {"graph_execution_order": GRAPH_EXECUTION_ORDER},
  715. )
  716. except Exception:
  717. log.debug("Failed to log graph_execution", exc_info=True)
  718. @contextlib.contextmanager
  719. def record_and_log_graph_execution_order() -> Iterator[None]:
  720. """Record graph execution order and log it once on exit."""
  721. global RECORD_GRAPH_EXECUTION, GRAPH_EXECUTION_ORDER, GRAPH_COMPILE_IDS
  722. GRAPH_EXECUTION_ORDER = []
  723. GRAPH_COMPILE_IDS = {}
  724. RECORD_GRAPH_EXECUTION = True
  725. try:
  726. yield
  727. finally:
  728. log_graph_execution()
  729. RECORD_GRAPH_EXECUTION = False
  730. GRAPH_EXECUTION_ORDER = None
  731. GRAPH_COMPILE_IDS = None
  732. @dataclasses.dataclass
  733. class TensorMetadataHolder:
  734. tensor_metadata: TensorMetadata
  735. device: torch.device
  736. save_args_cnt = itertools.count()
  737. def create_mapping_pre_post_grad_nodes(
  738. pre_grad_graph_id: Optional[int],
  739. post_to_pre_grad_nodes_json: dict[str, Any],
  740. ) -> dict[str, dict[str, list[str]]]:
  741. """
  742. Create bidirectional mappings between pre_grad graph nodes
  743. and post_grad graph code nodes, and vice versa.
  744. """
  745. # return a dummy dict if there's any error
  746. empty_return: dict[str, dict[str, list[str]]] = {
  747. "preToPost": {},
  748. "postToPre": {},
  749. }
  750. if not isinstance(post_to_pre_grad_nodes_json, dict):
  751. log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict")
  752. return empty_return
  753. if not isinstance(pre_grad_graph_id, int):
  754. # pre_grad_graph_id may be empty if there's no pre_grad graph
  755. # and there's only a backward graph from backward pass engine
  756. return empty_return
  757. pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet)
  758. post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet)
  759. try:
  760. def check_format(node: dict[str, Any]) -> bool:
  761. if not isinstance(node, dict):
  762. log.error(
  763. "Provenance tacking error: node provenance in post_to_pre_grad_nodes_json is not a dict"
  764. )
  765. return False
  766. if "graph_id" not in node or "name" not in node or "from_node" not in node:
  767. log.error(
  768. "Provenance tacking error: node provenance in post_to_pre_grad_nodes_json has wrong format"
  769. )
  770. return False
  771. return True
  772. for outer_key, node_array in post_to_pre_grad_nodes_json.items():
  773. if not isinstance(node_array, list):
  774. log.error(
  775. "Provenance tacking error: post_to_pre_grad_nodes_json value is not a list"
  776. )
  777. return empty_return
  778. for node in node_array:
  779. if not check_format(node):
  780. return empty_return
  781. # Check the current node first
  782. if node.get("graph_id") == pre_grad_graph_id:
  783. pre_to_post[node["name"]].add(outer_key)
  784. post_to_pre[outer_key].add(node["name"])
  785. # Check nested from_node array recursively, add node with the right graph_id to the map
  786. stack = [(n, outer_key) for n in node.get("from_node", [])]
  787. while stack:
  788. current_node, parent_key = stack.pop()
  789. if not check_format(current_node):
  790. return empty_return
  791. if current_node.get("graph_id") == pre_grad_graph_id:
  792. pre_to_post[current_node["name"]].add(parent_key)
  793. post_to_pre[parent_key].add(current_node["name"])
  794. stack.extend(
  795. (n, parent_key) for n in current_node.get("from_node", [])
  796. )
  797. def convert_sets_to_lists(d: dict[str, Any]) -> None:
  798. for key in d:
  799. d[key] = list(d[key])
  800. d = dict(d)
  801. # convert to list because set is not JSON serializable
  802. convert_sets_to_lists(pre_to_post)
  803. convert_sets_to_lists(post_to_pre)
  804. return {
  805. "preToPost": pre_to_post,
  806. "postToPre": post_to_pre,
  807. }
  808. except Exception as e:
  809. # Since this is just logging code, it should never interfere with regular
  810. # program execution, so we use this try-except to guard against any error
  811. signpost_event(
  812. "inductor",
  813. "provenance_tracking_error",
  814. {
  815. "function": "create_mapping_pre_post_grad_nodes",
  816. "error_msg": str(e),
  817. "stack_trace": traceback.format_exc(),
  818. },
  819. )
  820. log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json)
  821. log.error("pre_grad_graph_id: %s", pre_grad_graph_id)
  822. return empty_return
  823. def create_node_mapping_kernel_to_post_grad(
  824. triton_kernel_to_post_grad_json: dict[str, Any],
  825. ) -> dict[str, dict[str, Any]]:
  826. """Create bidirectional mappings between triton kernel name and post_grad
  827. graph code nodes, and vice versa.
  828. """
  829. # return a dummy dict if there's any error
  830. empty_return: dict[str, dict[str, Any]] = {
  831. "cppCodeToPost": {},
  832. "postToCppCode": {},
  833. }
  834. if not isinstance(triton_kernel_to_post_grad_json, dict):
  835. log.error(
  836. "Provenance tacking error: triton_kernel_to_post_grad_json is not a dict"
  837. )
  838. return empty_return
  839. post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet)
  840. try:
  841. for outer_key, node_array in triton_kernel_to_post_grad_json.items():
  842. if not isinstance(node_array, list):
  843. log.error(
  844. "Provenance tacking error: triton_kernel_to_post_grad_json value is not a list"
  845. )
  846. return empty_return
  847. for curr_node in node_array:
  848. post_to_cpp_code[curr_node].add(outer_key)
  849. def convert_sets_to_lists(d: dict[str, Any]) -> None:
  850. for key in d:
  851. d[key] = list(d[key])
  852. d = dict(d)
  853. # convert to list because set is not JSON serializable
  854. convert_sets_to_lists(post_to_cpp_code)
  855. return {
  856. "cppCodeToPost": triton_kernel_to_post_grad_json,
  857. "postToCppCode": post_to_cpp_code,
  858. }
  859. except Exception as e:
  860. # Since this is just logging code, it should never interfere with regular
  861. # program execution, so we use this try-except to guard against any error
  862. signpost_event(
  863. "inductor",
  864. "provenance_tracking_error",
  865. {
  866. "function": "create_mapping_kernel_to_post_grad",
  867. "error_msg": str(e),
  868. "stack_trace": traceback.format_exc(),
  869. },
  870. )
  871. log.error(
  872. "triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json
  873. )
  874. return empty_return
  875. def dump_inductor_provenance_info() -> dict[str, Any]:
  876. try:
  877. global _pre_grad_graph_id
  878. global _inductor_post_to_pre_grad_nodes
  879. global _inductor_triton_kernel_to_post_grad_node_info
  880. node_mapping: dict[str, Any] = {}
  881. if _pre_grad_graph_id:
  882. node_mapping_kernel = create_node_mapping_kernel_to_post_grad(
  883. _inductor_triton_kernel_to_post_grad_node_info
  884. )
  885. node_mapping = {
  886. **_inductor_post_to_pre_grad_nodes,
  887. **node_mapping_kernel,
  888. }
  889. if config.trace.enabled:
  890. with V.debug.fopen(
  891. "inductor_provenance_tracking_node_mappings.json", "w"
  892. ) as fd:
  893. json.dump(node_mapping, fd)
  894. # we need to update the node mapping version when node mapping format changes
  895. # so the tlparse tool knows which node mapping version it is looking at
  896. node_mapping["version"] = 2.0
  897. return node_mapping
  898. except Exception as e:
  899. # Since this is just debugging, it should never interfere with regular
  900. # program execution, so we use this try-except to guard against any error
  901. signpost_event(
  902. "inductor",
  903. "provenance_tracking_error",
  904. {
  905. "function": "dump_inductor_provenance_info",
  906. "error_msg": str(e),
  907. "stack_trace": traceback.format_exc(),
  908. },
  909. )
  910. return {}
  911. def create_kernel_information_json() -> dict[str, dict[str, list[str]]]:
  912. """Create kernel information JSON"""
  913. try:
  914. global _inductor_post_to_pre_grad_nodes
  915. global _inductor_kernel_stack_trace
  916. global _inductor_triton_kernel_to_post_grad_node_info
  917. post_to_pre = _inductor_post_to_pre_grad_nodes.get("postToPre", {})
  918. all_kernels = OrderedSet(_inductor_kernel_stack_trace.keys()) | OrderedSet(
  919. _inductor_triton_kernel_to_post_grad_node_info.keys()
  920. )
  921. result = {}
  922. for kernel_name in all_kernels:
  923. post_grad_nodes = _inductor_triton_kernel_to_post_grad_node_info.get(
  924. kernel_name, []
  925. )
  926. pre_grad_nodes: OrderedSet[str] = OrderedSet()
  927. for post_node in post_grad_nodes:
  928. pre_grad_nodes.update(post_to_pre.get(post_node, []))
  929. result[kernel_name] = {
  930. "stack_traces": _inductor_kernel_stack_trace.get(kernel_name, []),
  931. "post_grad_nodes": post_grad_nodes,
  932. "pre_grad_nodes": list(pre_grad_nodes),
  933. }
  934. return result
  935. except Exception as e:
  936. signpost_event(
  937. "inductor",
  938. "provenance_tracking_error",
  939. {
  940. "function": "create_kernel_information_json",
  941. "error_msg": str(e),
  942. "stack_trace": traceback.format_exc(),
  943. },
  944. )
  945. return {}
  946. def set_kernel_post_grad_provenance_tracing(
  947. node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
  948. kernel_name: str,
  949. is_extern: bool = False,
  950. ) -> Optional[int]:
  951. """
  952. Set the mapping between `kernel_name` and the post_grad nodes in `node_schedule`.
  953. Returns a unique int debug handler for each call to this function.
  954. """
  955. if config.trace.provenance_tracking_level == 0:
  956. return None
  957. try:
  958. from .codegen.simd_kernel_features import DisableReduction, EnableReduction
  959. global _inductor_triton_kernel_to_post_grad_node_info
  960. global _inductor_kernel_stack_trace
  961. global _inductor_kernel_provenance_debug_handle
  962. _inductor_kernel_provenance_debug_handle += 1
  963. stack_traces: list[str] = []
  964. kernel_name = f"{kernel_name}:{_inductor_kernel_provenance_debug_handle}"
  965. if is_extern:
  966. assert isinstance(node_schedule, ExternKernel)
  967. curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault(
  968. kernel_name, []
  969. )
  970. # 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel.
  971. # "origin_node" is more precise and says that the contents of this node corresponds
  972. # EXACTLY to the output of a particular FX node, but it's not always available
  973. if node_schedule.origin_node:
  974. origin_node_name = node_schedule.origin_node.name
  975. if origin_node_name not in curr_node_info:
  976. curr_node_info.append(origin_node_name)
  977. else:
  978. curr_node_info.extend(
  979. origin.name
  980. for origin in node_schedule.origins
  981. if origin.name not in curr_node_info
  982. )
  983. stack_traces = list(node_schedule.get_stack_traces())
  984. else:
  985. assert isinstance(node_schedule, list)
  986. stack_traces_set: OrderedSet[str] = OrderedSet()
  987. for snode in node_schedule:
  988. if snode not in (EnableReduction, DisableReduction):
  989. if snode.node is not None:
  990. curr_node_info = (
  991. _inductor_triton_kernel_to_post_grad_node_info.setdefault(
  992. kernel_name, []
  993. )
  994. )
  995. # pyrefly: ignore [missing-attribute]
  996. stack_traces_set.update(snode.node.get_stack_traces())
  997. curr_node_info.extend(
  998. origin.name
  999. # pyrefly: ignore [missing-attribute]
  1000. for origin in snode.node.origins
  1001. if origin.name not in curr_node_info
  1002. )
  1003. stack_traces = list(stack_traces_set)
  1004. _inductor_kernel_stack_trace.setdefault(kernel_name, []).extend(stack_traces)
  1005. return _inductor_kernel_provenance_debug_handle
  1006. except Exception as e:
  1007. # Since this is just debugging, it should never interfere with regular
  1008. # program execution, so we use this try-except to guard against any error
  1009. signpost_event(
  1010. "inductor",
  1011. "provenance_tracking_error",
  1012. {
  1013. "function": "set_kernel_post_grad_provenance_tracing",
  1014. "error_msg": str(e),
  1015. "stack_trace": traceback.format_exc(),
  1016. },
  1017. )
  1018. return None
  1019. def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None:
  1020. """
  1021. This function is used to save arguments for a compile_fx_inner function call
  1022. to the file system. Later on one can replay the compile_fx_inner call
  1023. with the saved arguments using load_args_and_run_compile_fx_inner.
  1024. """
  1025. folder = os.path.join(tempfile.gettempdir(), "inductor_saved_args")
  1026. if not os.path.exists(folder):
  1027. os.mkdir(folder)
  1028. def handle_tensor(x: Any) -> Any:
  1029. """
  1030. Pickle FakeTensor will result in error:
  1031. AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'
  1032. Convert all Tensor to metadata. This may also makes pickle faster.
  1033. """
  1034. if isinstance(x, torch.Tensor):
  1035. return TensorMetadataHolder(_extract_tensor_metadata(x), x.device)
  1036. else:
  1037. return x
  1038. args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs))
  1039. fn_name = "compile_fx_inner"
  1040. path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl"
  1041. with open(path, "wb") as f:
  1042. pickle.dump((args_to_save, kwargs_to_save), f)
  1043. if log.isEnabledFor(logging.DEBUG):
  1044. message = f"""
  1045. Arguments for a compile_fx_inner call is saved to {path}. To replay the call,
  1046. run the following:
  1047. from torch._inductor.debug import load_args_and_run_compile_fx_inner
  1048. load_args_and_run_compile_fx_inner({path!r})
  1049. """
  1050. # call print rather than log.debug. log.debug will print message
  1051. # prefix for each line which makes the code snippet harder to be
  1052. # copied.
  1053. # Not a big deal since the code is already been guarded by checking
  1054. # the log level.
  1055. print(message)
  1056. def load_args_and_run_compile_fx_inner(path: str) -> Any:
  1057. from torch._inductor.compile_fx import compile_fx_inner
  1058. with open(path, "rb") as f:
  1059. args, kwargs = pickle.load(f)
  1060. def handle_tensor(x: Any) -> Any:
  1061. if isinstance(x, TensorMetadataHolder):
  1062. return torch._dynamo.testing.rand_strided(
  1063. x.tensor_metadata.shape,
  1064. x.tensor_metadata.stride,
  1065. x.tensor_metadata.dtype,
  1066. x.device,
  1067. )
  1068. else:
  1069. return x
  1070. fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
  1071. with fake_mode, config.patch("save_args", False):
  1072. args, kwargs = tree_map(handle_tensor, (args, kwargs))
  1073. return compile_fx_inner(*args, **kwargs)
  1074. def aot_inductor_minifier_wrapper(
  1075. func: Callable[..., str],
  1076. exported_program: torch.export.ExportedProgram,
  1077. *,
  1078. inductor_configs: dict[str, Any],
  1079. package_path: Optional[FileLike] = None,
  1080. ) -> str:
  1081. from torch._dynamo.debug_utils import AccuracyError
  1082. from torch._dynamo.repro.aoti import dump_to_minify
  1083. from torch._inductor import config
  1084. from torch._inductor.compile_fx import _aoti_flatten_inputs
  1085. use_minifier = config.aot_inductor.dump_aoti_minifier
  1086. gm = exported_program.module(check_guards=False)
  1087. assert isinstance(gm, torch.fx.GraphModule)
  1088. args, kwargs = exported_program.example_inputs
  1089. try:
  1090. if use_minifier and config.aot_inductor.repro_level == 3:
  1091. # Always dump the original module in case we have segfaults
  1092. dump_to_minify(
  1093. exported_program,
  1094. "aot_inductor",
  1095. options=inductor_configs,
  1096. )
  1097. if use_minifier and config.aot_inductor.repro_level == 4:
  1098. # Check for accuracy
  1099. # We will first flatten the inputs before compiling and checking for accuracy.
  1100. # This is ok because we will flatten the inputs in the minifier anyway.
  1101. gm_copy = copy.deepcopy(gm)
  1102. example_inputs_copy = copy.deepcopy(exported_program.example_inputs)
  1103. config_copy = copy.deepcopy(inductor_configs)
  1104. flat_example_inputs, config_copy = _aoti_flatten_inputs(
  1105. gm_copy,
  1106. example_inputs_copy[0],
  1107. example_inputs_copy[1],
  1108. options=config_copy,
  1109. )
  1110. tuple_inputs = tuple(flat_example_inputs)
  1111. flattened_ep = torch.export.export(gm_copy, tuple_inputs, strict=False)
  1112. func(
  1113. flattened_ep.module(check_guards=False),
  1114. tuple_inputs,
  1115. inductor_configs=config_copy,
  1116. package_path=package_path,
  1117. load_and_run=True,
  1118. check_accuracy="accuracy",
  1119. )
  1120. return func(
  1121. gm,
  1122. args,
  1123. kwargs,
  1124. inductor_configs=inductor_configs,
  1125. package_path=package_path,
  1126. load_and_run=use_minifier,
  1127. )
  1128. except AccuracyError as e:
  1129. dump_to_minify(
  1130. exported_program,
  1131. "aot_inductor_accuracy",
  1132. command="minify",
  1133. options=inductor_configs,
  1134. )
  1135. log.warning("Accuracy failed")
  1136. raise e
  1137. except Exception as e:
  1138. if use_minifier:
  1139. command = "minify"
  1140. if config.aot_inductor.repro_level == 1:
  1141. command = "run"
  1142. dump_to_minify(
  1143. exported_program,
  1144. "aot_inductor",
  1145. command=command,
  1146. options=inductor_configs,
  1147. )
  1148. raise e