metrics_context.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. """Metrics collection and management system for Dynamo.
  2. This module provides context managers for gathering and reporting metrics during
  3. compilation and runtime.
  4. It includes two main components:
  5. - MetricsContext: A context manager for collecting metrics during compilation, supporting
  6. nested contexts and various metric types (counters, sets, key-value pairs)
  7. - RuntimeMetricsContext: A specialized context for runtime metrics collection that doesn't
  8. require explicit context management
  9. The metrics system enables comprehensive monitoring and analysis of both compilation and
  10. execution performance.
  11. """
  12. from __future__ import annotations
  13. import heapq
  14. import logging
  15. import time
  16. from collections.abc import Callable
  17. from typing import Any, Optional, TYPE_CHECKING, TypeAlias
  18. from typing_extensions import Self
  19. if TYPE_CHECKING:
  20. from collections.abc import Iterator
  21. from torch.utils._traceback import CapturedTraceback
  22. log = logging.getLogger(__name__)
  23. class TopN:
  24. """
  25. Helper to record a list of metrics, keeping only the top N "most expensive" elements.
  26. """
  27. def __init__(self, at_most: int = 25) -> None:
  28. self.at_most = at_most
  29. self.heap: list[tuple[int, Any]] = []
  30. def add(self, key: Any, val: int) -> None:
  31. # Push if we haven't reached the max size, else push and pop the smallest
  32. fn = heapq.heappush if len(self.heap) < self.at_most else heapq.heappushpop
  33. fn(self.heap, (val, key))
  34. def __len__(self) -> int:
  35. return len(self.heap)
  36. def __iter__(self) -> Iterator[tuple[Any, int]]:
  37. return ((key, val) for val, key in sorted(self.heap, reverse=True))
  38. OnExitType: TypeAlias = Callable[
  39. [int, int, dict[str, Any], Optional[type[BaseException]], Optional[BaseException]],
  40. None,
  41. ]
  42. class MetricsContext:
  43. def __init__(self, on_exit: OnExitType) -> None:
  44. """
  45. Use this class as a contextmanager to create a context under which to accumulate
  46. a set of metrics, e.g., metrics gathered during a compilation. On exit of the
  47. contextmanager, call the provided 'on_exit' function and pass a dictionary of
  48. all metrics set during the lifetime of the contextmanager.
  49. """
  50. self._on_exit = on_exit
  51. self._metrics: dict[str, Any] = {}
  52. self._start_time_ns: int = 0
  53. self._level: int = 0
  54. self._edits: list[tuple[CapturedTraceback, set[str]]] = []
  55. def __enter__(self) -> Self:
  56. """
  57. Initialize metrics recording.
  58. """
  59. if self._level == 0:
  60. # In case of recursion, track at the outermost context.
  61. self._metrics = {}
  62. self._start_time_ns = time.time_ns()
  63. self._level += 1
  64. return self
  65. def __exit__(
  66. self,
  67. exc_type: Optional[type[BaseException]],
  68. exc_value: Optional[BaseException],
  69. _traceback: Any,
  70. ) -> None:
  71. """
  72. At exit, call the provided on_exit function.
  73. """
  74. self._level -= 1
  75. assert self._level >= 0
  76. if self._level == 0:
  77. try:
  78. end_time_ns = time.time_ns()
  79. self._on_exit(
  80. self._start_time_ns, end_time_ns, self._metrics, exc_type, exc_value
  81. )
  82. except Exception:
  83. log.exception("Unexpected exception logging compilation metrics")
  84. def in_progress(self) -> bool:
  85. """
  86. True if we've entered the context.
  87. """
  88. return self._level > 0
  89. def increment(self, metric: str, value: int) -> None:
  90. """
  91. Increment a metric by a given amount.
  92. """
  93. if self._level == 0:
  94. raise RuntimeError(f"Cannot increment {metric} outside of a MetricsContext")
  95. if metric not in self._metrics:
  96. self._metrics[metric] = 0
  97. self._metrics[metric] += value
  98. def _render_edits(self, pred: set[str]) -> str:
  99. return "\n\n" + "\n\n".join(
  100. "Previous Traceback:\n" + "".join(e.format())
  101. for e, k in self._edits
  102. if k & pred
  103. )
  104. def set(self, metric: str, value: Any, overwrite: bool = False) -> None:
  105. """
  106. Set a metric to a given value. Raises if the metric has been assigned previously
  107. in the current context.
  108. """
  109. if self._level == 0:
  110. raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext")
  111. if metric in self._metrics and not overwrite:
  112. raise RuntimeError(
  113. self._render_edits({metric})
  114. + f"\n\nRuntimeError: Metric '{metric}' has already been set in the current context "
  115. "(see above for current and previous traceback)."
  116. )
  117. self._edits.append((CapturedTraceback.extract(skip=1), {metric}))
  118. self._metrics[metric] = value
  119. def set_key_value(self, metric: str, key: str, value: Any) -> None:
  120. """
  121. Treats a give metric as a dictionary and set the k and value within it.
  122. Note that the metric must be a dictionary or not present.
  123. We allow this to be called multiple times (i.e. for features, it's not uncommon
  124. for them to be used multiple times within a single compilation).
  125. """
  126. if self._level == 0:
  127. raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext")
  128. if metric not in self._metrics:
  129. self._metrics[metric] = {}
  130. self._metrics[metric][key] = value
  131. def update(self, values: dict[str, Any], overwrite: bool = False) -> None:
  132. """
  133. Set multiple metrics directly. This method does NOT increment. Raises if any
  134. metric has been assigned previously in the current context and overwrite is
  135. not set to True.
  136. """
  137. if self._level == 0:
  138. raise RuntimeError("Cannot update metrics outside of a MetricsContext")
  139. existing = self._metrics.keys() & values.keys()
  140. if existing and not overwrite:
  141. raise RuntimeError(
  142. self._render_edits(set(values.keys()))
  143. + f"\n\nRuntimeError: Metric(s) {existing} have already been set in the current context. "
  144. "(see above for current and previous traceback)."
  145. )
  146. self._edits.append((CapturedTraceback.extract(skip=1), set(values.keys())))
  147. self._metrics.update(values)
  148. def update_outer(self, values: dict[str, Any]) -> None:
  149. """
  150. Update, but only when at the outermost context.
  151. """
  152. if self._level == 0:
  153. raise RuntimeError("Cannot update metrics outside of a MetricsContext")
  154. if self._level == 1:
  155. self.update(values)
  156. def add_to_set(self, metric: str, value: Any) -> None:
  157. """
  158. Records a metric as a set() of values.
  159. """
  160. if self._level == 0:
  161. raise RuntimeError(f"Cannot add {metric} outside of a MetricsContext")
  162. if metric not in self._metrics:
  163. self._metrics[metric] = set()
  164. self._metrics[metric].add(value)
  165. def add_top_n(self, metric: str, key: Any, val: int) -> None:
  166. """
  167. Records a metric as a TopN set of values.
  168. """
  169. if self._level == 0:
  170. return
  171. if metric not in self._metrics:
  172. self._metrics[metric] = TopN()
  173. self._metrics[metric].add(key, val)
  174. class RuntimeMetricsContext:
  175. def __init__(self, on_exit: OnExitType) -> None:
  176. """
  177. Similar to MetricsContext, but used to gather the runtime metrics that are
  178. decoupled from compilation, where there's not a natural place to insert a
  179. context manager.
  180. """
  181. self._on_exit = on_exit
  182. self._metrics: dict[str, Any] = {}
  183. self._start_time_ns: int = 0
  184. def increment(
  185. self, metric: str, value: int, extra: Optional[dict[str, Any]] = None
  186. ) -> None:
  187. """
  188. Increment a metric by a given amount.
  189. """
  190. if not self._metrics:
  191. # Start timing on the first entry
  192. self._start_time_ns = time.time_ns()
  193. if metric not in self._metrics:
  194. self._metrics[metric] = 0
  195. self._metrics[metric] += value
  196. if extra:
  197. for k, v in extra.items():
  198. if k not in self._metrics and v is not None:
  199. self._metrics[k] = v
  200. def finish(self) -> None:
  201. """
  202. Call the on_exit function with the metrics gathered so far and reset.
  203. """
  204. if self._metrics:
  205. try:
  206. end_time_ns = time.time_ns()
  207. self._on_exit(
  208. self._start_time_ns, end_time_ns, self._metrics, None, None
  209. )
  210. except Exception:
  211. log.exception("Unexpected exception logging runtime metrics")
  212. finally:
  213. self._metrics = {}