dynamo_profiler.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. """
  2. Dynamo Profiler - tracks where Dynamo spends time during compilation.
  3. This module provides profiling functionality for Dynamo tracing, showing per-function
  4. cumtime (inclusive) and tottime (exclusive) in a cProfile-compatible format.
  5. The output can be visualized with tools like snakeviz.
  6. Usage:
  7. # Enable via config (prints pstats output):
  8. torch._dynamo.config.dynamo_profiler = True
  9. # Or save to file for snakeviz:
  10. torch._dynamo.config.dynamo_profiler = "/tmp/dynamo.prof"
  11. # Then: snakeviz /tmp/dynamo.prof
  12. """
  13. from __future__ import annotations
  14. from dataclasses import dataclass
  15. from typing import Any, TYPE_CHECKING
  16. if TYPE_CHECKING:
  17. import pstats
  18. @dataclass
  19. class FunctionTraceTiming:
  20. """
  21. Timing data for a single inlined function trace.
  22. Follows cProfile conventions:
  23. - cumtime: total time in function including all subcalls (inclusive)
  24. - tottime: time in function excluding subcalls (exclusive)
  25. - caller info: who called this function (for building call graph)
  26. """
  27. # Function identification
  28. func_name: str
  29. filename: str
  30. firstlineno: int
  31. # Timing data (in nanoseconds) - cProfile-style
  32. cumtime_ns: int # Inclusive time (includes subcalls)
  33. tottime_ns: int # Exclusive time (excludes subcalls)
  34. # Code stats (for comparing tracing overhead vs function complexity)
  35. bytecode_count: int
  36. # Nesting depth when this function was traced
  37. inline_depth: int
  38. # Caller information (for building call graph edges)
  39. caller_func_name: str | None = None
  40. caller_filename: str | None = None
  41. caller_firstlineno: int | None = None
  42. # Whether this is a primitive (non-recursive) call
  43. # A call is primitive if the function doesn't appear anywhere in the call stack
  44. is_primitive_call: bool = True
  45. # Full call stack at the time of this call (for proper snakeviz drill-down)
  46. # Each entry is (func_name, filename, firstlineno)
  47. call_stack: tuple[tuple[str, str, int], ...] = ()
  48. # Backwards compatibility alias
  49. @property
  50. def trace_time_ns(self) -> int:
  51. return self.cumtime_ns
  52. @property
  53. def trace_time_ms(self) -> float:
  54. return self.cumtime_ns / 1e6
  55. @property
  56. def cumtime_ms(self) -> float:
  57. return self.cumtime_ns / 1e6
  58. @property
  59. def tottime_ms(self) -> float:
  60. return self.tottime_ns / 1e6
  61. @property
  62. def caller_key(self) -> tuple[str, int, str] | None:
  63. """Return caller as a pstats-compatible key tuple."""
  64. if self.caller_func_name is not None:
  65. return (
  66. self.caller_filename or "",
  67. self.caller_firstlineno or 0,
  68. self.caller_func_name,
  69. )
  70. return None
  71. @property
  72. def func_key(self) -> tuple[str, int, str]:
  73. """Return this function as a pstats-compatible key tuple."""
  74. return (self.filename, self.firstlineno, self.func_name)
  75. def __repr__(self) -> str:
  76. return (
  77. f"FunctionTraceTiming({self.func_name} at {self.filename}:{self.firstlineno}, "
  78. f"cumtime={self.cumtime_ms:.2f}ms, tottime={self.tottime_ms:.2f}ms, "
  79. f"bytecode={self.bytecode_count}, depth={self.inline_depth})"
  80. )
  81. @dataclass
  82. class ProfilerStackEntry:
  83. """Stack entry for tracking function timing in the Dynamo profiler."""
  84. func_name: str
  85. filename: str
  86. firstlineno: int
  87. start_time_ns: int
  88. child_time_ns: int # Accumulated time spent in traced children
  89. is_primitive_call: bool = True # Whether this is a non-recursive call
  90. class DynamoProfilerState:
  91. """State for Dynamo profiler tracking function trace timings."""
  92. def __init__(self) -> None:
  93. self.timings: list[FunctionTraceTiming] = []
  94. self.stack: list[ProfilerStackEntry] = []
  95. def record_timing(self, timing: FunctionTraceTiming) -> None:
  96. """Record timing data for a traced function."""
  97. self.timings.append(timing)
  98. def get_timings(self) -> list[FunctionTraceTiming]:
  99. """Get all recorded timings."""
  100. return self.timings
  101. def push(
  102. self, func_name: str, filename: str, firstlineno: int, start_time_ns: int
  103. ) -> None:
  104. """Push a new entry onto the timing stack."""
  105. # Check if this function already exists in the stack (indirect recursion)
  106. is_primitive = not any(
  107. entry.func_name == func_name
  108. and entry.filename == filename
  109. and entry.firstlineno == firstlineno
  110. for entry in self.stack
  111. )
  112. self.stack.append(
  113. ProfilerStackEntry(
  114. func_name=func_name,
  115. filename=filename,
  116. firstlineno=firstlineno,
  117. start_time_ns=start_time_ns,
  118. child_time_ns=0,
  119. is_primitive_call=is_primitive,
  120. )
  121. )
  122. def pop(self) -> ProfilerStackEntry | None:
  123. """Pop the top entry from the timing stack."""
  124. if self.stack:
  125. return self.stack.pop()
  126. return None
  127. def add_child_time(self, child_cumtime_ns: int) -> None:
  128. """Add the child's cumulative time to the parent's child_time accumulator."""
  129. if self.stack:
  130. self.stack[-1].child_time_ns += child_cumtime_ns
  131. def get_current_caller(self) -> tuple[str, str, int] | None:
  132. """Get the current caller (top of stack) as (func_name, filename, firstlineno)."""
  133. if self.stack:
  134. entry = self.stack[-1]
  135. return (entry.func_name, entry.filename, entry.firstlineno)
  136. return None
  137. def get_call_stack(self) -> tuple[tuple[str, str, int], ...]:
  138. """Get the full current call stack as tuple of (func_name, filename, firstlineno)."""
  139. return tuple(
  140. (entry.func_name, entry.filename, entry.firstlineno) for entry in self.stack
  141. )
  142. def generate_pstats(
  143. self, output_file: str | None = None, print_raw: bool = False
  144. ) -> pstats.Stats:
  145. """Generate pstats.Stats object from recorded timings.
  146. Args:
  147. output_file: Optional file path to save the stats.
  148. print_raw: If True, print raw aggregate timings before returning.
  149. """
  150. import cProfile
  151. import io
  152. import logging
  153. import pstats
  154. log = logging.getLogger(__name__)
  155. # Aggregate by (filename, lineno, func_name)
  156. aggregated: dict[tuple[str, int, str], dict[str, Any]] = {}
  157. # caller_edges[callee_key][caller_key] -> edge stats
  158. caller_edges: dict[
  159. tuple[str, int, str], dict[tuple[str, int, str], dict[str, Any]]
  160. ] = {}
  161. for t in self.timings:
  162. key = (t.filename, t.firstlineno, t.func_name)
  163. if key not in aggregated:
  164. aggregated[key] = {
  165. "ncalls": 0,
  166. "pcalls": 0,
  167. "tottime": 0.0,
  168. "cumtime": 0.0,
  169. }
  170. caller_edges[key] = {}
  171. agg = aggregated[key]
  172. agg["ncalls"] += 1
  173. agg["tottime"] += t.tottime_ns / 1e9
  174. if t.is_primitive_call:
  175. agg["pcalls"] += 1
  176. agg["cumtime"] += t.cumtime_ns / 1e9
  177. # Build caller edge
  178. if t.caller_filename is not None:
  179. caller_key = (
  180. t.caller_filename,
  181. t.caller_firstlineno or 0,
  182. t.caller_func_name or "",
  183. )
  184. if caller_key not in caller_edges[key]:
  185. caller_edges[key][caller_key] = {
  186. "ncalls": 0,
  187. "pcalls": 0,
  188. "tottime": 0.0,
  189. "cumtime": 0.0,
  190. }
  191. edge = caller_edges[key][caller_key]
  192. edge["ncalls"] += 1
  193. edge["tottime"] += t.tottime_ns / 1e9
  194. # Always add cumtime to edges for visualization (gprof2dot)
  195. # Function-level cumtime is already correct (only primitive calls)
  196. edge["cumtime"] += t.cumtime_ns / 1e9
  197. if t.is_primitive_call:
  198. edge["pcalls"] += 1
  199. if print_raw:
  200. sorted_items = sorted(
  201. aggregated.items(), key=lambda x: x[1]["cumtime"], reverse=True
  202. )
  203. print("\n=== Aggregate Timings (raw) ===")
  204. print(
  205. f"{'ncalls':>8} {'pcalls':>8} {'tottime':>12} {'cumtime':>12} function"
  206. )
  207. print("-" * 80)
  208. total_cumtime = 0.0
  209. total_tottime = 0.0
  210. for (filename, lineno, func_name), agg in sorted_items:
  211. ncalls = agg["ncalls"]
  212. pcalls = agg["pcalls"]
  213. tottime = agg["tottime"] * 1000 # convert to ms
  214. cumtime = agg["cumtime"] * 1000
  215. total_cumtime += cumtime
  216. total_tottime += tottime
  217. short_file = filename.split("/")[-1] if "/" in filename else filename
  218. print(
  219. f"{ncalls:>8} {pcalls:>8} {tottime:>10.2f}ms {cumtime:>10.2f}ms "
  220. f"{func_name} ({short_file}:{lineno})"
  221. )
  222. print("-" * 80)
  223. print(
  224. f"Total timings: {len(self.timings)}, unique functions: {len(aggregated)}"
  225. )
  226. print(
  227. f"Sum tottime: {total_tottime:.2f}ms, Sum cumtime: {total_cumtime:.2f}ms"
  228. )
  229. # Ensure caller-only functions have a top-level entry.
  230. # gprof2dot expects every function referenced as a caller to also
  231. # exist as a top-level entry in the stats dict with timing data.
  232. for key in list(caller_edges.keys()):
  233. for caller_key in caller_edges[key]:
  234. if caller_key not in aggregated:
  235. aggregated[caller_key] = {
  236. "ncalls": 0,
  237. "pcalls": 0,
  238. "tottime": 0.0,
  239. "cumtime": 0.0,
  240. }
  241. caller_edges[caller_key] = {}
  242. # Build the stats dict in pstats format
  243. stats_dict: dict[
  244. tuple[str, int, str], tuple[int, int, float, float, dict[Any, Any]]
  245. ] = {}
  246. for key, agg in aggregated.items():
  247. callers: dict[tuple[str, int, str], tuple[int, int, float, float]] = {}
  248. for caller_key, edge in caller_edges[key].items():
  249. callers[caller_key] = (
  250. edge["ncalls"],
  251. edge["pcalls"],
  252. edge["tottime"],
  253. edge["cumtime"],
  254. )
  255. stats_dict[key] = (
  256. agg["pcalls"],
  257. agg["ncalls"],
  258. agg["tottime"],
  259. agg["cumtime"],
  260. callers,
  261. )
  262. # Create a pstats.Stats object
  263. dummy_profile = cProfile.Profile()
  264. dummy_profile.enable()
  265. dummy_profile.disable()
  266. stats = pstats.Stats(dummy_profile, stream=io.StringIO())
  267. stats.stats = stats_dict # type: ignore[attr-defined]
  268. stats.total_calls = sum(s[1] for s in stats_dict.values()) # type: ignore[attr-defined]
  269. stats.prim_calls = sum(s[0] for s in stats_dict.values()) # type: ignore[attr-defined]
  270. stats.total_tt = sum(s[2] for s in stats_dict.values()) # type: ignore[attr-defined]
  271. if output_file:
  272. stats.dump_stats(output_file)
  273. log.info(
  274. "Saved pstats to %s. Visualize with: snakeviz %s",
  275. output_file,
  276. output_file,
  277. )
  278. return stats
  279. def generate_svg(
  280. self, profile_file: str, svg_file: str | None = None
  281. ) -> str | None:
  282. """Generate an SVG call graph from a profile file using gprof2dot and graphviz.
  283. Args:
  284. profile_file: Path to the pstats profile file.
  285. svg_file: Optional path for the output SVG. If not provided, uses
  286. profile_file with .svg extension.
  287. Returns:
  288. Path to the generated SVG file, or None if generation failed.
  289. """
  290. import os
  291. import shutil
  292. import subprocess
  293. if not shutil.which("gprof2dot"):
  294. print("gprof2dot not found. Install with: pip install gprof2dot")
  295. return None
  296. if not shutil.which("dot"):
  297. print("graphviz 'dot' not found. Install graphviz package.")
  298. return None
  299. if svg_file is None:
  300. svg_file = profile_file.rsplit(".", 1)[0] + ".svg"
  301. try:
  302. # gprof2dot -f pstats profile.prof | dot -Tsvg -o profile.svg
  303. gprof2dot = subprocess.Popen(
  304. [
  305. "gprof2dot",
  306. "-f",
  307. "pstats",
  308. "--node-label=total-time-percentage",
  309. "--node-label=self-time-percentage",
  310. "--node-label=total-time",
  311. profile_file,
  312. ],
  313. stdout=subprocess.PIPE,
  314. stderr=subprocess.PIPE,
  315. )
  316. dot = subprocess.Popen(
  317. ["dot", "-Tsvg", "-o", svg_file],
  318. stdin=gprof2dot.stdout,
  319. stdout=subprocess.PIPE,
  320. stderr=subprocess.PIPE,
  321. )
  322. gprof2dot.stdout.close() # type: ignore[union-attr]
  323. _, dot_err = dot.communicate()
  324. _, gprof2dot_err = gprof2dot.communicate()
  325. if gprof2dot.returncode != 0:
  326. print(
  327. f"gprof2dot failed: {gprof2dot_err.decode()}" # noqa: B950
  328. )
  329. return None
  330. if dot.returncode != 0:
  331. print(f"graphviz dot failed: {dot_err.decode()}")
  332. return None
  333. if not os.path.isfile(svg_file):
  334. print(f"SVG file was not created: {svg_file}")
  335. return None
  336. print(f"SVG call graph saved to: {svg_file}")
  337. return svg_file
  338. except Exception as e:
  339. print(f"Failed to generate SVG: {e}")
  340. return None
  341. def dump_stats(
  342. self, output_file: str | None = None, generate_svg: bool = True
  343. ) -> None:
  344. """Print profiler stats to stdout and optionally save to file.
  345. Args:
  346. output_file: Optional path to save the pstats profile.
  347. generate_svg: If True and output_file is provided, also generate an SVG
  348. call graph using gprof2dot and graphviz.
  349. """
  350. import sys
  351. if not self.timings:
  352. return
  353. stats = self.generate_pstats(output_file, print_raw=True)
  354. print("\n=== Dynamo Profiler (pstats) ===")
  355. stats.stream = sys.stdout # type: ignore[attr-defined]
  356. stats.sort_stats("cumulative").print_stats()
  357. if output_file:
  358. print(f"\nProfile saved to: {output_file}")
  359. print(f"Visualize with: snakeviz {output_file}")
  360. if generate_svg:
  361. self.generate_svg(output_file)