trainer.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import logging
  2. import traceback
  3. from pathlib import Path
  4. from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, Union
  5. from ray.air._internal.util import (
  6. StartTraceback,
  7. StartTracebackWithWorkerRank,
  8. skip_exceptions,
  9. )
  10. from ray.train import Checkpoint, DataConfig
  11. from ray.train._internal.backend_executor import (
  12. BackendExecutor,
  13. InactiveWorkerGroupError,
  14. TrainBackendError,
  15. TrainingWorkerError,
  16. )
  17. from ray.train._internal.session import _TrainingResult, _TrainSession, get_session
  18. from ray.train._internal.utils import ActorWrapper
  19. from ray.train.backend import BackendConfig
  20. from ray.train.base_trainer import ( # noqa: F401
  21. BaseTrainer,
  22. GenDataset,
  23. TrainingFailedError,
  24. )
  25. from ray.util.annotations import DeveloperAPI
  26. if TYPE_CHECKING:
  27. from ray.data import Dataset
  28. T = TypeVar("T")
  29. S = TypeVar("S")
  30. logger = logging.getLogger(__name__)
  31. @DeveloperAPI
  32. class TrainingIterator:
  33. """An iterator over Train results. Returned by ``trainer.run_iterator``."""
  34. def __init__(
  35. self,
  36. backend_executor: Union[BackendExecutor, ActorWrapper],
  37. backend_config: BackendConfig,
  38. train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
  39. datasets: Dict[str, "Dataset"],
  40. metadata: Dict[str, Any],
  41. data_config: DataConfig,
  42. checkpoint: Optional[Union[Dict, str, Path, Checkpoint]],
  43. ):
  44. self._backend_executor = backend_executor
  45. self._backend = backend_config.backend_cls()
  46. self._train_func = train_func
  47. self._datasets = datasets
  48. self._metadata = metadata
  49. self._data_config = data_config
  50. self._start_training(
  51. train_func=train_func,
  52. datasets=self._datasets,
  53. metadata=self._metadata,
  54. data_config=self._data_config,
  55. checkpoint=checkpoint,
  56. )
  57. self._finished_training = False
  58. def __iter__(self):
  59. return self
  60. def _start_training(
  61. self,
  62. train_func,
  63. datasets,
  64. metadata,
  65. data_config,
  66. checkpoint: Optional[Checkpoint] = None,
  67. ):
  68. tune_session: _TrainSession = get_session()
  69. assert tune_session, "`_start_training` should only be called from within Tune"
  70. storage = tune_session.storage
  71. self._run_with_error_handling(
  72. lambda: self._backend_executor.start_training(
  73. train_func=train_func,
  74. datasets=datasets,
  75. metadata=metadata,
  76. data_config=data_config,
  77. storage=storage,
  78. checkpoint=checkpoint,
  79. )
  80. )
  81. def _run_with_error_handling(self, func: Callable):
  82. try:
  83. return func()
  84. except TrainingWorkerError:
  85. # TODO(ml-team): This Train fault-tolerance code doesn't get used
  86. # since max_retries=0
  87. # Workers have already been restarted.
  88. logger.info(
  89. "Workers have been successfully restarted. Resuming "
  90. "training from latest checkpoint."
  91. )
  92. self._start_training(
  93. self._train_func,
  94. self._datasets,
  95. self._metadata,
  96. self._data_config,
  97. )
  98. return self._run_with_error_handling(func)
  99. except InactiveWorkerGroupError:
  100. raise RuntimeError(
  101. "This Trainer is not active. It is either shutdown "
  102. "already or never started in the first place. "
  103. "Either create a new Trainer or start this one."
  104. ) from None
  105. except TrainBackendError:
  106. raise RuntimeError(
  107. "Training failed. You should not be seeing "
  108. "this error and this is a bug. Please create "
  109. "a new issue at "
  110. "https://github.com/ray-project/ray."
  111. ) from None
  112. def __next__(self):
  113. if self.is_finished():
  114. self._backend_executor.report_final_run_status(errored=False)
  115. raise StopIteration
  116. try:
  117. next_results = self._run_with_error_handling(self._fetch_next_result)
  118. if next_results is None:
  119. self._backend_executor.report_final_run_status(errored=False)
  120. self._run_with_error_handling(self._finish_training)
  121. self._finished_training = True
  122. raise StopIteration
  123. else:
  124. return next_results
  125. except StartTraceback as e:
  126. # If this is a StartTraceback, then this is a user error.
  127. # We raise it directly
  128. if isinstance(e, StartTracebackWithWorkerRank):
  129. failed_rank = e.worker_rank
  130. else:
  131. failed_rank = None
  132. # Extract the stack trace from the exception
  133. e = skip_exceptions(e)
  134. stack_trace = "".join(
  135. traceback.format_exception(type(e), e, e.__traceback__)
  136. )
  137. self._backend_executor.report_final_run_status(
  138. errored=True, stack_trace=stack_trace, failed_rank=failed_rank
  139. )
  140. try:
  141. # Exception raised in at least one training worker. Immediately raise
  142. # this error to the user and do not attempt to terminate gracefully.
  143. self._backend_executor.shutdown(graceful_termination=False)
  144. self._finished_training = True
  145. except Exception:
  146. pass
  147. raise
  148. def _fetch_next_result(self) -> Optional[List[Dict]]:
  149. """Fetch next results produced by ``session.report()`` from each worker.
  150. Assumes ``start_training`` has already been called.
  151. Returns:
  152. A list of dictionaries of values passed to ``session.report()`` from
  153. each worker. Each item corresponds to an intermediate result
  154. a single worker. If there are no more items to fetch,
  155. returns None.
  156. """
  157. results = self._backend_executor.get_next_results()
  158. if results is None:
  159. return None
  160. assert all(isinstance(result, _TrainingResult) for result in results)
  161. return results
  162. def _finish_training(self):
  163. """Finish training and return final results. Propagate any exceptions.
  164. Blocks until training is finished on all workers.
  165. Assumes `start_training` has already been called.
  166. Returns:
  167. A list of return values from calling ``train_func`` on each worker.
  168. Each item corresponds to the return value from a single worker.
  169. """
  170. return self._backend_executor.finish_training()
  171. def is_finished(self) -> bool:
  172. return self._finished_training