aim.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import logging
  2. from typing import TYPE_CHECKING, Dict, List, Optional, Union
  3. import numpy as np
  4. from ray.air.constants import TRAINING_ITERATION
  5. from ray.tune.logger.logger import LoggerCallback
  6. from ray.tune.result import TIME_TOTAL_S, TIMESTEPS_TOTAL
  7. from ray.tune.utils import flatten_dict
  8. from ray.util.annotations import PublicAPI
  9. if TYPE_CHECKING:
  10. from ray.tune.experiment.trial import Trial
  11. try:
  12. from aim.sdk import Repo, Run
  13. except ImportError:
  14. Repo, Run = None, None
  15. logger = logging.getLogger(__name__)
  16. VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
  17. @PublicAPI
  18. class AimLoggerCallback(LoggerCallback):
  19. """Aim Logger: logs metrics in Aim format.
  20. Aim is an open-source, self-hosted ML experiment tracking tool.
  21. It's good at tracking lots (thousands) of training runs, and it allows you to
  22. compare them with a performant and well-designed UI.
  23. Source: https://github.com/aimhubio/aim
  24. Args:
  25. repo: Aim repository directory or a `Repo` object that the Run object will
  26. log results to. If not provided, a default repo will be set up in the
  27. experiment directory (one level above trial directories).
  28. experiment: Sets the `experiment` property of each Run object, which is the
  29. experiment name associated with it. Can be used later to query
  30. runs/sequences.
  31. If not provided, the default will be the Tune experiment name set
  32. by `RunConfig(name=...)`.
  33. metrics: List of metric names (out of the metrics reported by Tune) to
  34. track in Aim. If no metric are specified, log everything that
  35. is reported.
  36. aim_run_kwargs: Additional arguments that will be passed when creating the
  37. individual `Run` objects for each trial. For the full list of arguments,
  38. please see the Aim documentation:
  39. https://aimstack.readthedocs.io/en/latest/refs/sdk.html
  40. """
  41. VALID_HPARAMS = (str, bool, int, float, list, type(None))
  42. VALID_NP_HPARAMS = (np.bool_, np.float32, np.float64, np.int32, np.int64)
  43. def __init__(
  44. self,
  45. repo: Optional[Union[str, "Repo"]] = None,
  46. experiment_name: Optional[str] = None,
  47. metrics: Optional[List[str]] = None,
  48. **aim_run_kwargs,
  49. ):
  50. """
  51. See help(AimLoggerCallback) for more information about parameters.
  52. """
  53. assert Run is not None, (
  54. "aim must be installed!. You can install aim with"
  55. " the command: `pip install aim`."
  56. )
  57. self._repo_path = repo
  58. self._experiment_name = experiment_name
  59. if not (bool(metrics) or metrics is None):
  60. raise ValueError(
  61. "`metrics` must either contain at least one metric name, or be None, "
  62. "in which case all reported metrics will be logged to the aim repo."
  63. )
  64. self._metrics = metrics
  65. self._aim_run_kwargs = aim_run_kwargs
  66. self._trial_to_run: Dict["Trial", Run] = {}
  67. def _create_run(self, trial: "Trial") -> Run:
  68. """Initializes an Aim Run object for a given trial.
  69. Args:
  70. trial: The Tune trial that aim will track as a Run.
  71. Returns:
  72. Run: The created aim run for a specific trial.
  73. """
  74. experiment_dir = trial.local_experiment_path
  75. run = Run(
  76. repo=self._repo_path or experiment_dir,
  77. experiment=self._experiment_name or trial.experiment_dir_name,
  78. **self._aim_run_kwargs,
  79. )
  80. # Attach a few useful trial properties
  81. run["trial_id"] = trial.trial_id
  82. run["trial_log_dir"] = trial.path
  83. trial_ip = trial.get_ray_actor_ip()
  84. if trial_ip:
  85. run["trial_ip"] = trial_ip
  86. return run
  87. def log_trial_start(self, trial: "Trial"):
  88. if trial in self._trial_to_run:
  89. # Cleanup an existing run if the trial has been restarted
  90. self._trial_to_run[trial].close()
  91. trial.init_local_path()
  92. self._trial_to_run[trial] = self._create_run(trial)
  93. if trial.evaluated_params:
  94. self._log_trial_hparams(trial)
  95. def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
  96. tmp_result = result.copy()
  97. step = result.get(TIMESTEPS_TOTAL, None) or result[TRAINING_ITERATION]
  98. for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]:
  99. tmp_result.pop(k, None) # not useful to log these
  100. # `context` and `epoch` are special keys that users can report,
  101. # which are treated as special aim metrics/configurations.
  102. context = tmp_result.pop("context", None)
  103. epoch = tmp_result.pop("epoch", None)
  104. trial_run = self._trial_to_run[trial]
  105. path = ["ray", "tune"]
  106. flat_result = flatten_dict(tmp_result, delimiter="/")
  107. valid_result = {}
  108. for attr, value in flat_result.items():
  109. if self._metrics and attr not in self._metrics:
  110. continue
  111. full_attr = "/".join(path + [attr])
  112. if isinstance(value, tuple(VALID_SUMMARY_TYPES)) and not (
  113. np.isnan(value) or np.isinf(value)
  114. ):
  115. valid_result[attr] = value
  116. trial_run.track(
  117. value=value,
  118. name=full_attr,
  119. epoch=epoch,
  120. step=step,
  121. context=context,
  122. )
  123. elif (isinstance(value, (list, tuple, set)) and len(value) > 0) or (
  124. isinstance(value, np.ndarray) and value.size > 0
  125. ):
  126. valid_result[attr] = value
  127. def log_trial_end(self, trial: "Trial", failed: bool = False):
  128. trial_run = self._trial_to_run.pop(trial)
  129. trial_run.close()
  130. def _log_trial_hparams(self, trial: "Trial"):
  131. params = flatten_dict(trial.evaluated_params, delimiter="/")
  132. flat_params = flatten_dict(params)
  133. scrubbed_params = {
  134. k: v for k, v in flat_params.items() if isinstance(v, self.VALID_HPARAMS)
  135. }
  136. np_params = {
  137. k: v.tolist()
  138. for k, v in flat_params.items()
  139. if isinstance(v, self.VALID_NP_HPARAMS)
  140. }
  141. scrubbed_params.update(np_params)
  142. removed = {
  143. k: v
  144. for k, v in flat_params.items()
  145. if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS)
  146. }
  147. if removed:
  148. logger.info(
  149. "Removed the following hyperparameter values when "
  150. "logging to aim: %s",
  151. str(removed),
  152. )
  153. run = self._trial_to_run[trial]
  154. run["hparams"] = scrubbed_params