| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- import logging
- from typing import TYPE_CHECKING, Dict
- import numpy as np
- from ray.air.constants import TRAINING_ITERATION
- from ray.tune.logger.logger import _LOGGER_DEPRECATION_WARNING, Logger, LoggerCallback
- from ray.tune.result import TIME_TOTAL_S, TIMESTEPS_TOTAL
- from ray.tune.utils import flatten_dict
- from ray.util.annotations import Deprecated, PublicAPI
- from ray.util.debug import log_once
- if TYPE_CHECKING:
- from ray.tune.experiment.trial import Trial # noqa: F401
- logger = logging.getLogger(__name__)
- VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
- @Deprecated(
- message=_LOGGER_DEPRECATION_WARNING.format(
- old="TBXLogger", new="ray.tune.tensorboardx.TBXLoggerCallback"
- ),
- warning=True,
- )
- @PublicAPI
- class TBXLogger(Logger):
- """TensorBoardX Logger.
- Note that hparams will be written only after a trial has terminated.
- This logger automatically flattens nested dicts to show on TensorBoard:
- {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
- """
- VALID_HPARAMS = (str, bool, int, float, list, type(None))
- VALID_NP_HPARAMS = (np.bool_, np.float32, np.float64, np.int32, np.int64)
- def _init(self):
- try:
- from tensorboardX import SummaryWriter
- except ImportError:
- if log_once("tbx-install"):
- logger.info('pip install "ray[tune]" to see TensorBoard files.')
- raise
- self._file_writer = SummaryWriter(self.logdir, flush_secs=30)
- self.last_result = None
- def on_result(self, result: Dict):
- step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
- tmp = result.copy()
- for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]:
- if k in tmp:
- del tmp[k] # not useful to log these
- flat_result = flatten_dict(tmp, delimiter="/")
- path = ["ray", "tune"]
- valid_result = {}
- for attr, value in flat_result.items():
- full_attr = "/".join(path + [attr])
- if isinstance(value, tuple(VALID_SUMMARY_TYPES)) and not np.isnan(value):
- valid_result[full_attr] = value
- self._file_writer.add_scalar(full_attr, value, global_step=step)
- elif (isinstance(value, list) and len(value) > 0) or (
- isinstance(value, np.ndarray) and value.size > 0
- ):
- valid_result[full_attr] = value
- # Must be a single image.
- if isinstance(value, np.ndarray) and value.ndim == 3:
- self._file_writer.add_image(
- full_attr,
- value,
- global_step=step,
- )
- continue
- # Must be a batch of images.
- if isinstance(value, np.ndarray) and value.ndim == 4:
- self._file_writer.add_images(
- full_attr,
- value,
- global_step=step,
- )
- continue
- # Must be video
- if isinstance(value, np.ndarray) and value.ndim == 5:
- self._file_writer.add_video(
- full_attr, value, global_step=step, fps=20
- )
- continue
- try:
- self._file_writer.add_histogram(full_attr, value, global_step=step)
- # In case TensorboardX still doesn't think it's a valid value
- # (e.g. `[[]]`), warn and move on.
- except (ValueError, TypeError):
- if log_once("invalid_tbx_value"):
- logger.warning(
- "You are trying to log an invalid value ({}={}) "
- "via {}!".format(full_attr, value, type(self).__name__)
- )
- self.last_result = valid_result
- self._file_writer.flush()
- def flush(self):
- if self._file_writer is not None:
- self._file_writer.flush()
- def close(self):
- if self._file_writer is not None:
- if self.trial and self.trial.evaluated_params and self.last_result:
- flat_result = flatten_dict(self.last_result, delimiter="/")
- scrubbed_result = {
- k: value
- for k, value in flat_result.items()
- if isinstance(value, tuple(VALID_SUMMARY_TYPES))
- }
- self._try_log_hparams(scrubbed_result)
- self._file_writer.close()
- def _try_log_hparams(self, result):
- # TBX currently errors if the hparams value is None.
- flat_params = flatten_dict(self.trial.evaluated_params)
- scrubbed_params = {
- k: v for k, v in flat_params.items() if isinstance(v, self.VALID_HPARAMS)
- }
- np_params = {
- k: v.tolist()
- for k, v in flat_params.items()
- if isinstance(v, self.VALID_NP_HPARAMS)
- }
- scrubbed_params.update(np_params)
- removed = {
- k: v
- for k, v in flat_params.items()
- if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS)
- }
- if removed:
- logger.info(
- "Removed the following hyperparameter values when "
- "logging to tensorboard: %s",
- str(removed),
- )
- from tensorboardX.summary import hparams
- try:
- experiment_tag, session_start_tag, session_end_tag = hparams(
- hparam_dict=scrubbed_params, metric_dict=result
- )
- self._file_writer.file_writer.add_summary(experiment_tag)
- self._file_writer.file_writer.add_summary(session_start_tag)
- self._file_writer.file_writer.add_summary(session_end_tag)
- except Exception:
- logger.exception(
- "TensorboardX failed to log hparams. "
- "This may be due to an unsupported type "
- "in the hyperparameter values."
- )
- @PublicAPI
- class TBXLoggerCallback(LoggerCallback):
- """TensorBoardX Logger.
- Note that hparams will be written only after a trial has terminated.
- This logger automatically flattens nested dicts to show on TensorBoard:
- {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
- """
- _SAVED_FILE_TEMPLATES = ["events.out.tfevents.*"]
- VALID_HPARAMS = (str, bool, int, float, list, type(None))
- VALID_NP_HPARAMS = (np.bool_, np.float32, np.float64, np.int32, np.int64)
- def __init__(self):
- try:
- from tensorboardX import SummaryWriter
- self._summary_writer_cls = SummaryWriter
- except ImportError:
- if log_once("tbx-install"):
- logger.info('pip install "ray[tune]" to see TensorBoard files.')
- raise
- self._trial_writer: Dict["Trial", SummaryWriter] = {}
- self._trial_result: Dict["Trial", Dict] = {}
- def log_trial_start(self, trial: "Trial"):
- if trial in self._trial_writer:
- self._trial_writer[trial].close()
- trial.init_local_path()
- self._trial_writer[trial] = self._summary_writer_cls(
- trial.local_path, flush_secs=30
- )
- self._trial_result[trial] = {}
- def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
- if trial not in self._trial_writer:
- self.log_trial_start(trial)
- step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
- tmp = result.copy()
- for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]:
- if k in tmp:
- del tmp[k] # not useful to log these
- flat_result = flatten_dict(tmp, delimiter="/")
- path = ["ray", "tune"]
- valid_result = {}
- for attr, value in flat_result.items():
- full_attr = "/".join(path + [attr])
- if isinstance(value, tuple(VALID_SUMMARY_TYPES)) and not np.isnan(value):
- valid_result[full_attr] = value
- self._trial_writer[trial].add_scalar(full_attr, value, global_step=step)
- elif (isinstance(value, list) and len(value) > 0) or (
- isinstance(value, np.ndarray) and value.size > 0
- ):
- valid_result[full_attr] = value
- # Must be a single image.
- if isinstance(value, np.ndarray) and value.ndim == 3:
- self._trial_writer[trial].add_image(
- full_attr,
- value,
- global_step=step,
- )
- continue
- # Must be a batch of images.
- if isinstance(value, np.ndarray) and value.ndim == 4:
- self._trial_writer[trial].add_images(
- full_attr,
- value,
- global_step=step,
- )
- continue
- # Must be video
- if isinstance(value, np.ndarray) and value.ndim == 5:
- self._trial_writer[trial].add_video(
- full_attr, value, global_step=step, fps=20
- )
- continue
- try:
- self._trial_writer[trial].add_histogram(
- full_attr, value, global_step=step
- )
- # In case TensorboardX still doesn't think it's a valid value
- # (e.g. `[[]]`), warn and move on.
- except (ValueError, TypeError):
- if log_once("invalid_tbx_value"):
- logger.warning(
- "You are trying to log an invalid value ({}={}) "
- "via {}!".format(full_attr, value, type(self).__name__)
- )
- self._trial_result[trial] = valid_result
- self._trial_writer[trial].flush()
- def log_trial_end(self, trial: "Trial", failed: bool = False):
- if trial in self._trial_writer:
- if trial and trial.evaluated_params and self._trial_result[trial]:
- flat_result = flatten_dict(self._trial_result[trial], delimiter="/")
- scrubbed_result = {
- k: value
- for k, value in flat_result.items()
- if isinstance(value, tuple(VALID_SUMMARY_TYPES))
- }
- self._try_log_hparams(trial, scrubbed_result)
- self._trial_writer[trial].close()
- del self._trial_writer[trial]
- del self._trial_result[trial]
- def _try_log_hparams(self, trial: "Trial", result: Dict):
- # TBX currently errors if the hparams value is None.
- flat_params = flatten_dict(trial.evaluated_params)
- scrubbed_params = {
- k: v for k, v in flat_params.items() if isinstance(v, self.VALID_HPARAMS)
- }
- np_params = {
- k: v.tolist()
- for k, v in flat_params.items()
- if isinstance(v, self.VALID_NP_HPARAMS)
- }
- scrubbed_params.update(np_params)
- removed = {
- k: v
- for k, v in flat_params.items()
- if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS)
- }
- if removed:
- logger.info(
- "Removed the following hyperparameter values when "
- "logging to tensorboard: %s",
- str(removed),
- )
- from tensorboardX.summary import hparams
- try:
- experiment_tag, session_start_tag, session_end_tag = hparams(
- hparam_dict=scrubbed_params, metric_dict=result
- )
- self._trial_writer[trial].file_writer.add_summary(experiment_tag)
- self._trial_writer[trial].file_writer.add_summary(session_start_tag)
- self._trial_writer[trial].file_writer.add_summary(session_end_tag)
- except Exception:
- logger.exception(
- "TensorboardX failed to log hparams. "
- "This may be due to an unsupported type "
- "in the hyperparameter values."
- )
|