profiler.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. """
  2. Dynamo profiling implementation.
  3. This module provides profiling functionality for Dynamo, including:
  4. - ProfileMetrics: Class for collecting and aggregating performance metrics like
  5. execution time, operator counts, and fusion statistics
  6. - ProfileResult: Class for analyzing and reporting profiling results
  7. - Utilities for tracking missed/uncaptured operations
  8. - Functions for instrumenting FX graphs with profiling capabilities
  9. The profiler helps measure and optimize the performance of Dynamo-compiled code
  10. by tracking both captured and total operations, timing, and graph statistics.
  11. """
  12. from __future__ import annotations
  13. import dataclasses
  14. import os
  15. from typing import Any
  16. from typing_extensions import Self
  17. import torch
  18. from .utils import print_once
  19. @dataclasses.dataclass
  20. class ProfileMetrics:
  21. microseconds: float = 0.0
  22. operators: int = 0
  23. fusions: int = 0
  24. graphs: int = 0
  25. def __iadd__(self, other: Self) -> Self:
  26. self.microseconds += other.microseconds
  27. self.operators += other.operators
  28. self.fusions += other.fusions
  29. return self
  30. def __add__(self, other: ProfileMetrics) -> ProfileMetrics:
  31. assert isinstance(other, ProfileMetrics)
  32. return ProfileMetrics(
  33. self.microseconds + other.microseconds,
  34. self.operators + other.operators,
  35. self.fusions + other.fusions,
  36. )
  37. def __truediv__(self, other: Any) -> ProfileMetrics:
  38. if isinstance(other, int):
  39. other = ProfileMetrics(other, other, other)
  40. return ProfileMetrics(
  41. self.microseconds / max(1, other.microseconds),
  42. # pyrefly: ignore [bad-argument-type]
  43. self.operators / max(1, other.operators),
  44. # pyrefly: ignore [bad-argument-type]
  45. self.fusions / max(1, other.fusions),
  46. )
  47. def __str__(self) -> str:
  48. return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time"
  49. def tocsv(self) -> list[float]:
  50. return [self.operators, self.microseconds]
  51. class ProfileResult:
  52. def __init__(
  53. self, captured: ProfileMetrics, total: ProfileMetrics, unique_graphs: int
  54. ) -> None:
  55. self.captured: ProfileMetrics = captured or ProfileMetrics()
  56. self.total: ProfileMetrics = total or ProfileMetrics()
  57. self.unique_graphs: int = unique_graphs
  58. def __iadd__(self, other: Self) -> Self:
  59. self.captured += other.captured
  60. self.total += other.total
  61. self.unique_graphs += other.unique_graphs
  62. return self
  63. def percent(self) -> ProfileMetrics:
  64. return self.captured / self.total
  65. def __str__(self) -> str:
  66. return (
  67. f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls "
  68. f"{self.captured.operators:4}/{self.total.operators:4} = "
  69. + str(self.percent())
  70. )
  71. def tocsv(self) -> list[Any]:
  72. return [
  73. self.unique_graphs,
  74. self.captured.graphs,
  75. self.captured.operators,
  76. self.total.operators,
  77. ] + self.percent().tocsv()
  78. def should_print_missing() -> bool:
  79. return os.environ.get("TORCHDYNAMO_PRINT_MISSING") == "1"
  80. def print_missing(stack: list[str]) -> None:
  81. if any("/torch/autograd/profiler.py" in x for x in stack):
  82. return
  83. stack = [
  84. x for x in stack if ("<built-in" not in x and "site-packages/torch/" not in x)
  85. ]
  86. print_once("MISSING", " >> ".join(stack[-3:]))
  87. class Profiler:
  88. unique_graphs: int = 0
  89. def __init__(self) -> None:
  90. self.prof = torch.profiler.profile(
  91. activities=[torch.profiler.ProfilerActivity.CPU],
  92. with_stack=should_print_missing(),
  93. )
  94. def results(self) -> ProfileResult:
  95. captured_regions = 0
  96. captured_ops = 0
  97. captured_microseconds = 0
  98. total_ops = 0
  99. total_microseconds = 0
  100. last_op_end_time = -1
  101. captured_region_end_time = -1
  102. events = sorted(self.prof.events(), key=lambda x: x.time_range.start)
  103. for e in events:
  104. if e.name == "TORCHDYNAMO":
  105. captured_region_end_time = e.time_range.end
  106. captured_regions += 1
  107. # ignore `handle = torch.zeros(1)` in record_function.__init__()
  108. total_ops -= 1
  109. elif e.time_range.start >= last_op_end_time:
  110. last_op_end_time = e.time_range.end
  111. if e.time_range.end <= captured_region_end_time:
  112. captured_ops += 1
  113. captured_microseconds += e.time_range.elapsed_us()
  114. elif should_print_missing():
  115. print_missing(e.stack)
  116. total_ops += 1
  117. total_microseconds += e.time_range.elapsed_us()
  118. else:
  119. pass # ops recursively called from other ops (ignored)
  120. unique_graphs = Profiler.unique_graphs
  121. Profiler.unique_graphs = 0
  122. # we counted one extra op that is part of the profiler setup code
  123. total_ops -= 1
  124. return ProfileResult(
  125. captured=ProfileMetrics(
  126. microseconds=captured_microseconds,
  127. operators=captured_ops,
  128. fusions=captured_ops - captured_regions,
  129. graphs=captured_regions,
  130. ),
  131. total=ProfileMetrics(
  132. microseconds=total_microseconds,
  133. operators=total_ops,
  134. fusions=total_ops - 1,
  135. ),
  136. unique_graphs=unique_graphs,
  137. )
  138. def fx_insert_profiling(gm: torch.fx.GraphModule, example_inputs: list[Any]) -> Any:
  139. def _wrapped(*args: Any) -> Any:
  140. with torch.profiler.record_function("TORCHDYNAMO"):
  141. return gm.forward(*args)
  142. Profiler.unique_graphs += 1
  143. return _wrapped