_traceback.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import inspect
  4. import os.path
  5. import tempfile
  6. import traceback
  7. from types import TracebackType
  8. # This file contains utilities for ensuring dynamically compile()'d
  9. # code fragments display their line numbers in backtraces.
  10. #
  11. # The constraints:
  12. #
  13. # - We don't have control over the user exception printer (in particular,
  14. # we cannot assume the linecache trick will work, c.f.
  15. # https://stackoverflow.com/q/50515651/23845 )
  16. #
  17. # - We don't want to create temporary files every time we compile()
  18. # some code; file creation should happen lazily only at exception
  19. # time. Arguably, you *should* be willing to write out your
  20. # generated Python code to file system, but in some situations
  21. # (esp. library code) it would violate user expectation to write
  22. # to the file system, so we try to avoid it. In particular, we'd
  23. # like to keep the files around, so users can open up the files
  24. # mentioned in the trace; if the file is invisible, we want to
  25. # avoid clogging up the filesystem.
  26. #
  27. # If this is not a constraint for you, there is a substantially simpler
  28. # way to implement the functionality in this PR: instead of using
  29. # eval/exec directly, just always write a Python file to filesystem
  30. # and compile that.
  31. #
  32. # - You have control over a context where the compiled code will get
  33. # executed, so that we can interpose while the stack is unwinding
  34. # (otherwise, we have no way to interpose on the exception printing
  35. # process.)
  36. #
  37. # There are two things you have to do to make use of the utilities here:
  38. #
  39. # - When you compile your source code, you must save its string source
  40. # in its f_globals under the magic name "__compile_source__"
  41. #
  42. # - Before running the compiled code, enter the
  43. # report_compile_source_on_error() context manager.
  44. @contextlib.contextmanager
  45. def report_compile_source_on_error():
  46. try:
  47. yield
  48. except Exception as exc:
  49. tb = exc.__traceback__
  50. # Walk the traceback, looking for frames that have
  51. # source attached
  52. stack = []
  53. while tb is not None:
  54. filename = tb.tb_frame.f_code.co_filename
  55. source = tb.tb_frame.f_globals.get("__compile_source__")
  56. if filename == "<string>" and source is not None:
  57. # What black magic are we doing here? Intuitively, what
  58. # we would like to do is overwrite the co_filename on any
  59. # frames that were generated from exec/eval so that they
  60. # point to a temporary file that has the actual line
  61. # information, so Python's default error printer can print
  62. # useful line information on it.
  63. #
  64. # Writing out the temporary file is easy. But overwriting
  65. # co_filename is not! You can't modify the code object
  66. # associated with a frame. You can, however, reconstruct
  67. # a traceback with entirely new frames from scratch, so that's
  68. # what we do. But there's another problem, which is how to
  69. # make the frame?
  70. #
  71. # The black magic is we make a frankenstein frame and code
  72. # object which resembles the original frame/code enough so
  73. # that it will print properly under traceback and the default
  74. # error printer, but IT IS NOT THE ORIGINAL FRAME (you
  75. # couldn't, e.g., execute its code with different variables
  76. # and expect it to work.)
  77. # Don't delete the temporary file so the user can inspect it
  78. # TODO: This creates a temporary file for every frame, but we
  79. # technically only need one per distinct __compile_source__
  80. with tempfile.NamedTemporaryFile(
  81. mode="w", delete=False, suffix=".py"
  82. ) as f:
  83. f.write(source)
  84. # Create a frame. Python doesn't let you construct
  85. # FrameType directly, so just make one with compile
  86. frame = tb.tb_frame
  87. code = compile("__inspect_currentframe()", f.name, "eval")
  88. code = code.replace(co_name=frame.f_code.co_name)
  89. # Python 3.11 only
  90. if hasattr(frame.f_code, "co_linetable"):
  91. # We can't copy ALL of the metadata over, because you
  92. # can cause Python to segfault this way. What exactly
  93. # do we need? We need enough information for
  94. # traceback to be able to print the exception
  95. # correctly. Code reading Lib/traceback.py reveals
  96. # that traceback calls code.co_positions() in order to
  97. # get the augmented line/col numbers. Objects/codeobject.c,
  98. # specifically _PyCode_InitAddressRange, reveals that
  99. # this iterator is initialized from co_linetable and
  100. # co_firstfileno. So copy these we must!
  101. code = code.replace( # type: ignore[call-arg]
  102. co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined]
  103. co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined]
  104. )
  105. fake_frame = eval(
  106. code,
  107. frame.f_globals,
  108. {**frame.f_locals, "__inspect_currentframe": inspect.currentframe},
  109. )
  110. fake_tb = TracebackType(None, fake_frame, tb.tb_lasti, tb.tb_lineno)
  111. stack.append(fake_tb)
  112. else:
  113. stack.append(tb)
  114. tb = tb.tb_next
  115. # Reconstruct the linked list
  116. tb_next = None
  117. for tb in reversed(stack):
  118. tb.tb_next = tb_next
  119. tb_next = tb
  120. raise exc.with_traceback(tb_next) # noqa: B904
  121. def shorten_filename(fn, *, base=None):
  122. """Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user."""
  123. if base is None:
  124. base = os.path.dirname(os.path.dirname(__file__))
  125. # Truncate torch/foo.py to foo.py
  126. try:
  127. prefix = os.path.commonpath([fn, base])
  128. except ValueError:
  129. return fn
  130. else:
  131. return fn[len(prefix) + 1 :]
  132. def format_frame(frame, *, base=None, line=False) -> str:
  133. """
  134. Format a FrameSummary in a short way, without printing full absolute path or code.
  135. The idea is the result fits on a single line.
  136. """
  137. extra_line = ""
  138. if line:
  139. extra_line = f"{frame.line} # "
  140. return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}"
  141. def format_traceback_short(tb):
  142. """Format a TracebackType in a short way, printing only the inner-most frame."""
  143. return format_frame(traceback.extract_tb(tb)[-1])
  144. class CapturedTraceback:
  145. __slots__ = ["tb", "skip"]
  146. def __init__(self, tb, skip=0) -> None:
  147. self.tb = tb
  148. self.skip = skip
  149. def cleanup(self) -> None:
  150. self.tb = None
  151. def summary(self):
  152. import torch._C._profiler
  153. if self.tb is None:
  154. # TODO: Maybe indicate that the traceback was elided?
  155. return traceback.StackSummary()
  156. return _extract_symbolized_tb(
  157. torch._C._profiler.symbolize_tracebacks([self.tb])[0], self.skip
  158. )
  159. def __getstate__(self):
  160. return (
  161. None,
  162. {
  163. "tb": None, # TB is not pickleable
  164. "skip": self.skip,
  165. },
  166. )
  167. @staticmethod
  168. def extract(*, script=False, cpp=False, skip=0):
  169. """
  170. Like traceback.extract_stack(), but faster (approximately 20x faster); it
  171. is fast enough that you can unconditionally log stacks this way as part of
  172. normal execution. It returns a torch._C._profiler.CapturedTraceback
  173. object that must be formatted specially with format_captured_tb.
  174. By default, this only reports Python backtraces (like extract_stack). You
  175. can set the script/cpp kwargs to also turn on TorchScript/C++ trace
  176. reporting.
  177. """
  178. import torch._C._profiler
  179. if script or cpp:
  180. if skip != 0:
  181. raise AssertionError("skip with script/cpp NYI")
  182. return CapturedTraceback(
  183. torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),
  184. # Elide extract() frame if we don't have script/cpp frames. If
  185. # we do have those frames, it doesn't work so force zero.
  186. 0 if script or cpp else skip + 1,
  187. )
  188. def format(self):
  189. """
  190. Formats a single torch._C._profiler.CapturedTraceback into a list of
  191. strings equivalent to the output of traceback.format_list. Note that if
  192. pass it CapturedTraceback with C++ traces, it is better not to use this
  193. function and use the batch formatting API format_captured_tbs to amortize
  194. the cost of symbolization
  195. """
  196. return traceback.format_list(self.summary())
  197. @staticmethod
  198. def format_all(tbs):
  199. """
  200. Bulk version of CapturedTraceback.format. Returns a list of list of strings.
  201. """
  202. import torch._C._profiler
  203. # Directly populate tracebacks that already have cached summaries
  204. rs: list[list[str] | None] = []
  205. delayed_idxs = []
  206. for i, tb in enumerate(tbs):
  207. if tb.tb is None:
  208. rs.append([])
  209. else:
  210. rs.append(None)
  211. delayed_idxs.append(i)
  212. torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs])
  213. for i in delayed_idxs:
  214. rs[i] = traceback.format_list(tbs[i].summary())
  215. return rs
  216. def _extract_symbolized_tb(tb, skip):
  217. """
  218. Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of
  219. pre-processed stack trace entries.
  220. """
  221. stack = traceback.StackSummary()
  222. for f in reversed(tb[skip:]):
  223. stack.append(traceback.FrameSummary(f["filename"], f["line"], f["name"]))
  224. return stack