| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259 |
- import abc
- import json
- import logging
- from pathlib import Path
- from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Type
- import pyarrow
- import yaml
- from ray.air._internal.json import SafeFallbackEncoder
- from ray.tune.callback import Callback
- from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI
- if TYPE_CHECKING:
- from ray.tune.experiment.trial import Trial # noqa: F401
- logger = logging.getLogger(__name__)
- # Apply flow style for sequences of this length
- _SEQUENCE_LEN_FLOW_STYLE = 3
- _LOGGER_DEPRECATION_WARNING = (
- "The `{old} interface is deprecated in favor of the "
- "`{new}` interface and will be removed in Ray 2.7."
- )
- @Deprecated(
- message=_LOGGER_DEPRECATION_WARNING.format(
- old="Logger", new="ray.tune.logger.LoggerCallback"
- ),
- )
- @DeveloperAPI
- class Logger(abc.ABC):
- """Logging interface for ray.tune.
- By default, the UnifiedLogger implementation is used which logs results in
- multiple formats (TensorBoard, rllab/viskit, plain json, custom loggers)
- at once.
- Arguments:
- config: Configuration passed to all logger creators.
- logdir: Directory for all logger creators to log to.
- trial: Trial object for the logger to access.
- """
- def __init__(self, config: Dict, logdir: str, trial: Optional["Trial"] = None):
- self.config = config
- self.logdir = logdir
- self.trial = trial
- self._init()
- def _init(self):
- pass
- def on_result(self, result):
- """Given a result, appends it to the existing log."""
- raise NotImplementedError
- def update_config(self, config):
- """Updates the config for logger."""
- pass
- def close(self):
- """Releases all resources used by this logger."""
- pass
- def flush(self):
- """Flushes all disk writes to storage."""
- pass
- @PublicAPI
- class LoggerCallback(Callback):
- """Base class for experiment-level logger callbacks
- This base class defines a general interface for logging events,
- like trial starts, restores, ends, checkpoint saves, and receiving
- trial results.
- Callbacks implementing this interface should make sure that logging
- utilities are cleaned up properly on trial termination, i.e. when
- ``log_trial_end`` is received. This includes e.g. closing files.
- """
- def log_trial_start(self, trial: "Trial"):
- """Handle logging when a trial starts.
- Args:
- trial: Trial object.
- """
- pass
- def log_trial_restore(self, trial: "Trial"):
- """Handle logging when a trial restores.
- Args:
- trial: Trial object.
- """
- pass
- def log_trial_save(self, trial: "Trial"):
- """Handle logging when a trial saves a checkpoint.
- Args:
- trial: Trial object.
- """
- pass
- def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
- """Handle logging when a trial reports a result.
- Args:
- trial: Trial object.
- result: Result dictionary.
- """
- pass
- def log_trial_end(self, trial: "Trial", failed: bool = False):
- """Handle logging when a trial ends.
- Args:
- trial: Trial object.
- failed: True if the Trial finished gracefully, False if
- it failed (e.g. when it raised an exception).
- """
- pass
- def on_trial_result(
- self,
- iteration: int,
- trials: List["Trial"],
- trial: "Trial",
- result: Dict,
- **info,
- ):
- self.log_trial_result(iteration, trial, result)
- def on_trial_start(
- self, iteration: int, trials: List["Trial"], trial: "Trial", **info
- ):
- self.log_trial_start(trial)
- def on_trial_restore(
- self, iteration: int, trials: List["Trial"], trial: "Trial", **info
- ):
- self.log_trial_restore(trial)
- def on_trial_save(
- self, iteration: int, trials: List["Trial"], trial: "Trial", **info
- ):
- self.log_trial_save(trial)
- def on_trial_complete(
- self, iteration: int, trials: List["Trial"], trial: "Trial", **info
- ):
- self.log_trial_end(trial, failed=False)
- def on_trial_error(
- self, iteration: int, trials: List["Trial"], trial: "Trial", **info
- ):
- self.log_trial_end(trial, failed=True)
- def _restore_from_remote(self, file_name: str, trial: "Trial") -> None:
- if not trial.checkpoint:
- # If there's no checkpoint, there's no logging artifacts to restore
- # since we're starting from scratch.
- return
- local_file = Path(trial.local_path, file_name).as_posix()
- remote_file = Path(trial.storage.trial_fs_path, file_name).as_posix()
- try:
- pyarrow.fs.copy_files(
- remote_file,
- local_file,
- source_filesystem=trial.storage.storage_filesystem,
- )
- logger.debug(f"Copied {remote_file} to {local_file}")
- except FileNotFoundError:
- logger.warning(f"Remote file not found: {remote_file}")
- except Exception:
- logger.exception(f"Error downloading {remote_file}")
- @DeveloperAPI
- class LegacyLoggerCallback(LoggerCallback):
- """Supports logging to trial-specific `Logger` classes.
- Previously, Ray Tune logging was handled via `Logger` classes that have
- been instantiated per-trial. This callback is a fallback to these
- `Logger`-classes, instantiating each `Logger` class for each trial
- and logging to them.
- Args:
- logger_classes: Logger classes that should
- be instantiated for each trial.
- """
- def __init__(self, logger_classes: Iterable[Type[Logger]]):
- self.logger_classes = list(logger_classes)
- self._class_trial_loggers: Dict[Type[Logger], Dict["Trial", Logger]] = {}
- def log_trial_start(self, trial: "Trial"):
- trial.init_local_path()
- for logger_class in self.logger_classes:
- trial_loggers = self._class_trial_loggers.get(logger_class, {})
- if trial not in trial_loggers:
- logger = logger_class(trial.config, trial.local_path, trial)
- trial_loggers[trial] = logger
- self._class_trial_loggers[logger_class] = trial_loggers
- def log_trial_restore(self, trial: "Trial"):
- for logger_class, trial_loggers in self._class_trial_loggers.items():
- if trial in trial_loggers:
- trial_loggers[trial].flush()
- def log_trial_save(self, trial: "Trial"):
- for logger_class, trial_loggers in self._class_trial_loggers.items():
- if trial in trial_loggers:
- trial_loggers[trial].flush()
- def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
- for logger_class, trial_loggers in self._class_trial_loggers.items():
- if trial in trial_loggers:
- trial_loggers[trial].on_result(result)
- def log_trial_end(self, trial: "Trial", failed: bool = False):
- for logger_class, trial_loggers in self._class_trial_loggers.items():
- if trial in trial_loggers:
- trial_loggers[trial].close()
- class _RayDumper(yaml.SafeDumper):
- def represent_sequence(self, tag, sequence, flow_style=None):
- if len(sequence) > _SEQUENCE_LEN_FLOW_STYLE:
- return super().represent_sequence(tag, sequence, flow_style=True)
- return super().represent_sequence(tag, sequence, flow_style=flow_style)
- @DeveloperAPI
- def pretty_print(result, exclude: Optional[Set[str]] = None):
- result = result.copy()
- result.update(config=None) # drop config from pretty print
- result.update(hist_stats=None) # drop hist_stats from pretty print
- out = {}
- for k, v in result.items():
- if v is not None and (exclude is None or k not in exclude):
- out[k] = v
- cleaned = json.dumps(out, cls=SafeFallbackEncoder)
- return yaml.dump(json.loads(cleaned), Dumper=_RayDumper, default_flow_style=False)
|