exceptions.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import os
  2. from typing import List, Optional
  3. from ray.train.v2._internal.constants import (
  4. COLLECTIVE_TIMEOUT_S_ENV_VAR,
  5. DEFAULT_WORKER_GROUP_START_TIMEOUT_S,
  6. DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S,
  7. WORKER_GROUP_START_TIMEOUT_S_ENV_VAR,
  8. WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR,
  9. )
  10. # TODO: Distinguish between user and system exceptions.
  11. class RayTrainError(Exception):
  12. """Base class for all Ray Train exceptions."""
  13. class WorkerHealthCheckTimeoutError(RayTrainError):
  14. """Exception raised when a worker health check hangs for long enough."""
  15. def __init__(self, message):
  16. timeout = os.getenv(
  17. WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR, DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S
  18. )
  19. message += (
  20. f"\nSet the {WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR} "
  21. "environment variable to increase the timeout "
  22. f"(current value: {timeout} seconds)."
  23. )
  24. super().__init__(message)
  25. class WorkerHealthCheckFailedError(RayTrainError):
  26. """Exception raised when a worker health check fails."""
  27. def __init__(self, message, failure: Exception):
  28. super().__init__(message)
  29. self._message = message
  30. self.health_check_failure = failure
  31. def __reduce__(self):
  32. return (self.__class__, (self._message, self.health_check_failure))
  33. def __str__(self):
  34. return self._message + "\n" + str(self.health_check_failure)
  35. class WorkerGroupStartupTimeoutError(RayTrainError):
  36. """Exception raised when the worker group startup times out.
  37. Example scenario: 4 GPUs are detected in the cluster, but when the worker
  38. are actually scheduled, one of the nodes goes down and only 3 GPUs are
  39. available. One of the worker tasks may be stuck pending, until a timeout is reached.
  40. """
  41. def __init__(self, num_workers: int):
  42. timeout = float(
  43. os.environ.get(
  44. WORKER_GROUP_START_TIMEOUT_S_ENV_VAR,
  45. DEFAULT_WORKER_GROUP_START_TIMEOUT_S,
  46. )
  47. )
  48. self.num_workers = num_workers
  49. super().__init__(
  50. f"The worker group startup timed out after {timeout} seconds waiting "
  51. f"for {num_workers} workers. "
  52. "Potential causes include: "
  53. "(1) temporary insufficient cluster resources while waiting for "
  54. "autoscaling (ignore this warning in this case), "
  55. "(2) infeasible resource request where the provided `ScalingConfig` "
  56. "cannot be satisfied), "
  57. "and (3) transient network issues. "
  58. f"Set the {WORKER_GROUP_START_TIMEOUT_S_ENV_VAR} "
  59. "environment variable to increase the timeout."
  60. )
  61. def __reduce__(self):
  62. return (self.__class__, (self.num_workers,))
  63. class WorkerGroupStartupFailedError(RayTrainError):
  64. """Exception raised when the worker group fails to start.
  65. Example scenario: A worker is scheduled onto a node that dies while
  66. the worker actor is initializing.
  67. """
  68. class InsufficientClusterResourcesError(RayTrainError):
  69. """Exception raised when the cluster has insufficient resources.
  70. Example scenario: A worker that requires 1 GPU is scheduled onto a cluster
  71. that only has CPU worker node types.
  72. """
  73. class CheckpointManagerInitializationError(RayTrainError):
  74. """Exception raised when the checkpoint manager fails to initialize from a snapshot.
  75. Example scenarios:
  76. 1. The checkpoint manager snapshot version is old and
  77. incompatible with the current version of Ray Train.
  78. 2. The checkpoint manager snapshot JSON file is corrupted.
  79. 3. The checkpoint manager snapshot references checkpoints that cannot be found
  80. in the run storage path.
  81. """
  82. class CollectiveTimeoutError(RayTrainError):
  83. """Exception raised when an internal Ray Train collective operation of
  84. the worker group times out.
  85. """
  86. class BroadcastCollectiveTimeoutError(CollectiveTimeoutError):
  87. """Exception raised when the broadcast operation times out.
  88. There are two main timeout examples:
  89. 1. If not all workers call `ray.train.report`, the entire worker group will
  90. hang until the timeout before raising. This prevents indefinite worker
  91. group hangs.
  92. 2. If a worker is slow in the training loop and fails to reach the broadcast
  93. time, the collective will time out.
  94. """
  95. def __init__(
  96. self, time_elapsed: Optional[float], missing_ranks: List[int], timeout_s: float
  97. ):
  98. self._time_elapsed = time_elapsed
  99. self._missing_ranks = missing_ranks
  100. self._timeout_s = timeout_s
  101. message = (
  102. f"The collective operation timed out after {time_elapsed:.2f} seconds. "
  103. f"The following ranks have not joined the collective operation: {missing_ranks}\n"
  104. f"You can set the timeout with the {COLLECTIVE_TIMEOUT_S_ENV_VAR} "
  105. f"environment variable (current value: {timeout_s:.2f} seconds). "
  106. "Disable the timeout by setting the environment variable to -1."
  107. )
  108. super().__init__(message)
  109. def __reduce__(self):
  110. return (
  111. self.__class__,
  112. (self._time_elapsed, self._missing_ranks, self._timeout_s),
  113. )
  114. class UserExceptionWithTraceback(RayTrainError):
  115. """This class wraps a user code exception raised on the worker
  116. with its original traceback string, for logging and debugging purposes.
  117. This is needed because the original exception traceback is not serialized
  118. with the exception when it is *returned* back to the main process.
  119. """
  120. def __init__(self, exc: BaseException, traceback_str: str):
  121. self._base_exc = exc
  122. self._traceback_str = traceback_str
  123. def __reduce__(self):
  124. return (self.__class__, (self._base_exc, self._traceback_str))
  125. def __str__(self):
  126. return self._traceback_str