keras.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import shutil
  2. from abc import abstractmethod
  3. from typing import Dict, List, Optional, Union
  4. from tensorflow.keras.callbacks import Callback as KerasCallback
  5. import ray
  6. from ray.train.tensorflow import TensorflowCheckpoint
  7. from ray.util.annotations import PublicAPI
  8. class _Callback(KerasCallback):
  9. """Base class for Ray Train's Keras callbacks."""
  10. _allowed = [
  11. "epoch_begin",
  12. "epoch_end",
  13. "train_batch_begin",
  14. "train_batch_end",
  15. "test_batch_begin",
  16. "test_batch_end",
  17. "predict_batch_begin",
  18. "predict_batch_end",
  19. "train_begin",
  20. "train_end",
  21. "test_begin",
  22. "test_end",
  23. "predict_begin",
  24. "predict_end",
  25. ]
  26. def __init__(self, on: Union[str, List[str]] = "validation_end"):
  27. super(_Callback, self).__init__()
  28. if not isinstance(on, list):
  29. on = [on]
  30. if any(w not in self._allowed for w in on):
  31. raise ValueError(
  32. "Invalid trigger time selected: {}. Must be one of {}".format(
  33. on, self._allowed
  34. )
  35. )
  36. self._on = on
  37. def _handle(self, logs: Dict, when: str):
  38. raise NotImplementedError
  39. def on_epoch_begin(self, epoch, logs=None):
  40. if "epoch_begin" in self._on:
  41. self._handle(logs, "epoch_begin")
  42. def on_epoch_end(self, epoch, logs=None):
  43. if "epoch_end" in self._on:
  44. self._handle(logs, "epoch_end")
  45. def on_train_batch_begin(self, batch, logs=None):
  46. if "train_batch_begin" in self._on:
  47. self._handle(logs, "train_batch_begin")
  48. def on_train_batch_end(self, batch, logs=None):
  49. if "train_batch_end" in self._on:
  50. self._handle(logs, "train_batch_end")
  51. def on_test_batch_begin(self, batch, logs=None):
  52. if "test_batch_begin" in self._on:
  53. self._handle(logs, "test_batch_begin")
  54. def on_test_batch_end(self, batch, logs=None):
  55. if "test_batch_end" in self._on:
  56. self._handle(logs, "test_batch_end")
  57. def on_predict_batch_begin(self, batch, logs=None):
  58. if "predict_batch_begin" in self._on:
  59. self._handle(logs, "predict_batch_begin")
  60. def on_predict_batch_end(self, batch, logs=None):
  61. if "predict_batch_end" in self._on:
  62. self._handle(logs, "predict_batch_end")
  63. def on_train_begin(self, logs=None):
  64. if "train_begin" in self._on:
  65. self._handle(logs, "train_begin")
  66. def on_train_end(self, logs=None):
  67. if "train_end" in self._on:
  68. self._handle(logs, "train_end")
  69. def on_test_begin(self, logs=None):
  70. if "test_begin" in self._on:
  71. self._handle(logs, "test_begin")
  72. def on_test_end(self, logs=None):
  73. if "test_end" in self._on:
  74. self._handle(logs, "test_end")
  75. def on_predict_begin(self, logs=None):
  76. if "predict_begin" in self._on:
  77. self._handle(logs, "predict_begin")
  78. def on_predict_end(self, logs=None):
  79. if "predict_end" in self._on:
  80. self._handle(logs, "predict_end")
  81. class RayReportCallback(_Callback):
  82. def __init__(
  83. self,
  84. checkpoint_on: Union[str, List[str]] = "epoch_end",
  85. report_metrics_on: Union[str, List[str]] = "epoch_end",
  86. metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
  87. ):
  88. if isinstance(checkpoint_on, str):
  89. checkpoint_on = [checkpoint_on]
  90. if isinstance(report_metrics_on, str):
  91. report_metrics_on = [report_metrics_on]
  92. on = list(set(checkpoint_on + report_metrics_on))
  93. super().__init__(on=on)
  94. self._checkpoint_on: List[str] = checkpoint_on
  95. self._report_metrics_on: List[str] = report_metrics_on
  96. self._metrics = metrics
  97. def _get_reported_metrics(self, logs: Dict) -> Dict:
  98. assert isinstance(self._metrics, (type(None), str, list, dict))
  99. if self._metrics is None:
  100. reported_metrics = logs
  101. elif isinstance(self._metrics, str):
  102. reported_metrics = {self._metrics: logs[self._metrics]}
  103. elif isinstance(self._metrics, list):
  104. reported_metrics = {metric: logs[metric] for metric in self._metrics}
  105. elif isinstance(self._metrics, dict):
  106. reported_metrics = {
  107. key: logs[metric] for key, metric in self._metrics.items()
  108. }
  109. assert isinstance(reported_metrics, dict)
  110. return reported_metrics
  111. @abstractmethod
  112. def _save_and_report_checkpoint(
  113. self, metrics: Dict, checkpoint: TensorflowCheckpoint
  114. ):
  115. """Save checkpoint and report metrics corresonding to this checkpoint."""
  116. raise NotImplementedError
  117. @abstractmethod
  118. def _report_metrics(self, metrics: Dict):
  119. """Report metrics."""
  120. raise NotImplementedError
  121. def _handle(self, logs: Dict, when: str):
  122. assert when in self._checkpoint_on or when in self._report_metrics_on
  123. metrics = self._get_reported_metrics(logs)
  124. should_checkpoint = when in self._checkpoint_on
  125. if should_checkpoint:
  126. checkpoint = TensorflowCheckpoint.from_model(self.model)
  127. self._save_and_report_checkpoint(metrics, checkpoint)
  128. # Clean up temporary checkpoint
  129. shutil.rmtree(checkpoint.path, ignore_errors=True)
  130. else:
  131. self._report_metrics(metrics)
  132. @PublicAPI(stability="alpha")
  133. class ReportCheckpointCallback(RayReportCallback):
  134. """Keras callback for Ray Train reporting and checkpointing.
  135. .. note::
  136. Metrics are always reported with checkpoints, even if the event isn't specified
  137. in ``report_metrics_on``.
  138. Example:
  139. .. testcode:: python
  140. ############# Using it in TrainSession ###############
  141. from ray.air.integrations.keras import ReportCheckpointCallback
  142. def train_loop_per_worker():
  143. strategy = tf.distribute.MultiWorkerMirroredStrategy()
  144. with strategy.scope():
  145. model = build_model()
  146. model.fit(dataset_shard, callbacks=[ReportCheckpointCallback()])
  147. Args:
  148. metrics: Metrics to report. If this is a list, each item describes
  149. the metric key reported to Keras, and it's reported under the
  150. same name. If this is a dict, each key is the name reported
  151. and the respective value is the metric key reported to Keras.
  152. If this is None, all Keras logs are reported.
  153. report_metrics_on: When to report metrics. Must be one of
  154. the Keras event hooks (less the ``on_``), e.g.
  155. "train_start" or "predict_end". Defaults to "epoch_end".
  156. checkpoint_on: When to save checkpoints. Must be one of the Keras event hooks
  157. (less the ``on_``), e.g. "train_start" or "predict_end". Defaults to
  158. "epoch_end".
  159. """
  160. def _save_and_report_checkpoint(
  161. self, metrics: Dict, checkpoint: TensorflowCheckpoint
  162. ):
  163. """Save checkpoint and report metrics corresonding to this checkpoint."""
  164. ray.train.report(metrics, checkpoint=checkpoint)
  165. def _report_metrics(self, metrics: Dict):
  166. """Report metrics."""
  167. ray.train.report(metrics, checkpoint=None)