api.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. #!/usr/bin/env python3
  2. # mypy: allow-untyped-defs
  3. # Copyright (c) Facebook, Inc. and its affiliates.
  4. # All rights reserved.
  5. #
  6. # This source code is licensed under the BSD-style license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. import abc
  9. import time
  10. from collections import namedtuple
  11. from functools import wraps
  12. from typing_extensions import deprecated
  13. __all__ = [
  14. "MetricsConfig",
  15. "MetricHandler",
  16. "ConsoleMetricHandler",
  17. "NullMetricHandler",
  18. "MetricStream",
  19. "configure",
  20. "getStream",
  21. "prof",
  22. "profile",
  23. "put_metric",
  24. "publish_metric",
  25. "get_elapsed_time_ms",
  26. "MetricData",
  27. ]
  28. MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"])
  29. class MetricsConfig:
  30. __slots__ = ["params"]
  31. def __init__(self, params: dict[str, str] | None = None):
  32. self.params = params
  33. if self.params is None:
  34. self.params = {}
  35. class MetricHandler(abc.ABC):
  36. @abc.abstractmethod
  37. def emit(self, metric_data: MetricData):
  38. pass
  39. class ConsoleMetricHandler(MetricHandler):
  40. def emit(self, metric_data: MetricData):
  41. print(
  42. f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}"
  43. )
  44. class NullMetricHandler(MetricHandler):
  45. def emit(self, metric_data: MetricData):
  46. pass
  47. class MetricStream:
  48. def __init__(self, group_name: str, handler: MetricHandler):
  49. self.group_name = group_name
  50. self.handler = handler
  51. def add_value(self, metric_name: str, metric_value: int):
  52. self.handler.emit(
  53. MetricData(time.time(), self.group_name, metric_name, metric_value)
  54. )
  55. _metrics_map: dict[str, MetricHandler] = {}
  56. _default_metrics_handler: MetricHandler = NullMetricHandler()
  57. # pyre-fixme[9]: group has type `str`; used as `None`.
  58. def configure(handler: MetricHandler, group: str | None = None):
  59. if group is None:
  60. global _default_metrics_handler
  61. # pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used
  62. # as `MetricHandler`.
  63. _default_metrics_handler = handler
  64. else:
  65. _metrics_map[group] = handler
  66. def getStream(group: str):
  67. handler = _metrics_map.get(group, _default_metrics_handler)
  68. return MetricStream(group, handler)
  69. def _get_metric_name(fn):
  70. qualname = fn.__qualname__
  71. split = qualname.split(".")
  72. if len(split) == 1:
  73. module = fn.__module__
  74. if module:
  75. return module.split(".")[-1] + "." + split[0]
  76. else:
  77. return split[0]
  78. else:
  79. return qualname
  80. def prof(fn=None, group: str = "torchelastic"):
  81. r"""
  82. @profile decorator publishes duration.ms, count, success, failure metrics for the function that it decorates.
  83. The metric name defaults to the qualified name (``class_name.def_name``) of the function.
  84. If the function does not belong to a class, it uses the leaf module name instead.
  85. Usage
  86. ::
  87. @metrics.prof
  88. def x():
  89. pass
  90. @metrics.prof(group="agent")
  91. def y():
  92. pass
  93. """
  94. def wrap(f):
  95. @wraps(f)
  96. def wrapper(*args, **kwargs):
  97. key = _get_metric_name(f)
  98. try:
  99. start = time.time()
  100. result = f(*args, **kwargs)
  101. put_metric(f"{key}.success", 1, group)
  102. except Exception:
  103. put_metric(f"{key}.failure", 1, group)
  104. raise
  105. finally:
  106. put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined]
  107. return result
  108. return wrapper
  109. if fn:
  110. return wrap(fn)
  111. else:
  112. return wrap
  113. @deprecated("Deprecated, use `@prof` instead", category=FutureWarning)
  114. def profile(group=None):
  115. """
  116. @profile decorator adds latency and success/failure metrics to any given function.
  117. Usage
  118. ::
  119. @metrics.profile("my_metric_group")
  120. def some_function(<arguments>):
  121. """
  122. def wrap(func):
  123. @wraps(func)
  124. def wrapper(*args, **kwargs):
  125. try:
  126. start_time = time.time()
  127. result = func(*args, **kwargs)
  128. # pyrefly: ignore [bad-argument-type]
  129. publish_metric(group, f"{func.__name__}.success", 1)
  130. except Exception:
  131. # pyrefly: ignore [bad-argument-type]
  132. publish_metric(group, f"{func.__name__}.failure", 1)
  133. raise
  134. finally:
  135. publish_metric(
  136. # pyrefly: ignore [bad-argument-type]
  137. group,
  138. f"{func.__name__}.duration.ms",
  139. get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined]
  140. )
  141. return result
  142. return wrapper
  143. return wrap
  144. def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"):
  145. """
  146. Publish a metric data point.
  147. Usage
  148. ::
  149. put_metric("metric_name", 1)
  150. put_metric("metric_name", 1, "metric_group_name")
  151. """
  152. getStream(metric_group).add_value(metric_name, metric_value)
  153. @deprecated(
  154. "Deprecated, use `put_metric(metric_group)(metric_name, metric_value)` instead",
  155. category=FutureWarning,
  156. )
  157. def publish_metric(metric_group: str, metric_name: str, metric_value: int):
  158. metric_stream = getStream(metric_group)
  159. metric_stream.add_value(metric_name, metric_value)
  160. def get_elapsed_time_ms(start_time_in_seconds: float):
  161. """Return the elapsed time in millis from the given start time."""
  162. end_time = time.time()
  163. return int((end_time - start_time_in_seconds) * 1000)