| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- """Metrics collection and management system for Dynamo.
- This module provides context managers for gathering and reporting metrics during
- compilation and runtime.
- It includes two main components:
- - MetricsContext: A context manager for collecting metrics during compilation, supporting
- nested contexts and various metric types (counters, sets, key-value pairs)
- - RuntimeMetricsContext: A specialized context for runtime metrics collection that doesn't
- require explicit context management
- The metrics system enables comprehensive monitoring and analysis of both compilation and
- execution performance.
- """
- from __future__ import annotations
- import heapq
- import logging
- import time
- from collections.abc import Callable
- from typing import Any, Optional, TYPE_CHECKING, TypeAlias
- from typing_extensions import Self
- if TYPE_CHECKING:
- from collections.abc import Iterator
- from torch.utils._traceback import CapturedTraceback
- log = logging.getLogger(__name__)
- class TopN:
- """
- Helper to record a list of metrics, keeping only the top N "most expensive" elements.
- """
- def __init__(self, at_most: int = 25) -> None:
- self.at_most = at_most
- self.heap: list[tuple[int, Any]] = []
- def add(self, key: Any, val: int) -> None:
- # Push if we haven't reached the max size, else push and pop the smallest
- fn = heapq.heappush if len(self.heap) < self.at_most else heapq.heappushpop
- fn(self.heap, (val, key))
- def __len__(self) -> int:
- return len(self.heap)
- def __iter__(self) -> Iterator[tuple[Any, int]]:
- return ((key, val) for val, key in sorted(self.heap, reverse=True))
- OnExitType: TypeAlias = Callable[
- [int, int, dict[str, Any], Optional[type[BaseException]], Optional[BaseException]],
- None,
- ]
- class MetricsContext:
- def __init__(self, on_exit: OnExitType) -> None:
- """
- Use this class as a contextmanager to create a context under which to accumulate
- a set of metrics, e.g., metrics gathered during a compilation. On exit of the
- contextmanager, call the provided 'on_exit' function and pass a dictionary of
- all metrics set during the lifetime of the contextmanager.
- """
- self._on_exit = on_exit
- self._metrics: dict[str, Any] = {}
- self._start_time_ns: int = 0
- self._level: int = 0
- self._edits: list[tuple[CapturedTraceback, set[str]]] = []
- def __enter__(self) -> Self:
- """
- Initialize metrics recording.
- """
- if self._level == 0:
- # In case of recursion, track at the outermost context.
- self._metrics = {}
- self._start_time_ns = time.time_ns()
- self._level += 1
- return self
- def __exit__(
- self,
- exc_type: Optional[type[BaseException]],
- exc_value: Optional[BaseException],
- _traceback: Any,
- ) -> None:
- """
- At exit, call the provided on_exit function.
- """
- self._level -= 1
- assert self._level >= 0
- if self._level == 0:
- try:
- end_time_ns = time.time_ns()
- self._on_exit(
- self._start_time_ns, end_time_ns, self._metrics, exc_type, exc_value
- )
- except Exception:
- log.exception("Unexpected exception logging compilation metrics")
- def in_progress(self) -> bool:
- """
- True if we've entered the context.
- """
- return self._level > 0
- def increment(self, metric: str, value: int) -> None:
- """
- Increment a metric by a given amount.
- """
- if self._level == 0:
- raise RuntimeError(f"Cannot increment {metric} outside of a MetricsContext")
- if metric not in self._metrics:
- self._metrics[metric] = 0
- self._metrics[metric] += value
- def _render_edits(self, pred: set[str]) -> str:
- return "\n\n" + "\n\n".join(
- "Previous Traceback:\n" + "".join(e.format())
- for e, k in self._edits
- if k & pred
- )
- def set(self, metric: str, value: Any, overwrite: bool = False) -> None:
- """
- Set a metric to a given value. Raises if the metric has been assigned previously
- in the current context.
- """
- if self._level == 0:
- raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext")
- if metric in self._metrics and not overwrite:
- raise RuntimeError(
- self._render_edits({metric})
- + f"\n\nRuntimeError: Metric '{metric}' has already been set in the current context "
- "(see above for current and previous traceback)."
- )
- self._edits.append((CapturedTraceback.extract(skip=1), {metric}))
- self._metrics[metric] = value
- def set_key_value(self, metric: str, key: str, value: Any) -> None:
- """
- Treats a give metric as a dictionary and set the k and value within it.
- Note that the metric must be a dictionary or not present.
- We allow this to be called multiple times (i.e. for features, it's not uncommon
- for them to be used multiple times within a single compilation).
- """
- if self._level == 0:
- raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext")
- if metric not in self._metrics:
- self._metrics[metric] = {}
- self._metrics[metric][key] = value
- def update(self, values: dict[str, Any], overwrite: bool = False) -> None:
- """
- Set multiple metrics directly. This method does NOT increment. Raises if any
- metric has been assigned previously in the current context and overwrite is
- not set to True.
- """
- if self._level == 0:
- raise RuntimeError("Cannot update metrics outside of a MetricsContext")
- existing = self._metrics.keys() & values.keys()
- if existing and not overwrite:
- raise RuntimeError(
- self._render_edits(set(values.keys()))
- + f"\n\nRuntimeError: Metric(s) {existing} have already been set in the current context. "
- "(see above for current and previous traceback)."
- )
- self._edits.append((CapturedTraceback.extract(skip=1), set(values.keys())))
- self._metrics.update(values)
- def update_outer(self, values: dict[str, Any]) -> None:
- """
- Update, but only when at the outermost context.
- """
- if self._level == 0:
- raise RuntimeError("Cannot update metrics outside of a MetricsContext")
- if self._level == 1:
- self.update(values)
- def add_to_set(self, metric: str, value: Any) -> None:
- """
- Records a metric as a set() of values.
- """
- if self._level == 0:
- raise RuntimeError(f"Cannot add {metric} outside of a MetricsContext")
- if metric not in self._metrics:
- self._metrics[metric] = set()
- self._metrics[metric].add(value)
- def add_top_n(self, metric: str, key: Any, val: int) -> None:
- """
- Records a metric as a TopN set of values.
- """
- if self._level == 0:
- return
- if metric not in self._metrics:
- self._metrics[metric] = TopN()
- self._metrics[metric].add(key, val)
- class RuntimeMetricsContext:
- def __init__(self, on_exit: OnExitType) -> None:
- """
- Similar to MetricsContext, but used to gather the runtime metrics that are
- decoupled from compilation, where there's not a natural place to insert a
- context manager.
- """
- self._on_exit = on_exit
- self._metrics: dict[str, Any] = {}
- self._start_time_ns: int = 0
- def increment(
- self, metric: str, value: int, extra: Optional[dict[str, Any]] = None
- ) -> None:
- """
- Increment a metric by a given amount.
- """
- if not self._metrics:
- # Start timing on the first entry
- self._start_time_ns = time.time_ns()
- if metric not in self._metrics:
- self._metrics[metric] = 0
- self._metrics[metric] += value
- if extra:
- for k, v in extra.items():
- if k not in self._metrics and v is not None:
- self._metrics[k] = v
- def finish(self) -> None:
- """
- Call the on_exit function with the metrics gathered so far and reset.
- """
- if self._metrics:
- try:
- end_time_ns = time.time_ns()
- self._on_exit(
- self._start_time_ns, end_time_ns, self._metrics, None, None
- )
- except Exception:
- log.exception("Unexpected exception logging runtime metrics")
- finally:
- self._metrics = {}
|