traceback.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import logging
  4. import traceback
  5. from collections import defaultdict
  6. from contextlib import contextmanager
  7. from enum import Enum
  8. from typing import Any, Optional, Union
  9. from torch._utils_internal import signpost_event
  10. from ._compatibility import compatibility
  11. from .graph import Graph
  12. from .graph_module import GraphModule
  13. from .node import Node
  14. log = logging.getLogger(__name__)
  15. __all__ = [
  16. "annotate",
  17. "annotate_fn",
  18. "preserve_node_meta",
  19. "has_preserved_node_meta",
  20. "set_stack_trace",
  21. "set_grad_fn_seq_nr",
  22. "reset_grad_fn_seq_nr",
  23. "format_stack",
  24. "set_current_meta",
  25. "get_current_meta",
  26. "NodeSource",
  27. "NodeSourceAction",
  28. "get_graph_provenance_json",
  29. "set_current_replay_node",
  30. "get_current_replay_node",
  31. ]
  32. current_meta: dict[str, Any] = {}
  33. current_replay_node: Optional[Node] = None
  34. # Preserve the node meta fields in torch.fx.proxy._COPY_META_FIELDS
  35. should_preserve_node_meta = False
  36. # Preserve the "seq_nr" node meta field
  37. _should_preserve_node_meta = False
  38. GRADIENT_ACC_SPECIAL_STACK = (
  39. "Gradient addition node due to multiple use of tensor around:"
  40. )
  41. # =============================================================================
  42. # FX Metadata Registry for Memory Profiler
  43. # =============================================================================
  44. # Global in-memory registry for FX metadata
  45. # Maps module_name -> metadata dict containing lineno_map and node_metadata
  46. _FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {}
  47. def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
  48. """
  49. Register FX metadata in the global in-memory registry.
  50. This is called automatically during graph module compilation to store metadata
  51. for later use by memory profiler augmentation.
  52. Args:
  53. module_name: The module identifier (content-addressed filename)
  54. metadata: Metadata dict containing lineno_map, node_metadata, and source_code
  55. """
  56. # TODO: add logging to tlparse
  57. _FX_METADATA_REGISTRY[module_name] = metadata
  58. @compatibility(is_backward_compatible=False)
  59. class NodeSourceAction(Enum):
  60. CREATE = "create"
  61. REPLACE = "replace"
  62. @compatibility(is_backward_compatible=False)
  63. class NodeSource:
  64. """
  65. NodeSource is a data structure that contains the provenance information of a node.
  66. If node `a` is created from node `b`, then `a.meta["from_node"]` may contain NodeSource(b).
  67. """
  68. class NodeInfo:
  69. def __init__(self, name: str, target: str, graph_id: int):
  70. self.name = name
  71. self.target = target
  72. self.graph_id = graph_id
  73. pass_name: str
  74. action: list["NodeSourceAction"]
  75. from_node: list["NodeSource"]
  76. node_info: Optional["NodeInfo"]
  77. _dict: Optional[dict[str, Any]]
  78. _action_string: Optional[str]
  79. def __init__(
  80. self,
  81. node: Optional[Node],
  82. pass_name: str = "",
  83. action: Optional[Union["NodeSourceAction", list["NodeSourceAction"]]] = None,
  84. ):
  85. self.pass_name = pass_name
  86. if action is None:
  87. action = []
  88. elif not isinstance(action, list):
  89. action = [action]
  90. for a in action:
  91. if not isinstance(a, NodeSourceAction):
  92. raise AssertionError(f"Expected NodeSourceAction, got {type(a)}")
  93. self.action = action
  94. if node:
  95. self.node_info = self.NodeInfo(
  96. name=node.name, target=str(node.target), graph_id=id(node.graph)
  97. )
  98. self.from_node = (
  99. copy.deepcopy(node.meta["from_node"])
  100. if "from_node" in node.meta
  101. else []
  102. )
  103. else:
  104. self.node_info = None
  105. self.from_node = []
  106. # cache the action string and dict representation for performance.
  107. self._action_string: Optional[str] = None
  108. self._dict: Optional[dict[str, Any]] = None
  109. @property
  110. def name(self) -> str:
  111. return self.node_info.name if self.node_info else ""
  112. @property
  113. def target(self) -> str:
  114. return self.node_info.target if self.node_info else ""
  115. @property
  116. def graph_id(self) -> int:
  117. return self.node_info.graph_id if self.node_info else -1
  118. def __repr__(self):
  119. return self.print_readable()
  120. def _get_action_string(self):
  121. if self._action_string is None:
  122. self._action_string = "+".join([a.name.lower() for a in self.action])
  123. return self._action_string
  124. def print_readable(self, indent=0):
  125. if indent > 9:
  126. return ""
  127. result = ""
  128. action_string = self._get_action_string()
  129. result += (
  130. " " * indent * 4
  131. + f"(name={self.name}, pass_name={self.pass_name}, action={action_string}, graph_id={self.graph_id})\n"
  132. )
  133. for item in self.from_node:
  134. result += item.print_readable(indent + 1)
  135. return result
  136. def to_dict(self) -> dict:
  137. if self._dict is None:
  138. # Convert the object to a dictionary
  139. action_string = self._get_action_string()
  140. self._dict = {
  141. "name": self.name,
  142. "target": self.target,
  143. "graph_id": self.graph_id,
  144. "pass_name": self.pass_name,
  145. "action": action_string,
  146. "from_node": [node.to_dict() for node in self.from_node],
  147. }
  148. if self._dict is None:
  149. raise AssertionError("_dict is None after initialization")
  150. return self._dict
  151. def __eq__(self, other: object):
  152. if not isinstance(other, NodeSource):
  153. return False
  154. return self.to_dict() == other.to_dict()
  155. def __hash__(self):
  156. # Create a hash based on the dictionary representation
  157. # We need to convert the dict to a hashable form
  158. def _make_hashable(obj):
  159. if isinstance(obj, dict):
  160. return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items()))
  161. elif isinstance(obj, list):
  162. return tuple(_make_hashable(item) for item in obj)
  163. else:
  164. return obj
  165. return hash(_make_hashable(self.to_dict()))
  166. @classmethod
  167. def _from_dict(cls, d: Optional[dict]) -> Optional["NodeSource"]:
  168. """
  169. Recursively deserialize from_node metadata from dictionary data.
  170. It is used to deserialize the from_node field from serialized metadata.
  171. Please use constructor NodeSource(node, ...) to create a NodeSource object.
  172. """
  173. if d is None:
  174. return None
  175. if not isinstance(d, dict):
  176. raise AssertionError(f"Expected a dict, got {type(d)}")
  177. # Create a NodeSource object directly without going through the constructor
  178. # to avoid issues with graph ID and node creation
  179. node_source = NodeSource.__new__(NodeSource)
  180. # Reset the cached properties
  181. node_source._action_string = None
  182. node_source._dict = None
  183. # Set the basic attributes
  184. node_source.pass_name = d.get("pass_name", "")
  185. # Parse action string back to NodeSourceAction enum list
  186. action_str = d.get("action", "")
  187. actions = []
  188. if action_str:
  189. for action_name in action_str.split("+"):
  190. if action_name.upper() == "CREATE":
  191. actions.append(NodeSourceAction.CREATE)
  192. elif action_name.upper() == "REPLACE":
  193. actions.append(NodeSourceAction.REPLACE)
  194. node_source.action = actions
  195. # Create the NodeInfo object directly
  196. if "name" in d and "target" in d and "graph_id" in d:
  197. node_info = NodeSource.NodeInfo(
  198. d.get("name", ""), d.get("target", ""), d.get("graph_id", -1)
  199. )
  200. node_source.node_info = node_info
  201. else:
  202. node_source.node_info = None
  203. # Recursively deserialize nested from_node
  204. if d.get("from_node", None) is not None:
  205. node_source.from_node = [
  206. result
  207. for fn in d.get("from_node", [])
  208. if (result := cls._from_dict(fn)) is not None
  209. ]
  210. else:
  211. node_source.from_node = []
  212. return node_source
  213. @compatibility(is_backward_compatible=False)
  214. @contextmanager
  215. def preserve_node_meta(enable=True):
  216. global should_preserve_node_meta
  217. global current_meta
  218. saved_should_preserve_node_meta = should_preserve_node_meta
  219. # Shallow copy is OK since fields of current_meta are not mutated
  220. saved_current_meta = current_meta.copy()
  221. try:
  222. should_preserve_node_meta = enable
  223. yield
  224. finally:
  225. should_preserve_node_meta = saved_should_preserve_node_meta
  226. current_meta = saved_current_meta
  227. @contextmanager
  228. def _preserve_node_seq_nr(preserve_seq_nr=True):
  229. """
  230. Temporarily enables or disables the preservation of node.meta["seq_nr"] in the
  231. tracing context.
  232. """
  233. global _should_preserve_node_meta
  234. saved = _should_preserve_node_meta
  235. try:
  236. _should_preserve_node_meta = preserve_seq_nr
  237. yield
  238. finally:
  239. _should_preserve_node_meta = saved
  240. @compatibility(is_backward_compatible=False)
  241. def set_stack_trace(stack: list[str]):
  242. global current_meta
  243. if should_preserve_node_meta and stack:
  244. current_meta["stack_trace"] = "".join(stack)
  245. @compatibility(is_backward_compatible=False)
  246. @contextmanager
  247. def annotate(annotation_dict: dict):
  248. """
  249. Temporarily adds custom annotations to the current tracing context.
  250. The fx_node produced from this tracing context will have the
  251. custom annotations in node.metadata["custom"] field.
  252. This context manager allows you to insert arbitrary metadata into the PT2
  253. tracing system by updating the global `current_meta["custom"]` dictionary.
  254. The annotations are automatically reverted after the context exits.
  255. Gradient accumulation nodes will not be annotated.
  256. This is intended for advanced users who need to attach additional metadata to the fx nodes
  257. (e.g., for debugging, analysis, or external tooling) during export tracing.
  258. Note:
  259. This API is **not backward compatible** and may evolve in future releases.
  260. Note:
  261. This API is not compatible with fx.symbolic_trace or jit.trace. It's intended
  262. to be used with PT2 family of tracers, e.g. torch.export and dynamo.
  263. Args:
  264. annotation_dict (dict): A dictionary of custom key-value pairs to inject
  265. into the FX trace metadata.
  266. Example:
  267. After exiting the context, custom annotations are removed.
  268. >>> with annotate({"source": "custom_pass", "tag": 42}):
  269. ... pass # Your computation here
  270. """
  271. global current_meta
  272. has_custom = "custom" in current_meta
  273. old_custom = copy.copy(current_meta.get("custom", {}))
  274. try:
  275. if not has_custom:
  276. current_meta["custom"] = {}
  277. # Update with all key-value pairs from the input dict
  278. current_meta["custom"].update(annotation_dict)
  279. yield
  280. finally:
  281. if has_custom:
  282. # Restore the original custom dict
  283. current_meta["custom"] = old_custom
  284. else:
  285. del current_meta["custom"]
  286. @compatibility(is_backward_compatible=False)
  287. def annotate_fn(annotation_dict: dict):
  288. """
  289. A decorator that wraps a function with the annotate context manager.
  290. Use this when you want to annotate an entire function instead of a specific code block.
  291. Note:
  292. This API is **not backward compatible** and may evolve in future releases.
  293. Note:
  294. This API is not compatible with fx.symbolic_trace or jit.trace. It's intended
  295. to be used with PT2 family of tracers, e.g. torch.export and dynamo.
  296. Args:
  297. annotation_dict (dict): A dictionary of custom key-value pairs to inject
  298. into the FX trace metadata for all operations in the function.
  299. Example:
  300. All operations in my_function will have {"pp_stage": 1} in their metadata.
  301. >>> @annotate_fn({"pp_stage": 1})
  302. ... def my_function(x):
  303. ... return x + 1
  304. """
  305. from functools import wraps
  306. def decorator(func):
  307. @wraps(func)
  308. def wrapper(*args, **kwargs):
  309. with annotate(annotation_dict):
  310. return func(*args, **kwargs)
  311. return wrapper
  312. return decorator
  313. @compatibility(is_backward_compatible=False)
  314. def set_grad_fn_seq_nr(seq_nr):
  315. global current_meta
  316. if should_preserve_node_meta:
  317. # The seq_nr is captured by eager mode in the grad_fn during forward
  318. current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [
  319. seq_nr
  320. ]
  321. current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1
  322. @compatibility(is_backward_compatible=False)
  323. def reset_grad_fn_seq_nr():
  324. # NB: reset state properly, this would be helpful towards supporting
  325. # reentrant autograd if we actually wanted to do that.
  326. global current_meta
  327. if should_preserve_node_meta:
  328. current_level = current_meta.get("in_grad_fn", 0)
  329. if current_level <= 0:
  330. raise AssertionError(f"Expected current_level > 0, got {current_level}")
  331. if current_level == 1:
  332. del current_meta["in_grad_fn"]
  333. del current_meta["grad_fn_seq_nr"]
  334. else:
  335. current_meta["in_grad_fn"] = current_level - 1
  336. current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1]
  337. @compatibility(is_backward_compatible=False)
  338. def format_stack() -> list[str]:
  339. if should_preserve_node_meta:
  340. return [current_meta.get("stack_trace", "")]
  341. else:
  342. # fallback to traceback.format_stack()
  343. return traceback.format_list(traceback.extract_stack()[:-1])
  344. @compatibility(is_backward_compatible=False)
  345. def has_preserved_node_meta() -> bool:
  346. return should_preserve_node_meta
  347. def _is_preserving_node_seq_nr() -> bool:
  348. return _should_preserve_node_meta
  349. @compatibility(is_backward_compatible=False)
  350. @contextmanager
  351. def set_current_meta(node, pass_name=""):
  352. global current_meta
  353. if should_preserve_node_meta and node.meta:
  354. saved_meta = current_meta
  355. try:
  356. current_meta = node.meta.copy()
  357. # Update the "from_node" field in current_meta for provenance tracking.
  358. # Instead of appending, overwrite the "from_node" field because current_meta
  359. # will be assigned to the new node. The new NodeSource(node, ...) will
  360. # include the information from the previous current_meta["from_node"].
  361. current_meta["from_node"] = [
  362. NodeSource(node, pass_name, NodeSourceAction.CREATE)
  363. ]
  364. yield
  365. finally:
  366. current_meta = saved_meta
  367. else:
  368. yield
  369. @compatibility(is_backward_compatible=False)
  370. def get_current_meta() -> dict[str, Any]:
  371. return current_meta
  372. @compatibility(is_backward_compatible=False)
  373. @contextmanager
  374. def set_current_replay_node(node):
  375. """
  376. Set the currently replay node. If `current_replay_node` is not None,
  377. then we're re-generating the `current_replay_node` in FunctionalTensorMode.
  378. """
  379. # See [Note] annotation for more details.
  380. global current_replay_node
  381. saved_current_replay_node = current_replay_node
  382. try:
  383. current_replay_node = node
  384. yield
  385. finally:
  386. current_replay_node = saved_current_replay_node
  387. @compatibility(is_backward_compatible=False)
  388. def get_current_replay_node():
  389. """
  390. Get the currently replay node
  391. """
  392. return current_replay_node
  393. @compatibility(is_backward_compatible=False)
  394. def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
  395. """
  396. Given an fx.Graph, return a json that contains the provenance information of each node.
  397. """
  398. try:
  399. provenance_tracking_json = {}
  400. for node in graph.nodes:
  401. if node.op == "call_function":
  402. provenance_tracking_json[node.name] = (
  403. [source.to_dict() for source in node.meta["from_node"]]
  404. if "from_node" in node.meta
  405. else []
  406. )
  407. return provenance_tracking_json
  408. except Exception as e:
  409. # Since this is just debugging, it should never interfere with regular
  410. # program execution, so we use this try-except to guard against any error
  411. signpost_event(
  412. "inductor",
  413. "provenance_tracking_error",
  414. {
  415. "function": "get_graph_provenance_json",
  416. "error_msg": str(e),
  417. "stack_trace": traceback.format_exc(),
  418. },
  419. )
  420. return {}
  421. def _get_custom_metadata(gm: GraphModule) -> str:
  422. if not isinstance(gm, GraphModule):
  423. raise AssertionError(f"Expected GraphModule, got {type(gm)}")
  424. def helper(gm: GraphModule):
  425. custom_metadata = []
  426. for node in gm.graph.nodes:
  427. if hasattr(node, "meta") and node.meta.get("custom", None):
  428. custom_metadata.append((node.op, node.name, node.meta["custom"]))
  429. if node.op == "get_attr" and isinstance(
  430. getattr(gm, node.target), GraphModule
  431. ):
  432. custom_metadata.append(helper(getattr(gm, node.target)))
  433. return custom_metadata
  434. return "\n".join(str(x) for x in helper(gm))
  435. def _get_ordered_seq_nr_groups(
  436. gm: Union[GraphModule, list[GraphModule]],
  437. ) -> list[list[str]]:
  438. """
  439. Group call_function nodes by seq_nr, order by seq_nr value,
  440. and return a list of lists of node names (sorted alphabetically).
  441. Args:
  442. gm: A single GraphModule or a list of GraphModules to process.
  443. When a list is provided, nodes from all graphs are grouped together.
  444. Returns:
  445. A list of lists, where each inner list contains node names that share the same seq_nr,
  446. sorted alphabetically. The outer list is ordered by seq_nr value.
  447. """
  448. # Normalize input to a list
  449. if isinstance(gm, GraphModule):
  450. gms = [gm]
  451. else:
  452. gms = gm
  453. seq_nr_dict: dict[int, list[str]] = defaultdict(list)
  454. for graph_module in gms:
  455. for node in graph_module.graph.nodes:
  456. if node.op == "call_function":
  457. seq_nr = node.meta.get("seq_nr")
  458. if seq_nr is not None:
  459. seq_nr_dict[seq_nr].append(node.name)
  460. # Sort by seq_nr and return list of sorted lists
  461. return [sorted(seq_nr_dict[k]) for k in sorted(seq_nr_dict.keys())]