| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- import shutil
- from abc import abstractmethod
- from typing import Dict, List, Optional, Union
- from tensorflow.keras.callbacks import Callback as KerasCallback
- import ray
- from ray.train.tensorflow import TensorflowCheckpoint
- from ray.util.annotations import PublicAPI
- class _Callback(KerasCallback):
- """Base class for Ray Train's Keras callbacks."""
- _allowed = [
- "epoch_begin",
- "epoch_end",
- "train_batch_begin",
- "train_batch_end",
- "test_batch_begin",
- "test_batch_end",
- "predict_batch_begin",
- "predict_batch_end",
- "train_begin",
- "train_end",
- "test_begin",
- "test_end",
- "predict_begin",
- "predict_end",
- ]
- def __init__(self, on: Union[str, List[str]] = "validation_end"):
- super(_Callback, self).__init__()
- if not isinstance(on, list):
- on = [on]
- if any(w not in self._allowed for w in on):
- raise ValueError(
- "Invalid trigger time selected: {}. Must be one of {}".format(
- on, self._allowed
- )
- )
- self._on = on
- def _handle(self, logs: Dict, when: str):
- raise NotImplementedError
- def on_epoch_begin(self, epoch, logs=None):
- if "epoch_begin" in self._on:
- self._handle(logs, "epoch_begin")
- def on_epoch_end(self, epoch, logs=None):
- if "epoch_end" in self._on:
- self._handle(logs, "epoch_end")
- def on_train_batch_begin(self, batch, logs=None):
- if "train_batch_begin" in self._on:
- self._handle(logs, "train_batch_begin")
- def on_train_batch_end(self, batch, logs=None):
- if "train_batch_end" in self._on:
- self._handle(logs, "train_batch_end")
- def on_test_batch_begin(self, batch, logs=None):
- if "test_batch_begin" in self._on:
- self._handle(logs, "test_batch_begin")
- def on_test_batch_end(self, batch, logs=None):
- if "test_batch_end" in self._on:
- self._handle(logs, "test_batch_end")
- def on_predict_batch_begin(self, batch, logs=None):
- if "predict_batch_begin" in self._on:
- self._handle(logs, "predict_batch_begin")
- def on_predict_batch_end(self, batch, logs=None):
- if "predict_batch_end" in self._on:
- self._handle(logs, "predict_batch_end")
- def on_train_begin(self, logs=None):
- if "train_begin" in self._on:
- self._handle(logs, "train_begin")
- def on_train_end(self, logs=None):
- if "train_end" in self._on:
- self._handle(logs, "train_end")
- def on_test_begin(self, logs=None):
- if "test_begin" in self._on:
- self._handle(logs, "test_begin")
- def on_test_end(self, logs=None):
- if "test_end" in self._on:
- self._handle(logs, "test_end")
- def on_predict_begin(self, logs=None):
- if "predict_begin" in self._on:
- self._handle(logs, "predict_begin")
- def on_predict_end(self, logs=None):
- if "predict_end" in self._on:
- self._handle(logs, "predict_end")
- class RayReportCallback(_Callback):
- def __init__(
- self,
- checkpoint_on: Union[str, List[str]] = "epoch_end",
- report_metrics_on: Union[str, List[str]] = "epoch_end",
- metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
- ):
- if isinstance(checkpoint_on, str):
- checkpoint_on = [checkpoint_on]
- if isinstance(report_metrics_on, str):
- report_metrics_on = [report_metrics_on]
- on = list(set(checkpoint_on + report_metrics_on))
- super().__init__(on=on)
- self._checkpoint_on: List[str] = checkpoint_on
- self._report_metrics_on: List[str] = report_metrics_on
- self._metrics = metrics
- def _get_reported_metrics(self, logs: Dict) -> Dict:
- assert isinstance(self._metrics, (type(None), str, list, dict))
- if self._metrics is None:
- reported_metrics = logs
- elif isinstance(self._metrics, str):
- reported_metrics = {self._metrics: logs[self._metrics]}
- elif isinstance(self._metrics, list):
- reported_metrics = {metric: logs[metric] for metric in self._metrics}
- elif isinstance(self._metrics, dict):
- reported_metrics = {
- key: logs[metric] for key, metric in self._metrics.items()
- }
- assert isinstance(reported_metrics, dict)
- return reported_metrics
- @abstractmethod
- def _save_and_report_checkpoint(
- self, metrics: Dict, checkpoint: TensorflowCheckpoint
- ):
- """Save checkpoint and report metrics corresonding to this checkpoint."""
- raise NotImplementedError
- @abstractmethod
- def _report_metrics(self, metrics: Dict):
- """Report metrics."""
- raise NotImplementedError
- def _handle(self, logs: Dict, when: str):
- assert when in self._checkpoint_on or when in self._report_metrics_on
- metrics = self._get_reported_metrics(logs)
- should_checkpoint = when in self._checkpoint_on
- if should_checkpoint:
- checkpoint = TensorflowCheckpoint.from_model(self.model)
- self._save_and_report_checkpoint(metrics, checkpoint)
- # Clean up temporary checkpoint
- shutil.rmtree(checkpoint.path, ignore_errors=True)
- else:
- self._report_metrics(metrics)
- @PublicAPI(stability="alpha")
- class ReportCheckpointCallback(RayReportCallback):
- """Keras callback for Ray Train reporting and checkpointing.
- .. note::
- Metrics are always reported with checkpoints, even if the event isn't specified
- in ``report_metrics_on``.
- Example:
- .. testcode:: python
- ############# Using it in TrainSession ###############
- from ray.air.integrations.keras import ReportCheckpointCallback
- def train_loop_per_worker():
- strategy = tf.distribute.MultiWorkerMirroredStrategy()
- with strategy.scope():
- model = build_model()
- model.fit(dataset_shard, callbacks=[ReportCheckpointCallback()])
- Args:
- metrics: Metrics to report. If this is a list, each item describes
- the metric key reported to Keras, and it's reported under the
- same name. If this is a dict, each key is the name reported
- and the respective value is the metric key reported to Keras.
- If this is None, all Keras logs are reported.
- report_metrics_on: When to report metrics. Must be one of
- the Keras event hooks (less the ``on_``), e.g.
- "train_start" or "predict_end". Defaults to "epoch_end".
- checkpoint_on: When to save checkpoints. Must be one of the Keras event hooks
- (less the ``on_``), e.g. "train_start" or "predict_end". Defaults to
- "epoch_end".
- """
- def _save_and_report_checkpoint(
- self, metrics: Dict, checkpoint: TensorflowCheckpoint
- ):
- """Save checkpoint and report metrics corresonding to this checkpoint."""
- ray.train.report(metrics, checkpoint=checkpoint)
- def _report_metrics(self, metrics: Dict):
- """Report metrics."""
- ray.train.report(metrics, checkpoint=None)
|