debug_utils.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943
  1. """
  2. Debug utilities for TorchDynamo compilation and execution.
  3. This module provides various debugging tools and utilities for TorchDynamo, including:
  4. - Minification support for reducing test cases while preserving bugs
  5. - Input/output handling via InputReader and InputWriter for reproducible testing
  6. - Accuracy checking between original and compiled models
  7. - Neural network module string conversion via NNModuleToString
  8. - Profiling tools and system information collection
  9. - Buck build system integration for Meta-internal testing
  10. Key classes:
  11. - InputReader/InputWriter: Handle serialization of model inputs/outputs
  12. - NNModuleToString: Converts nn.Modules to string representations
  13. - BuckTargetWriter: Manages Buck build system integration
  14. """
  15. from __future__ import annotations
  16. import atexit
  17. import copy
  18. import cProfile
  19. import functools
  20. import getpass
  21. import inspect
  22. import itertools
  23. import logging
  24. import os
  25. import re
  26. import subprocess
  27. import sys
  28. import tempfile
  29. import textwrap
  30. from collections import Counter
  31. from importlib import import_module
  32. from typing import Any, Optional, TYPE_CHECKING, TypeVar
  33. import torch
  34. import torch._prims_common as utils
  35. import torch._subclasses.meta_utils
  36. from torch import Tensor
  37. from torch._dynamo.testing import rand_strided
  38. from torch._inductor.cpp_builder import normalize_path_separator
  39. from torch._prims_common import is_float_dtype
  40. from torch.multiprocessing.reductions import StorageWeakRef
  41. from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
  42. from . import config
  43. from .utils import clone_inputs, get_debug_dir, warn_once
  44. if TYPE_CHECKING:
  45. from collections.abc import Callable, Sequence
  46. from torch.hub import tqdm
  47. from torch.storage import UntypedStorage
  48. log = logging.getLogger(__name__)
  49. T = TypeVar("T")
  50. inductor_config = import_module("torch._inductor.config")
  51. use_buck = inductor_config.is_fbcode()
  52. if use_buck:
  53. import libfb.py.build_info
  54. # pyrefly: ignore [implicit-any]
  55. extra_deps = []
  56. extra_imports = ""
  57. cur_target = ""
  58. if use_buck:
  59. extra_deps = [
  60. "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu",
  61. "//caffe2/torch/fb/sparsenn:sparsenn_operators",
  62. "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu",
  63. "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops",
  64. ]
  65. cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") # type: ignore[possibly-undefined]
  66. extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps])
  67. BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"]
  68. class BuckTargetWriter:
  69. def __init__(self, filename: str) -> None:
  70. self.subdir, self.py_file = os.path.split(os.path.abspath(filename))
  71. self.target = self.py_file.replace(".py", "")
  72. # Get main_module path from fbcode
  73. self.path = f"{self.subdir.replace('/', '.')}.{self.target}"
  74. self.path = self.path[self.path.find("fbcode.") :]
  75. self.path = self.path[7:]
  76. # Get cmd line path
  77. tmp = self.subdir
  78. tmp = tmp[tmp.find("fbcode/") :][7:]
  79. self.cmd_line_path = f"//{tmp}:{self.target}"
  80. def build(self) -> str:
  81. extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps])
  82. return textwrap.dedent(
  83. f"""
  84. load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
  85. python_binary(
  86. name="{self.target}",
  87. srcs = ["{self.py_file}"],
  88. compile = False,
  89. deps = [
  90. "//caffe2:torch",
  91. "//caffe2:libtorch",
  92. "//caffe2/functorch:functorch",
  93. "//triton:triton",
  94. "{cur_target}",
  95. ],
  96. cpp_deps = [
  97. {extra_cpp_deps}
  98. ],
  99. main_module = "{self.path}",
  100. par_style = "xar",
  101. )
  102. """
  103. )
  104. def write(self, print_msg: bool = True) -> list[str]:
  105. target_file = os.path.join(self.subdir, "TARGETS")
  106. with open(target_file, "w") as fd:
  107. fd.write(self.build())
  108. # log.warning("Wrote isolation TARGETS file at %s", target_file)
  109. cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path]
  110. if print_msg:
  111. log.warning(
  112. "Found an example that reproduces the error. Run this cmd to repro - %s",
  113. " ".join(cmd_split),
  114. )
  115. return cmd_split
  116. def minifier_dir() -> str:
  117. path = os.path.join(get_debug_dir(), "minifier")
  118. if path is None:
  119. path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}"
  120. if not os.path.exists(path):
  121. os.makedirs(path, exist_ok=True)
  122. return path
  123. MAX_CONSTANT_NUMEL_INLINE = 4
  124. class NNModuleToString:
  125. safe_reprs = [
  126. torch.nn.Linear,
  127. torch.nn.Conv1d,
  128. torch.nn.Conv2d,
  129. torch.nn.Conv3d,
  130. torch.nn.BatchNorm1d,
  131. torch.nn.BatchNorm2d,
  132. torch.nn.BatchNorm3d,
  133. torch.nn.LayerNorm,
  134. torch.nn.Dropout,
  135. torch.nn.Softmax,
  136. torch.nn.ReLU,
  137. torch.nn.GELU,
  138. torch.nn.Identity,
  139. torch.nn.MaxPool2d,
  140. torch.nn.Embedding,
  141. torch.nn.Tanh,
  142. torch.nn.ConvTranspose1d,
  143. torch.nn.GLU,
  144. torch.nn.LSTM,
  145. torch.nn.Flatten,
  146. torch.nn.AdaptiveAvgPool2d,
  147. ]
  148. @staticmethod
  149. def can_convert_to_string(gm: torch.fx.GraphModule) -> bool:
  150. cant_convert = set()
  151. for _, module in gm.named_children():
  152. if type(module) not in NNModuleToString.safe_reprs:
  153. cant_convert.add(module)
  154. if len(cant_convert) > 0:
  155. log.warning("We have not tested reprs of some modules - %s", cant_convert)
  156. # TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct.
  157. return True
  158. @staticmethod
  159. def convert(gm: torch.fx.GraphModule) -> str:
  160. from torch.nn.modules.module import _addindent
  161. tab = " " * 4
  162. model_str = textwrap.dedent(
  163. """
  164. from torch.nn import *
  165. class Repro(torch.nn.Module):
  166. def __init__(self) -> None:
  167. super().__init__()
  168. """
  169. )
  170. for module_name, module in gm.named_children():
  171. module_str = f"{module.__repr__()}"
  172. # module should be a core torch.nn.Module, so all parameters
  173. # should be on the same device.
  174. example_param = next(module.parameters(), None)
  175. if example_param is not None and example_param.is_cuda:
  176. module_str = f"{module_str}.cuda()"
  177. model_str += f"{tab * 2}self.{module_name} = {module_str}\n"
  178. for buffer_name, buffer in gm._buffers.items():
  179. if buffer is None:
  180. continue
  181. # Serialize full data for small buffers
  182. if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE:
  183. from torch._tensor_str import PRINT_OPTS
  184. assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE
  185. tensor_str = repr(buffer)
  186. elif torch.is_floating_point(buffer):
  187. tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})"
  188. else:
  189. tensor_str = (
  190. f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})"
  191. )
  192. if buffer.is_cuda:
  193. tensor_str = f"{tensor_str}.cuda()"
  194. model_str += (
  195. f"{tab * 2}self.register_buffer('{buffer_name}', {tensor_str})\n"
  196. )
  197. for param_name, param in gm._parameters.items():
  198. if param is None:
  199. continue
  200. maybe_device = ""
  201. if param.is_cuda:
  202. maybe_device = ', device="cuda"'
  203. tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))"
  204. model_str += f"{tab * 2}self.{param_name} = {tensor_str}\n"
  205. # TODO - Keep this code for now. But, I don't think we will need this.
  206. # attrs = dir(gm)
  207. # for attr in attrs:
  208. # if "_tensor_constant" in attr:
  209. # val = getattr(gm, attr)
  210. # model_str += f" {attr} = {val!r}\n"
  211. model_str += f"{_addindent(gm.code, 4)}\n"
  212. return model_str
  213. @functools.cache # subprocess is expensive
  214. def _cuda_system_info_comment() -> str:
  215. if not torch.cuda.is_available():
  216. return "# torch.cuda.is_available()==False, no GPU info collected\n"
  217. model_str = "# CUDA Info: \n"
  218. try:
  219. if torch.version.hip is None:
  220. cuda_version_out = subprocess.check_output(["nvcc", "--version"])
  221. cuda_version_lines = cuda_version_out.decode().split("\n")
  222. comment = "".join([f"# {s} \n" for s in cuda_version_lines if s != ""])
  223. model_str += f"{comment}\n"
  224. else:
  225. model_str += "# Not searching for nvcc on ROCM setup\n"
  226. except (FileNotFoundError, subprocess.CalledProcessError):
  227. model_str += "# nvcc not found\n"
  228. gpu_names = Counter(
  229. torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())
  230. )
  231. model_str += "# GPU Hardware Info: \n"
  232. for name, count in gpu_names.items():
  233. model_str += f"# {name} : {count} \n"
  234. model_str += "\n"
  235. return model_str
  236. def generate_env_vars_string(*, stable_output: bool = False) -> str:
  237. """
  238. Generate a string configuration for environment variables related to Dynamo, Inductor, and Triton.
  239. """
  240. if stable_output:
  241. return "# env var omitted due to stable_output=True"
  242. allow_list = ["TORCH", "DYNAMO", "INDUCTOR", "TRITON"]
  243. skip_list = ["TRITON_LIBDEVICE_PATH", "TRITON_PTXAS_PATH", "TRITON_LIBCUDA_PATH"]
  244. def filter(key: str) -> bool:
  245. return any(string in key for string in allow_list) and key not in skip_list
  246. config_lines = [
  247. f"""os.environ['{key}'] = '{value.replace("'", '"')}'"""
  248. for key, value in os.environ.items()
  249. if filter(key)
  250. ]
  251. config_string = "\n".join(config_lines)
  252. return normalize_path_separator(f"""\
  253. import os
  254. {config_string}
  255. """)
  256. def generate_config_string(*, stable_output: bool = False) -> str:
  257. import torch._functorch.config
  258. import torch._inductor.config
  259. if stable_output:
  260. return "# config omitted due to stable_output=True"
  261. experimental_config = torch.fx.experimental._config.codegen_config() # type: ignore[attr-defined]
  262. return f"""\
  263. import torch._dynamo.config
  264. import torch._inductor.config
  265. import torch._functorch.config
  266. import torch.fx.experimental._config
  267. {torch._dynamo.config.codegen_config()}
  268. {torch._inductor.config.codegen_config()}
  269. {torch._functorch.config.codegen_config()}
  270. {experimental_config}
  271. """
  272. def get_minifier_repro_path() -> str:
  273. return os.path.join(minifier_dir(), "minifier_launcher.py")
  274. def helper_for_dump_minify(contents: str) -> None:
  275. minified_repro_path = get_minifier_repro_path()
  276. log.warning("Writing minified repro to:\n%s", minified_repro_path)
  277. if use_buck:
  278. BuckTargetWriter(minified_repro_path).write()
  279. try:
  280. with open(minified_repro_path, "w") as fd:
  281. fd.write(contents)
  282. except OSError as e:
  283. log.exception("")
  284. raise NotImplementedError(f"Could not write to {minified_repro_path}") from e
  285. class AccuracyError(Exception):
  286. pass
  287. def clone_inputs_retaining_gradness(example_inputs: Sequence[Any]) -> list[Any]:
  288. """
  289. This clone inputs is different from utils clone_input. In case of minifier,
  290. all the tensors are leaf tensors while creating a new graph. So, we set the
  291. requires_grad field w/o checking the leafness of the tensor.
  292. """
  293. cloned_inputs = clone_inputs(example_inputs)
  294. for idx in range(len(example_inputs)):
  295. if isinstance(cloned_inputs[idx], torch.Tensor):
  296. cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad)
  297. return cloned_inputs # type: ignore[return-value]
  298. def run_fwd_maybe_bwd(
  299. gm: torch.fx.GraphModule,
  300. args: Sequence[Any],
  301. only_fwd: bool = False,
  302. disable_clone: bool = False,
  303. ) -> Any:
  304. """
  305. Runs a forward and possibly backward iteration for a given mod and args.
  306. When disable_clone is True, we will use args as-is without cloning.
  307. This is higher fidelity but we may destroy the args in the process.
  308. """
  309. from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass
  310. gm = copy.deepcopy(gm)
  311. if not disable_clone:
  312. args = clone_inputs_retaining_gradness(args)
  313. if hasattr(gm, "zero_grad"):
  314. gm.zero_grad(True)
  315. # TorchInductor returned callable expects lists. So, may need a boxed calling convention.
  316. out = gm(args) if getattr(gm, "_boxed_call", False) else gm(*args)
  317. if only_fwd:
  318. return out
  319. if requires_bwd_pass(out):
  320. loss = reduce_to_scalar_loss(out)
  321. loss.backward()
  322. return collect_results(gm, out, None, args)
  323. def same_two_models(
  324. gm: torch.fx.GraphModule,
  325. opt_gm: torch.fx.GraphModule,
  326. example_inputs: Sequence[Any],
  327. only_fwd: bool = False,
  328. *,
  329. require_fp64: bool = False,
  330. ignore_non_fp: bool = False,
  331. ) -> bool:
  332. """
  333. Check two models have same accuracy.
  334. require_fp64: if True, raise an error if we unable to calculate the fp64 reference
  335. ignore_non_fp: if True, do not compare outputs which are not floating point. This
  336. is mostly useful for the minifier (which wants to avoid quantizing floating point
  337. error into integer/boolean error)
  338. """
  339. from .utils import same
  340. ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)
  341. fp64_ref = None
  342. if config.same_two_models_use_fp64:
  343. try:
  344. fp64_model, fp64_examples = cast_to_fp64(
  345. copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
  346. )
  347. fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd)
  348. except Exception:
  349. if require_fp64:
  350. raise RuntimeError( # noqa: B904
  351. "Could not generate fp64 outputs, workaround with torch._dynamo.config.same_two_models_use_fp64 = False"
  352. )
  353. log.warning("Could not generate fp64 outputs")
  354. try:
  355. res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd)
  356. except Exception:
  357. # This means that the minified graph is bad/exposes a different problem.
  358. # As we are checking accuracy here, lets log the exception and return True.
  359. log.exception(
  360. "While minifying the program in accuracy minification mode, "
  361. "ran into a runtime exception which is likely an unrelated issue."
  362. " Skipping this graph."
  363. )
  364. return True
  365. passing = same(
  366. ref,
  367. res,
  368. fp64_ref,
  369. tol=config.repro_tolerance,
  370. equal_nan=True,
  371. ignore_non_fp=ignore_non_fp,
  372. )
  373. return passing
  374. def cast_dtype_args_to_fp64(model: torch.fx.GraphModule) -> torch.fx.GraphModule:
  375. for node in model.graph.nodes:
  376. if (
  377. node.op == "call_function"
  378. and node.target is torch.ops.prims.convert_element_type.default
  379. ):
  380. assert len(node.args) == 2
  381. if is_float_dtype(node.args[1]) and node.args[1] != torch.float64:
  382. node.args = (node.args[0], torch.float64)
  383. if node.op == "call_function":
  384. dtype = node.kwargs.get("dtype")
  385. if dtype is not None and is_float_dtype(dtype):
  386. new_kwargs = dict(node.kwargs)
  387. new_kwargs["dtype"] = torch.float64
  388. node.kwargs = new_kwargs
  389. model.graph.lint()
  390. model.recompile()
  391. return model
  392. def cast_to(
  393. dtype: torch.dtype, model: torch.fx.GraphModule, inputs: list[Any]
  394. ) -> tuple[torch.fx.GraphModule, list[Any]]:
  395. from torch.utils._pytree import tree_map
  396. model = model.to(dtype)
  397. if dtype == torch.float64:
  398. # If casting to fp64 for accuracy comparison, we need to
  399. # replace dtype arguments embedded in the graph with fp64
  400. model = cast_dtype_args_to_fp64(model)
  401. inputs = tree_map(
  402. lambda x: x.to(dtype)
  403. if isinstance(x, torch.Tensor) and x.is_floating_point()
  404. else x,
  405. inputs,
  406. )
  407. return model, inputs
  408. def cast_to_fp64(
  409. model: torch.fx.GraphModule, inputs: list[Any]
  410. ) -> tuple[torch.fx.GraphModule, list[Any]]:
  411. return cast_to(torch.float64, model, inputs)
  412. def backend_accuracy_fails(
  413. gm: torch.fx.GraphModule,
  414. example_inputs: Sequence[Any],
  415. compiler_fn: Callable[[torch.fx.GraphModule, list[Any]], torch.fx.GraphModule],
  416. only_fwd: bool = False,
  417. *,
  418. require_fp64: bool = False,
  419. ignore_non_fp: bool = False,
  420. ) -> bool:
  421. try:
  422. compiled_gm = compiler_fn(
  423. copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
  424. )
  425. return not same_two_models(
  426. gm,
  427. compiled_gm,
  428. example_inputs,
  429. only_fwd,
  430. require_fp64=require_fp64,
  431. ignore_non_fp=ignore_non_fp,
  432. )
  433. except Exception:
  434. # This means that the minified graph is bad/exposes a different problem.
  435. # As we are checking accuracy here, lets log the exception and return False.
  436. log.exception(
  437. "While minifying the program in accuracy minification mode, "
  438. "ran into a runtime exception which is likely an unrelated issue."
  439. " Skipping this graph"
  440. )
  441. return False
  442. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  443. # REPRO SUPPORT CODE
  444. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  445. # Helper functions for computing what the default values of tensor
  446. # values should be. These all coincide with factory functions, e.g., torch.empty
  447. def _stride_or_default(
  448. stride: Optional[torch._prims_common.StrideType],
  449. *,
  450. shape: torch._prims_common.ShapeType,
  451. ) -> torch._prims_common.StrideType:
  452. return stride if stride is not None else utils.make_contiguous_strides_for(shape)
  453. def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]:
  454. return lambda x: x if x is not None else d
  455. _dtype_or_default = _mk_defaulter(torch.float32)
  456. _device_or_default = _mk_defaulter(torch.device("cpu"))
  457. _storage_offset_or_default = _mk_defaulter(0)
  458. _requires_grad_or_default = _mk_defaulter(False)
  459. _is_leaf_or_default = _mk_defaulter(False)
  460. class NopInputReader:
  461. def __init__(self) -> None:
  462. self.total = 0
  463. def storage(
  464. self,
  465. storage_hash: Optional[str],
  466. nbytes: int,
  467. *,
  468. device: Optional[torch._prims_common.DeviceLikeType] = None,
  469. dtype_hint: Optional[torch.dtype] = None,
  470. ) -> None:
  471. self.total += 1
  472. def tensor(self, *args: Any, **kwargs: Any) -> Optional[torch.Tensor]:
  473. pass
  474. def symint(self, *args: Any, **kwargs: Any) -> Optional[int]:
  475. pass
  476. # TODO: Support bundling the entire repro into a zip file for ease of
  477. # transferring around
  478. class InputReader:
  479. def __init__(
  480. self, save_dir: str | None = None, *, pbar: tqdm | None = None
  481. ) -> None:
  482. # If None, we will generate random data instead. It's important
  483. # to natively support this use case as it will allow people to
  484. # share repros without including the real data, if the problem
  485. # reproduces even on random data.
  486. if save_dir is None:
  487. log.warning("no save_dir specified, will generate random data")
  488. self.store = ContentStoreReader(save_dir) if save_dir is not None else None
  489. self.args: list[Any] = []
  490. self.pbar = pbar
  491. def storage(
  492. self,
  493. storage_hash: Optional[str],
  494. nbytes: int,
  495. *,
  496. device: Optional[torch._prims_common.DeviceLikeType] = None,
  497. dtype_hint: Optional[torch.dtype] = None,
  498. ) -> UntypedStorage:
  499. if self.pbar is not None:
  500. self.pbar.update(1)
  501. device = _device_or_default(device) # type: ignore[arg-type]
  502. dtype_hint = _dtype_or_default(dtype_hint)
  503. if self.store is not None and storage_hash is not None:
  504. try:
  505. storage = self.store.read_storage(storage_hash)
  506. except FileNotFoundError:
  507. pass
  508. else:
  509. if device != storage.device:
  510. log.warning("device mismatch: %s != %s", device, storage.device)
  511. # TODO: transfer it to the right device? But failing this
  512. # way would be very mysterious! Would have been better
  513. # not to store device in the serialized format...
  514. return storage
  515. warn_once(f"could not load {storage_hash}, generating random data instead")
  516. shape = (nbytes // dtype_hint.itemsize,)
  517. stride = _stride_or_default(None, shape=shape)
  518. return rand_strided(shape, stride, dtype_hint, device).untyped_storage()
  519. def tensor(
  520. self,
  521. storage: UntypedStorage,
  522. shape: torch._prims_common.ShapeType,
  523. stride: Optional[torch._prims_common.StrideType] = None,
  524. *,
  525. storage_offset: Optional[int] = None,
  526. dtype: Optional[torch.dtype] = None,
  527. requires_grad: Optional[bool] = None,
  528. is_leaf: Optional[bool] = None,
  529. **metadata: Any,
  530. ) -> torch.Tensor:
  531. stride = _stride_or_default(stride, shape=shape)
  532. storage_offset = _storage_offset_or_default(storage_offset)
  533. dtype = _dtype_or_default(dtype)
  534. is_leaf = _is_leaf_or_default(is_leaf)
  535. requires_grad = _requires_grad_or_default(requires_grad)
  536. t = torch.tensor(
  537. [], dtype=dtype, device=storage.device, requires_grad=requires_grad
  538. )
  539. with torch.no_grad():
  540. t.set_(storage, storage_offset, shape, stride)
  541. if not is_leaf:
  542. # Fake up some autograd history in a very naughty way
  543. with torch.enable_grad():
  544. t = t.clone(memory_format=torch.preserve_format)
  545. with torch.no_grad():
  546. t.set_(storage, storage_offset, shape, stride)
  547. assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf
  548. torch._utils.set_tensor_metadata(t, metadata)
  549. self.args.append(t)
  550. return t # for BC
  551. def symint(self, val: Any) -> Any:
  552. self.args.append(val)
  553. return val # for BC
  554. # Here is our writer strategy:
  555. # 1. We will stream all of the inputs to disk
  556. # 2. You can now deterministically randomize the inputs, or reload
  557. # the inputs from disk
  558. # 3. You can YOLO run the script without the inputs, in which case
  559. # we'll fill the inputs with random data and pray. This is the
  560. # legacy behavior, but it's also useful if you want to find out
  561. # if we're so broken even random inputs trigger it
  562. # 4. We could offer an in process "check if the randomized thing
  563. # works too" but this is delicate so we don't do it
  564. class InputWriter:
  565. def __init__(self, save_dir: Optional[str], *, stable_hash: bool = False) -> None:
  566. self._lines: list[str] = []
  567. # TODO: consider ensuring tensor and storage counters line up?
  568. self.storage_counter = itertools.count()
  569. self.save_dir = save_dir
  570. self.store = (
  571. ContentStoreWriter(save_dir, stable_hash=stable_hash)
  572. if save_dir is not None
  573. else None
  574. )
  575. self.seen_storages: dict[StorageWeakRef, str] = {}
  576. def lines(self) -> list[str]:
  577. r = [
  578. "def load_args(reader):",
  579. ]
  580. r.extend(f" {l}" for l in self._lines)
  581. # In case we need to change the internal format of load_args
  582. # in an FC-breaking way
  583. r.append("load_args._version = 0")
  584. return r
  585. # Storages are untyped, but we need to initialize them with data if
  586. # we don't have the real data, so we give a hint saying what kind
  587. # of initialization may be appropriate
  588. #
  589. # If we had a FakeTensor, device_hint tells us what device should be
  590. def storage(
  591. self,
  592. untyped_storage: UntypedStorage,
  593. *,
  594. device_hint: Optional[torch._prims_common.DeviceLikeType] = None,
  595. dtype_hint: Optional[torch.dtype] = None,
  596. ) -> str:
  597. ws = StorageWeakRef(untyped_storage)
  598. v = self.seen_storages.get(ws)
  599. if v is not None:
  600. return v
  601. v = f"buf{next(self.storage_counter)}"
  602. maybe_dtype_hint = ""
  603. if _dtype_or_default(None) != _dtype_or_default(dtype_hint):
  604. maybe_dtype_hint = f", dtype_hint={dtype_hint!r}"
  605. # TODO: being optional on device is kind of pointless as the default
  606. # is CPU but most repros we care about are CUDA
  607. maybe_device = ""
  608. device = untyped_storage.device
  609. if device.type == "meta":
  610. assert device_hint is not None
  611. device = device_hint # type: ignore[assignment]
  612. if _device_or_default(None) != device:
  613. maybe_device = f", device={device!r}"
  614. nbytes = untyped_storage.nbytes()
  615. storage_hash = None
  616. if self.store is not None and untyped_storage.device.type != "meta":
  617. storage_hash = self.store.write_storage(untyped_storage)
  618. self._lines.append(
  619. f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})"
  620. )
  621. self.seen_storages[ws] = v
  622. return v
  623. def tensor(self, name: str, t: torch.Tensor) -> None:
  624. from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
  625. storage = self.storage(
  626. t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device
  627. )
  628. args = []
  629. # NB: this is positional, must come first
  630. if not statically_known_true(
  631. sym_eq(_stride_or_default(None, shape=t.shape), t.stride())
  632. ):
  633. args.append(str(tuple(t.stride())))
  634. if _dtype_or_default(None) != t.dtype:
  635. args.append(f"dtype={t.dtype!r}")
  636. if not statically_known_true(
  637. _storage_offset_or_default(None) == t.storage_offset()
  638. ):
  639. args.append(f"storage_offset={t.storage_offset()!r}")
  640. tensor_metadata = torch._utils.get_tensor_metadata(t)
  641. if tensor_metadata:
  642. args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items())
  643. if _requires_grad_or_default(None) != t.requires_grad:
  644. args.append(f"requires_grad={t.requires_grad!r}")
  645. is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t)
  646. if _is_leaf_or_default(None) != is_leaf:
  647. args.append(f"is_leaf={is_leaf!r}")
  648. self._lines.append(
  649. "reader.tensor("
  650. + ", ".join([storage, str(tuple(t.shape)), *args])
  651. + f") # {name}"
  652. )
  653. def unsupported(self, name: str, arg: Any) -> None:
  654. # NB: Try hard not to /print/ a tensor, that will be very slow
  655. self._lines.append(f"# {name} was unsupported type for dumping: {type(arg)}")
  656. # Best effort dump as much useful stuff we can lol, in case you want
  657. # to repair the repro
  658. if isinstance(arg, (list, tuple)):
  659. self._lines.append('"""')
  660. for i, a in enumerate(arg):
  661. name_i = f"{name}[{i}]"
  662. if isinstance(a, torch.Tensor):
  663. self.tensor(name_i, a)
  664. elif isinstance(a, (int, torch.SymInt)):
  665. self.symint(name_i, a)
  666. else:
  667. self.unsupported(name_i, a)
  668. self._lines.append('"""')
  669. # write out that the arg was filtered out as it is constant
  670. def const(self, name: str) -> None:
  671. self._lines.append(
  672. f"reader.const({name!r}) # {name}, filtered out during compilation"
  673. )
  674. # TODO: this doesn't actually symint atm
  675. def symint(self, name: str, val: Any) -> None:
  676. if isinstance(val, torch.SymInt):
  677. val = val.node.hint
  678. self._lines.append(f"reader.symint({val!r}) # {name}")
  679. def aot_graph_input_parser(
  680. func: Callable[[list[Tensor]], list[Tensor]],
  681. device: str = "cuda",
  682. sym_shapes: Optional[dict[str, int]] = None,
  683. default_sym_shape: Optional[int] = None,
  684. ) -> dict[str, Any]:
  685. """
  686. Takes in a function which has been printed with print_readable() and constructs kwargs to run it.
  687. Handles Tensor inputs, Symints, and a graph module which might have tensor constants.
  688. Consider a function `forward` defined as follows:
  689. def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "Sym(s0)",):
  690. _tensor_constant0: "i64[4190]" = self._tensor_constant0
  691. # Further implementation
  692. kwargs = aot_graph_input_parser(forward)
  693. forward(**kwargs)
  694. """
  695. from torch.utils._dtype_abbrs import dtype_abbrs
  696. dtype_map: dict[str, torch.dtype] = {
  697. value: key for key, value in dtype_abbrs.items()
  698. }
  699. dtype_pattern: str = "|".join(dtype_abbrs.values())
  700. # Extracting the source code from the function
  701. source = inspect.getsource(func)
  702. # Regular expressions
  703. tensor_assignment_regex = rf"(_tensor_constant\d+): \"({dtype_pattern})\[\s*(.*?)\s*\]\" = self\.(_tensor_constant\d+)"
  704. tensor_regex = rf"({dtype_pattern})\[\s*(.*?)\s*\]"
  705. sym_shape_regex = r"Sym\((s\d+)\)"
  706. class TensorContainer:
  707. "Container for tensors as attributes"
  708. # Dictionary for tensors from annotations
  709. kwargs: dict[str, Any] = {}
  710. sym_shapes_dict: dict[str, int] = sym_shapes or {}
  711. def get_sym_int(symint: str) -> int:
  712. torch._check(
  713. symint in sym_shapes_dict or default_sym_shape is not None,
  714. lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in",
  715. )
  716. return sym_shapes_dict.get(symint, default_sym_shape) # type: ignore[return-value]
  717. def gen_tensor(shape: torch._prims_common.ShapeType, dtype: torch.dtype) -> Tensor:
  718. # Resolve symbolic shapes to concrete values
  719. resolved_shape = []
  720. dynamic_dims = []
  721. for i, dim in enumerate(shape):
  722. dim = dim.strip() # type: ignore[attr-defined]
  723. if "s" in dim:
  724. s = get_sym_int(dim)
  725. resolved_shape.append(s)
  726. dynamic_dims.append(i)
  727. else:
  728. if dim:
  729. resolved_shape.append(int(dim))
  730. constructor = torch.randn if dtype.is_floating_point else torch.zeros
  731. out = constructor(resolved_shape, dtype=dtype, device=device) # type: ignore[call-arg]
  732. for d in dynamic_dims:
  733. torch._dynamo.mark_dynamic(out, d)
  734. return out
  735. # Parse function annotations for tensor generation
  736. annotations = func.__annotations__
  737. for param, annotation in annotations.items():
  738. # Skip 'return' annotation
  739. if param == "return":
  740. continue
  741. match = re.search(tensor_regex, annotation)
  742. if match:
  743. data_type, shape_str = match.groups()
  744. shape = tuple(shape_str.split(","))
  745. dtype = dtype_map[data_type]
  746. # pyrefly: ignore [bad-argument-type]
  747. kwargs[param] = gen_tensor(shape, dtype)
  748. match = re.search(sym_shape_regex, annotation)
  749. if match:
  750. kwargs[param] = get_sym_int(match.group(1))
  751. if "self" in inspect.signature(func).parameters:
  752. container = TensorContainer()
  753. kwargs["self"] = container
  754. for match in re.finditer(tensor_assignment_regex, source):
  755. attr_name, data_type, shape_str, _ = match.groups()
  756. shape = tuple(shape_str.split(","))
  757. dtype = dtype_map[data_type]
  758. # pyrefly: ignore [bad-argument-type]
  759. setattr(container, attr_name, gen_tensor(shape, dtype))
  760. return kwargs
  761. def profile_to_file(filename: str) -> Callable[[T], T]:
  762. """
  763. Decorator to cProfile a given function and save the result to disk on process exit.
  764. Args:
  765. filename: filename to save profile to
  766. """
  767. prof = cProfile.Profile()
  768. filename = os.path.abspath(os.path.expanduser(filename))
  769. def decorator(fn: Any) -> Any:
  770. @functools.wraps(fn)
  771. def wrapper(*args: Any, **kwargs: Any) -> Any:
  772. prof.enable()
  773. try:
  774. return fn(*args, **kwargs)
  775. finally:
  776. prof.disable()
  777. return wrapper
  778. def save_it() -> None:
  779. prof.dump_stats(filename)
  780. sys.stderr.write(
  781. textwrap.dedent(
  782. f"""\
  783. Wrote profile to {filename}, view with:
  784. snakeviz {filename}
  785. """
  786. )
  787. )
  788. atexit.register(save_it)
  789. return decorator