| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- from typing import Dict
- from ray.train.v2._internal.exceptions import RayTrainError
- from ray.util.annotations import PublicAPI
- @PublicAPI(stability="alpha")
- class TrainingFailedError(RayTrainError):
- """Exception raised when training fails from a `trainer.fit()` call.
- This is either :class:`ray.train.WorkerGroupError` or :class:`ray.train.ControllerError`.
- """
- @PublicAPI(stability="alpha")
- class WorkerGroupError(TrainingFailedError):
- """Exception raised from the worker group during training.
- Args:
- error_message: A human-readable error message describing the training worker failures.
- worker_failures: A mapping from worker rank to the exception that
- occurred on that worker during training.
- """
- def __init__(self, error_message: str, worker_failures: Dict[int, Exception]):
- super().__init__("Training failed due to worker errors:\n" + error_message)
- self._error_message = error_message
- self.worker_failures = worker_failures
- def __reduce__(self):
- return (self.__class__, (self._error_message, self.worker_failures))
- @PublicAPI(stability="alpha")
- class ControllerError(TrainingFailedError):
- """Exception raised when training fails due to a controller error.
- Args:
- controller_failure: The exception that occurred on the controller.
- """
- def __init__(self, controller_failure: Exception):
- super().__init__(
- "Training failed due to controller error:\n" + str(controller_failure)
- )
- self.controller_failure = controller_failure
- self.with_traceback(controller_failure.__traceback__)
- def __reduce__(self):
- return (self.__class__, (self.controller_failure,))
|