keras.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from typing import Dict
  2. import ray.tune
  3. from ray.train.tensorflow import TensorflowCheckpoint
  4. from ray.train.tensorflow.keras import RayReportCallback
  5. from ray.util.annotations import PublicAPI
  6. _DEPRECATION_MESSAGE = (
  7. "The `ray.tune.integration.keras` module is deprecated in favor of "
  8. "`ray.train.tensorflow.keras.ReportCheckpointCallback`."
  9. )
  10. class TuneReportCallback:
  11. """Deprecated.
  12. Use :class:`ray.train.tensorflow.keras.ReportCheckpointCallback` instead."""
  13. def __new__(cls, *args, **kwargs):
  14. raise DeprecationWarning(_DEPRECATION_MESSAGE)
  15. class _TuneCheckpointCallback:
  16. """Deprecated.
  17. Use :class:`ray.train.tensorflow.keras.ReportCheckpointCallback` instead."""
  18. def __new__(cls, *args, **kwargs):
  19. raise DeprecationWarning(_DEPRECATION_MESSAGE)
  20. @PublicAPI(stability="alpha")
  21. class TuneReportCheckpointCallback(RayReportCallback):
  22. """Keras callback for Ray Tune reporting and checkpointing.
  23. .. note::
  24. Metrics are always reported with checkpoints, even if the event isn't specified
  25. in ``report_metrics_on``.
  26. Example:
  27. .. code-block:: python
  28. ############# Using it in Ray Tune ###############
  29. from ray.tune.integrations.keras import TuneReportCheckpointCallback
  30. def train_fn():
  31. model = build_model()
  32. model.fit(dataset_shard, callbacks=[TuneReportCheckpointCallback()])
  33. tuner = tune.Tuner(train_fn)
  34. results = tuner.fit()
  35. Args:
  36. metrics: Metrics to report. If this is a list, each item describes
  37. the metric key reported to Keras, and it's reported under the
  38. same name. If this is a dict, each key is the name reported
  39. and the respective value is the metric key reported to Keras.
  40. If this is None, all Keras logs are reported.
  41. report_metrics_on: When to report metrics. Must be one of
  42. the Keras event hooks (less the ``on_``), e.g.
  43. "train_start" or "predict_end". Defaults to "epoch_end".
  44. checkpoint_on: When to save checkpoints. Must be one of the Keras event hooks
  45. (less the ``on_``), e.g. "train_start" or "predict_end". Defaults to
  46. "epoch_end".
  47. """
  48. def _save_and_report_checkpoint(
  49. self, metrics: Dict, checkpoint: TensorflowCheckpoint
  50. ):
  51. ray.tune.report(metrics, checkpoint=checkpoint)
  52. def _report_metrics(self, metrics: Dict):
  53. ray.tune.report(metrics)