exceptions.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from typing import Dict
  2. from ray.train.v2._internal.exceptions import RayTrainError
  3. from ray.util.annotations import PublicAPI
  4. @PublicAPI(stability="alpha")
  5. class TrainingFailedError(RayTrainError):
  6. """Exception raised when training fails from a `trainer.fit()` call.
  7. This is either :class:`ray.train.WorkerGroupError` or :class:`ray.train.ControllerError`.
  8. """
  9. @PublicAPI(stability="alpha")
  10. class WorkerGroupError(TrainingFailedError):
  11. """Exception raised from the worker group during training.
  12. Args:
  13. error_message: A human-readable error message describing the training worker failures.
  14. worker_failures: A mapping from worker rank to the exception that
  15. occurred on that worker during training.
  16. """
  17. def __init__(self, error_message: str, worker_failures: Dict[int, Exception]):
  18. super().__init__("Training failed due to worker errors:\n" + error_message)
  19. self._error_message = error_message
  20. self.worker_failures = worker_failures
  21. def __reduce__(self):
  22. return (self.__class__, (self._error_message, self.worker_failures))
  23. @PublicAPI(stability="alpha")
  24. class ControllerError(TrainingFailedError):
  25. """Exception raised when training fails due to a controller error.
  26. Args:
  27. controller_failure: The exception that occurred on the controller.
  28. """
  29. def __init__(self, controller_failure: Exception):
  30. super().__init__(
  31. "Training failed due to controller error:\n" + str(controller_failure)
  32. )
  33. self.controller_failure = controller_failure
  34. self.with_traceback(controller_failure.__traceback__)
  35. def __reduce__(self):
  36. return (self.__class__, (self.controller_failure,))