pytorch_lightning.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import inspect
  2. import logging
  3. import os
  4. import tempfile
  5. import warnings
  6. from contextlib import contextmanager
  7. from typing import Dict, List, Optional, Type, Union
  8. import ray.tune
  9. from ray.tune import Checkpoint
  10. from ray.util import log_once
  11. from ray.util.annotations import Deprecated, PublicAPI
  12. try:
  13. from lightning import Callback, LightningModule, Trainer
  14. except ModuleNotFoundError:
  15. from pytorch_lightning import Callback, LightningModule, Trainer
  16. logger = logging.getLogger(__name__)
  17. # Get all Pytorch Lightning Callback hooks based on whatever PTL version is being used.
  18. _allowed_hooks = {
  19. name
  20. for name, fn in inspect.getmembers(Callback, predicate=inspect.isfunction)
  21. if name.startswith("on_")
  22. }
  23. def _override_ptl_hooks(callback_cls: Type["TuneCallback"]) -> Type["TuneCallback"]:
  24. """Overrides all allowed PTL Callback hooks with our custom handle logic."""
  25. def generate_overridden_hook(fn_name):
  26. def overridden_hook(
  27. self,
  28. trainer: Trainer,
  29. *args,
  30. pl_module: Optional[LightningModule] = None,
  31. **kwargs,
  32. ):
  33. if fn_name in self._on:
  34. self._handle(trainer=trainer, pl_module=pl_module)
  35. return overridden_hook
  36. # Set the overridden hook to all the allowed hooks in TuneCallback.
  37. for fn_name in _allowed_hooks:
  38. setattr(callback_cls, fn_name, generate_overridden_hook(fn_name))
  39. return callback_cls
  40. @_override_ptl_hooks
  41. class TuneCallback(Callback):
  42. """Base class for Tune's PyTorch Lightning callbacks.
  43. Args:
  44. on: When to trigger checkpoint creations. Must be one of
  45. the PyTorch Lightning event hooks (less the ``on_``), e.g.
  46. "train_batch_start", or "train_end". Defaults to "validation_end"
  47. """
  48. def __init__(self, on: Union[str, List[str]] = "validation_end"):
  49. if not isinstance(on, list):
  50. on = [on]
  51. for hook in on:
  52. if f"on_{hook}" not in _allowed_hooks:
  53. raise ValueError(
  54. f"Invalid hook selected: {hook}. Must be one of "
  55. f"{_allowed_hooks}"
  56. )
  57. # Add back the "on_" prefix for internal consistency.
  58. on = [f"on_{hook}" for hook in on]
  59. self._on = on
  60. def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]):
  61. raise NotImplementedError
  62. @PublicAPI
  63. class TuneReportCheckpointCallback(TuneCallback):
  64. """PyTorch Lightning report and checkpoint callback
  65. Saves checkpoints after each validation step. Also reports metrics to Tune,
  66. which is needed for checkpoint registration.
  67. Args:
  68. metrics: Metrics to report to Tune. If this is a list,
  69. each item describes the metric key reported to PyTorch Lightning,
  70. and it will reported under the same name to Tune. If this is a
  71. dict, each key will be the name reported to Tune and the respective
  72. value will be the metric key reported to PyTorch Lightning.
  73. filename: Filename of the checkpoint within the checkpoint
  74. directory. Defaults to "checkpoint".
  75. save_checkpoints: If True (default), checkpoints will be saved and
  76. reported to Ray. If False, only metrics will be reported.
  77. on: When to trigger checkpoint creations and metric reports. Must be one of
  78. the PyTorch Lightning event hooks (less the ``on_``), e.g.
  79. "train_batch_start", or "train_end". Defaults to "validation_end".
  80. Example:
  81. .. code-block:: python
  82. import pytorch_lightning as pl
  83. from ray.tune.integration.pytorch_lightning import (
  84. TuneReportCheckpointCallback)
  85. # Save checkpoint after each training batch and after each
  86. # validation epoch.
  87. trainer = pl.Trainer(callbacks=[TuneReportCheckpointCallback(
  88. metrics={"loss": "val_loss", "mean_accuracy": "val_acc"},
  89. filename="trainer.ckpt", on="validation_end")])
  90. """
  91. def __init__(
  92. self,
  93. metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
  94. filename: str = "checkpoint",
  95. save_checkpoints: bool = True,
  96. on: Union[str, List[str]] = "validation_end",
  97. ):
  98. super(TuneReportCheckpointCallback, self).__init__(on=on)
  99. if isinstance(metrics, str):
  100. metrics = [metrics]
  101. self._save_checkpoints = save_checkpoints
  102. self._filename = filename
  103. self._metrics = metrics
  104. def _get_report_dict(self, trainer: Trainer, pl_module: LightningModule):
  105. # Don't report if just doing initial validation sanity checks.
  106. if trainer.sanity_checking:
  107. return
  108. if not self._metrics:
  109. report_dict = {k: v.item() for k, v in trainer.callback_metrics.items()}
  110. else:
  111. report_dict = {}
  112. for key in self._metrics:
  113. if isinstance(self._metrics, dict):
  114. metric = self._metrics[key]
  115. else:
  116. metric = key
  117. if metric in trainer.callback_metrics:
  118. report_dict[key] = trainer.callback_metrics[metric].item()
  119. else:
  120. logger.warning(
  121. f"Metric {metric} does not exist in "
  122. "`trainer.callback_metrics."
  123. )
  124. return report_dict
  125. @contextmanager
  126. def _get_checkpoint(self, trainer: Trainer) -> Optional[Checkpoint]:
  127. if not self._save_checkpoints:
  128. yield None
  129. return
  130. with tempfile.TemporaryDirectory() as checkpoint_dir:
  131. trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename))
  132. checkpoint = Checkpoint.from_directory(checkpoint_dir)
  133. yield checkpoint
  134. def _handle(self, trainer: Trainer, pl_module: LightningModule):
  135. if trainer.sanity_checking:
  136. return
  137. report_dict = self._get_report_dict(trainer, pl_module)
  138. if not report_dict:
  139. return
  140. with self._get_checkpoint(trainer) as checkpoint:
  141. ray.tune.report(report_dict, checkpoint=checkpoint)
  142. class _TuneCheckpointCallback(TuneCallback):
  143. def __init__(self, *args, **kwargs):
  144. raise DeprecationWarning(
  145. "`ray.tune.integration.pytorch_lightning._TuneCheckpointCallback` "
  146. "is deprecated."
  147. )
  148. @Deprecated
  149. class TuneReportCallback(TuneReportCheckpointCallback):
  150. def __init__(
  151. self,
  152. metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
  153. on: Union[str, List[str]] = "validation_end",
  154. ):
  155. if log_once("tune_ptl_report_deprecated"):
  156. warnings.warn(
  157. "`ray.tune.integration.pytorch_lightning.TuneReportCallback` "
  158. "is deprecated. Use "
  159. "`ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback`"
  160. " instead."
  161. )
  162. super(TuneReportCallback, self).__init__(
  163. metrics=metrics, save_checkpoints=False, on=on
  164. )