aoti.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662
  1. """
  2. Utilities for debugging and reproducing issues in Ahead of Time with Inductor (AOTI) compilation.
  3. This file provides tools and utilities for:
  4. - Generating minimal reproducible test cases (minification)
  5. - Handling exported programs and graph modules
  6. - Creating debug repros for AOTI compilation issues
  7. - Supporting both accuracy testing and error reproduction
  8. - Managing configuration and environment for repro cases
  9. The main components include:
  10. - Minification tools to reduce test cases while preserving errors
  11. - Repro generation utilities for exported programs
  12. - Error handling specific to AOTI compilation
  13. - Command-line interface for running and managing repros
  14. """
  15. import argparse
  16. import functools
  17. import io
  18. import logging
  19. import os
  20. import re
  21. import shutil
  22. import sys
  23. import textwrap
  24. from collections.abc import Sequence
  25. from importlib import import_module
  26. from typing import Any, IO, Optional, Union
  27. import torch
  28. from torch._dynamo.debug_utils import (
  29. _cuda_system_info_comment,
  30. BuckTargetWriter,
  31. extra_imports,
  32. generate_config_string,
  33. generate_env_vars_string,
  34. helper_for_dump_minify,
  35. InputReader,
  36. minifier_dir,
  37. NNModuleToString,
  38. NopInputReader,
  39. )
  40. from torch.export import ExportedProgram
  41. from torch.hub import tqdm
  42. log = logging.getLogger(__name__)
  43. inductor_config = import_module("torch._inductor.config")
  44. use_buck = inductor_config.is_fbcode()
  45. class AOTIMinifierError(Exception):
  46. def __init__(self, original_exception: Union[str, Exception]) -> None:
  47. additional_message = "This error is caused by a bug in the AOTI minifier, please report a bug to PyTorch"
  48. full_message = f"{additional_message}: {str(original_exception)}"
  49. super().__init__(full_message)
  50. self.original_exception = original_exception
  51. def dump_to_minify(
  52. exported_program: ExportedProgram,
  53. compiler_name: str,
  54. command: str = "minify",
  55. options: Optional[dict[str, Any]] = None,
  56. ) -> None:
  57. """
  58. If command is "minify":
  59. Dump exported_program to `debug_dir/minifier/minifier_launcher.py`, with minify command.
  60. If command is "run":
  61. Dump exported_program to `cwd/repro.py`, with run command.
  62. """
  63. assert command in ["minify", "run"]
  64. subdir = os.path.join(minifier_dir(), "checkpoints")
  65. if not os.path.exists(subdir):
  66. os.makedirs(subdir, exist_ok=True)
  67. if command == "minify":
  68. out = io.StringIO()
  69. save_graph_repro_ep(
  70. out,
  71. compiler_name,
  72. exported_program=exported_program,
  73. save_dir=subdir,
  74. command="minify",
  75. config_patches=options,
  76. )
  77. return helper_for_dump_minify(out.getvalue())
  78. else:
  79. curdir = os.getcwd()
  80. file_name = os.path.join(curdir, "repro.py")
  81. try:
  82. with open(file_name, "w") as fd:
  83. save_graph_repro_ep(
  84. fd,
  85. compiler_name,
  86. exported_program=exported_program,
  87. config_patches=options,
  88. save_dir=subdir,
  89. command="run",
  90. module_in_comment=True,
  91. )
  92. log.warning("Writing repro file to %s", file_name)
  93. if use_buck:
  94. BuckTargetWriter(file_name).write()
  95. except OSError:
  96. log.warning("No write permissions for %s", file_name)
  97. def get_module_string(gm: torch.fx.GraphModule) -> str:
  98. def _convert_to_comment(s_: str) -> str:
  99. s = s_.split("\n")
  100. if len(s) == 1:
  101. return "# " + s_
  102. first = s.pop(0)
  103. for i in range(len(s)):
  104. line = s[i]
  105. if line.strip() != "":
  106. s[i] = "# " + line
  107. else:
  108. s[i] = ""
  109. s = "\n".join(s)
  110. s = first + "\n" + s
  111. return s
  112. module_string = NNModuleToString.convert(gm)
  113. return _convert_to_comment(module_string)
  114. def save_graph_repro_ep(
  115. fd: IO[Any],
  116. compiler_name: str,
  117. *,
  118. exported_program: Optional[ExportedProgram] = None,
  119. gm: Optional[torch.nn.Module] = None,
  120. args: Optional[tuple[Any]] = None,
  121. config_patches: Optional[dict[str, str]] = None,
  122. stable_output: bool = False,
  123. save_dir: Optional[str] = None,
  124. command: str = "run",
  125. accuracy: Optional[Union[str, bool]] = None,
  126. check_str: Optional[str] = None,
  127. module_in_comment: bool = False,
  128. strict: bool = False,
  129. ) -> None:
  130. # Save graph for reproducing the error.
  131. # Either exported_program or gm will be saved, depending on which one is defined.
  132. # Only one of exported_program and gm should be defined.
  133. if exported_program is None and gm is None:
  134. raise AOTIMinifierError("One of exported_program and gm must be defined")
  135. if exported_program is not None and gm is not None:
  136. raise AOTIMinifierError("Only one of exported_program and gm can be defined")
  137. if gm is not None and args is None:
  138. raise AOTIMinifierError("If gm is defined, args should also be defined")
  139. if exported_program is None:
  140. assert gm is not None
  141. assert args is not None
  142. exported_program = torch.export.export(gm, args, strict=strict)
  143. elif gm is None:
  144. gm = exported_program.module(check_guards=False)
  145. # save a graph preview using gm
  146. module_string = get_module_string(gm) # type: ignore[arg-type]
  147. fd.write(module_string)
  148. # save a graph repro using exported_program
  149. fd.write(
  150. generate_compiler_repro_exported_program(
  151. exported_program,
  152. options=config_patches,
  153. stable_output=stable_output,
  154. save_dir=save_dir,
  155. )
  156. )
  157. if accuracy is None:
  158. accuracy = "_accuracy" in compiler_name
  159. fd.write("if __name__ == '__main__':\n")
  160. fd.write(" from torch._dynamo.repro.aoti import run_repro\n")
  161. fd.write(
  162. f" with torch.no_grad():\n"
  163. f" run_repro(exported_program, config_patches=config_patches, accuracy={accuracy!r}, command={command!r}, "
  164. f"save_dir={save_dir!r}, check_str={check_str!r})\n"
  165. )
  166. def dump_compiler_graph_state(
  167. gm: torch.fx.GraphModule,
  168. args: Sequence[Any],
  169. compiler_name: str,
  170. *,
  171. config_patches: Optional[dict[str, str]] = None,
  172. accuracy: Optional[Union[str, bool]] = None,
  173. strict: bool = False,
  174. ) -> None:
  175. subdir = os.path.join(minifier_dir(), "checkpoints")
  176. if not os.path.exists(subdir):
  177. os.makedirs(subdir, exist_ok=True)
  178. file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py")
  179. log.warning(
  180. "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name
  181. )
  182. with open(file_name, "w") as fd:
  183. save_graph_repro_ep(
  184. fd,
  185. compiler_name,
  186. gm=gm,
  187. args=tuple(args),
  188. config_patches=config_patches,
  189. save_dir=subdir,
  190. accuracy=accuracy,
  191. module_in_comment=True,
  192. strict=strict,
  193. )
  194. curdir = os.getcwd()
  195. repro_path = os.path.join(curdir, "repro.py")
  196. try:
  197. shutil.copyfile(file_name, repro_path)
  198. log.warning("Copying repro file for convenience to %s", repro_path)
  199. if use_buck:
  200. BuckTargetWriter(file_name).write()
  201. except OSError:
  202. log.warning("No write permissions for %s", repro_path)
  203. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  204. # DUMP REPROS
  205. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  206. def generate_compiler_repro_exported_program(
  207. exported_program: ExportedProgram,
  208. *,
  209. options: Optional[dict[str, str]] = None,
  210. stable_output: bool = False,
  211. save_dir: Optional[str] = None,
  212. ) -> str:
  213. model_str = textwrap.dedent(
  214. f"""
  215. {generate_env_vars_string(stable_output=stable_output)}
  216. import torch
  217. import torch._inductor.inductor_prims
  218. {generate_config_string(stable_output=stable_output)}
  219. isolate_fails_code_str = None
  220. {extra_imports}
  221. """
  222. )
  223. if not stable_output:
  224. model_str += f"# torch version: {torch.version.__version__}\n"
  225. if hasattr(torch.version, "cuda"):
  226. model_str += f"# torch cuda version: {torch.version.cuda}\n"
  227. if hasattr(torch.version, "git_version"):
  228. model_str += f"# torch git version: {torch.version.git_version}\n\n\n"
  229. model_str += _cuda_system_info_comment()
  230. if save_dir:
  231. ep_path = os.path.join(save_dir, "exported_program.pt2")
  232. else:
  233. ep_path = "exported_program.pt2"
  234. torch.export.save(exported_program, ep_path)
  235. model_str += f"exported_program = torch.export.load('{ep_path}')\n"
  236. model_str += "# print(exported_program.graph)\n"
  237. model_str += f"config_patches={options}\n"
  238. return model_str
  239. def repro_load_args(load_args: Any, save_dir: Optional[str]) -> tuple[Any]:
  240. if not hasattr(load_args, "_version"):
  241. log.warning(
  242. "load_args does not have a _version attribute, please file a bug to PyTorch "
  243. "and describe how you generate this repro script"
  244. )
  245. else:
  246. if load_args._version > 0:
  247. log.warning(
  248. "load_args is version %s, but this version of PyTorch only supports "
  249. "version 0. We will try to run it anyway but there may be an incompatibility; "
  250. "if so, try upgrading your version of PyTorch.",
  251. load_args._version,
  252. )
  253. nop_reader = NopInputReader()
  254. load_args(nop_reader)
  255. with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar:
  256. input_reader = InputReader(save_dir=save_dir, pbar=pbar)
  257. load_args(input_reader)
  258. args = input_reader.args
  259. return tuple(args)
  260. def repro_common(
  261. options: Any, exported_program: ExportedProgram
  262. ) -> tuple[torch.fx.GraphModule, Any, Any]:
  263. # pyrefly: ignore [bad-assignment]
  264. torch._inductor.config.generate_intermediate_hooks = True
  265. mod = exported_program.module(check_guards=False)
  266. args, kwargs = exported_program.example_inputs
  267. return mod, args, kwargs # type: ignore[return-value]
  268. def repro_get_args(
  269. options: Any,
  270. exported_program: ExportedProgram,
  271. config_patches: Optional[dict[str, Any]],
  272. ) -> tuple[torch.fx.GraphModule, Any, Any]:
  273. mod, args, kwargs = repro_common(options, exported_program)
  274. return mod, args, kwargs
  275. def repro_run(
  276. options: Any,
  277. exported_program: ExportedProgram,
  278. config_patches: Optional[dict[str, Any]],
  279. ) -> None:
  280. from torch._inductor import _aoti_compile_and_package_inner
  281. gm, args, kwargs = repro_common(options, exported_program)
  282. from torch.cuda import synchronize
  283. _aoti_compile_and_package_inner(
  284. gm,
  285. args,
  286. kwargs,
  287. load_and_run=True,
  288. check_accuracy=options.accuracy,
  289. inductor_configs=config_patches,
  290. )
  291. need_sync = False
  292. for arg in args:
  293. if isinstance(arg, torch.Tensor) and arg.is_cuda:
  294. need_sync = True
  295. break
  296. if need_sync:
  297. synchronize() # ensure segfaults are surfaced
  298. def export_for_aoti_minifier(
  299. gm: torch.nn.Module,
  300. tuple_inputs: tuple[Any],
  301. strict: bool = False,
  302. skip_export_error: bool = True,
  303. ) -> Optional[torch.nn.Module]:
  304. # Some graphs cannot be used for AOTI/export (illegal graphs), these should be
  305. # considered as graphs that don't fail in the minifier, so the minifier keeps searching.
  306. # In these case, we return None. Otherwise, we return the exported graph module.
  307. # This won't affect the minifier result because the minifier is only responsible for catching
  308. # errors in AOTI, not export.
  309. #
  310. # Please add to this list of illegal graphs if you change the implementation here.
  311. # - graph output is not allowed by export
  312. #
  313. # If skip_export_error=True, then the errors in export will not be raised, and the minifier
  314. # will keep exploring and ignore this graph.
  315. from torch._dynamo.exc import UserError, UserErrorType
  316. try:
  317. ep = torch.export.export(gm, tuple_inputs, strict=strict)
  318. gm = ep.module(check_guards=False)
  319. return gm
  320. except Exception as e:
  321. if skip_export_error:
  322. return None
  323. if isinstance(e, UserError) and e.error_type == UserErrorType.INVALID_OUTPUT:
  324. # graph output is not allowed by export when strict=True
  325. return None
  326. if isinstance(e, RuntimeError):
  327. # graph output is not allowed by export when strict=False
  328. pattern = r"Found .* in output, which is not a known type\."
  329. if re.search(pattern, str(e)) is not None:
  330. return None
  331. raise AOTIMinifierError(e) from e
  332. # we should never reach here
  333. # pyrefly: ignore [unreachable]
  334. return None
  335. def repro_minify(
  336. options: Any,
  337. exported_program: ExportedProgram,
  338. config_patches: Optional[dict[str, Any]],
  339. ) -> None:
  340. from functorch.compile import minifier
  341. from torch._inductor import _aoti_compile_and_package_inner
  342. from torch._inductor.compile_fx import _aoti_flatten_inputs
  343. mod, args, kwargs = repro_common(options, exported_program)
  344. # update serialized_in_spec and serialized_out_spec
  345. flat_example_inputs, inductor_configs = _aoti_flatten_inputs(
  346. mod, args, kwargs, options=config_patches
  347. )
  348. compiler_name = "aot_inductor"
  349. assert options.minifier_export_mode in ["dynamo", "python"]
  350. strict = options.minifier_export_mode == "dynamo"
  351. skip_export_error = options.skip_export_error
  352. from torch.cuda import synchronize
  353. need_sync = False
  354. for arg in args:
  355. if isinstance(arg, torch.Tensor) and arg.is_cuda:
  356. need_sync = True
  357. break
  358. def module_fails(
  359. gm: torch.fx.GraphModule,
  360. flat_example_inputs: list[Any],
  361. check_str: Optional[str] = None,
  362. ) -> bool:
  363. # Need to export first so the in_spec and out_spec are populated
  364. tuple_inputs = tuple(flat_example_inputs)
  365. # pyrefly: ignore [bad-assignment]
  366. gm = export_for_aoti_minifier(
  367. gm, tuple_inputs, strict=strict, skip_export_error=skip_export_error
  368. )
  369. # Some graphs cannot be used for AOTI/export (illegal graphs), these should be
  370. # considered as graphs that don't fail in the minifier, so the minifier keeps searching.
  371. if gm is None:
  372. return False
  373. assert isinstance(gm, torch.fx.GraphModule)
  374. try:
  375. _aoti_compile_and_package_inner(
  376. gm,
  377. tuple_inputs,
  378. load_and_run=True,
  379. check_accuracy=options.accuracy,
  380. inductor_configs=inductor_configs,
  381. )
  382. if need_sync:
  383. synchronize() # ensure segfaults are surfaced
  384. return False
  385. except Exception as e:
  386. if check_str is not None and check_str not in repr(e):
  387. return False
  388. return True
  389. minifier(
  390. mod,
  391. flat_example_inputs,
  392. module_fails=functools.partial(module_fails, check_str=options.check_str),
  393. dump_state=functools.partial(
  394. dump_compiler_graph_state,
  395. compiler_name=compiler_name,
  396. config_patches=config_patches,
  397. accuracy=options.accuracy,
  398. strict=strict,
  399. ),
  400. save_dir=options.save_dir,
  401. offload_to_disk=options.offload_to_disk,
  402. skip_offload=options.skip_saving_eager_intermediates,
  403. skip_sanity=options.skip_sanity,
  404. max_granularity=options.max_granularity,
  405. )
  406. def run_repro(
  407. exported_program: ExportedProgram,
  408. *,
  409. config_patches: Optional[dict[str, str]] = None,
  410. command: str = "run",
  411. accuracy: Union[bool, str] = "",
  412. save_dir: Optional[str] = None,
  413. tracing_mode: Optional[str] = None,
  414. check_str: Optional[str] = None,
  415. minifier_export_mode: str = "python",
  416. skip_export_error: bool = True,
  417. **more_kwargs: Any,
  418. ) -> Any:
  419. for k in more_kwargs:
  420. log.warning(
  421. "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch",
  422. k,
  423. )
  424. if accuracy is True:
  425. accuracy = "accuracy"
  426. elif accuracy is False:
  427. accuracy = ""
  428. parser = argparse.ArgumentParser(
  429. description=f"""\
  430. An AOTI repro script, typically triggering a bug in PyTorch AOTInductor.
  431. When run with no arguments, this script defaults to running '{command}'.
  432. Extra flags may be available; to find out more, try '{command} --help'.
  433. There are also alternate subcommands available, see below.
  434. default settings on this script:
  435. {accuracy=}
  436. {tracing_mode=}
  437. {save_dir=}
  438. {check_str=}
  439. """,
  440. formatter_class=argparse.RawTextHelpFormatter,
  441. )
  442. def common_flags(parser: argparse.ArgumentParser) -> None:
  443. accuracy_group = parser.add_mutually_exclusive_group()
  444. accuracy_group.add_argument(
  445. "--no-accuracy",
  446. dest="accuracy",
  447. action="store_const",
  448. const="",
  449. default=accuracy,
  450. help="do not test accuracy, just run the module and see if it errors",
  451. )
  452. accuracy_group.add_argument(
  453. "--accuracy",
  454. action="store_const",
  455. const="accuracy",
  456. default=accuracy,
  457. help="""\
  458. test if the RMSE between the compiled module and the fp64 reference is greater
  459. than eager and the fp64 reference. This is usually more reliable than the
  460. standard allclose test, as we expect numeric differences from compiling, often
  461. improving accuracy over eager. RMSE test allows for compiled module to
  462. diverge greatly from eager, as long as this divergence moves it closer to the
  463. 'true' mathematical value of the network. Caveats: (1) double precision can
  464. still suffer from rounding error, so it is not a perfect reference (see for
  465. example 'Herbie: Automatically Improving Floating Point Accuracy') for
  466. approaches that detect the necessary working precision and compute it in
  467. arbitrary precision floating point; unfortunately, this is not practical for
  468. tensor computation; (2) if there are not enough samples in the output being
  469. compared, we may get unlucky and have an unlucky greater RMSE than eager; this
  470. could be overcome by applying a more rigorous statistical test at some
  471. p-value, which we leave for future work.
  472. """,
  473. )
  474. accuracy_group.add_argument(
  475. "--strict-accuracy",
  476. dest="accuracy",
  477. action="store_const",
  478. const="strict_accuracy",
  479. default=accuracy,
  480. help="""\
  481. by default, when doing accuracy minification we will reject reductions which
  482. change the divergence from a floating point divergence to a integral/boolean
  483. divergence. This is because some operations like ReLU involve temporarily
  484. sharp boundaries that smooth out again afterwards; without requiring
  485. divergence on floating point, the minifier will often fixate on divergent
  486. boolean tensor even though this is not the true source of the divergence.
  487. However, rejecting these reductions makes it more difficult for the minifier
  488. to make process. Using this option will let the minifier progress for ALL
  489. divergences--you just might not end up with a useful repro in the end.""",
  490. )
  491. parser.add_argument(
  492. "--save-dir",
  493. type=str,
  494. default=save_dir,
  495. metavar="DIR",
  496. help="directory where saved inputs live",
  497. )
  498. parser.add_argument(
  499. "--no-save-dir",
  500. dest="save_dir",
  501. action="store_const",
  502. const=None,
  503. help="don't use any directory for saved inputs",
  504. )
  505. subparsers = parser.add_subparsers(
  506. dest="command", metavar="{run,minify}", required=True
  507. )
  508. parser_run = subparsers.add_parser(
  509. "run",
  510. help="just run the repro",
  511. )
  512. common_flags(parser_run)
  513. parser_minify = subparsers.add_parser(
  514. "minify", help="run the minifier on the repro"
  515. )
  516. common_flags(parser_minify)
  517. parser_get_args = subparsers.add_parser("get_args", help="get the args")
  518. common_flags(parser_get_args)
  519. parser_minify.add_argument(
  520. "--skip-saving-eager-intermediates",
  521. action="store_true",
  522. help="skip saving eager intermediates on --minify",
  523. )
  524. parser_minify.add_argument(
  525. "--offload-to-disk",
  526. action="store_true",
  527. help="during minification, offload delta debugging intermediates to disk. Use if you're OOMing",
  528. )
  529. parser_minify.add_argument(
  530. "--skip-sanity",
  531. action="store_true",
  532. help="skip sanity check at beginning of minification on original graph",
  533. )
  534. parser_minify.add_argument(
  535. "--max-granularity",
  536. type=int,
  537. default=None,
  538. help="start at this granularity and work down; must be power of 2",
  539. )
  540. parser_minify.add_argument(
  541. "--check-str",
  542. type=str,
  543. default=check_str,
  544. help="require minified program to fail with error containing this string",
  545. )
  546. parser_minify.add_argument(
  547. "--minifier-export-mode",
  548. type=str,
  549. default=minifier_export_mode,
  550. help=(
  551. "The export mode used in minifier, either dynamo or python."
  552. "`dynamo` corresponds to strict=True, and `python` corresponds to strict=False."
  553. ),
  554. )
  555. parser_minify.add_argument(
  556. "--skip-export-error",
  557. type=bool,
  558. default=skip_export_error,
  559. help="Skip intermediate graphs that cannot be exported.",
  560. )
  561. # Run the repro in the context of minification, inverting exit code meaning
  562. parser_minifier_query = subparsers.add_parser(
  563. "minifier-query",
  564. )
  565. common_flags(parser_minifier_query)
  566. parser_minifier_query.add_argument(
  567. "--check-str",
  568. type=str,
  569. default=check_str,
  570. help="require minified program to fail with error containing this string",
  571. )
  572. args = None
  573. if len(sys.argv) <= 1:
  574. args = [command, *sys.argv[1:]]
  575. options = parser.parse_args(args)
  576. COMMAND_FNS = {
  577. "minify": repro_minify,
  578. "run": repro_run,
  579. "get_args": repro_get_args,
  580. }
  581. return COMMAND_FNS[options.command](
  582. options, exported_program, config_patches=config_patches
  583. )