_debug_mode.py 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596
  1. # mypy: allow-untyped-defs
  2. """
  3. DebugMode is a debugging TorchDispatchMode that intercepts and logs runtime calls
  4. to a hierarchical string dump. It logs real tensor, DTensor, and optionally FakeTensor
  5. operations, with some additional handling for DTensor internals.
  6. An example dump from an eager mode DTensor matmul:
  7. torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$6: f32[8, 32]| S(0)
  8. aten::mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0))
  9. redistribute_input(1, S(0) -> R)
  10. redistribute_input(t$2: f32[1, 32], trace: S(0)->R)
  11. _c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32]
  12. _c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32]
  13. aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32]
  14. This mode runs "under" compile, which means it hides itself during compilation, and is re-enabled
  15. at runtime, and DebugMode-related operations won't show up in the compiled region.
  16. DebugMode also provides some visibility into non-torch-dispatch calls (e.g. DTensor redistribute calls,
  17. inductor-generated triton kernels), but requires special handling for these, since dispatch modes
  18. can't intercept them by default.
  19. The mode also provides some extensions for custom debugging (e.g. adding custom dispatch call hooks
  20. via dispatch_hooks), or numerics debugging (e.g. tensor hashing for bitwise equivalence/closeness,
  21. via log_tensor_hashes). These decorators allow annotating string dumps with additional per-call information,
  22. for any region of runtime code.
  23. Usage::
  24. with DebugMode() as debug_mode:
  25. result = some_pytorch_operation(tensor_input)
  26. print(debug_mode.debug_string())
  27. """
  28. import contextlib
  29. import functools
  30. import inspect
  31. import logging
  32. import os
  33. import traceback
  34. import weakref
  35. from collections.abc import Callable
  36. from typing import Any, TYPE_CHECKING
  37. import torch
  38. from torch._logging import warning_once
  39. from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
  40. from torch.fx.graph import _parse_stack_trace
  41. from torch.utils._dtype_abbrs import dtype_abbrs
  42. from torch.utils._python_dispatch import (
  43. _get_current_dispatch_mode,
  44. _get_current_dispatch_mode_stack,
  45. TorchDispatchMode,
  46. )
  47. from torch.utils._pytree import (
  48. keystr,
  49. tree_all,
  50. tree_map,
  51. tree_map_only,
  52. tree_map_with_path,
  53. )
  54. from torch.utils._traceback import CapturedTraceback
  55. from torch.utils.weak import WeakIdRef
  56. if TYPE_CHECKING:
  57. from torch._dynamo.device_interface import DeviceInterface
  58. from torch.distributed._tools.mod_tracker import ModTracker
  59. log = logging.getLogger(__name__)
  60. __all__ = ["DebugMode", "get_active_debug_mode"]
  61. REDISTRIBUTE_FUNC = "redistribute_input"
  62. # registered dispatch call hooks
  63. _DISPATCH_RECORD_HOOKS: list[Callable] = []
  64. _DISPATCH_LOG_HOOKS: list[Callable] = []
  65. _DISPATCH_PRE_LOG_HOOKS: list[Callable] = []
  66. # Tracks if we're in inductor benchmarking, and temporarily disables logging
  67. # (for ignoring autotuning kernel launches which don't affect the user-facing result)
  68. _IN_INDUCTOR_BENCHMARK = False
  69. # For record_outputs, log_tensor_hashes hooks for triton kernels.
  70. # Stores kernel outputs in call.record["output"]
  71. _RECORD_TRITON_OUTPUTS = False
  72. # Annotates kernel output hashes, and stores them in call.post_hashes
  73. _TRITON_OUTPUT_HASH_FN = None
  74. # Annotates kernel input hashes, and stores them in call.pre_hashes
  75. _TRITON_INPUT_HASH_FN = None
  76. # Counter for active DebugMode instances (fast path for get_active_debug_mode)
  77. _ACTIVE_DEBUG_MODE_COUNT = 0
  78. def _stringify_shape(shape) -> str:
  79. return f"[{', '.join([str(x) for x in shape])}]"
  80. def _stringify_device_mesh(mesh) -> str:
  81. return f"DM({', '.join([str(s) for s in mesh.shape])})"
  82. def _stringify_placement(placement) -> str:
  83. return f"[{', '.join([str(p) for p in placement])}]"
  84. def _stringify_attributes(tensor, attributes) -> str:
  85. pairs = {}
  86. for attr in attributes:
  87. if hasattr(tensor, attr):
  88. pairs[attr] = getattr(tensor, attr)
  89. if len(pairs) == 0:
  90. return ""
  91. return f"{{{', '.join([f'{k}={v}' for k, v in pairs.items()])}}}"
  92. def _stringify_dtensor_spec(spec) -> str:
  93. from torch.distributed.tensor._dtensor_spec import DTensorSpec
  94. return DTensorSpec.format_shard_order_str(spec.placements, spec.shard_order)
  95. class TensorIdTracker:
  96. def __init__(self) -> None:
  97. self.tensor_memo: dict[WeakIdRef, int] = {}
  98. self.next_tensor_id = 0
  99. def _id(self, tensor) -> int:
  100. with torch._C._DisablePythonDispatcher():
  101. o = WeakIdRef(tensor)
  102. def del_memo() -> None:
  103. self.tensor_memo.pop(o, None)
  104. weakref.finalize(tensor, del_memo)
  105. if o not in self.tensor_memo:
  106. self.tensor_memo[o] = self.next_tensor_id
  107. self.next_tensor_id += 1
  108. return self.tensor_memo[o]
  109. def _tensor_debug_string(tensor, attributes, tensor_memo=None) -> str:
  110. """Convert tensor to debug string representation."""
  111. if isinstance(tensor, torch.Tensor):
  112. tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}"
  113. id_str = f"${tensor_memo._id(tensor)}" if tensor_memo is not None else ""
  114. if isinstance(tensor, torch.distributed.tensor.DTensor):
  115. # omitted device mesh
  116. return f"dt{id_str}: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}"
  117. elif isinstance(tensor, FakeTensor):
  118. return f"ft{id_str}: {tensor_debug_str}"
  119. else:
  120. return f"t{id_str}: {tensor_debug_str}"
  121. else:
  122. raise RuntimeError(f"Unsupported tensor type: {type(tensor)}")
  123. def _arg_to_str(arg, attributes, tensor_memo=None) -> str:
  124. from torch.distributed.tensor._dtensor_spec import DTensorSpec
  125. def to_str(x):
  126. if isinstance(x, torch.Tensor):
  127. return _tensor_debug_string(x, attributes, tensor_memo)
  128. elif isinstance(x, DTensorSpec):
  129. return _stringify_dtensor_spec(x)
  130. return x
  131. arg = tree_map(to_str, arg)
  132. return str(arg)
  133. def norm_hash_fn(t: torch.Tensor, use_scalar: bool = False) -> torch.Tensor | float:
  134. """
  135. from Observer. Computes a hash for a tensor by converting it to float (if needed), making it contiguous,
  136. replacing NaN/inf values with fixed numbers, and then computing the L1 norm in float64 or complex128.
  137. This is used to generate a deterministic summary value for tensor comparison.
  138. """
  139. with torch._C._DisablePythonDispatcher():
  140. if not (t.is_floating_point() or t.is_complex()):
  141. t = t.float()
  142. t = t.contiguous()
  143. if t.is_complex():
  144. t_float = t.to(dtype=torch.complex128)
  145. else:
  146. t_float = t.to(dtype=torch.float64)
  147. out = t_float.norm(p=1)
  148. if use_scalar:
  149. return out.item()
  150. return out
  151. def _compute_rel_diff(hash1, hash2):
  152. # Relative difference: |hash1 - hash2| / max(|hash1|, |hash2|, eps)
  153. numerator = abs(hash1 - hash2)
  154. denominator = max(abs(hash1), abs(hash2), 1e-10)
  155. return numerator / denominator
  156. def hash_tensor_fn(t: torch.Tensor, use_scalar: bool = False) -> torch.Tensor | int:
  157. """
  158. wrapper over torch.hash_tensor
  159. """
  160. if isinstance(t, torch.distributed.tensor.DTensor):
  161. t = t.to_local()
  162. if t.is_floating_point():
  163. t_clean = t.to(dtype=torch.float64)
  164. elif t.is_complex():
  165. t_clean = t.to(dtype=torch.complex128).view(torch.float64)
  166. else:
  167. t_clean = t.to(dtype=torch.int64)
  168. if t.numel() > 0:
  169. out = torch.hash_tensor(t_clean)
  170. else:
  171. out = torch.zeros((), device=t_clean.device, dtype=torch.uint64)
  172. if use_scalar:
  173. return out.item() # type: ignore[attribute]
  174. return out
  175. def _get_stack_trace() -> str:
  176. from torch.fx.experimental.symbolic_shapes import uninteresting_files
  177. summary = CapturedTraceback.extract().summary()
  178. summary = summary[:-4] # filter out DebugMode frames
  179. summary = [
  180. frame for frame in summary if frame.filename not in uninteresting_files()
  181. ]
  182. summary = traceback.StackSummary.from_list(summary)
  183. return "".join(summary.format())
  184. def _get_user_stack_trace(stack_trace_str: str) -> str | None:
  185. # Extract user code stack trace, filtering out torch internals.
  186. torch_dir = os.path.dirname(inspect.getfile(torch))
  187. filter_fn = lambda file, name, code: not file.startswith(torch_dir + os.path.sep) # noqa: E731
  188. trace = _parse_stack_trace(stack_trace_str, filter_fn=filter_fn)
  189. if trace:
  190. return f"File: {trace.file}:{trace.lineno} in {trace.name}, code: {trace.code}"
  191. return None
  192. def _maybe_get_autograd_trace() -> str | None:
  193. if torch._C._current_autograd_node() is not None:
  194. tb = torch._C._current_autograd_node().metadata.get("traceback_") # type: ignore[attr-defined]
  195. if tb:
  196. return "".join(tb)
  197. return None
  198. def _get_op_name(op) -> str:
  199. if isinstance(op, torch._ops.OpOverload):
  200. op_name = op.__qualname__
  201. elif hasattr(op, "__module__") and hasattr(op, "__name__"):
  202. op_name = f"{op.__module__}.{op.__name__}"
  203. else:
  204. op_name = str(op)
  205. return op_name
  206. _annotate_decorated = False
  207. def _ensure_annotate_decorated():
  208. """
  209. Lazily apply dont_skip_tracing decorator to DebugMode._annotate, to avoid circular import/initialization issues.
  210. """
  211. global _annotate_decorated
  212. if not _annotate_decorated:
  213. DebugMode._annotate = torch._dynamo.dont_skip_tracing(DebugMode._annotate) # type: ignore[has-type]
  214. # Mark annotate as side-effectful so aot_eager doesn't DCE it.
  215. from torch.fx.node import _side_effectful_functions
  216. _side_effectful_functions.add(torch.ops.debug_mode_ops.annotate.default)
  217. # Register no-op lowering for inductor backend
  218. from torch._inductor.lowering import register_lowering
  219. @register_lowering(torch.ops.debug_mode_ops.annotate)
  220. def _annotate_lowering(tag: str) -> None:
  221. warning_once(log, 'DebugMode._annotate() is a no-op for backend="inductor"')
  222. return None
  223. _annotate_decorated = True
  224. class _DebugCall:
  225. """Base class for tracking operator calls in DebugMode"""
  226. def __init__(
  227. self,
  228. call_depth: int,
  229. record: dict[str, Any] | None = None,
  230. log: dict[str, Any] | None = None,
  231. stack: bool = False,
  232. ) -> None:
  233. self.call_depth = call_depth
  234. if stack:
  235. self.stack_trace = _get_stack_trace()
  236. self.fwd_stack_trace = _maybe_get_autograd_trace()
  237. # results from dispatch hooks
  238. self.record = record
  239. self.log = log
  240. self.output_str: str | None = None
  241. def stringify_args(
  242. self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
  243. ) -> None:
  244. """
  245. To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs.
  246. """
  247. raise NotImplementedError(
  248. "Subclasses must implement stringify_args(), even if no-op"
  249. )
  250. def stringify_output(
  251. self,
  252. output: Any,
  253. attributes: list[str],
  254. tensor_memo: TensorIdTracker | None = None,
  255. ) -> None:
  256. """Store stringified version of call output in self.output_str"""
  257. if tree_all(lambda x: x is None, output):
  258. return
  259. output_str = tree_map(lambda x: _arg_to_str(x, attributes, tensor_memo), output)
  260. self.output_str = f" -> {str(output_str)}"
  261. def render(self, attributes: list[str]) -> str:
  262. raise NotImplementedError("Subclasses must implement string render()")
  263. def __repr__(self) -> str:
  264. return self.render([])
  265. class _OpCall(_DebugCall):
  266. """Normal operator call"""
  267. def __init__(
  268. self,
  269. op,
  270. args: tuple,
  271. kwargs: dict,
  272. call_depth: int,
  273. stack: bool = False,
  274. ) -> None:
  275. super().__init__(call_depth, stack=stack)
  276. self.op = op
  277. self.args = args
  278. self.kwargs = kwargs
  279. self.args_str: str | None = None
  280. self.kwargs_str: str | None = None
  281. def stringify_args(
  282. self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
  283. ) -> None:
  284. self.args_str = ", ".join(
  285. _arg_to_str(arg, attributes, tensor_memo) for arg in self.args
  286. )
  287. if self.kwargs:
  288. self.kwargs_str = ", " + ", ".join(
  289. f"{k}={_arg_to_str(v, attributes, tensor_memo)}"
  290. for k, v in self.kwargs.items()
  291. )
  292. else:
  293. self.kwargs_str = ""
  294. del self.args
  295. del self.kwargs
  296. def render(self, attributes: list[str]) -> str:
  297. if self.args_str is not None:
  298. args_str = self.args_str
  299. else:
  300. args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args)
  301. if self.kwargs_str is not None:
  302. kwargs_str = self.kwargs_str
  303. else:
  304. if self.kwargs:
  305. kwargs_str = ", " + ", ".join(
  306. f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items()
  307. )
  308. else:
  309. kwargs_str = ""
  310. if isinstance(self.op, torch._ops.OpOverload):
  311. op_name = self.op.__qualname__
  312. elif hasattr(self.op, "__module__") and hasattr(self.op, "__name__"):
  313. op_name = f"{self.op.__module__}.{self.op.__name__}"
  314. else:
  315. op_name = str(self.op)
  316. base_str = f"{op_name}({args_str}{kwargs_str})"
  317. if self.output_str:
  318. base_str += self.output_str
  319. if self.log:
  320. base_str += f" # {self.log}"
  321. return base_str
  322. def __iter__(self):
  323. # for BC; tuple(self) returns (op, args, kwargs, call_depth)
  324. if self.args_str is not None:
  325. yield from [self.op, self.args_str, self.kwargs_str, self.call_depth]
  326. else:
  327. yield from [self.op, self.args, self.kwargs, self.call_depth]
  328. class _RedistributeCall(_DebugCall):
  329. def __init__(
  330. self,
  331. arg,
  332. src_placement,
  333. dst_placement,
  334. transform_info_str,
  335. call_depth,
  336. stack=False,
  337. is_explicit=False,
  338. ) -> None:
  339. super().__init__(call_depth, stack=stack)
  340. self.arg = arg
  341. self.src_placement = src_placement
  342. self.dst_placement = dst_placement
  343. self.transform_info_str = transform_info_str
  344. self.is_explicit = is_explicit
  345. self.is_outer_call = isinstance(arg, int)
  346. self.arg_str: str | None = None
  347. def stringify_args(
  348. self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
  349. ) -> None:
  350. self.arg_str = f"{_arg_to_str(self.arg, attributes, tensor_memo)}"
  351. del self.arg
  352. def render(self, attributes: list[str]) -> str:
  353. if self.arg_str is not None:
  354. arg_str = self.arg_str
  355. else:
  356. arg_str = f"{_arg_to_str(self.arg, attributes)}"
  357. if self.transform_info_str is not None: # prioritize over src/dst placements
  358. placement_str = f"trace: {self.transform_info_str}"
  359. else:
  360. src_placement_str = _arg_to_str(self.src_placement, attributes)
  361. dst_placement_str = _arg_to_str(self.dst_placement, attributes)
  362. placement_str = f"{src_placement_str} -> {dst_placement_str}"
  363. # DebugMode will add redistribute_input logs at 2 levels,
  364. # once per redistribute decision, and once per redistributed input.
  365. # We only annotate [implicit/explicit] logs on the former (outer-level call).
  366. if self.is_outer_call:
  367. annotation = " [implicit] "
  368. elif self.is_explicit:
  369. annotation = " [explicit] "
  370. else:
  371. annotation = ""
  372. base_str = f"{REDISTRIBUTE_FUNC}{annotation}({arg_str}, {placement_str})"
  373. if self.output_str:
  374. base_str += self.output_str
  375. return base_str
  376. def __iter__(self):
  377. # for BC; tuple(self) returns (op, placement info, kwargs, call_depth)
  378. if self.arg_str is not None:
  379. arg = self.arg_str
  380. else:
  381. arg = self.arg
  382. yield REDISTRIBUTE_FUNC
  383. if self.transform_info_str:
  384. yield [arg, self.transform_info_str]
  385. else:
  386. yield [arg, self.src_placement, self.dst_placement]
  387. yield {}
  388. yield self.call_depth
  389. class _OutputPlacementCall(_DebugCall):
  390. """Records output placement for a DTensor op."""
  391. def __init__(self, placements_str: str, call_depth: int) -> None:
  392. super().__init__(call_depth)
  393. self.placements_str = placements_str
  394. def stringify_args(
  395. self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
  396. ) -> None:
  397. pass # Already stringified
  398. def render(self, attributes: list[str]) -> str:
  399. return f"-> output: {self.placements_str}"
  400. class _TritonKernelCall(_DebugCall):
  401. """Triton kernel call from Inductor"""
  402. def __init__(
  403. self,
  404. kernel_name: str,
  405. kwargs: dict[str, Any],
  406. call_depth: int,
  407. ):
  408. super().__init__(call_depth)
  409. self.kernel_name = kernel_name
  410. self.kwargs = kwargs
  411. self.kwargs_str: str | None = None
  412. self.pre_hashes: dict[str, Any] | None = None
  413. self.post_hashes: dict[str, Any] | None = None
  414. def stringify_args(
  415. self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
  416. ) -> None:
  417. # Optionally hash kernel inputs before launch
  418. global _TRITON_INPUT_HASH_FN
  419. if hash_fn := _TRITON_INPUT_HASH_FN:
  420. self.pre_hashes = {
  421. k: hash_fn(v)
  422. for k, v in self.kwargs.items()
  423. if isinstance(v, torch.Tensor)
  424. }
  425. if self.kwargs:
  426. self.kwargs_str = ", ".join(
  427. f"{k}={_arg_to_str(v, attributes, tensor_memo)}"
  428. for k, v in self.kwargs.items()
  429. )
  430. else:
  431. self.kwargs_str = ""
  432. def render(self, attributes: list[str]) -> str:
  433. base_str = f"[triton] {self.kernel_name}({self.kwargs_str})"
  434. if self.pre_hashes:
  435. pre_hashes_str = ", ".join(f"{k}: {v}" for k, v in self.pre_hashes.items())
  436. pre_hashes_str = (
  437. "\n "
  438. + " " * self.call_depth
  439. + f"# pre-kernel hashes: {{{pre_hashes_str}}}"
  440. )
  441. else:
  442. pre_hashes_str = ""
  443. if self.post_hashes:
  444. post_hashes_str = ", ".join(
  445. f"{k}: {v}" for k, v in self.post_hashes.items()
  446. )
  447. post_hashes_str = (
  448. "\n "
  449. + " " * self.call_depth
  450. + f"# post-kernel hashes: {{{post_hashes_str}}}"
  451. )
  452. else:
  453. post_hashes_str = ""
  454. return f"{base_str}{pre_hashes_str}{post_hashes_str}\n"
  455. def finalize(self, device_interface: "DeviceInterface"):
  456. # synchronize -> hash/store kernel results
  457. global _RECORD_TRITON_OUTPUTS, _TRITON_OUTPUT_HASH_FN
  458. device_interface.synchronize(device_interface.current_device())
  459. if _RECORD_TRITON_OUTPUTS:
  460. self.record = {
  461. "output": {
  462. k: v.clone() if isinstance(v, torch.Tensor) else v
  463. for k, v in self.kwargs.items()
  464. }
  465. }
  466. if hash_fn := _TRITON_OUTPUT_HASH_FN:
  467. self.post_hashes = {
  468. k: hash_fn(v)
  469. for k, v in self.kwargs.items()
  470. if isinstance(v, torch.Tensor)
  471. }
  472. # don't store tensors
  473. del self.kwargs
  474. def __iter__(self):
  475. yield from [self.kernel_name, (), self.kwargs_str, self.call_depth]
  476. class _AnnotateCall(_DebugCall):
  477. """Custom annotation call"""
  478. def __init__(
  479. self, tag: Any, header: str, call_depth: int, stack: bool = False
  480. ) -> None:
  481. super().__init__(call_depth, stack=stack)
  482. self.tag = tag
  483. self.header = header
  484. def render(self, attributes: list[str]) -> str:
  485. return f"[{self.header}] {self.tag}"
  486. def __iter__(self):
  487. yield from [
  488. f"[{self.header}] {self.tag}",
  489. (),
  490. {},
  491. self.call_depth,
  492. ]
  493. def _run_hook(hook, *args):
  494. out = hook(*args)
  495. if out is not None and not isinstance(out, dict):
  496. raise AssertionError(f"hook must return None or dict, got {type(out).__name__}")
  497. return out
  498. def _run_dispatch_pre_log_hooks(call: _DebugCall, func, types, args, kwargs) -> None:
  499. global _DISPATCH_PRE_LOG_HOOKS
  500. if _DISPATCH_PRE_LOG_HOOKS:
  501. for hook in _DISPATCH_PRE_LOG_HOOKS:
  502. hook_out = _run_hook(hook, func, types, args, kwargs, call)
  503. if hook_out is not None:
  504. # Store pre-hook results in call.log
  505. if call.log is None:
  506. call.log = {}
  507. call.log.update(hook_out)
  508. def _run_dispatch_hooks(call: _DebugCall, func, types, args, kwargs, result) -> None:
  509. global _DISPATCH_RECORD_HOOKS, _DISPATCH_LOG_HOOKS
  510. if _DISPATCH_RECORD_HOOKS:
  511. record = {}
  512. for hook in _DISPATCH_RECORD_HOOKS:
  513. hook_out = _run_hook(hook, func, types, args, kwargs, result)
  514. if hook_out is not None:
  515. record.update(hook_out)
  516. if record:
  517. call.record = record
  518. if _DISPATCH_LOG_HOOKS:
  519. # Preserve existing log from pre-hooks (e.g., input_hash)
  520. if call.log is None:
  521. call.log = {}
  522. for hook in _DISPATCH_LOG_HOOKS:
  523. hook_out = _run_hook(hook, func, types, args, kwargs, result)
  524. if hook_out is not None:
  525. call.log.update(hook_out)
  526. def _get_call_name(call: _DebugCall) -> str:
  527. """String identifying _DebugCall (e.g. func, kernel, module name)"""
  528. if isinstance(call, _OpCall):
  529. return _get_op_name(call.op)
  530. elif isinstance(call, _TritonKernelCall):
  531. return call.kernel_name
  532. elif isinstance(call, _AnnotateCall):
  533. return f"[{call.header}] {call.tag}"
  534. elif isinstance(call, _RedistributeCall):
  535. return REDISTRIBUTE_FUNC
  536. else:
  537. return str(call)
  538. @torch.library.custom_op("debug_mode_ops::annotate", mutates_args=())
  539. def _annotate(tag: str) -> None:
  540. # This is special-cased in DebugMode.__torch_dispatch__
  541. return None
  542. @_annotate.register_fake
  543. def _annotate_fake(tag: str) -> None:
  544. return None
  545. class DebugInterpreter(torch.fx.Interpreter):
  546. """
  547. Interpreter class for running aot_eager compiled regions when DebugMode is active,
  548. instead of using the compiled code. This gives us access to fx.Node metadata to decorate
  549. and contextualize DebugMode logs (e.g. nn_module_stack, stack_trace, compiled region boundaries).
  550. Note: this is currently only enabled with DebugMode(run_compile_with_interpreter=True).
  551. """
  552. def __init__(self, module, backend):
  553. super().__init__(module)
  554. self.mode = get_active_debug_mode()
  555. if self.mode is None:
  556. raise RuntimeError("No DebugMode is currently active")
  557. # for tracking initial nn_module_stack
  558. self.base_nn_module_stack = list(self.mode.current_nn_module_stack)
  559. # annotate start of region
  560. self.backend = backend
  561. self.mode.operators.append(
  562. _AnnotateCall(
  563. "enter", f"{self.backend} region (compile)", self.mode.call_depth
  564. )
  565. )
  566. def run_node(self, n: torch.fx.Node) -> Any:
  567. if self.mode is None:
  568. raise RuntimeError("No DebugMode is currently active")
  569. # handling of nn.Module stack
  570. if self.mode.record_nn_module and n.op not in ["placeholder", "output"]:
  571. self.mode._handle_fx_nn_module_stack(
  572. self.base_nn_module_stack,
  573. n.meta.get("nn_module_stack", {}),
  574. n.meta.get("fwd_nn_module_stack", {}),
  575. )
  576. # override stack trace with n.meta
  577. if (
  578. self.mode.record_stack_trace
  579. and n.op not in ["placeholder", "output"]
  580. and (stack_trace := n.meta.get("stack_trace", None)) is not None
  581. ):
  582. with self.mode.set_fx_stack_trace(stack_trace):
  583. return super().run_node(n)
  584. else:
  585. return super().run_node(n)
  586. def run(self, *args, **kwargs):
  587. if self.mode is None:
  588. raise RuntimeError("No DebugMode is currently active")
  589. result = super().run(*args)
  590. # reset nn.Module stack to pre-compiled region value
  591. if len(self.mode.current_nn_module_stack) < len(self.base_nn_module_stack):
  592. warning_once(
  593. log, "unexpected handling of nn_module_stack in DebugInterpreter"
  594. )
  595. while len(self.mode.current_nn_module_stack) > len(self.base_nn_module_stack):
  596. self.mode._exit_nn_module_call()
  597. # annotate end of region
  598. self.mode.operators.append(
  599. _AnnotateCall(
  600. "exit", f"{self.backend} region (compile)", self.mode.call_depth
  601. )
  602. )
  603. return result
  604. class DebugMode(TorchDispatchMode):
  605. def __init__(
  606. self,
  607. *,
  608. record_torchfunction=False,
  609. record_faketensor=False,
  610. record_realtensor=True,
  611. record_tensor_attributes=None,
  612. record_nn_module=False,
  613. store_original_args=False,
  614. record_stack_trace=False,
  615. record_output=True,
  616. record_ids=False,
  617. record_profiler_context=True,
  618. record_localtensor=True,
  619. run_compile_with_interpreter=False,
  620. ) -> None:
  621. super().__init__()
  622. import torch.distributed.tensor # noqa: F401
  623. _ensure_annotate_decorated()
  624. self.supports_higher_order_operators = True
  625. # Pushes DebugMode onto the torchfunction stack, and records __torch_function__ calls as well.
  626. # WARNING: currently incompatible with torch.compile due to dynamo guard failures.
  627. self.record_torchfunction = record_torchfunction
  628. # Records __torch_dispatch__ calls on FakeTensors.
  629. self.record_faketensor = record_faketensor
  630. # Records __torch_dispatch__ calls on real tensors.
  631. self.record_realtensor = record_realtensor
  632. # Records __torch_dispatch__ calls on LocalTensor.
  633. self.record_localtensor = record_localtensor
  634. # Optional list[str] of tensor attributes, to be annotated in the string dump.
  635. self.record_tensor_attributes = record_tensor_attributes or []
  636. # Uses ModTracker to record nn.Module entrances.
  637. # This flag currently has no effect on torch.compiled-regions.
  638. self.record_nn_module = record_nn_module
  639. self.module_tracker: ModTracker | None = None
  640. if self.record_nn_module:
  641. self.module_tracker_setup()
  642. # If True, stores call args/kwargs in logs, without immediately stringifying.
  643. # Defaults to False for memory concerns.
  644. self.store_original_args = store_original_args
  645. # For stack trace recording, stores log call stack traces in .stack_trace.
  646. # For backward graph nodes, will also store the corresponding forward stack traces in .fwd_stack_trace.
  647. # NOTE: this is only available if autograd tracebacks are being set during the forward pass,
  648. # e.g. via DebugMode(record_stack_trace=True), or torch.autograd.set_detect_anomaly().
  649. self.record_stack_trace = record_stack_trace
  650. # Records call outputs in logs (e.g. for __torch_dispatch__, __torch_function__, redistribute_input)
  651. self.record_output: bool = record_output
  652. # Annotates string dumps with graph-style tensor ids, e.g. op($1, $2) -> $3.
  653. self.record_ids: bool = record_ids
  654. # Annotates string dumps with profiler.record_function contexts from runtime code.
  655. # Currently does not preserve contexts inside torch.compile-d regions.
  656. self.record_profiler_context: bool = record_profiler_context
  657. # For aot_eager compiled regions, wraps the compiled fx.GraphModule with a DebugInterpreter,
  658. # and uses it at runtime for node metadata visibility.
  659. self.run_compile_with_interpreter: bool = run_compile_with_interpreter
  660. self.reset()
  661. def reset(self) -> None:
  662. self.operators = []
  663. self.call_depth = 0
  664. self._tensor_memo = TensorIdTracker()
  665. self._output_info: dict[int, object] = {}
  666. self.ignored_record_functions = 0
  667. self.current_nn_module_stack = []
  668. self.fx_stack_trace = None
  669. def _track_op_output(self, op_index, result) -> None:
  670. """Assign IDs to output tensors and store in output_info"""
  671. self._output_info[op_index] = result
  672. # Without this override, running torch.compile under DebugMode
  673. # will force torch.compile to always use the “eager” backend
  674. # With this, DebugMode will not take effect on torch.compile
  675. @classmethod
  676. def ignore_compile_internals(cls) -> bool:
  677. return True
  678. def _record_call(self, call) -> None:
  679. global _IN_INDUCTOR_BENCHMARK
  680. if _IN_INDUCTOR_BENCHMARK:
  681. return
  682. if str(call).startswith("profiler::_record_function"):
  683. return
  684. if not self.store_original_args:
  685. call.stringify_args(
  686. self.record_tensor_attributes,
  687. self._tensor_memo if self.record_ids else None,
  688. )
  689. if self.fx_stack_trace:
  690. call.stack_trace = call.fwd_stack_trace = self.fx_stack_trace
  691. self.operators.append(call)
  692. def _record_call_output(self, call, output) -> None:
  693. if not self.record_output:
  694. return
  695. call.stringify_output(
  696. output,
  697. self.record_tensor_attributes,
  698. self._tensor_memo if self.record_ids else None,
  699. )
  700. def __torch_function__(self, func, types, args=(), kwargs=None):
  701. if kwargs is None:
  702. kwargs = {}
  703. call = _OpCall(
  704. func, args, kwargs, self.call_depth, stack=self.record_stack_trace
  705. )
  706. self._record_call(call)
  707. try:
  708. self.call_depth += 1
  709. result = func(*args, **kwargs)
  710. self._record_call_output(call, result)
  711. return result
  712. finally:
  713. self.call_depth -= 1
  714. def _maybe_record_function(self, tag):
  715. # filter out tags that appear noisy, or aren't runtime-related
  716. if any(
  717. tag.startswith(prefix)
  718. for prefix in [
  719. # assuming these are from benchmarking, not the actual runtime call
  720. "CachingAutotuner.",
  721. "InductorBenchmarker.",
  722. # inductor compilation
  723. "compile_fx.<locals>.",
  724. ]
  725. ):
  726. self.ignored_record_functions += 1
  727. return
  728. call = _AnnotateCall(
  729. tag, "record function", self.call_depth, stack=self.record_stack_trace
  730. )
  731. self.operators.append(call)
  732. self.call_depth += 1
  733. def _maybe_exit_record_function(self):
  734. if self.ignored_record_functions < 0:
  735. raise AssertionError(
  736. f"ignored_record_functions is negative: {self.ignored_record_functions}"
  737. )
  738. if self.ignored_record_functions > 0:
  739. self.ignored_record_functions -= 1
  740. else:
  741. self.call_depth -= 1
  742. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  743. if kwargs is None:
  744. kwargs = {}
  745. # Handle record_function entries
  746. if self.record_profiler_context:
  747. if func == torch.ops.profiler._record_function_enter_new.default:
  748. if len(args) != 1:
  749. raise AssertionError(f"expected 1 arg, got {len(args)}")
  750. self._maybe_record_function(args[0])
  751. elif func == torch.ops.profiler._record_function_exit._RecordFunction:
  752. self._maybe_exit_record_function()
  753. # Handle DebugMode._annotate()
  754. if func is torch.ops.debug_mode_ops.annotate.default:
  755. if len(args) != 1:
  756. raise AssertionError(f"expected 1 arg, got {len(args)}")
  757. self._handle_annotate(args[0])
  758. return
  759. from torch.distributed._local_tensor import LocalTensor
  760. # Record the operation with its call depth
  761. call = None
  762. if torch.distributed.tensor.DTensor in types:
  763. call = _OpCall(
  764. func, args, kwargs, self.call_depth, stack=self.record_stack_trace
  765. )
  766. self._record_call(call)
  767. return NotImplemented
  768. elif FakeTensor in types or isinstance(
  769. _get_current_dispatch_mode(), FakeTensorMode
  770. ):
  771. if self.record_faketensor:
  772. if func != torch.ops.prim.device.default:
  773. call = _OpCall(
  774. func,
  775. args,
  776. kwargs,
  777. self.call_depth + 1,
  778. stack=self.record_stack_trace,
  779. )
  780. self._record_call(call)
  781. # TODO: check the context manager
  782. elif LocalTensor in types:
  783. if self.record_localtensor:
  784. call = _OpCall(
  785. func,
  786. args,
  787. kwargs,
  788. self.call_depth + 1,
  789. stack=self.record_stack_trace,
  790. )
  791. self._record_call(call)
  792. elif len(types) == 0:
  793. if self.record_realtensor:
  794. call = _OpCall(
  795. func,
  796. args,
  797. kwargs,
  798. self.call_depth + 1,
  799. stack=self.record_stack_trace,
  800. )
  801. self._record_call(call)
  802. # Run pre-hooks before executing the operation to hash inputs
  803. # We have to run becore the func() call in case there's any
  804. # in-place mutation
  805. if call:
  806. _run_dispatch_pre_log_hooks(call, func, types, args, kwargs)
  807. result = func(*args, **kwargs)
  808. if call:
  809. self._record_call_output(call, result)
  810. _run_dispatch_hooks(call, func, types, args, kwargs, result)
  811. return result
  812. def __enter__(self):
  813. global _ACTIVE_DEBUG_MODE_COUNT
  814. _ACTIVE_DEBUG_MODE_COUNT += 1
  815. if self.record_torchfunction:
  816. torch._C._push_on_torch_function_stack(self)
  817. super().__enter__()
  818. if self.record_nn_module:
  819. self.module_tracker.__enter__() # type: ignore[attribute, union-attr]
  820. if self.record_stack_trace:
  821. self.anomaly_for_traces = torch.autograd.set_detect_anomaly(
  822. True, check_nan=False
  823. )
  824. self.anomaly_for_traces.__enter__()
  825. return self
  826. # pyrefly: ignore [bad-override]
  827. def __exit__(self, *args):
  828. global _ACTIVE_DEBUG_MODE_COUNT
  829. _ACTIVE_DEBUG_MODE_COUNT -= 1
  830. super().__exit__(*args)
  831. if self.record_nn_module:
  832. self.module_tracker.__exit__() # type: ignore[attribute, union-attr]
  833. if self.record_torchfunction:
  834. torch._C._pop_torch_function_stack()
  835. if self.record_stack_trace:
  836. self.anomaly_for_traces.__exit__(*args)
  837. @contextlib.contextmanager
  838. def set_fx_stack_trace(self, stack_trace):
  839. self.fx_stack_trace = stack_trace
  840. try:
  841. yield
  842. finally:
  843. self.fx_stack_trace = None
  844. def _enter_nn_module_call(self, fqn, header):
  845. call = _AnnotateCall(
  846. fqn, header, self.call_depth + 1, stack=self.record_stack_trace
  847. )
  848. self.operators.append(call)
  849. self.current_nn_module_stack.append(fqn)
  850. self.call_depth += 1
  851. def _exit_nn_module_call(self):
  852. self.call_depth -= 1
  853. self.current_nn_module_stack.pop()
  854. def module_tracker_setup(self) -> None:
  855. from torch.distributed._tools.mod_tracker import ModTracker
  856. self.module_tracker = ModTracker()
  857. # module pre-fw hook: record module call
  858. def pre_fw_hook(module, input) -> None:
  859. fqn = self.module_tracker._get_mod_name(module) # type: ignore[attribute, union-attr]
  860. self._enter_nn_module_call(fqn, "nn.Mod")
  861. # module post-fw hook: decrement call depth
  862. def post_fw_hook(module, input, output) -> None:
  863. self._exit_nn_module_call()
  864. self.module_tracker.register_user_hooks(pre_fw_hook, post_fw_hook)
  865. def _handle_fx_nn_module_stack(
  866. self,
  867. base_stack: list[str],
  868. nn_module_stack: dict[str, tuple[str, Any]] | None,
  869. fwd_nn_module_stack: dict[str, tuple[str, Any]] | None,
  870. ) -> None:
  871. """
  872. Called when DebugInterpreter observes nn_module_stack or fwd_nn_module_stack metadata
  873. from executing the compiled GraphModule.
  874. If the current module stack is mismatched with what's currently tracked in DebugMode
  875. (current_nn_module_stack), we adjust call depth and add new [nn.Module] log entries accordingly.
  876. """
  877. nn_module_stack = nn_module_stack or {}
  878. fwd_nn_module_stack = fwd_nn_module_stack or {}
  879. if nn_module_stack and fwd_nn_module_stack:
  880. raise AssertionError(
  881. "Expecting at most one of nn_module_stack and fwd_nn_module_stack."
  882. )
  883. is_fwd = nn_module_stack
  884. stack = nn_module_stack if is_fwd else fwd_nn_module_stack
  885. # forward stack
  886. current_stack = self.current_nn_module_stack
  887. new_stack = base_stack + [v[0] for v in stack.values()]
  888. entered = set(new_stack) - set(current_stack)
  889. exited = set(current_stack) - set(new_stack)
  890. # Decrement depth for exited modules
  891. for _ in exited:
  892. self._exit_nn_module_call()
  893. if self.call_depth < 0:
  894. raise AssertionError("Unexpectedly, DebugMode call_depth is negative")
  895. # Add [nn.Module] entries for newly entered modules
  896. for fqn in sorted(entered):
  897. self._enter_nn_module_call(
  898. fqn, "nn.Mod (compile)" if is_fwd else "nn.Mod (compile bwd)"
  899. )
  900. self.current_nn_module_stack = new_stack
  901. @contextlib.contextmanager
  902. def record_redistribute_calls(
  903. self,
  904. arg,
  905. src_placement,
  906. dst_placement,
  907. transform_info_str: str | None = None,
  908. is_explicit: bool = False,
  909. ):
  910. try:
  911. self._record_call(
  912. _RedistributeCall(
  913. arg,
  914. src_placement=src_placement,
  915. dst_placement=dst_placement,
  916. transform_info_str=transform_info_str,
  917. call_depth=self.call_depth + 1,
  918. stack=self.record_stack_trace,
  919. is_explicit=is_explicit,
  920. )
  921. )
  922. self.call_depth += 1
  923. yield
  924. finally:
  925. self.call_depth -= 1
  926. def record_output_placements(self, output_spec) -> None:
  927. """Record output placements for a DTensor op as a separate line."""
  928. if not self.record_output:
  929. return
  930. from torch.distributed.tensor._dtensor_spec import DTensorSpec
  931. placements_str = str(
  932. tree_map_only(DTensorSpec, _stringify_dtensor_spec, output_spec)
  933. )
  934. call = _OutputPlacementCall(placements_str, self.call_depth + 1)
  935. self._record_call(call)
  936. def record_triton_kernel(
  937. self, kernel_name: str, kwargs: dict[str, Any]
  938. ) -> _TritonKernelCall:
  939. call = _TritonKernelCall(kernel_name, kwargs, self.call_depth + 1)
  940. call.stringify_args(self.record_tensor_attributes)
  941. self.operators.append(call)
  942. return call
  943. def debug_string(self, show_stack_trace: bool | None = None) -> str:
  944. """
  945. show_stack_trace: option to display one-line stack trace summaries above groups
  946. of operations (similar to gm.print_readable() style).
  947. Requires record_stack_trace=True.
  948. if None, uses self.record_stack_trace, otherwise overrides it.
  949. """
  950. show_stack_trace = (
  951. self.record_stack_trace if show_stack_trace is None else show_stack_trace
  952. )
  953. with torch._C.DisableTorchFunction():
  954. if not show_stack_trace:
  955. result = "\n".join(
  956. " "
  957. + " " * op.call_depth
  958. + op.render(self.record_tensor_attributes)
  959. for op in self.operators
  960. )
  961. return result
  962. # Group operations by stack trace
  963. lines = []
  964. prev_stack_summary = None
  965. for op in self.operators:
  966. # Get the stack trace: prefer fwd_stack_trace, fallback to stack_trace
  967. stack_trace = None
  968. if hasattr(op, "fwd_stack_trace") and op.fwd_stack_trace:
  969. stack_trace = op.fwd_stack_trace
  970. elif hasattr(op, "stack_trace") and op.stack_trace:
  971. stack_trace = op.stack_trace
  972. stack_summary = None
  973. if stack_trace:
  974. stack_summary = _get_user_stack_trace(stack_trace)
  975. if stack_summary and stack_summary != prev_stack_summary:
  976. # add blank line before stack trace comment for readability
  977. if lines: # don't add blank line at the very start
  978. lines.append("")
  979. indent = " " * (op.call_depth + 1)
  980. lines.append(indent + "# " + stack_summary)
  981. prev_stack_summary = stack_summary
  982. # Add the operation line
  983. line = (
  984. " "
  985. + " " * op.call_depth
  986. + op.render(self.record_tensor_attributes)
  987. )
  988. lines.append(line)
  989. return "\n".join(lines)
  990. @staticmethod
  991. @contextlib.contextmanager
  992. def dispatch_hooks(
  993. record_hook: Callable | None = None,
  994. log_hook: Callable | None = None,
  995. pre_log_hook: Callable | None = None,
  996. ):
  997. """
  998. Allows installing post-hooks on arguments to intercepted __torch_dispatch__ calls;
  999. hook signatures are expected as (func, types, args, kwargs, result),
  1000. i.e. __torch_dispatch__ args + return value.
  1001. Logging hook outputs are stored in call.log and annotate calls in debug_string(),
  1002. while recording hook outputs are just stored in call.record.
  1003. For now hooks are expected to return dictionaries.
  1004. pre_log_hook signature is (func, types, args, kwargs, call) and is executed before
  1005. the operation. It allows capturing state before in-place mutations.
  1006. """
  1007. global _DISPATCH_RECORD_HOOKS, _DISPATCH_LOG_HOOKS, _DISPATCH_PRE_LOG_HOOKS
  1008. if record_hook:
  1009. _DISPATCH_RECORD_HOOKS.append(record_hook)
  1010. if log_hook:
  1011. _DISPATCH_LOG_HOOKS.append(log_hook)
  1012. if pre_log_hook:
  1013. _DISPATCH_PRE_LOG_HOOKS.append(pre_log_hook)
  1014. try:
  1015. yield
  1016. finally:
  1017. if record_hook:
  1018. _DISPATCH_RECORD_HOOKS.pop()
  1019. if log_hook:
  1020. _DISPATCH_LOG_HOOKS.pop()
  1021. if pre_log_hook:
  1022. _DISPATCH_PRE_LOG_HOOKS.pop()
  1023. @staticmethod
  1024. @contextlib.contextmanager
  1025. def record_outputs():
  1026. """
  1027. Hook for storing cloned output tensors in .record["output"].
  1028. """
  1029. def dispatch_hook(func, types, args, kwargs, result):
  1030. out = tree_map(
  1031. lambda x: x.clone() if isinstance(x, torch.Tensor) else x, result
  1032. )
  1033. return {"output": out}
  1034. global _RECORD_TRITON_OUTPUTS
  1035. try:
  1036. _old_record_triton = _RECORD_TRITON_OUTPUTS
  1037. _RECORD_TRITON_OUTPUTS = True
  1038. with DebugMode.dispatch_hooks(record_hook=dispatch_hook):
  1039. yield
  1040. finally:
  1041. _RECORD_TRITON_OUTPUTS = _old_record_triton
  1042. @staticmethod
  1043. @contextlib.contextmanager
  1044. def log_tensor_hashes(
  1045. hash_fn: Callable | str | list[str] = "norm", hash_inputs: bool = False
  1046. ):
  1047. """
  1048. Installs hook for tensor hash logging.
  1049. hash_fn: One of:
  1050. - Custom-defined hash function
  1051. - String: one of ("norm", "hash_tensor")
  1052. - "norm": uses norm_hash_fn; basically tensor's L1 norm
  1053. - "hash_tensor": uses torch.hash_tensor (XOR sum reduction)
  1054. - List of strings: returns tuple of hashes from above options
  1055. hash_inputs: if True, also hashes tensors in (args, kwargs), storing them in "input_hash".
  1056. Input hashes are captured before the operation executes, so they reflect the state before
  1057. any in-place mutations.
  1058. """
  1059. def hash_fn_option(hash_type):
  1060. if not isinstance(hash_type, str) or hash_type not in [
  1061. "norm",
  1062. "hash_tensor",
  1063. ]:
  1064. raise AssertionError(
  1065. f"hash_type must be 'norm' or 'hash_tensor', got {hash_type!r}"
  1066. )
  1067. return functools.partial(
  1068. norm_hash_fn if hash_type == "norm" else hash_tensor_fn, use_scalar=True
  1069. )
  1070. if callable(hash_fn):
  1071. fn = hash_fn
  1072. elif isinstance(hash_fn, str):
  1073. fn = hash_fn_option(hash_fn)
  1074. elif isinstance(hash_fn, list):
  1075. fns = [hash_fn_option(fn) for fn in hash_fn]
  1076. fn = lambda x: tuple(fn(x) for fn in fns) # noqa: E731
  1077. else:
  1078. raise NotImplementedError(
  1079. f"log_tensor_hashes() expected hash_fn to be callable, str, or list[str], but found {type(hash_fn)}"
  1080. )
  1081. def _tree_hash(obj):
  1082. return tree_map(
  1083. lambda x: fn(x) if isinstance(x, torch.Tensor) else None, obj
  1084. )
  1085. def _dispatch_pre_log_hook(func, types, args, kwargs, call):
  1086. """Pre-hook to capture input hashes before operation executes"""
  1087. if "empty" in str(func) or "profiler" in str(func):
  1088. return None
  1089. if hash_inputs:
  1090. # Capture input hashes before the operation
  1091. input_hash = _tree_hash((args, kwargs))
  1092. if not tree_all(lambda x: x is None, input_hash):
  1093. return {"input_hash": input_hash}
  1094. return None
  1095. def _dispatch_post_hook(func, types, args, kwargs, result):
  1096. """Post-hook to capture output hashes after operation executes"""
  1097. if "empty" in str(func) or "profiler" in str(func):
  1098. return None
  1099. out = {}
  1100. out["hash"] = _tree_hash(result)
  1101. if tree_all(lambda x: x is None, out.values()):
  1102. return None
  1103. return out
  1104. global _TRITON_INPUT_HASH_FN, _TRITON_OUTPUT_HASH_FN
  1105. try:
  1106. if hash_inputs:
  1107. _old_input_hfn = _TRITON_INPUT_HASH_FN
  1108. _TRITON_INPUT_HASH_FN = fn
  1109. _old_output_hfn = _TRITON_OUTPUT_HASH_FN
  1110. _TRITON_OUTPUT_HASH_FN = fn
  1111. with DebugMode.dispatch_hooks(
  1112. log_hook=_dispatch_post_hook,
  1113. pre_log_hook=_dispatch_pre_log_hook if hash_inputs else None,
  1114. ):
  1115. yield
  1116. finally:
  1117. if hash_inputs:
  1118. _TRITON_INPUT_HASH_FN = _old_input_hfn # type: ignore[assignment]
  1119. _TRITON_OUTPUT_HASH_FN = _old_output_hfn
  1120. @staticmethod
  1121. @contextlib.contextmanager
  1122. def _benchmarking_inductor():
  1123. """
  1124. Context manager for disabling logging during inductor benchmarking,
  1125. so logs don't contain all kernels launched from autotuning.
  1126. """
  1127. global _IN_INDUCTOR_BENCHMARK
  1128. try:
  1129. _IN_INDUCTOR_BENCHMARK = True
  1130. yield
  1131. finally:
  1132. _IN_INDUCTOR_BENCHMARK = False
  1133. @property
  1134. def logs(self):
  1135. return list(self.operators)
  1136. def _handle_annotate(self, tag):
  1137. """Handles DebugMode._annotate()"""
  1138. call = _AnnotateCall(tag, "annotate", self.call_depth, self.record_stack_trace)
  1139. self.operators.append(call)
  1140. @staticmethod
  1141. def _annotate(tag: Any) -> None:
  1142. """
  1143. If an active DebugMode exists, adds an "[annotate] <tag>" entry to the logs. Useful for contextualizing logs.
  1144. Implemented with a custom op.
  1145. """
  1146. torch.ops.debug_mode_ops.annotate(tag)
  1147. @staticmethod
  1148. def check_hash_mismatches(
  1149. logs1: list, logs2: list, compare_inputs: bool = False
  1150. ) -> list[dict]:
  1151. """
  1152. Compares tensor hashes between two DebugMode runs, for checking run-to-run numerical divergence.
  1153. This first validates the two log sequences have identical structure (same operations, input shapes/dtypes, etc.),
  1154. then compares tensor hash values, and returns a list of call outputs where mismatches were found.
  1155. Expects input logs to have been run with log_tensor_hashes, and looks for hashes in .log["hash"] & .log["input_hash"]
  1156. (or .post_hashes & .pre_hashes for triton kernels).
  1157. note: skips checking log pairs where hashes aren't present, but will raise if present in one & not the other.
  1158. Args:
  1159. logs1: logs from the first DebugMode run (from debug_mode.logs)
  1160. logs2: logs from the second DebugMode run
  1161. compare_inputs: If True, also compare input tensor hashes (default: only output checking)
  1162. Returns:
  1163. List of dictionaries describing hash mismatches. Each dict contains:
  1164. - call_type: "torch op" or "triton kernel"
  1165. - call: Operator/kernel name
  1166. - arg_name: For triton kernels, the argument name; None for torch ops
  1167. - pytree_path: For torch ops, the pytree path to the differing tensor; None for kernels
  1168. - hash1: Hash value from the first run
  1169. - hash2: Hash value from the second run
  1170. - rel_diff: Relative difference between hash values
  1171. - is_input_hash: True if this is an input hash, False for output hash
  1172. Raises:
  1173. ValueError: If logs have different lengths, call types, operator names, or call depths
  1174. Usage::
  1175. # Run model first time
  1176. with DebugMode() as debug_mode, DebugMode.log_tensor_hashes():
  1177. model(x)
  1178. logs1 = debug_mode.logs
  1179. # Run again, in exactly the same way
  1180. with DebugMode() as debug_mode, DebugMode.log_tensor_hashes():
  1181. model(x)
  1182. logs2 = debug_mode.logs
  1183. mismatches = DebugMode.check_hash_mismatches(logs1, logs2)
  1184. for m in mismatches:
  1185. print(f"{m['call']}: hash diff {m['rel_diff']:.2e}")
  1186. """
  1187. if len(logs1) != len(logs2):
  1188. raise ValueError(f"Log lengths don't match: {len(logs1)} vs {len(logs2)}")
  1189. difference_info = []
  1190. for i, (log1, log2) in enumerate(zip(logs1, logs2)):
  1191. # check call type
  1192. call1_type = type(log1).__name__
  1193. call2_type = type(log2).__name__
  1194. if call1_type != call2_type:
  1195. raise ValueError(
  1196. f"Call types don't match at index {i}: {call1_type} vs {call2_type}"
  1197. )
  1198. call_type = call1_type
  1199. # check call name
  1200. op1_name, op2_name = _get_call_name(log1), _get_call_name(log2)
  1201. if op1_name != op2_name:
  1202. raise ValueError(
  1203. f"Operators don't match at index {i}: {call_type}[{op1_name}] vs {call_type}[{op2_name}]"
  1204. )
  1205. op_name = op1_name
  1206. # check call depth
  1207. if log1.call_depth != log2.call_depth:
  1208. raise ValueError(
  1209. f"Call depths for {call_type}[{op_name}] don't match at index {i}: {log1.call_depth} vs {log2.call_depth}"
  1210. )
  1211. # Redistribute: call args should be the same
  1212. if isinstance(log1, _RedistributeCall):
  1213. if tuple(log1) != tuple(log2):
  1214. raise ValueError(
  1215. f"Redistribute calls don't match at index {i}: {log1} vs {log2}"
  1216. )
  1217. # Triton kernel: same arg names, arg types
  1218. elif isinstance(log1, _TritonKernelCall):
  1219. if log1.kwargs_str != log2.kwargs_str:
  1220. raise ValueError(
  1221. f"Triton kernel call args don't match for {log1.kernel_name} at index {i}:"
  1222. f"\n\nlog1: {log1.kwargs_str}\n\nlog2: {log2.kwargs_str}"
  1223. )
  1224. def compare_triton_hashes(hashes1, hashes2, is_input):
  1225. if set(hashes1.keys()) != set(hashes2.keys()): # type: ignore[union-attr]
  1226. raise AssertionError(
  1227. f"hash key mismatch: {set(hashes1.keys())} vs {set(hashes2.keys())}"
  1228. )
  1229. for key in hashes1:
  1230. if hashes1[key] != hashes2[key]:
  1231. difference_info.append(
  1232. {
  1233. "call_type": "triton kernel",
  1234. "call": op_name,
  1235. "arg_name": key,
  1236. "pytree_path": None,
  1237. "hash1": hashes1[key],
  1238. "hash2": hashes2[key],
  1239. "rel_diff": _compute_rel_diff(
  1240. hashes1[key], hashes2[key]
  1241. ),
  1242. "is_input_hash": is_input,
  1243. }
  1244. )
  1245. # check output hashes
  1246. has_post_1, has_post_2 = (
  1247. log1.post_hashes is not None,
  1248. log2.post_hashes is not None,
  1249. )
  1250. if has_post_1 != has_post_2:
  1251. raise ValueError(
  1252. f"Triton kernel post-hash presence inconsistent for {log1.kernel_name} "
  1253. f"at index {i}: log1 has post_hashes={has_post_1}, log2 has post_hashes={has_post_2}"
  1254. )
  1255. if has_post_1:
  1256. compare_triton_hashes(
  1257. log1.post_hashes, log2.post_hashes, is_input=False
  1258. )
  1259. # maybe check input hashes
  1260. if compare_inputs:
  1261. has_pre_1, has_pre_2 = (
  1262. log1.pre_hashes is not None,
  1263. log2.pre_hashes is not None,
  1264. )
  1265. if has_pre_1 != has_pre_2:
  1266. raise ValueError(
  1267. f"Triton kernel pre-hash presence inconsistent for {log1.kernel_name} "
  1268. f"at index {i}: log1 has pre_hashes={has_pre_1}, log2 has pre_hashes={has_pre_2}"
  1269. )
  1270. if has_pre_1:
  1271. compare_triton_hashes(
  1272. log1.pre_hashes, log2.pre_hashes, is_input=True
  1273. )
  1274. # regular log calls
  1275. elif isinstance(log1, _OpCall):
  1276. def compare_op_hashes(hashes1, hashes2, is_input):
  1277. def _helper(keypath, hash1, hash2):
  1278. if hash1 != hash2:
  1279. difference_info.append(
  1280. {
  1281. "call_type": "torch op",
  1282. "call": op_name,
  1283. "arg_name": None,
  1284. "pytree_path": keystr(keypath),
  1285. "hash1": hash1,
  1286. "hash2": hash2,
  1287. "rel_diff": _compute_rel_diff(hash1, hash2),
  1288. "is_input_hash": is_input,
  1289. }
  1290. )
  1291. tree_map_with_path(_helper, hashes1, hashes2)
  1292. # check output hashes
  1293. has_hash1 = log1.log is not None and "hash" in log1.log
  1294. has_hash2 = log2.log is not None and "hash" in log2.log
  1295. if has_hash1 != has_hash2:
  1296. raise ValueError(
  1297. f"Output hash presence inconsistent for triton kernel {call_type}[{op_name}] "
  1298. f"at index {i}: log1 has hash={has_hash1}, log2 has hash={has_hash2}"
  1299. )
  1300. if has_hash1:
  1301. compare_op_hashes(
  1302. log1.log["hash"], # type: ignore[union-attr]
  1303. log2.log["hash"],
  1304. is_input=False,
  1305. )
  1306. # maybe check input hashes
  1307. if compare_inputs:
  1308. has_hash1 = log1.log is not None and "input_hash" in log1.log
  1309. has_hash2 = log2.log is not None and "input_hash" in log2.log
  1310. if has_hash1 != has_hash2:
  1311. raise ValueError(
  1312. f"Input hash presence inconsistent for triton kernel {call_type}[{op_name}] "
  1313. f"at index {i}: log1 has input_hash={has_hash1}, log2 has input_hash={has_hash2}"
  1314. )
  1315. if has_hash1:
  1316. compare_op_hashes(
  1317. log1.log["input_hash"], # type: ignore[union-attr]
  1318. log2.log["input_hash"],
  1319. is_input=True,
  1320. )
  1321. return difference_info
  1322. def get_active_debug_mode() -> DebugMode | None:
  1323. # Fast path: if no DebugMode is active, skip the stack walk
  1324. if _ACTIVE_DEBUG_MODE_COUNT == 0:
  1325. return None
  1326. debug_mode = None
  1327. for mode in _get_current_dispatch_mode_stack():
  1328. if isinstance(mode, DebugMode):
  1329. debug_mode = mode
  1330. break
  1331. return debug_mode