| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- import logging
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
- import numpy as np
- from ray.air.constants import TRAINING_ITERATION
- from ray.tune.logger.logger import LoggerCallback
- from ray.tune.result import TIME_TOTAL_S, TIMESTEPS_TOTAL
- from ray.tune.utils import flatten_dict
- from ray.util.annotations import PublicAPI
- if TYPE_CHECKING:
- from ray.tune.experiment.trial import Trial
- try:
- from aim.sdk import Repo, Run
- except ImportError:
- Repo, Run = None, None
- logger = logging.getLogger(__name__)
- VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
- @PublicAPI
- class AimLoggerCallback(LoggerCallback):
- """Aim Logger: logs metrics in Aim format.
- Aim is an open-source, self-hosted ML experiment tracking tool.
- It's good at tracking lots (thousands) of training runs, and it allows you to
- compare them with a performant and well-designed UI.
- Source: https://github.com/aimhubio/aim
- Args:
- repo: Aim repository directory or a `Repo` object that the Run object will
- log results to. If not provided, a default repo will be set up in the
- experiment directory (one level above trial directories).
- experiment: Sets the `experiment` property of each Run object, which is the
- experiment name associated with it. Can be used later to query
- runs/sequences.
- If not provided, the default will be the Tune experiment name set
- by `RunConfig(name=...)`.
- metrics: List of metric names (out of the metrics reported by Tune) to
- track in Aim. If no metric are specified, log everything that
- is reported.
- aim_run_kwargs: Additional arguments that will be passed when creating the
- individual `Run` objects for each trial. For the full list of arguments,
- please see the Aim documentation:
- https://aimstack.readthedocs.io/en/latest/refs/sdk.html
- """
- VALID_HPARAMS = (str, bool, int, float, list, type(None))
- VALID_NP_HPARAMS = (np.bool_, np.float32, np.float64, np.int32, np.int64)
- def __init__(
- self,
- repo: Optional[Union[str, "Repo"]] = None,
- experiment_name: Optional[str] = None,
- metrics: Optional[List[str]] = None,
- **aim_run_kwargs,
- ):
- """
- See help(AimLoggerCallback) for more information about parameters.
- """
- assert Run is not None, (
- "aim must be installed!. You can install aim with"
- " the command: `pip install aim`."
- )
- self._repo_path = repo
- self._experiment_name = experiment_name
- if not (bool(metrics) or metrics is None):
- raise ValueError(
- "`metrics` must either contain at least one metric name, or be None, "
- "in which case all reported metrics will be logged to the aim repo."
- )
- self._metrics = metrics
- self._aim_run_kwargs = aim_run_kwargs
- self._trial_to_run: Dict["Trial", Run] = {}
- def _create_run(self, trial: "Trial") -> Run:
- """Initializes an Aim Run object for a given trial.
- Args:
- trial: The Tune trial that aim will track as a Run.
- Returns:
- Run: The created aim run for a specific trial.
- """
- experiment_dir = trial.local_experiment_path
- run = Run(
- repo=self._repo_path or experiment_dir,
- experiment=self._experiment_name or trial.experiment_dir_name,
- **self._aim_run_kwargs,
- )
- # Attach a few useful trial properties
- run["trial_id"] = trial.trial_id
- run["trial_log_dir"] = trial.path
- trial_ip = trial.get_ray_actor_ip()
- if trial_ip:
- run["trial_ip"] = trial_ip
- return run
- def log_trial_start(self, trial: "Trial"):
- if trial in self._trial_to_run:
- # Cleanup an existing run if the trial has been restarted
- self._trial_to_run[trial].close()
- trial.init_local_path()
- self._trial_to_run[trial] = self._create_run(trial)
- if trial.evaluated_params:
- self._log_trial_hparams(trial)
- def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
- tmp_result = result.copy()
- step = result.get(TIMESTEPS_TOTAL, None) or result[TRAINING_ITERATION]
- for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]:
- tmp_result.pop(k, None) # not useful to log these
- # `context` and `epoch` are special keys that users can report,
- # which are treated as special aim metrics/configurations.
- context = tmp_result.pop("context", None)
- epoch = tmp_result.pop("epoch", None)
- trial_run = self._trial_to_run[trial]
- path = ["ray", "tune"]
- flat_result = flatten_dict(tmp_result, delimiter="/")
- valid_result = {}
- for attr, value in flat_result.items():
- if self._metrics and attr not in self._metrics:
- continue
- full_attr = "/".join(path + [attr])
- if isinstance(value, tuple(VALID_SUMMARY_TYPES)) and not (
- np.isnan(value) or np.isinf(value)
- ):
- valid_result[attr] = value
- trial_run.track(
- value=value,
- name=full_attr,
- epoch=epoch,
- step=step,
- context=context,
- )
- elif (isinstance(value, (list, tuple, set)) and len(value) > 0) or (
- isinstance(value, np.ndarray) and value.size > 0
- ):
- valid_result[attr] = value
- def log_trial_end(self, trial: "Trial", failed: bool = False):
- trial_run = self._trial_to_run.pop(trial)
- trial_run.close()
- def _log_trial_hparams(self, trial: "Trial"):
- params = flatten_dict(trial.evaluated_params, delimiter="/")
- flat_params = flatten_dict(params)
- scrubbed_params = {
- k: v for k, v in flat_params.items() if isinstance(v, self.VALID_HPARAMS)
- }
- np_params = {
- k: v.tolist()
- for k, v in flat_params.items()
- if isinstance(v, self.VALID_NP_HPARAMS)
- }
- scrubbed_params.update(np_params)
- removed = {
- k: v
- for k, v in flat_params.items()
- if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS)
- }
- if removed:
- logger.info(
- "Removed the following hyperparameter values when "
- "logging to aim: %s",
- str(removed),
- )
- run = self._trial_to_run[trial]
- run["hparams"] = scrubbed_params
|