callback.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. import glob
  2. import warnings
  3. from abc import ABCMeta
  4. from pathlib import Path
  5. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
  6. import ray.tune
  7. from ray.tune.utils.util import _atomic_save, _load_newest_checkpoint
  8. from ray.util.annotations import DeveloperAPI, PublicAPI
  9. if TYPE_CHECKING:
  10. from ray.tune.experiment import Trial
  11. from ray.tune.stopper import Stopper
  12. class _CallbackMeta(ABCMeta):
  13. """A helper metaclass to ensure container classes (e.g. CallbackList) have
  14. implemented all the callback methods (e.g. `on_*`).
  15. """
  16. def __new__(mcs, name: str, bases: Tuple[type], attrs: Dict[str, Any]) -> type:
  17. cls = super().__new__(mcs, name, bases, attrs)
  18. if mcs.need_check(cls, name, bases, attrs):
  19. mcs.check(cls, name, bases, attrs)
  20. return cls
  21. @classmethod
  22. def need_check(
  23. mcs, cls: type, name: str, bases: Tuple[type], attrs: Dict[str, Any]
  24. ) -> bool:
  25. return attrs.get("IS_CALLBACK_CONTAINER", False)
  26. @classmethod
  27. def check(
  28. mcs, cls: type, name: str, bases: Tuple[type], attrs: Dict[str, Any]
  29. ) -> None:
  30. methods = set()
  31. for base in bases:
  32. methods.update(
  33. attr_name
  34. for attr_name, attr in vars(base).items()
  35. if mcs.need_override_by_subclass(attr_name, attr)
  36. )
  37. overridden = {
  38. attr_name
  39. for attr_name, attr in attrs.items()
  40. if mcs.need_override_by_subclass(attr_name, attr)
  41. }
  42. missing = methods.difference(overridden)
  43. if missing:
  44. raise TypeError(
  45. f"Found missing callback method: {missing} "
  46. f"in class {cls.__module__}.{cls.__qualname__}."
  47. )
  48. @classmethod
  49. def need_override_by_subclass(mcs, attr_name: str, attr: Any) -> bool:
  50. return (
  51. (
  52. attr_name.startswith("on_")
  53. and not attr_name.startswith("on_trainer_init")
  54. )
  55. or attr_name == "setup"
  56. ) and callable(attr)
  57. @PublicAPI(stability="beta")
  58. class Callback(metaclass=_CallbackMeta):
  59. """Tune base callback that can be extended and passed to a ``TrialRunner``
  60. Tune callbacks are called from within the ``TrialRunner`` class. There are
  61. several hooks that can be used, all of which are found in the submethod
  62. definitions of this base class.
  63. The parameters passed to the ``**info`` dict vary between hooks. The
  64. parameters passed are described in the docstrings of the methods.
  65. This example will print a metric each time a result is received:
  66. .. testcode::
  67. from ray import tune
  68. from ray.tune import Callback
  69. class MyCallback(Callback):
  70. def on_trial_result(self, iteration, trials, trial, result,
  71. **info):
  72. print(f"Got result: {result['metric']}")
  73. def train_func(config):
  74. for i in range(10):
  75. tune.report(metric=i)
  76. tuner = tune.Tuner(
  77. train_func,
  78. run_config=tune.RunConfig(
  79. callbacks=[MyCallback()]
  80. )
  81. )
  82. tuner.fit()
  83. .. testoutput::
  84. :hide:
  85. ...
  86. """
  87. # File templates for any artifacts written by this callback
  88. # These files should live in the `trial.local_path` for each trial.
  89. # TODO(ml-team): Make this more visible to users to override. Internal use for now.
  90. _SAVED_FILE_TEMPLATES = []
  91. # arguments here match Experiment.public_spec
  92. def setup(
  93. self,
  94. stop: Optional["Stopper"] = None,
  95. num_samples: Optional[int] = None,
  96. total_num_samples: Optional[int] = None,
  97. **info,
  98. ):
  99. """Called once at the very beginning of training.
  100. Any Callback setup should be added here (setting environment
  101. variables, etc.)
  102. Arguments:
  103. stop: Stopping criteria.
  104. If ``time_budget_s`` was passed to ``tune.RunConfig``, a
  105. ``TimeoutStopper`` will be passed here, either by itself
  106. or as a part of a ``CombinedStopper``.
  107. num_samples: Number of times to sample from the
  108. hyperparameter space. Defaults to 1. If `grid_search` is
  109. provided as an argument, the grid will be repeated
  110. `num_samples` of times. If this is -1, (virtually) infinite
  111. samples are generated until a stopping condition is met.
  112. total_num_samples: Total number of samples factoring
  113. in grid search samplers.
  114. **info: Kwargs dict for forward compatibility.
  115. """
  116. pass
  117. def on_step_begin(self, iteration: int, trials: List["Trial"], **info):
  118. """Called at the start of each tuning loop step.
  119. Arguments:
  120. iteration: Number of iterations of the tuning loop.
  121. trials: List of trials.
  122. **info: Kwargs dict for forward compatibility.
  123. """
  124. pass
  125. def on_step_end(self, iteration: int, trials: List["Trial"], **info):
  126. """Called at the end of each tuning loop step.
  127. The iteration counter is increased before this hook is called.
  128. Arguments:
  129. iteration: Number of iterations of the tuning loop.
  130. trials: List of trials.
  131. **info: Kwargs dict for forward compatibility.
  132. """
  133. pass
  134. def on_trial_start(
  135. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  136. ):
  137. """Called after starting a trial instance.
  138. Arguments:
  139. iteration: Number of iterations of the tuning loop.
  140. trials: List of trials.
  141. trial: Trial that just has been started.
  142. **info: Kwargs dict for forward compatibility.
  143. """
  144. pass
  145. def on_trial_restore(
  146. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  147. ):
  148. """Called after restoring a trial instance.
  149. Arguments:
  150. iteration: Number of iterations of the tuning loop.
  151. trials: List of trials.
  152. trial: Trial that just has been restored.
  153. **info: Kwargs dict for forward compatibility.
  154. """
  155. pass
  156. def on_trial_save(
  157. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  158. ):
  159. """Called after receiving a checkpoint from a trial.
  160. Arguments:
  161. iteration: Number of iterations of the tuning loop.
  162. trials: List of trials.
  163. trial: Trial that just saved a checkpoint.
  164. **info: Kwargs dict for forward compatibility.
  165. """
  166. pass
  167. def on_trial_result(
  168. self,
  169. iteration: int,
  170. trials: List["Trial"],
  171. trial: "Trial",
  172. result: Dict,
  173. **info,
  174. ):
  175. """Called after receiving a result from a trial.
  176. The search algorithm and scheduler are notified before this
  177. hook is called.
  178. Arguments:
  179. iteration: Number of iterations of the tuning loop.
  180. trials: List of trials.
  181. trial: Trial that just sent a result.
  182. result: Result that the trial sent.
  183. **info: Kwargs dict for forward compatibility.
  184. """
  185. pass
  186. def on_trial_complete(
  187. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  188. ):
  189. """Called after a trial instance completed.
  190. The search algorithm and scheduler are notified before this
  191. hook is called.
  192. Arguments:
  193. iteration: Number of iterations of the tuning loop.
  194. trials: List of trials.
  195. trial: Trial that just has been completed.
  196. **info: Kwargs dict for forward compatibility.
  197. """
  198. pass
  199. def on_trial_recover(
  200. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  201. ):
  202. """Called after a trial instance failed (errored) but the trial is scheduled
  203. for retry.
  204. The search algorithm and scheduler are not notified.
  205. Arguments:
  206. iteration: Number of iterations of the tuning loop.
  207. trials: List of trials.
  208. trial: Trial that just has errored.
  209. **info: Kwargs dict for forward compatibility.
  210. """
  211. pass
  212. def on_trial_error(
  213. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  214. ):
  215. """Called after a trial instance failed (errored).
  216. The search algorithm and scheduler are notified before this
  217. hook is called.
  218. Arguments:
  219. iteration: Number of iterations of the tuning loop.
  220. trials: List of trials.
  221. trial: Trial that just has errored.
  222. **info: Kwargs dict for forward compatibility.
  223. """
  224. pass
  225. def on_checkpoint(
  226. self,
  227. iteration: int,
  228. trials: List["Trial"],
  229. trial: "Trial",
  230. checkpoint: "ray.tune.Checkpoint",
  231. **info,
  232. ):
  233. """Called after a trial saved a checkpoint with Tune.
  234. Arguments:
  235. iteration: Number of iterations of the tuning loop.
  236. trials: List of trials.
  237. trial: Trial that just has errored.
  238. checkpoint: Checkpoint object that has been saved
  239. by the trial.
  240. **info: Kwargs dict for forward compatibility.
  241. """
  242. pass
  243. def on_experiment_end(self, trials: List["Trial"], **info):
  244. """Called after experiment is over and all trials have concluded.
  245. Arguments:
  246. trials: List of trials.
  247. **info: Kwargs dict for forward compatibility.
  248. """
  249. pass
  250. def get_state(self) -> Optional[Dict]:
  251. """Get the state of the callback.
  252. This method should be implemented by subclasses to return a dictionary
  253. representation of the object's current state.
  254. This is called automatically by Tune to periodically checkpoint callback state.
  255. Upon :ref:`Tune experiment restoration <tune-experiment-level-fault-tolerance>`,
  256. callback state will be restored via :meth:`~ray.tune.Callback.set_state`.
  257. .. testcode::
  258. from typing import Dict, List, Optional
  259. from ray.tune import Callback
  260. from ray.tune.experiment import Trial
  261. class MyCallback(Callback):
  262. def __init__(self):
  263. self._trial_ids = set()
  264. def on_trial_start(
  265. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  266. ):
  267. self._trial_ids.add(trial.trial_id)
  268. def get_state(self) -> Optional[Dict]:
  269. return {"trial_ids": self._trial_ids.copy()}
  270. def set_state(self, state: Dict) -> Optional[Dict]:
  271. self._trial_ids = state["trial_ids"]
  272. Returns:
  273. dict: State of the callback. Should be `None` if the callback does not
  274. have any state to save (this is the default).
  275. """
  276. return None
  277. def set_state(self, state: Dict):
  278. """Set the state of the callback.
  279. This method should be implemented by subclasses to restore the callback's
  280. state based on the given dict state.
  281. This is used automatically by Tune to restore checkpoint callback state
  282. on :ref:`Tune experiment restoration <tune-experiment-level-fault-tolerance>`.
  283. See :meth:`~ray.tune.Callback.get_state` for an example implementation.
  284. Args:
  285. state: State of the callback.
  286. """
  287. pass
  288. @DeveloperAPI
  289. class CallbackList(Callback):
  290. """Call multiple callbacks at once."""
  291. IS_CALLBACK_CONTAINER = True
  292. CKPT_FILE_TMPL = "callback-states-{}.pkl"
  293. def __init__(self, callbacks: List[Callback]):
  294. self._callbacks = callbacks
  295. def setup(self, **info):
  296. for callback in self._callbacks:
  297. try:
  298. callback.setup(**info)
  299. except TypeError as e:
  300. if "argument" in str(e):
  301. warnings.warn(
  302. "Please update `setup` method in callback "
  303. f"`{callback.__class__}` to match the method signature"
  304. " in `ray.tune.callback.Callback`.",
  305. FutureWarning,
  306. )
  307. callback.setup()
  308. else:
  309. raise e
  310. def on_step_begin(self, **info):
  311. for callback in self._callbacks:
  312. callback.on_step_begin(**info)
  313. def on_step_end(self, **info):
  314. for callback in self._callbacks:
  315. callback.on_step_end(**info)
  316. def on_trial_start(self, **info):
  317. for callback in self._callbacks:
  318. callback.on_trial_start(**info)
  319. def on_trial_restore(self, **info):
  320. for callback in self._callbacks:
  321. callback.on_trial_restore(**info)
  322. def on_trial_save(self, **info):
  323. for callback in self._callbacks:
  324. callback.on_trial_save(**info)
  325. def on_trial_result(self, **info):
  326. for callback in self._callbacks:
  327. callback.on_trial_result(**info)
  328. def on_trial_complete(self, **info):
  329. for callback in self._callbacks:
  330. callback.on_trial_complete(**info)
  331. def on_trial_recover(self, **info):
  332. for callback in self._callbacks:
  333. callback.on_trial_recover(**info)
  334. def on_trial_error(self, **info):
  335. for callback in self._callbacks:
  336. callback.on_trial_error(**info)
  337. def on_checkpoint(self, **info):
  338. for callback in self._callbacks:
  339. callback.on_checkpoint(**info)
  340. def on_experiment_end(self, **info):
  341. for callback in self._callbacks:
  342. callback.on_experiment_end(**info)
  343. def get_state(self) -> Optional[Dict]:
  344. """Gets the state of all callbacks contained within this list.
  345. If there are no stateful callbacks, then None will be returned in order
  346. to avoid saving an unnecessary callback checkpoint file."""
  347. state = {}
  348. any_stateful_callbacks = False
  349. for i, callback in enumerate(self._callbacks):
  350. callback_state = callback.get_state()
  351. if callback_state:
  352. any_stateful_callbacks = True
  353. state[i] = callback_state
  354. if not any_stateful_callbacks:
  355. return None
  356. return state
  357. def set_state(self, state: Dict):
  358. """Sets the state for all callbacks contained within this list.
  359. Skips setting state for all stateless callbacks where `get_state`
  360. returned None."""
  361. for i, callback in enumerate(self._callbacks):
  362. callback_state = state.get(i, None)
  363. if callback_state:
  364. callback.set_state(callback_state)
  365. def save_to_dir(self, checkpoint_dir: str, session_str: str = "default"):
  366. """Save the state of the callback list to the checkpoint_dir.
  367. Args:
  368. checkpoint_dir: directory where the checkpoint is stored.
  369. session_str: Unique identifier of the current run session (ex: timestamp).
  370. """
  371. state_dict = self.get_state()
  372. if state_dict:
  373. file_name = self.CKPT_FILE_TMPL.format(session_str)
  374. tmp_file_name = f"tmp-{file_name}"
  375. _atomic_save(
  376. state=state_dict,
  377. checkpoint_dir=checkpoint_dir,
  378. file_name=file_name,
  379. tmp_file_name=tmp_file_name,
  380. )
  381. def restore_from_dir(self, checkpoint_dir: str):
  382. """Restore the state of the list of callbacks from the checkpoint_dir.
  383. You should check if it's possible to restore with `can_restore`
  384. before calling this method.
  385. Args:
  386. checkpoint_dir: directory where the checkpoint is stored.
  387. Raises:
  388. RuntimeError: if unable to find checkpoint.
  389. NotImplementedError: if the `set_state` method is not implemented.
  390. """
  391. state_dict = _load_newest_checkpoint(
  392. checkpoint_dir, self.CKPT_FILE_TMPL.format("*")
  393. )
  394. if not state_dict:
  395. raise RuntimeError(
  396. "Unable to find checkpoint in {}.".format(checkpoint_dir)
  397. )
  398. self.set_state(state_dict)
  399. def can_restore(self, checkpoint_dir: str) -> bool:
  400. """Check if the checkpoint_dir contains the saved state for this callback list.
  401. Returns:
  402. can_restore: True if the checkpoint_dir contains a file of the
  403. format `CKPT_FILE_TMPL`. False otherwise.
  404. """
  405. return any(
  406. glob.iglob(Path(checkpoint_dir, self.CKPT_FILE_TMPL.format("*")).as_posix())
  407. )
  408. def __len__(self) -> int:
  409. return len(self._callbacks)
  410. def __getitem__(self, i: int) -> "Callback":
  411. return self._callbacks[i]