keras.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import shutil
  2. from typing import Dict, List, Optional, Union
  3. from tensorflow.keras.callbacks import Callback as KerasCallback
  4. import ray
  5. from ray.train.tensorflow import TensorflowCheckpoint
  6. from ray.util.annotations import PublicAPI
  7. class _Callback(KerasCallback):
  8. """Base class for Air's Keras callbacks."""
  9. _allowed = [
  10. "epoch_begin",
  11. "epoch_end",
  12. "train_batch_begin",
  13. "train_batch_end",
  14. "test_batch_begin",
  15. "test_batch_end",
  16. "predict_batch_begin",
  17. "predict_batch_end",
  18. "train_begin",
  19. "train_end",
  20. "test_begin",
  21. "test_end",
  22. "predict_begin",
  23. "predict_end",
  24. ]
  25. def __init__(self, on: Union[str, List[str]] = "validation_end"):
  26. super(_Callback, self).__init__()
  27. if not isinstance(on, list):
  28. on = [on]
  29. if any(w not in self._allowed for w in on):
  30. raise ValueError(
  31. "Invalid trigger time selected: {}. Must be one of {}".format(
  32. on, self._allowed
  33. )
  34. )
  35. self._on = on
  36. def _handle(self, logs: Dict, when: str):
  37. raise NotImplementedError
  38. def on_epoch_begin(self, epoch, logs=None):
  39. if "epoch_begin" in self._on:
  40. self._handle(logs, "epoch_begin")
  41. def on_epoch_end(self, epoch, logs=None):
  42. if "epoch_end" in self._on:
  43. self._handle(logs, "epoch_end")
  44. def on_train_batch_begin(self, batch, logs=None):
  45. if "train_batch_begin" in self._on:
  46. self._handle(logs, "train_batch_begin")
  47. def on_train_batch_end(self, batch, logs=None):
  48. if "train_batch_end" in self._on:
  49. self._handle(logs, "train_batch_end")
  50. def on_test_batch_begin(self, batch, logs=None):
  51. if "test_batch_begin" in self._on:
  52. self._handle(logs, "test_batch_begin")
  53. def on_test_batch_end(self, batch, logs=None):
  54. if "test_batch_end" in self._on:
  55. self._handle(logs, "test_batch_end")
  56. def on_predict_batch_begin(self, batch, logs=None):
  57. if "predict_batch_begin" in self._on:
  58. self._handle(logs, "predict_batch_begin")
  59. def on_predict_batch_end(self, batch, logs=None):
  60. if "predict_batch_end" in self._on:
  61. self._handle(logs, "predict_batch_end")
  62. def on_train_begin(self, logs=None):
  63. if "train_begin" in self._on:
  64. self._handle(logs, "train_begin")
  65. def on_train_end(self, logs=None):
  66. if "train_end" in self._on:
  67. self._handle(logs, "train_end")
  68. def on_test_begin(self, logs=None):
  69. if "test_begin" in self._on:
  70. self._handle(logs, "test_begin")
  71. def on_test_end(self, logs=None):
  72. if "test_end" in self._on:
  73. self._handle(logs, "test_end")
  74. def on_predict_begin(self, logs=None):
  75. if "predict_begin" in self._on:
  76. self._handle(logs, "predict_begin")
  77. def on_predict_end(self, logs=None):
  78. if "predict_end" in self._on:
  79. self._handle(logs, "predict_end")
  80. @PublicAPI(stability="alpha")
  81. class ReportCheckpointCallback(_Callback):
  82. """Keras callback for Ray Train reporting and checkpointing.
  83. .. note::
  84. Metrics are always reported with checkpoints, even if the event isn't specified
  85. in ``report_metrics_on``.
  86. Example:
  87. .. code-block:: python
  88. ############# Using it in TrainSession ###############
  89. from ray.air.integrations.keras import ReportCheckpointCallback
  90. def train_loop_per_worker():
  91. strategy = tf.distribute.MultiWorkerMirroredStrategy()
  92. with strategy.scope():
  93. model = build_model()
  94. model.fit(dataset_shard, callbacks=[ReportCheckpointCallback()])
  95. Args:
  96. metrics: Metrics to report. If this is a list, each item describes
  97. the metric key reported to Keras, and it's reported under the
  98. same name. If this is a dict, each key is the name reported
  99. and the respective value is the metric key reported to Keras.
  100. If this is None, all Keras logs are reported.
  101. report_metrics_on: When to report metrics. Must be one of
  102. the Keras event hooks (less the ``on_``), e.g.
  103. "train_start" or "predict_end". Defaults to "epoch_end".
  104. checkpoint_on: When to save checkpoints. Must be one of the Keras event hooks
  105. (less the ``on_``), e.g. "train_start" or "predict_end". Defaults to
  106. "epoch_end".
  107. """
  108. def __init__(
  109. self,
  110. checkpoint_on: Union[str, List[str]] = "epoch_end",
  111. report_metrics_on: Union[str, List[str]] = "epoch_end",
  112. metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
  113. ):
  114. if isinstance(checkpoint_on, str):
  115. checkpoint_on = [checkpoint_on]
  116. if isinstance(report_metrics_on, str):
  117. report_metrics_on = [report_metrics_on]
  118. on = list(set(checkpoint_on + report_metrics_on))
  119. super().__init__(on=on)
  120. self._checkpoint_on: List[str] = checkpoint_on
  121. self._report_metrics_on: List[str] = report_metrics_on
  122. self._metrics = metrics
  123. def _handle(self, logs: Dict, when: str):
  124. assert when in self._checkpoint_on or when in self._report_metrics_on
  125. metrics = self._get_reported_metrics(logs)
  126. should_checkpoint = when in self._checkpoint_on
  127. if should_checkpoint:
  128. checkpoint = TensorflowCheckpoint.from_model(self.model)
  129. ray.train.report(metrics, checkpoint=checkpoint)
  130. # Clean up temporary checkpoint
  131. shutil.rmtree(checkpoint.path, ignore_errors=True)
  132. else:
  133. ray.train.report(metrics, checkpoint=None)
  134. def _get_reported_metrics(self, logs: Dict) -> Dict:
  135. assert isinstance(self._metrics, (type(None), str, list, dict))
  136. if self._metrics is None:
  137. reported_metrics = logs
  138. elif isinstance(self._metrics, str):
  139. reported_metrics = {self._metrics: logs[self._metrics]}
  140. elif isinstance(self._metrics, list):
  141. reported_metrics = {metric: logs[metric] for metric in self._metrics}
  142. elif isinstance(self._metrics, dict):
  143. reported_metrics = {
  144. key: logs[metric] for key, metric in self._metrics.items()
  145. }
  146. assert isinstance(reported_metrics, dict)
  147. return reported_metrics