| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- import logging
- import traceback
- from pathlib import Path
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, Union
- from ray.air._internal.util import (
- StartTraceback,
- StartTracebackWithWorkerRank,
- skip_exceptions,
- )
- from ray.train import Checkpoint, DataConfig
- from ray.train._internal.backend_executor import (
- BackendExecutor,
- InactiveWorkerGroupError,
- TrainBackendError,
- TrainingWorkerError,
- )
- from ray.train._internal.session import _TrainingResult, _TrainSession, get_session
- from ray.train._internal.utils import ActorWrapper
- from ray.train.backend import BackendConfig
- from ray.train.base_trainer import ( # noqa: F401
- BaseTrainer,
- GenDataset,
- TrainingFailedError,
- )
- from ray.util.annotations import DeveloperAPI
- if TYPE_CHECKING:
- from ray.data import Dataset
- T = TypeVar("T")
- S = TypeVar("S")
- logger = logging.getLogger(__name__)
- @DeveloperAPI
- class TrainingIterator:
- """An iterator over Train results. Returned by ``trainer.run_iterator``."""
- def __init__(
- self,
- backend_executor: Union[BackendExecutor, ActorWrapper],
- backend_config: BackendConfig,
- train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
- datasets: Dict[str, "Dataset"],
- metadata: Dict[str, Any],
- data_config: DataConfig,
- checkpoint: Optional[Union[Dict, str, Path, Checkpoint]],
- ):
- self._backend_executor = backend_executor
- self._backend = backend_config.backend_cls()
- self._train_func = train_func
- self._datasets = datasets
- self._metadata = metadata
- self._data_config = data_config
- self._start_training(
- train_func=train_func,
- datasets=self._datasets,
- metadata=self._metadata,
- data_config=self._data_config,
- checkpoint=checkpoint,
- )
- self._finished_training = False
- def __iter__(self):
- return self
- def _start_training(
- self,
- train_func,
- datasets,
- metadata,
- data_config,
- checkpoint: Optional[Checkpoint] = None,
- ):
- tune_session: _TrainSession = get_session()
- assert tune_session, "`_start_training` should only be called from within Tune"
- storage = tune_session.storage
- self._run_with_error_handling(
- lambda: self._backend_executor.start_training(
- train_func=train_func,
- datasets=datasets,
- metadata=metadata,
- data_config=data_config,
- storage=storage,
- checkpoint=checkpoint,
- )
- )
- def _run_with_error_handling(self, func: Callable):
- try:
- return func()
- except TrainingWorkerError:
- # TODO(ml-team): This Train fault-tolerance code doesn't get used
- # since max_retries=0
- # Workers have already been restarted.
- logger.info(
- "Workers have been successfully restarted. Resuming "
- "training from latest checkpoint."
- )
- self._start_training(
- self._train_func,
- self._datasets,
- self._metadata,
- self._data_config,
- )
- return self._run_with_error_handling(func)
- except InactiveWorkerGroupError:
- raise RuntimeError(
- "This Trainer is not active. It is either shutdown "
- "already or never started in the first place. "
- "Either create a new Trainer or start this one."
- ) from None
- except TrainBackendError:
- raise RuntimeError(
- "Training failed. You should not be seeing "
- "this error and this is a bug. Please create "
- "a new issue at "
- "https://github.com/ray-project/ray."
- ) from None
- def __next__(self):
- if self.is_finished():
- self._backend_executor.report_final_run_status(errored=False)
- raise StopIteration
- try:
- next_results = self._run_with_error_handling(self._fetch_next_result)
- if next_results is None:
- self._backend_executor.report_final_run_status(errored=False)
- self._run_with_error_handling(self._finish_training)
- self._finished_training = True
- raise StopIteration
- else:
- return next_results
- except StartTraceback as e:
- # If this is a StartTraceback, then this is a user error.
- # We raise it directly
- if isinstance(e, StartTracebackWithWorkerRank):
- failed_rank = e.worker_rank
- else:
- failed_rank = None
- # Extract the stack trace from the exception
- e = skip_exceptions(e)
- stack_trace = "".join(
- traceback.format_exception(type(e), e, e.__traceback__)
- )
- self._backend_executor.report_final_run_status(
- errored=True, stack_trace=stack_trace, failed_rank=failed_rank
- )
- try:
- # Exception raised in at least one training worker. Immediately raise
- # this error to the user and do not attempt to terminate gracefully.
- self._backend_executor.shutdown(graceful_termination=False)
- self._finished_training = True
- except Exception:
- pass
- raise
- def _fetch_next_result(self) -> Optional[List[Dict]]:
- """Fetch next results produced by ``session.report()`` from each worker.
- Assumes ``start_training`` has already been called.
- Returns:
- A list of dictionaries of values passed to ``session.report()`` from
- each worker. Each item corresponds to an intermediate result
- a single worker. If there are no more items to fetch,
- returns None.
- """
- results = self._backend_executor.get_next_results()
- if results is None:
- return None
- assert all(isinstance(result, _TrainingResult) for result in results)
- return results
- def _finish_training(self):
- """Finish training and return final results. Propagate any exceptions.
- Blocks until training is finished on all workers.
- Assumes `start_training` has already been called.
- Returns:
- A list of return values from calling ``train_func`` on each worker.
- Each item corresponds to the return value from a single worker.
- """
- return self._backend_executor.finish_training()
- def is_finished(self) -> bool:
- return self._finished_training
|