| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- from typing import Dict
- import ray.tune
- from ray.train.tensorflow import TensorflowCheckpoint
- from ray.train.tensorflow.keras import RayReportCallback
- from ray.util.annotations import PublicAPI
- _DEPRECATION_MESSAGE = (
- "The `ray.tune.integration.keras` module is deprecated in favor of "
- "`ray.train.tensorflow.keras.ReportCheckpointCallback`."
- )
- class TuneReportCallback:
- """Deprecated.
- Use :class:`ray.train.tensorflow.keras.ReportCheckpointCallback` instead."""
- def __new__(cls, *args, **kwargs):
- raise DeprecationWarning(_DEPRECATION_MESSAGE)
- class _TuneCheckpointCallback:
- """Deprecated.
- Use :class:`ray.train.tensorflow.keras.ReportCheckpointCallback` instead."""
- def __new__(cls, *args, **kwargs):
- raise DeprecationWarning(_DEPRECATION_MESSAGE)
- @PublicAPI(stability="alpha")
- class TuneReportCheckpointCallback(RayReportCallback):
- """Keras callback for Ray Tune reporting and checkpointing.
- .. note::
- Metrics are always reported with checkpoints, even if the event isn't specified
- in ``report_metrics_on``.
- Example:
- .. code-block:: python
- ############# Using it in Ray Tune ###############
- from ray.tune.integrations.keras import TuneReportCheckpointCallback
- def train_fn():
- model = build_model()
- model.fit(dataset_shard, callbacks=[TuneReportCheckpointCallback()])
- tuner = tune.Tuner(train_fn)
- results = tuner.fit()
- Args:
- metrics: Metrics to report. If this is a list, each item describes
- the metric key reported to Keras, and it's reported under the
- same name. If this is a dict, each key is the name reported
- and the respective value is the metric key reported to Keras.
- If this is None, all Keras logs are reported.
- report_metrics_on: When to report metrics. Must be one of
- the Keras event hooks (less the ``on_``), e.g.
- "train_start" or "predict_end". Defaults to "epoch_end".
- checkpoint_on: When to save checkpoints. Must be one of the Keras event hooks
- (less the ``on_``), e.g. "train_start" or "predict_end". Defaults to
- "epoch_end".
- """
- def _save_and_report_checkpoint(
- self, metrics: Dict, checkpoint: TensorflowCheckpoint
- ):
- ray.tune.report(metrics, checkpoint=checkpoint)
- def _report_metrics(self, metrics: Dict):
- ray.tune.report(metrics)
|