_utils.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # mypy: allow-untyped-defs
  2. import sys
  3. from typing import Optional
  4. import torch
  5. from torch._logging import LazyString
  6. def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
  7. """
  8. Returns a LazyString that formats the graph code.
  9. """
  10. def format_name():
  11. if maybe_id is not None:
  12. return f"{name} {maybe_id}"
  13. else:
  14. return name
  15. if "print_output" not in kwargs:
  16. kwargs["print_output"] = False
  17. if "colored" in kwargs:
  18. try:
  19. if not sys.stdout.isatty():
  20. kwargs["colored"] = False
  21. except AttributeError:
  22. kwargs["colored"] = False
  23. return LazyString(
  24. lambda: _format_graph_code(
  25. f"===== {format_name()} =====\n",
  26. gm.forward.__code__.co_filename,
  27. gm.print_readable(**kwargs),
  28. )
  29. )
  30. def _format_graph_code(name, filename, graph_str):
  31. """
  32. Returns a string that formats the graph code.
  33. """
  34. return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
  35. def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[dict]:
  36. """
  37. Returns the nn_module_stack of the first call_function node.
  38. """
  39. for node in graph.nodes:
  40. if node.op == "call_function" and "nn_module_stack" in node.meta:
  41. return node.meta["nn_module_stack"]
  42. return None
  43. def get_node_context(node, num_nodes=2) -> str:
  44. """
  45. Returns a string of the last num_nodes nodes in the graph.
  46. """
  47. node_contexts = []
  48. cur = node
  49. for _ in range(num_nodes):
  50. node_contexts.append(cur.format_node())
  51. if cur.op == "root":
  52. break
  53. cur = cur.prev
  54. return "\n".join(node_contexts[::-1])