structured.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. """
  2. Utilities for converting data types into structured JSON for dumping.
  3. """
  4. import inspect
  5. import os
  6. import traceback
  7. from collections.abc import Sequence
  8. from typing import Any, Optional
  9. import torch._logging._internal
  10. INTERN_TABLE: dict[str, int] = {}
  11. DUMPED_FILES: set[str] = set()
  12. def intern_string(s: Optional[str]) -> int:
  13. if s is None:
  14. return -1
  15. r = INTERN_TABLE.get(s)
  16. if r is None:
  17. r = len(INTERN_TABLE)
  18. INTERN_TABLE[s] = r
  19. torch._logging._internal.trace_structured(
  20. "str", lambda: (s, r), suppress_context=True
  21. )
  22. return r
  23. def dump_file(filename: str) -> None:
  24. if "eval_with_key" not in filename:
  25. return
  26. if filename in DUMPED_FILES:
  27. return
  28. DUMPED_FILES.add(filename)
  29. from torch.fx.graph_module import _loader
  30. torch._logging._internal.trace_structured(
  31. "dump_file",
  32. metadata_fn=lambda: {
  33. "name": filename,
  34. },
  35. payload_fn=lambda: _loader.get_source(filename),
  36. )
  37. def from_traceback(tb: Sequence[traceback.FrameSummary]) -> list[dict[str, Any]]:
  38. # dict naming convention here coincides with
  39. # python/combined_traceback.cpp
  40. r = [
  41. {
  42. "line": frame.lineno,
  43. "name": frame.name,
  44. "filename": intern_string(frame.filename),
  45. "loc": frame.line,
  46. }
  47. for frame in tb
  48. ]
  49. return r
  50. def get_user_stack(num_frames: int) -> list[dict[str, Any]]:
  51. from torch._guards import TracingContext
  52. from torch.utils._traceback import CapturedTraceback
  53. user_tb = TracingContext.extract_stack()
  54. if user_tb:
  55. return from_traceback(user_tb[-1 * num_frames :])
  56. tb = CapturedTraceback.extract().summary()
  57. # Filter out frames that are within the torch/ codebase
  58. torch_filepath = os.path.dirname(inspect.getfile(torch)) + os.path.sep
  59. for i, frame in enumerate(reversed(tb)):
  60. if torch_filepath not in frame.filename:
  61. # Only display `num_frames` frames in the traceback
  62. filtered_tb = tb[len(tb) - i - num_frames : len(tb) - i]
  63. return from_traceback(filtered_tb)
  64. return from_traceback(tb[-1 * num_frames :])
  65. def get_framework_stack(
  66. num_frames: int = 25, cpp: bool = False
  67. ) -> list[dict[str, Any]]:
  68. """
  69. Returns the traceback for the user stack and the framework stack
  70. """
  71. from torch.fx.experimental.symbolic_shapes import uninteresting_files
  72. from torch.utils._traceback import CapturedTraceback
  73. tb = CapturedTraceback.extract(cpp=cpp).summary()
  74. tb = [
  75. frame
  76. for frame in tb
  77. if (
  78. (
  79. frame.filename.endswith(".py")
  80. and frame.filename not in uninteresting_files()
  81. )
  82. or ("at::" in frame.name or "torch::" in frame.name)
  83. )
  84. ]
  85. return from_traceback(tb[-1 * num_frames :])