callback.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from typing import Any, Dict, List, Optional
  2. from ray.train import Checkpoint
  3. from ray.train.v2._internal.execution.context import TrainRunContext
  4. from ray.util.annotations import DeveloperAPI
  5. @DeveloperAPI
  6. class RayTrainCallback:
  7. """Base Ray Train callback interface."""
  8. pass
  9. @DeveloperAPI
  10. class UserCallback(RayTrainCallback):
  11. """Callback interface for custom user-defined callbacks to handling events
  12. during training.
  13. This callback is called on the Ray Train controller process, not on the
  14. worker processes.
  15. """
  16. def after_report(
  17. self,
  18. run_context: TrainRunContext,
  19. metrics: List[Dict[str, Any]],
  20. checkpoint: Optional[Checkpoint],
  21. ):
  22. """Called after all workers have reported a metric + checkpoint
  23. via `ray.train.report`.
  24. Args:
  25. run_context: The `TrainRunContext` for the current training run.
  26. metrics: A list of metric dictionaries reported by workers,
  27. where metrics[i] is the metrics dict reported by worker i.
  28. checkpoint: A Checkpoint object that has been persisted to
  29. storage. This is None if no workers reported a checkpoint
  30. (e.g. `ray.train.report(metrics, checkpoint=None)`).
  31. """
  32. pass
  33. def after_exception(
  34. self, run_context: TrainRunContext, worker_exceptions: Dict[int, Exception]
  35. ):
  36. """Called after one or more workers have raised an exception.
  37. Args:
  38. run_context: The `TrainRunContext` for the current training run.
  39. worker_exceptions: A dict from worker world rank to the exception
  40. raised by that worker.
  41. """
  42. pass