logger.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import abc
  2. import json
  3. import logging
  4. from pathlib import Path
  5. from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Type
  6. import pyarrow
  7. import yaml
  8. from ray.air._internal.json import SafeFallbackEncoder
  9. from ray.tune.callback import Callback
  10. from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI
  11. if TYPE_CHECKING:
  12. from ray.tune.experiment.trial import Trial # noqa: F401
  13. logger = logging.getLogger(__name__)
  14. # Apply flow style for sequences of this length
  15. _SEQUENCE_LEN_FLOW_STYLE = 3
  16. _LOGGER_DEPRECATION_WARNING = (
  17. "The `{old} interface is deprecated in favor of the "
  18. "`{new}` interface and will be removed in Ray 2.7."
  19. )
  20. @Deprecated(
  21. message=_LOGGER_DEPRECATION_WARNING.format(
  22. old="Logger", new="ray.tune.logger.LoggerCallback"
  23. ),
  24. )
  25. @DeveloperAPI
  26. class Logger(abc.ABC):
  27. """Logging interface for ray.tune.
  28. By default, the UnifiedLogger implementation is used which logs results in
  29. multiple formats (TensorBoard, rllab/viskit, plain json, custom loggers)
  30. at once.
  31. Arguments:
  32. config: Configuration passed to all logger creators.
  33. logdir: Directory for all logger creators to log to.
  34. trial: Trial object for the logger to access.
  35. """
  36. def __init__(self, config: Dict, logdir: str, trial: Optional["Trial"] = None):
  37. self.config = config
  38. self.logdir = logdir
  39. self.trial = trial
  40. self._init()
  41. def _init(self):
  42. pass
  43. def on_result(self, result):
  44. """Given a result, appends it to the existing log."""
  45. raise NotImplementedError
  46. def update_config(self, config):
  47. """Updates the config for logger."""
  48. pass
  49. def close(self):
  50. """Releases all resources used by this logger."""
  51. pass
  52. def flush(self):
  53. """Flushes all disk writes to storage."""
  54. pass
  55. @PublicAPI
  56. class LoggerCallback(Callback):
  57. """Base class for experiment-level logger callbacks
  58. This base class defines a general interface for logging events,
  59. like trial starts, restores, ends, checkpoint saves, and receiving
  60. trial results.
  61. Callbacks implementing this interface should make sure that logging
  62. utilities are cleaned up properly on trial termination, i.e. when
  63. ``log_trial_end`` is received. This includes e.g. closing files.
  64. """
  65. def log_trial_start(self, trial: "Trial"):
  66. """Handle logging when a trial starts.
  67. Args:
  68. trial: Trial object.
  69. """
  70. pass
  71. def log_trial_restore(self, trial: "Trial"):
  72. """Handle logging when a trial restores.
  73. Args:
  74. trial: Trial object.
  75. """
  76. pass
  77. def log_trial_save(self, trial: "Trial"):
  78. """Handle logging when a trial saves a checkpoint.
  79. Args:
  80. trial: Trial object.
  81. """
  82. pass
  83. def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
  84. """Handle logging when a trial reports a result.
  85. Args:
  86. trial: Trial object.
  87. result: Result dictionary.
  88. """
  89. pass
  90. def log_trial_end(self, trial: "Trial", failed: bool = False):
  91. """Handle logging when a trial ends.
  92. Args:
  93. trial: Trial object.
  94. failed: True if the Trial finished gracefully, False if
  95. it failed (e.g. when it raised an exception).
  96. """
  97. pass
  98. def on_trial_result(
  99. self,
  100. iteration: int,
  101. trials: List["Trial"],
  102. trial: "Trial",
  103. result: Dict,
  104. **info,
  105. ):
  106. self.log_trial_result(iteration, trial, result)
  107. def on_trial_start(
  108. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  109. ):
  110. self.log_trial_start(trial)
  111. def on_trial_restore(
  112. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  113. ):
  114. self.log_trial_restore(trial)
  115. def on_trial_save(
  116. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  117. ):
  118. self.log_trial_save(trial)
  119. def on_trial_complete(
  120. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  121. ):
  122. self.log_trial_end(trial, failed=False)
  123. def on_trial_error(
  124. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  125. ):
  126. self.log_trial_end(trial, failed=True)
  127. def _restore_from_remote(self, file_name: str, trial: "Trial") -> None:
  128. if not trial.checkpoint:
  129. # If there's no checkpoint, there's no logging artifacts to restore
  130. # since we're starting from scratch.
  131. return
  132. local_file = Path(trial.local_path, file_name).as_posix()
  133. remote_file = Path(trial.storage.trial_fs_path, file_name).as_posix()
  134. try:
  135. pyarrow.fs.copy_files(
  136. remote_file,
  137. local_file,
  138. source_filesystem=trial.storage.storage_filesystem,
  139. )
  140. logger.debug(f"Copied {remote_file} to {local_file}")
  141. except FileNotFoundError:
  142. logger.warning(f"Remote file not found: {remote_file}")
  143. except Exception:
  144. logger.exception(f"Error downloading {remote_file}")
  145. @DeveloperAPI
  146. class LegacyLoggerCallback(LoggerCallback):
  147. """Supports logging to trial-specific `Logger` classes.
  148. Previously, Ray Tune logging was handled via `Logger` classes that have
  149. been instantiated per-trial. This callback is a fallback to these
  150. `Logger`-classes, instantiating each `Logger` class for each trial
  151. and logging to them.
  152. Args:
  153. logger_classes: Logger classes that should
  154. be instantiated for each trial.
  155. """
  156. def __init__(self, logger_classes: Iterable[Type[Logger]]):
  157. self.logger_classes = list(logger_classes)
  158. self._class_trial_loggers: Dict[Type[Logger], Dict["Trial", Logger]] = {}
  159. def log_trial_start(self, trial: "Trial"):
  160. trial.init_local_path()
  161. for logger_class in self.logger_classes:
  162. trial_loggers = self._class_trial_loggers.get(logger_class, {})
  163. if trial not in trial_loggers:
  164. logger = logger_class(trial.config, trial.local_path, trial)
  165. trial_loggers[trial] = logger
  166. self._class_trial_loggers[logger_class] = trial_loggers
  167. def log_trial_restore(self, trial: "Trial"):
  168. for logger_class, trial_loggers in self._class_trial_loggers.items():
  169. if trial in trial_loggers:
  170. trial_loggers[trial].flush()
  171. def log_trial_save(self, trial: "Trial"):
  172. for logger_class, trial_loggers in self._class_trial_loggers.items():
  173. if trial in trial_loggers:
  174. trial_loggers[trial].flush()
  175. def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
  176. for logger_class, trial_loggers in self._class_trial_loggers.items():
  177. if trial in trial_loggers:
  178. trial_loggers[trial].on_result(result)
  179. def log_trial_end(self, trial: "Trial", failed: bool = False):
  180. for logger_class, trial_loggers in self._class_trial_loggers.items():
  181. if trial in trial_loggers:
  182. trial_loggers[trial].close()
  183. class _RayDumper(yaml.SafeDumper):
  184. def represent_sequence(self, tag, sequence, flow_style=None):
  185. if len(sequence) > _SEQUENCE_LEN_FLOW_STYLE:
  186. return super().represent_sequence(tag, sequence, flow_style=True)
  187. return super().represent_sequence(tag, sequence, flow_style=flow_style)
  188. @DeveloperAPI
  189. def pretty_print(result, exclude: Optional[Set[str]] = None):
  190. result = result.copy()
  191. result.update(config=None) # drop config from pretty print
  192. result.update(hist_stats=None) # drop hist_stats from pretty print
  193. out = {}
  194. for k, v in result.items():
  195. if v is not None and (exclude is None or k not in exclude):
  196. out[k] = v
  197. cleaned = json.dumps(out, cls=SafeFallbackEncoder)
  198. return yaml.dump(json.loads(cleaned), Dumper=_RayDumper, default_flow_style=False)