wandb_metric.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. """metric."""
  2. from __future__ import annotations
  3. import logging
  4. from collections.abc import Sequence
  5. from typing import Callable
  6. from wandb.proto import wandb_internal_pb2 as pb
  7. logger = logging.getLogger("wandb")
  8. class Metric:
  9. """Metric object."""
  10. _callback: Callable[[pb.MetricRecord], None] | None
  11. _name: str
  12. _step_metric: str | None
  13. _step_sync: bool | None
  14. _hidden: bool | None
  15. _summary: Sequence[str] | None
  16. _goal: str | None
  17. _overwrite: bool | None
  18. def __init__(
  19. self,
  20. name: str,
  21. step_metric: str | None = None,
  22. step_sync: bool | None = None,
  23. hidden: bool | None = None,
  24. summary: Sequence[str] | None = None,
  25. goal: str | None = None,
  26. overwrite: bool | None = None,
  27. ) -> None:
  28. self._callback = None
  29. self._name = name
  30. self._step_metric = step_metric
  31. # default to step_sync=True if step metric is set
  32. step_sync = step_sync if step_sync is not None else step_metric is not None
  33. self._step_sync = step_sync
  34. self._hidden = hidden
  35. self._summary = summary
  36. self._goal = goal
  37. self._overwrite = overwrite
  38. def _set_callback(self, cb: Callable[[pb.MetricRecord], None]) -> None:
  39. self._callback = cb
  40. @property
  41. def name(self) -> str:
  42. return self._name
  43. @property
  44. def step_metric(self) -> str | None:
  45. return self._step_metric
  46. @property
  47. def step_sync(self) -> bool | None:
  48. return self._step_sync
  49. @property
  50. def summary(self) -> tuple[str, ...] | None:
  51. if self._summary is None:
  52. return None
  53. return tuple(self._summary)
  54. @property
  55. def hidden(self) -> bool | None:
  56. return self._hidden
  57. @property
  58. def goal(self) -> str | None:
  59. goal_dict = dict(min="minimize", max="maximize")
  60. return goal_dict[self._goal] if self._goal else None
  61. def _commit(self) -> None:
  62. m = pb.MetricRecord()
  63. m.options.defined = True
  64. if self._name.endswith("*"):
  65. m.glob_name = self._name
  66. else:
  67. m.name = self._name
  68. if self._step_metric:
  69. m.step_metric = self._step_metric
  70. if self._step_sync:
  71. m.options.step_sync = self._step_sync
  72. if self._hidden:
  73. m.options.hidden = self._hidden
  74. if self._summary:
  75. summary_set = set(self._summary)
  76. if "min" in summary_set:
  77. m.summary.min = True
  78. if "max" in summary_set:
  79. m.summary.max = True
  80. if "mean" in summary_set:
  81. m.summary.mean = True
  82. if "last" in summary_set:
  83. m.summary.last = True
  84. if "copy" in summary_set:
  85. m.summary.copy = True
  86. if "none" in summary_set:
  87. m.summary.none = True
  88. if "best" in summary_set:
  89. m.summary.best = True
  90. if "first" in summary_set:
  91. m.summary.first = True
  92. if self._goal == "min":
  93. m.goal = m.GOAL_MINIMIZE
  94. if self._goal == "max":
  95. m.goal = m.GOAL_MAXIMIZE
  96. if self._overwrite:
  97. m._control.overwrite = self._overwrite
  98. if self._callback:
  99. self._callback(m)