| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203 |
- import functools
- import logging
- import os
- import platform
- import queue
- import sys
- import threading
- import time
- import warnings
- from dataclasses import dataclass
- from datetime import datetime
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type
- import ray
- from ray.air._internal.util import RunnerThread, StartTraceback
- from ray.air.constants import (
- _ERROR_FETCH_TIMEOUT,
- _RESULT_FETCH_TIMEOUT,
- SESSION_MISUSE_LOG_ONCE_KEY,
- TIME_THIS_ITER_S,
- TIMESTAMP,
- )
- from ray.train import Checkpoint
- from ray.train._internal.accelerator import Accelerator
- from ray.train._internal.storage import StorageContext
- from ray.train.constants import (
- CHECKPOINT_DIR_NAME,
- DETAILED_AUTOFILLED_KEYS,
- RAY_CHDIR_TO_TRIAL_DIR,
- TIME_TOTAL_S,
- WORKER_HOSTNAME,
- WORKER_NODE_IP,
- WORKER_PID,
- _v2_migration_warnings_enabled,
- )
- from ray.train.error import SessionMisuseError
- from ray.train.utils import _log_deprecation_warning
- from ray.util import queue as ray_queue
- from ray.util.annotations import DeveloperAPI, PublicAPI
- from ray.util.debug import log_once
- from ray.util.placement_group import _valid_resource_shape
- from ray.util.scheduling_strategies import (
- PlacementGroupSchedulingStrategy,
- SchedulingStrategyT,
- )
- if TYPE_CHECKING:
- from ray.data import DataIterator, Dataset
- from ray.tune.execution.placement_groups import PlacementGroupFactory
- logger = logging.getLogger(__name__)
- @dataclass
- class TrialInfo:
- """The trial information to propagate to TrainSession."""
- name: str
- id: str
- resources: Dict[str, float]
- logdir: str
- driver_ip: str
- driver_node_id: str
- experiment_name: Optional[str] = None
- run_id: Optional[str] = None
- class _FutureTrainingResult:
- """A future that will be resolved to a `_TrainingResult`.
- This is needed for specific schedulers such as PBT that schedule saves.
- This wrapper should be removed after refactoring PBT to not schedule saves anymore.
- """
- def __init__(self, future: ray.ObjectRef):
- self.future = future
- def resolve(self, block: bool = True) -> Optional["_TrainingResult"]:
- """Resolve into ``_TrainingResult``.
- This will return None for function trainables if no checkpoint has been
- saved before.
- """
- if block:
- timeout = None
- else:
- timeout = 1e-9
- try:
- return ray.get(self.future, timeout=timeout)
- except TimeoutError:
- # Not ready, yet
- pass
- except Exception as exc:
- logger.error(f"Error resolving result: {exc}")
- class _TrainingResult:
- """A (checkpoint, metrics) result reported by the user."""
- def __init__(self, checkpoint: Optional[Checkpoint], metrics: Dict[str, Any]):
- self.checkpoint = checkpoint
- self.metrics = metrics
- def __repr__(self) -> str:
- return f"TrainingResult(checkpoint={self.checkpoint}, metrics={self.metrics})"
- # TODO(xwjiang): This needs a better name.
- @DeveloperAPI
- class _TrainSession:
- """Holds information for training on each worker."""
- def __init__(
- self,
- training_func: Callable,
- world_rank: Optional[int],
- local_rank: Optional[int],
- node_rank: Optional[int],
- local_world_size: Optional[int],
- world_size: Optional[int],
- trial_info: Optional[TrialInfo] = None,
- dataset_shard: Optional[Dict[str, "Dataset"]] = None,
- metadata: Dict[str, Any] = None,
- checkpoint: Optional[Checkpoint] = None,
- detailed_autofilled_metrics: bool = False,
- storage: Optional[StorageContext] = None,
- synchronous_result_reporting: bool = False,
- ):
- # `synchronous_result_reporting` refers to whether or not the
- # training function is immediately unblocked to continue running
- # after the main thread receives its result.
- # Ex 1: For 2 Ray Train workers with synchronous_result_reporting=True,
- # the worker that produces a result first will immediately will continue
- # onto the next iteration.
- # Ex 2: For a Tune function Trainable with `synchronous_result_reporting=False`,
- # training will only continue with an explicit call to `session.get_next`.
- # Synchronous reporting in example 2 is needed for Tune schedulers to
- # be able to stop the execution of the training function at will,
- # for advanced pausing schedulers (PBT, BOHB) and actor reuse.
- self.synchronous_result_reporting = synchronous_result_reporting
- # Ray Train worker properties
- # Note: These are set to None for Tune function Trainables.
- self.dataset_shard = dataset_shard
- self.metadata = metadata
- self.world_rank = world_rank
- self.local_rank = local_rank
- self.node_rank = node_rank
- self.local_world_size = local_world_size
- self.world_size = world_size
- assert storage
- logger.debug(f"StorageContext on SESSION (rank={world_rank}):\n{storage}")
- # NOTE: `reset` will initialize many properties needed to start running the
- # training_func as a thread.
- self.reset(
- training_func=training_func,
- trial_info=trial_info,
- storage=storage,
- loaded_checkpoint=checkpoint,
- )
- # Autofilled metrics attributes.
- self.detailed_autofilled_metrics = detailed_autofilled_metrics
- self.last_report_time = time.time()
- self.iteration = 0
- self.time_total = 0.0
- self.local_ip = self.get_current_ip()
- self.accelerator = None
- self._state = {}
- def get_state(self, key: str) -> Any:
- return self._state.get(key)
- def set_state(self, key: str, value: Any):
- self._state[key] = value
- def get_current_ip(self):
- self.local_ip = ray.util.get_node_ip_address()
- return self.local_ip
- def start(self):
- """Starts the training thread."""
- self.training_started = True
- self.training_thread.start()
- def reset(
- self,
- training_func: Callable,
- trial_info: TrialInfo,
- storage: StorageContext,
- loaded_checkpoint=None,
- ):
- # This lock is used to control the execution of the training thread.
- self.continue_lock = threading.Semaphore(0)
- # This event is used to signal the training thread to stop.
- self.stop_event = threading.Event()
- # Queue for sending results across threads.
- self.result_queue = queue.Queue(1)
- # Queue for sending results from training actor to main thread.
- self._inter_actor_queue: Optional[ray_queue.Queue[Dict]] = None
- # Queue for raising exceptions from runner thread to main thread.
- # The error queue has a max size of one to prevent stacking error and force
- # error reporting to block until finished.
- self.error_queue = queue.Queue(1)
- # The Thread object that is running the training function.
- self.training_thread = RunnerThread(
- target=training_func, daemon=True, error_queue=self.error_queue
- )
- # Possibly override with new state
- self.trial_info = trial_info
- self.storage = storage
- self.loaded_checkpoint = loaded_checkpoint
- # Reset state
- self._state = {}
- self.ignore_report = False
- self.training_started = False
- self._first_report = True
- # Change the working directory to a special trial folder.
- # This is to ensure that all Ray Train workers have a common working directory.
- os.makedirs(storage.trial_working_directory, exist_ok=True)
- if bool(int(os.environ.get(RAY_CHDIR_TO_TRIAL_DIR, "1"))):
- logger.debug(
- f"Changing the working directory to: {storage.trial_working_directory}"
- )
- os.chdir(storage.trial_working_directory)
- def pause_reporting(self):
- """Ignore all future ``session.report()`` calls."""
- self.ignore_report = True
- def finish(self, timeout: Optional[float] = None) -> Optional[Any]:
- """Finishes the training thread.
- Raises any Exception from training.
- """
- # Set the stop event for the training thread to gracefully exit.
- self.stop_event.set()
- # Release the lock so that training thread can process this event.
- self.continue_lock.release()
- # Force a final (blocking) sync of artifacts in the trial path to storage.
- self.storage.persist_artifacts(force=True)
- # Wait for training to finish.
- # This will raise any errors that occur during training, including SystemError
- # This returns the result of the training function.
- output = None
- if self.training_started:
- output = self.training_thread.join(timeout=timeout)
- return output
- def get_next(self) -> Optional[_TrainingResult]:
- """Gets the next ``_TrainingResult`` from the result queue.
- If the result queue is empty, then this function returns ``None``.
- """
- if not self.training_started:
- raise RuntimeError("Please call start before calling get_next.")
- if self.synchronous_result_reporting:
- # There's no need to release the lock on the first report
- # since `start` already started the training thread.
- if not self._first_report:
- # Release the lock to trigger training to continue,
- # until the next call to report.
- self.continue_lock.release()
- self._first_report = False
- result = None
- # While training is still ongoing, attempt to get the result.
- while result is None and self.training_thread.is_alive():
- result = self._get_result_from_queues(block=True)
- # If no result was found, then the runner must no longer be alive.
- if result is None:
- # Try one last time to fetch results in case results were
- # reported in between the time of the last check and the
- # termination of the thread runner.
- result = self._get_result_from_queues(block=False)
- # check if error occurred inside the thread runner.
- if result is None:
- # only raise an error from the runner if all results are consumed
- self._report_thread_runner_error(block=True)
- else:
- if not self.error_queue.empty():
- logger.debug(
- (
- "Runner error waiting to be raised in main thread. "
- "Logging all available results first."
- )
- )
- if not self.synchronous_result_reporting:
- # At this point, the training thread has reached
- # the `train.report` and is blocked there.
- # If performing asynchronous result reporting,
- # release the lock to allow each worker to keep training
- # immediately after the coordinator fetches their result.
- self.continue_lock.release()
- # Return None if there are no more results to fetch.
- return result
- def _get_or_create_inter_actor_queue(self):
- """Get or create the inter-actor queue."""
- if self._inter_actor_queue is None:
- self._inter_actor_queue = ray_queue.Queue(1, actor_options={"num_cpus": 0})
- return self._inter_actor_queue
- def _get_result_from_queues(self, block: bool) -> Optional[_TrainingResult]:
- """Get result from result queue. Pass result from training actor result queue if needed."""
- result = None
- if self._inter_actor_queue is not None:
- try:
- inter_actor_item = self._inter_actor_queue.get(
- block=block, timeout=_RESULT_FETCH_TIMEOUT
- )
- if inter_actor_item:
- # Must release continue_lock to allow report to work.
- self.continue_lock.release()
- self.report(inter_actor_item)
- except ray_queue.Empty:
- pass
- try:
- result = self.result_queue.get(block=block, timeout=_RESULT_FETCH_TIMEOUT)
- except queue.Empty:
- pass
- return result
- def _auto_fill_metrics(self, result: dict) -> dict:
- """Add autofilled metrics and update attributes."""
- current_time = time.time()
- current_datetime = datetime.now()
- if TIME_THIS_ITER_S in result:
- time_this_iter = result[TIME_THIS_ITER_S]
- else:
- time_this_iter = current_time - self.last_report_time
- self.iteration += 1
- self.time_total += time_this_iter
- self.last_report_time = current_time
- auto_filled_metrics = {
- TIMESTAMP: int(time.mktime(current_datetime.timetuple())),
- TIME_TOTAL_S: self.time_total,
- WORKER_PID: os.getpid(),
- WORKER_HOSTNAME: platform.node(),
- WORKER_NODE_IP: self.local_ip,
- }
- if not self.detailed_autofilled_metrics:
- auto_filled_metrics = {
- k: v
- for k, v in auto_filled_metrics.items()
- if k not in DETAILED_AUTOFILLED_KEYS
- }
- result = result.copy()
- result.update(auto_filled_metrics)
- return result
- def _auto_fill_checkpoint_metrics(self, result: dict) -> dict:
- """Add autofilled metrics and update attributes."""
- current_datetime = datetime.now()
- auto_filled_metrics = {
- TIMESTAMP: int(time.mktime(current_datetime.timetuple()))
- }
- result = result.copy()
- result.update(auto_filled_metrics)
- return result
- def _report_thread_runner_error(self, block=False):
- try:
- e = self.error_queue.get(block=block, timeout=_ERROR_FETCH_TIMEOUT)
- raise StartTraceback from e
- except queue.Empty:
- pass
- def _report_training_result(self, training_result: _TrainingResult) -> None:
- """Place a training result on the result queue for the main thread to process,
- then block until the main thread signals that training should continue.
- NOTE: This is used internally to report results from Train to Tune
- without persisting checkpoints to storage 2 times.
- `report` is the public API that directly persists to storage, which
- should only be called by user code.
- """
- if training_result.checkpoint:
- # NOTE: This populates `train.get_checkpoint`
- self.loaded_checkpoint = training_result.checkpoint
- # Add result to a thread-safe queue.
- self.result_queue.put(training_result, block=True)
- # Acquire lock to stop the training thread until main thread
- # triggers resume.
- self.continue_lock.acquire()
- # If the trial should be terminated, exit gracefully.
- # NOTE: This is only really useful if `synchronous_result_reporting=True`.
- # Otherwise, the lock is immediately released on reporting, and this
- # check is skipped before the main thread decides to set the stop event.
- if self.stop_event.is_set():
- self.stop_event.clear()
- sys.exit(0)
- def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
- # Special case: early fail for Torch tensors
- if "torch" in sys.modules:
- from ray.air._internal.torch_utils import contains_tensor
- if contains_tensor(metrics):
- raise ValueError(
- "Passing objects containg Torch tensors as metrics "
- "is not supported as it will throw an exception on "
- "deserialization. You can either convert the tensors "
- "to Python objects or report a `train.Checkpoint` "
- "with `ray.train.report` to store your Torch objects."
- )
- if self.ignore_report:
- return
- metrics = self._auto_fill_metrics(metrics)
- persisted_checkpoint = None
- if checkpoint:
- self.storage._update_checkpoint_index(metrics)
- # Persist the reported checkpoint files to storage.
- persisted_checkpoint = self.storage.persist_current_checkpoint(checkpoint)
- metrics[CHECKPOINT_DIR_NAME] = self.storage.checkpoint_dir_name
- else:
- metrics[CHECKPOINT_DIR_NAME] = None
- # Persist trial artifacts to storage.
- force_artifact_sync = (
- persisted_checkpoint
- and self.storage.sync_config.sync_artifacts_on_checkpoint
- )
- self.storage.persist_artifacts(force=force_artifact_sync)
- # Set additional user metadata from the Trainer.
- if persisted_checkpoint and self.metadata:
- user_metadata = persisted_checkpoint.get_metadata()
- for k, v in self.metadata.items():
- # Update keys not already set by the user. This gives user-set keys
- # precedence over keys set at the Trainer level.
- if k not in user_metadata:
- user_metadata[k] = v
- persisted_checkpoint.set_metadata(user_metadata)
- result = _TrainingResult(checkpoint=persisted_checkpoint, metrics=metrics)
- self._report_training_result(result)
- @property
- def experiment_name(self) -> str:
- return self.trial_info.experiment_name
- @property
- def trial_name(self) -> str:
- return self.trial_info.name
- @property
- def trial_id(self) -> str:
- return self.trial_info.id
- @property
- def run_id(self) -> str:
- return self.trial_info.run_id
- @property
- def trial_resources(self) -> "PlacementGroupFactory":
- return self.trial_info.resources
- @property
- def trial_dir(self) -> str:
- return self.trial_info.logdir
- def get_dataset_shard(
- self,
- dataset_name: Optional[str] = None,
- ) -> Optional["DataIterator"]:
- shard = self.dataset_shard
- if shard is None:
- warnings.warn(
- "No dataset passed in. Returning None. Make sure to "
- "pass in a Dataset to Trainer.run to use this "
- "function."
- )
- elif isinstance(shard, dict):
- if not dataset_name:
- raise RuntimeError(
- "Multiple datasets were passed into ``Trainer``, "
- "but no ``dataset_name`` is passed into "
- "``get_dataset_shard``. Please specify which "
- "dataset shard to retrieve."
- )
- return shard.get(dataset_name)
- return shard
- # Cache of resource dicts that have been checked by the launch hook already.
- _checked_resources: Set[frozenset] = set()
- # Global _TrainSession object initialized by Ray Tune function trainables
- # and Ray Train V1 workers.
- _session: Optional[_TrainSession] = None
- def _tune_task_and_actor_launch_hook(
- fn, resources: Dict[str, float], strategy: Optional[SchedulingStrategyT]
- ):
- """Launch hook to catch nested tasks that can't fit in the placement group.
- This gives users a nice warning in case they launch a nested task in a Tune trial
- without reserving resources in the trial placement group to fit it.
- """
- # Already checked, skip for performance reasons.
- key = frozenset({(k, v) for k, v in resources.items() if v > 0})
- if not key or key in _checked_resources:
- return
- # No need to check if placement group is None.
- if (
- not isinstance(strategy, PlacementGroupSchedulingStrategy)
- or strategy.placement_group is None
- ):
- return
- # Check if the resource request is targeting the current placement group.
- cur_pg = ray.util.get_current_placement_group()
- if not cur_pg or strategy.placement_group.id != cur_pg.id:
- return
- _checked_resources.add(key)
- # Check if the request can be fulfilled by the current placement group.
- pgf = get_trial_resources()
- if pgf.head_bundle_is_empty:
- available_bundles = cur_pg.bundle_specs[0:]
- else:
- available_bundles = cur_pg.bundle_specs[1:]
- # Check if the request can be fulfilled by the current placement group.
- if _valid_resource_shape(resources, available_bundles):
- return
- if fn.class_name:
- submitted = "actor"
- name = fn.module_name + "." + fn.class_name + "." + fn.function_name
- else:
- submitted = "task"
- name = fn.module_name + "." + fn.function_name
- # Normalize the resource spec so it looks the same as the placement group bundle.
- main_resources = cur_pg.bundle_specs[0]
- resources = {k: float(v) for k, v in resources.items() if v > 0}
- raise RuntimeError(
- f"No trial resources are available for launching the {submitted} `{name}`. "
- "To resolve this, specify the Tune option:\n\n"
- "> resources_per_trial=tune.PlacementGroupFactory(\n"
- f"> [{main_resources}] + [{resources}] * N\n"
- "> )\n\n"
- f"Where `N` is the number of slots to reserve for trial {submitted}s. "
- "If you are using a Ray training library, there might be a utility function "
- "to set this automatically for you. For more information, refer to "
- "https://docs.ray.io/en/latest/tune/tutorials/tune-resources.html"
- )
- def init_session(*args, **kwargs) -> None:
- global _session
- if _session:
- raise ValueError(
- "A Train session is already in use. Do not call "
- "`init_session()` manually."
- )
- # Setup hooks for generating placement group resource deadlock warnings.
- from ray import actor, remote_function
- if "TUNE_DISABLE_RESOURCE_CHECKS" not in os.environ:
- actor._actor_launch_hook = _tune_task_and_actor_launch_hook
- remote_function._task_launch_hook = _tune_task_and_actor_launch_hook
- _session = _TrainSession(*args, **kwargs)
- def get_session() -> Optional[_TrainSession]:
- return _session
- def shutdown_session():
- """Shuts down the initialized session."""
- global _session
- _session = None
- def _raise_accelerator_session_misuse():
- """Raises a SessionMisuseError because a utility function was used improperly."""
- raise SessionMisuseError(
- "prepare/accelerate utility functions should be called inside a training "
- "function executed by `Trainer.run`"
- )
- def get_accelerator(default_accelerator_cls: Type[Accelerator]) -> Accelerator:
- """The accelerator for this training session.
- If an accelerator has not been set, then this method will construct an
- accelerator using the provided accelerator class.
- Raises:
- SessionMisuseError: if the session is uninitialized.
- """
- session = get_session()
- if session is None:
- _raise_accelerator_session_misuse()
- if session.accelerator is None:
- session.accelerator = default_accelerator_cls()
- return session.accelerator
- def set_accelerator(accelerator: Accelerator) -> None:
- """Sets the accelerator for this training session.
- Args:
- accelerator: The accelerator to use for training.
- Raises:
- SessionMisuseError: if the session is unitialized.
- RuntimeError: if the accelerator has already been set.
- """
- session = get_session()
- if session is None:
- _raise_accelerator_session_misuse()
- if session.accelerator is not None:
- raise RuntimeError("Cannot change accelerator once set.")
- session.accelerator = accelerator
- def _warn_session_misuse(default_value: Any = None):
- """Warns if fn is being used outside of session and returns ``default_value``."""
- def inner(fn: Callable):
- fn_name = fn.__name__
- @functools.wraps(fn)
- def wrapper(*args, **kwargs):
- session = get_session()
- if not session:
- if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"):
- warnings.warn(
- f"`{fn_name}` is meant to only be "
- "called inside a function that is executed by a Tuner"
- f" or Trainer. Returning `{default_value}`."
- )
- return default_value
- return fn(*args, **kwargs)
- return wrapper
- return inner
- @PublicAPI(stability="stable")
- @_warn_session_misuse()
- def report(
- metrics: Dict,
- *,
- checkpoint: Optional[Checkpoint] = None,
- checkpoint_dir_name: Optional[str] = None,
- ) -> None:
- """Report metrics and optionally save a checkpoint.
- If a checkpoint is provided, it will be
- :ref:`persisted to storage <persistent-storage-guide>`.
- If this is called in multiple distributed training workers:
- - Only the metrics reported by the rank 0 worker will be tracked by Ray Train.
- See :ref:`the metrics logging guide <train-monitoring-and-logging>`.
- - A checkpoint will be registered as long as one or more workers reports
- checkpoint that is not None.
- See the :ref:`checkpointing guide <train-dl-saving-checkpoints>`.
- - Checkpoints from multiple workers will be merged into one directory
- in persistent storage.
- See :ref:`the distributed checkpointing guide <train-distributed-checkpointing>`.
- .. note::
- Each invocation of this method will automatically increment the underlying
- ``training_iteration`` number. The physical meaning of this "iteration" is
- defined by user depending on how often they call ``report``.
- It does not necessarily map to one epoch.
- .. warning::
- All workers must call `ray.train.report` the same number of times
- so that Ray Train can properly synchronize the training state across
- workers. Otherwise, your training will hang.
- .. warning::
- This method does NOT act as a barrier for distributed training workers.
- Workers will upload their checkpoint, then continue training immediately.
- If you need to synchronize workers, you can use a framework-native barrier
- such as `torch.distributed.barrier()`.
- Example:
- .. testcode::
- import tempfile
- from ray import train
- from ray.train import Checkpoint
- from ray.train.torch import TorchTrainer
- def train_func(config):
- start_epoch = 0
- checkpoint = train.get_checkpoint()
- if checkpoint:
- with checkpoint.as_directory() as checkpoint_dir:
- # Load back training state
- ...
- for epoch in range(start_epoch, config.get("num_epochs", 10)):
- # Do training...
- metrics = {"loss": ...}
- with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
- # Save the checkpoint...
- # torch.save(...)
- checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
- # Example: Only the rank 0 worker uploads the checkpoint.
- if ray.train.get_context().get_world_rank() == 0:
- train.report(metrics, checkpoint=checkpoint)
- else:
- train.report(metrics, checkpoint=None)
- trainer = TorchTrainer(
- train_func, scaling_config=train.ScalingConfig(num_workers=2)
- )
- Args:
- metrics: The metrics you want to report.
- checkpoint: The optional checkpoint you want to report.
- """
- if checkpoint_dir_name is not None:
- logger.warning(
- "`checkpoint_dir_name` is only supported in the new Ray Train "
- "implementation, which can be enabled with `RAY_TRAIN_V2_ENABLED=1`. "
- "This argument will be ignored."
- )
- # If we are running in a Tune function, switch to `ray.tune.report`.
- from ray.tune.trainable.trainable_fn_utils import _in_tune_session
- if _in_tune_session():
- import ray.tune
- if _v2_migration_warnings_enabled():
- _log_deprecation_warning(
- "`ray.train.report` should be switched to "
- "`ray.tune.report` when running in a function "
- "passed to Ray Tune. This will be an error in the future. "
- "See this issue for more context: "
- "https://github.com/ray-project/ray/issues/49454"
- )
- return ray.tune.report(metrics, checkpoint=checkpoint)
- get_session().report(metrics, checkpoint=checkpoint)
- @PublicAPI(stability="stable")
- @_warn_session_misuse()
- def get_checkpoint() -> Optional[Checkpoint]:
- """Access the latest reported checkpoint to resume from if one exists.
- Example:
- .. testcode::
- import tempfile
- from ray import train
- from ray.train import Checkpoint
- from ray.train.torch import TorchTrainer
- def train_func(config):
- start_epoch = 0
- checkpoint = train.get_checkpoint()
- if checkpoint:
- with checkpoint.as_directory() as checkpoint_dir:
- # Load back training state
- ...
- for epoch in range(start_epoch, config.get("num_epochs", 10)):
- # Do training...
- metrics = {"loss": ...}
- with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
- # Save the checkpoint...
- checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
- train.report(metrics, checkpoint=checkpoint)
- trainer = TorchTrainer(
- train_func, scaling_config=train.ScalingConfig(num_workers=2)
- )
- Returns:
- Checkpoint object if the session is currently being resumed.
- Otherwise, return None.
- """
- # If we are running in a Tune function, switch to `ray.tune.get_checkpoint`.
- from ray.tune.trainable.trainable_fn_utils import _in_tune_session
- if _in_tune_session():
- import ray.tune
- if _v2_migration_warnings_enabled():
- _log_deprecation_warning(
- "`ray.train.get_checkpoint` should be switched to "
- "`ray.tune.get_checkpoint` when running in a function "
- "passed to Ray Tune. This will be an error in the future. "
- "See this issue for more context: "
- "https://github.com/ray-project/ray/issues/49454"
- )
- return ray.tune.get_checkpoint()
- return get_session().loaded_checkpoint
- @PublicAPI(stability="beta")
- @_warn_session_misuse()
- def get_metadata() -> Dict[str, Any]:
- """User metadata dict passed to the Trainer constructor."""
- return get_session().metadata
- @PublicAPI(stability="beta")
- @_warn_session_misuse()
- def get_experiment_name() -> str:
- """Experiment name for the corresponding trial."""
- return get_session().experiment_name
- @PublicAPI(stability="beta")
- @_warn_session_misuse()
- def get_trial_name() -> str:
- """Trial name for the corresponding trial."""
- return get_session().trial_name
- @PublicAPI(stability="beta")
- @_warn_session_misuse()
- def get_trial_id() -> str:
- """Trial id for the corresponding trial."""
- return get_session().trial_id
- @PublicAPI(stability="alpha")
- @_warn_session_misuse()
- def get_run_id() -> str:
- """Unique Train Run id for the corresponding trial."""
- return get_session().run_id
- @PublicAPI(stability="beta")
- @_warn_session_misuse()
- def get_trial_resources() -> "PlacementGroupFactory":
- """Trial resources for the corresponding trial."""
- return get_session().trial_resources
- @PublicAPI(stability="beta")
- @_warn_session_misuse()
- def get_trial_dir() -> str:
- """Log directory corresponding to the trial directory for a Tune session.
- If calling from a Train session, this will give the trial directory of its parent
- Tune session.
- .. testcode::
- import ray.tune
- def train_func(config):
- print(ray.tune.get_context().get_trial_dir())
- tuner = ray.tune.Tuner(train_func)
- tuner.fit()
- .. testoutput::
- :options: +MOCK
- /Users/root/ray_results/train_func_2023-07-19_15-01-37/train_func_d620c_00000_0_2023-07-19_15-01-40
- """
- return get_session().trial_dir
- @PublicAPI(stability="beta")
- @_warn_session_misuse(default_value=1)
- def get_world_size() -> int:
- """Get the current world size (i.e. total number of workers) for this run.
- .. testcode::
- import ray
- from ray import train
- from ray.train import ScalingConfig
- from ray.train.tensorflow import TensorflowTrainer
- NUM_WORKERS = 2
- def train_loop_per_worker(config):
- assert train.get_context().get_world_size() == NUM_WORKERS
- train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
- trainer = TensorflowTrainer(
- train_loop_per_worker,
- scaling_config=ScalingConfig(num_workers=NUM_WORKERS),
- datasets={"train": train_dataset}
- )
- trainer.fit()
- .. testoutput::
- :hide:
- ...
- """
- session = get_session()
- if not hasattr(session, "world_size"):
- raise RuntimeError(
- "`get_world_size` can only be called for TrainSession! "
- "Make sure you only use that in `train_loop_per_worker` function"
- "that is passed into `DataParallelTrainer`."
- )
- return session.world_size
- @PublicAPI(stability="beta")
- @_warn_session_misuse(default_value=0)
- def get_world_rank() -> int:
- """Get the world rank of this worker.
- .. testcode::
- import ray
- from ray import train
- from ray.train import ScalingConfig
- from ray.train.tensorflow import TensorflowTrainer
- def train_loop_per_worker(config):
- if train.get_context().get_world_rank() == 0:
- print("Worker 0")
- train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
- trainer = TensorflowTrainer(
- train_loop_per_worker,
- scaling_config=ScalingConfig(num_workers=2),
- datasets={"train": train_dataset}
- )
- trainer.fit()
- .. testoutput::
- :hide:
- ...
- """
- session = get_session()
- if not hasattr(session, "world_rank"):
- raise RuntimeError(
- "`get_world_rank` can only be called for TrainSession! "
- "Make sure you only use that in `train_loop_per_worker` function"
- "that is passed into `DataParallelTrainer`."
- )
- return session.world_rank
- @PublicAPI(stability="beta")
- @_warn_session_misuse(default_value=0)
- def get_local_rank() -> int:
- """Get the local rank of this worker (rank of the worker on its node).
- .. testcode::
- import torch
- import ray
- from ray import train
- from ray.train import ScalingConfig
- from ray.train.torch import TorchTrainer
- def train_loop_per_worker(config):
- if torch.cuda.is_available():
- torch.cuda.set_device(train.get_context().get_local_rank())
- ...
- train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
- trainer = TorchTrainer(
- train_loop_per_worker,
- scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
- datasets={"train": train_dataset}
- )
- trainer.fit()
- .. testoutput::
- :hide:
- ...
- """
- session = get_session()
- if not hasattr(session, "local_rank"):
- raise RuntimeError(
- "`get_local_rank` can only be called for TrainSession! "
- "Make sure you only use that in `train_loop_per_worker` function"
- "that is passed into `DataParallelTrainer`."
- )
- return session.local_rank
- @PublicAPI(stability="beta")
- @_warn_session_misuse(default_value=0)
- def get_local_world_size() -> int:
- """Get the local world size of this node (i.e. number of workers on this node).
- Example:
- .. testcode::
- import ray
- from ray import train
- from ray.train import ScalingConfig
- from ray.train.torch import TorchTrainer
- def train_loop_per_worker():
- print(train.get_context().get_local_world_size())
- train_dataset = ray.data.from_items(
- [{"x": x, "y": x + 1} for x in range(32)])
- trainer = TorchTrainer(train_loop_per_worker,
- scaling_config=ScalingConfig(num_workers=1),
- datasets={"train": train_dataset})
- trainer.fit()
- .. testoutput::
- :hide:
- ...
- """
- session = get_session()
- if not hasattr(session, "local_world_size"):
- raise RuntimeError(
- "`get_local_world_size` can only be called for TrainSession! "
- "Make sure you only use that in `train_loop_per_worker` function"
- "that is passed into `DataParallelTrainer`."
- )
- return session.local_world_size
- @PublicAPI(stability="beta")
- @_warn_session_misuse(default_value=0)
- def get_node_rank() -> int:
- """Get the rank of this node.
- Example:
- .. testcode::
- import ray
- from ray import train
- from ray.train import ScalingConfig
- from ray.train.torch import TorchTrainer
- def train_loop_per_worker():
- print(train.get_context().get_node_rank())
- train_dataset = ray.data.from_items(
- [{"x": x, "y": x + 1} for x in range(32)])
- trainer = TorchTrainer(train_loop_per_worker,
- scaling_config=ScalingConfig(num_workers=1),
- datasets={"train": train_dataset})
- trainer.fit()
- .. testoutput::
- :hide:
- ...
- """
- session = get_session()
- if not hasattr(session, "node_rank"):
- raise RuntimeError(
- "`get_node_rank` can only be called for TrainSession! "
- "Make sure you only use that in `train_loop_per_worker` function"
- "that is passed into `DataParallelTrainer`."
- )
- return session.node_rank
- @PublicAPI(stability="stable")
- @_warn_session_misuse()
- def get_dataset_shard(
- dataset_name: Optional[str] = None,
- ) -> Optional["DataIterator"]:
- """Returns the :class:`ray.data.DataIterator` shard for this worker.
- Call :meth:`~ray.data.DataIterator.iter_torch_batches` or
- :meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the
- appropriate framework-specific data type.
- .. testcode::
- import ray
- from ray import train
- from ray.train import ScalingConfig
- from ray.train.torch import TorchTrainer
- def train_loop_per_worker(config):
- ...
- for epoch in range(2):
- # Trainer will automatically handle sharding.
- data_shard = train.get_dataset_shard("train")
- for batch in data_shard.iter_torch_batches():
- ...
- train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
- trainer = TorchTrainer(
- train_loop_per_worker,
- scaling_config=ScalingConfig(num_workers=2),
- datasets={"train": train_dataset}
- )
- trainer.fit()
- .. testoutput::
- :hide:
- ...
- Args:
- dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then
- specifies which dataset shard to return.
- Returns:
- The ``DataIterator`` shard to use for this worker.
- If no dataset is passed into Trainer, then return None.
- """
- session = get_session()
- if not hasattr(session, "get_dataset_shard"):
- raise RuntimeError(
- "`get_dataset_shard` can only be called for TrainSession! "
- "Make sure you only use that in `train_loop_per_worker` function"
- "that is passed into `DataParallelTrainer`."
- )
- return session.get_dataset_shard(dataset_name)
- @DeveloperAPI
- @_warn_session_misuse()
- def get_storage() -> StorageContext:
- """Returns the :class:`~ray.train._internal.storage.StorageContext` storage
- context which gives advanced access to the filesystem and paths
- configured through `RunConfig`.
- NOTE: This is a developer API, and the `StorageContext` interface may change
- without notice between minor versions.
- """
- return get_session().storage
- def _in_ray_train_worker() -> bool:
- """Check if the current process is a Ray Train V1 worker."""
- return bool(get_session()) and get_session().world_rank is not None
|