callback.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import importlib
  2. import logging
  3. import os
  4. from typing import TYPE_CHECKING, Collection, List, Optional, Type, Union
  5. from ray.tune.callback import Callback, CallbackList
  6. from ray.tune.constants import RAY_TUNE_CALLBACKS_ENV_VAR
  7. from ray.tune.logger import (
  8. CSVLogger,
  9. CSVLoggerCallback,
  10. JsonLogger,
  11. JsonLoggerCallback,
  12. LegacyLoggerCallback,
  13. TBXLogger,
  14. TBXLoggerCallback,
  15. )
  16. logger = logging.getLogger(__name__)
  17. if TYPE_CHECKING:
  18. from ray.tune.experimental.output import AirVerbosity
  19. DEFAULT_CALLBACK_CLASSES = (
  20. CSVLoggerCallback,
  21. JsonLoggerCallback,
  22. TBXLoggerCallback,
  23. )
  24. def _get_artifact_templates_for_callbacks(
  25. callbacks: Union[List[Callback], List[Type[Callback]], CallbackList]
  26. ) -> List[str]:
  27. templates = []
  28. for callback in callbacks:
  29. templates += list(callback._SAVED_FILE_TEMPLATES)
  30. return templates
  31. def _create_default_callbacks(
  32. callbacks: Optional[List[Callback]],
  33. *,
  34. air_verbosity: Optional["AirVerbosity"] = None,
  35. entrypoint: Optional[str] = None,
  36. metric: Optional[str] = None,
  37. mode: Optional[str] = None,
  38. config: Optional[dict] = None,
  39. progress_metrics: Optional[Collection[str]] = None,
  40. ) -> List[Callback]:
  41. """Create default callbacks for `Tuner.fit()`.
  42. This function takes a list of existing callbacks and adds default
  43. callbacks to it.
  44. Specifically, three kinds of callbacks will be added:
  45. 1. Loggers. Ray Tune's experiment analysis relies on CSV and JSON logging.
  46. 2. Syncer. Ray Tune synchronizes logs and checkpoint between workers and
  47. the head node.
  48. 2. Trial progress reporter. For reporting intermediate progress, like trial
  49. results, Ray Tune uses a callback.
  50. These callbacks will only be added if they don't already exist, i.e. if
  51. they haven't been passed (and configured) by the user. A notable case
  52. is when a Logger is passed, which is not a CSV or JSON logger - then
  53. a CSV and JSON logger will still be created.
  54. Lastly, this function will ensure that the Syncer callback comes after all
  55. Logger callbacks, to ensure that the most up-to-date logs and checkpoints
  56. are synced across nodes.
  57. """
  58. callbacks = callbacks or []
  59. # Initialize callbacks from environment variable
  60. env_callbacks = _initialize_env_callbacks()
  61. callbacks.extend(env_callbacks)
  62. has_csv_logger = False
  63. has_json_logger = False
  64. has_tbx_logger = False
  65. from ray.tune.progress_reporter import TrialProgressCallback
  66. has_trial_progress_callback = any(
  67. isinstance(c, TrialProgressCallback) for c in callbacks
  68. )
  69. if has_trial_progress_callback and air_verbosity is not None:
  70. logger.warning(
  71. "AIR_VERBOSITY is set, ignoring passed-in TrialProgressCallback."
  72. )
  73. new_callbacks = [
  74. c for c in callbacks if not isinstance(c, TrialProgressCallback)
  75. ]
  76. callbacks = new_callbacks
  77. if air_verbosity is not None: # new flow
  78. from ray.tune.experimental.output import (
  79. _detect_reporter as _detect_air_reporter,
  80. )
  81. air_progress_reporter = _detect_air_reporter(
  82. air_verbosity,
  83. num_samples=1, # Update later with setup()
  84. entrypoint=entrypoint,
  85. metric=metric,
  86. mode=mode,
  87. config=config,
  88. progress_metrics=progress_metrics,
  89. )
  90. callbacks.append(air_progress_reporter)
  91. elif not has_trial_progress_callback: # old flow
  92. trial_progress_callback = TrialProgressCallback(
  93. metric=metric, progress_metrics=progress_metrics
  94. )
  95. callbacks.append(trial_progress_callback)
  96. # Check if we have a CSV, JSON and TensorboardX logger
  97. for i, callback in enumerate(callbacks):
  98. if isinstance(callback, LegacyLoggerCallback):
  99. if CSVLogger in callback.logger_classes:
  100. has_csv_logger = True
  101. if JsonLogger in callback.logger_classes:
  102. has_json_logger = True
  103. if TBXLogger in callback.logger_classes:
  104. has_tbx_logger = True
  105. elif isinstance(callback, CSVLoggerCallback):
  106. has_csv_logger = True
  107. elif isinstance(callback, JsonLoggerCallback):
  108. has_json_logger = True
  109. elif isinstance(callback, TBXLoggerCallback):
  110. has_tbx_logger = True
  111. # If CSV, JSON or TensorboardX loggers are missing, add
  112. if os.environ.get("TUNE_DISABLE_AUTO_CALLBACK_LOGGERS", "0") != "1":
  113. if not has_csv_logger:
  114. callbacks.append(CSVLoggerCallback())
  115. if not has_json_logger:
  116. callbacks.append(JsonLoggerCallback())
  117. if not has_tbx_logger:
  118. try:
  119. callbacks.append(TBXLoggerCallback())
  120. except ImportError:
  121. logger.warning(
  122. "The TensorboardX logger cannot be instantiated because "
  123. "either TensorboardX or one of it's dependencies is not "
  124. "installed. Please make sure you have the latest version "
  125. "of TensorboardX installed: `pip install -U tensorboardx`"
  126. )
  127. return callbacks
  128. def _initialize_env_callbacks() -> List[Callback]:
  129. """Initialize callbacks from environment variable.
  130. Returns:
  131. List of callbacks initialized from environment variable.
  132. """
  133. callbacks = []
  134. callbacks_str = os.environ.get(RAY_TUNE_CALLBACKS_ENV_VAR, "")
  135. if not callbacks_str:
  136. return callbacks
  137. for callback_path in callbacks_str.split(","):
  138. callback_path = callback_path.strip()
  139. if not callback_path:
  140. continue
  141. try:
  142. module_path, class_name = callback_path.rsplit(".", 1)
  143. module = importlib.import_module(module_path)
  144. callback_cls = getattr(module, class_name)
  145. if not issubclass(callback_cls, Callback):
  146. raise TypeError(
  147. f"Callback class '{callback_path}' must be a subclass of "
  148. f"Callback, got {type(callback_cls).__name__}"
  149. )
  150. callback = callback_cls()
  151. callbacks.append(callback)
  152. except (ImportError, AttributeError, ValueError, TypeError) as e:
  153. raise ValueError(f"Failed to import callback from '{callback_path}'") from e
  154. return callbacks