import functools import logging import os import platform import queue import sys import threading import time import warnings from dataclasses import dataclass from datetime import datetime from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type import ray from ray.air._internal.util import RunnerThread, StartTraceback from ray.air.constants import ( _ERROR_FETCH_TIMEOUT, _RESULT_FETCH_TIMEOUT, SESSION_MISUSE_LOG_ONCE_KEY, TIME_THIS_ITER_S, TIMESTAMP, ) from ray.train import Checkpoint from ray.train._internal.accelerator import Accelerator from ray.train._internal.storage import StorageContext from ray.train.constants import ( CHECKPOINT_DIR_NAME, DETAILED_AUTOFILLED_KEYS, RAY_CHDIR_TO_TRIAL_DIR, TIME_TOTAL_S, WORKER_HOSTNAME, WORKER_NODE_IP, WORKER_PID, _v2_migration_warnings_enabled, ) from ray.train.error import SessionMisuseError from ray.train.utils import _log_deprecation_warning from ray.util import queue as ray_queue from ray.util.annotations import DeveloperAPI, PublicAPI from ray.util.debug import log_once from ray.util.placement_group import _valid_resource_shape from ray.util.scheduling_strategies import ( PlacementGroupSchedulingStrategy, SchedulingStrategyT, ) if TYPE_CHECKING: from ray.data import DataIterator, Dataset from ray.tune.execution.placement_groups import PlacementGroupFactory logger = logging.getLogger(__name__) @dataclass class TrialInfo: """The trial information to propagate to TrainSession.""" name: str id: str resources: Dict[str, float] logdir: str driver_ip: str driver_node_id: str experiment_name: Optional[str] = None run_id: Optional[str] = None class _FutureTrainingResult: """A future that will be resolved to a `_TrainingResult`. This is needed for specific schedulers such as PBT that schedule saves. This wrapper should be removed after refactoring PBT to not schedule saves anymore. """ def __init__(self, future: ray.ObjectRef): self.future = future def resolve(self, block: bool = True) -> Optional["_TrainingResult"]: """Resolve into ``_TrainingResult``. This will return None for function trainables if no checkpoint has been saved before. """ if block: timeout = None else: timeout = 1e-9 try: return ray.get(self.future, timeout=timeout) except TimeoutError: # Not ready, yet pass except Exception as exc: logger.error(f"Error resolving result: {exc}") class _TrainingResult: """A (checkpoint, metrics) result reported by the user.""" def __init__(self, checkpoint: Optional[Checkpoint], metrics: Dict[str, Any]): self.checkpoint = checkpoint self.metrics = metrics def __repr__(self) -> str: return f"TrainingResult(checkpoint={self.checkpoint}, metrics={self.metrics})" # TODO(xwjiang): This needs a better name. @DeveloperAPI class _TrainSession: """Holds information for training on each worker.""" def __init__( self, training_func: Callable, world_rank: Optional[int], local_rank: Optional[int], node_rank: Optional[int], local_world_size: Optional[int], world_size: Optional[int], trial_info: Optional[TrialInfo] = None, dataset_shard: Optional[Dict[str, "Dataset"]] = None, metadata: Dict[str, Any] = None, checkpoint: Optional[Checkpoint] = None, detailed_autofilled_metrics: bool = False, storage: Optional[StorageContext] = None, synchronous_result_reporting: bool = False, ): # `synchronous_result_reporting` refers to whether or not the # training function is immediately unblocked to continue running # after the main thread receives its result. # Ex 1: For 2 Ray Train workers with synchronous_result_reporting=True, # the worker that produces a result first will immediately will continue # onto the next iteration. # Ex 2: For a Tune function Trainable with `synchronous_result_reporting=False`, # training will only continue with an explicit call to `session.get_next`. # Synchronous reporting in example 2 is needed for Tune schedulers to # be able to stop the execution of the training function at will, # for advanced pausing schedulers (PBT, BOHB) and actor reuse. self.synchronous_result_reporting = synchronous_result_reporting # Ray Train worker properties # Note: These are set to None for Tune function Trainables. self.dataset_shard = dataset_shard self.metadata = metadata self.world_rank = world_rank self.local_rank = local_rank self.node_rank = node_rank self.local_world_size = local_world_size self.world_size = world_size assert storage logger.debug(f"StorageContext on SESSION (rank={world_rank}):\n{storage}") # NOTE: `reset` will initialize many properties needed to start running the # training_func as a thread. self.reset( training_func=training_func, trial_info=trial_info, storage=storage, loaded_checkpoint=checkpoint, ) # Autofilled metrics attributes. self.detailed_autofilled_metrics = detailed_autofilled_metrics self.last_report_time = time.time() self.iteration = 0 self.time_total = 0.0 self.local_ip = self.get_current_ip() self.accelerator = None self._state = {} def get_state(self, key: str) -> Any: return self._state.get(key) def set_state(self, key: str, value: Any): self._state[key] = value def get_current_ip(self): self.local_ip = ray.util.get_node_ip_address() return self.local_ip def start(self): """Starts the training thread.""" self.training_started = True self.training_thread.start() def reset( self, training_func: Callable, trial_info: TrialInfo, storage: StorageContext, loaded_checkpoint=None, ): # This lock is used to control the execution of the training thread. self.continue_lock = threading.Semaphore(0) # This event is used to signal the training thread to stop. self.stop_event = threading.Event() # Queue for sending results across threads. self.result_queue = queue.Queue(1) # Queue for sending results from training actor to main thread. self._inter_actor_queue: Optional[ray_queue.Queue[Dict]] = None # Queue for raising exceptions from runner thread to main thread. # The error queue has a max size of one to prevent stacking error and force # error reporting to block until finished. self.error_queue = queue.Queue(1) # The Thread object that is running the training function. self.training_thread = RunnerThread( target=training_func, daemon=True, error_queue=self.error_queue ) # Possibly override with new state self.trial_info = trial_info self.storage = storage self.loaded_checkpoint = loaded_checkpoint # Reset state self._state = {} self.ignore_report = False self.training_started = False self._first_report = True # Change the working directory to a special trial folder. # This is to ensure that all Ray Train workers have a common working directory. os.makedirs(storage.trial_working_directory, exist_ok=True) if bool(int(os.environ.get(RAY_CHDIR_TO_TRIAL_DIR, "1"))): logger.debug( f"Changing the working directory to: {storage.trial_working_directory}" ) os.chdir(storage.trial_working_directory) def pause_reporting(self): """Ignore all future ``session.report()`` calls.""" self.ignore_report = True def finish(self, timeout: Optional[float] = None) -> Optional[Any]: """Finishes the training thread. Raises any Exception from training. """ # Set the stop event for the training thread to gracefully exit. self.stop_event.set() # Release the lock so that training thread can process this event. self.continue_lock.release() # Force a final (blocking) sync of artifacts in the trial path to storage. self.storage.persist_artifacts(force=True) # Wait for training to finish. # This will raise any errors that occur during training, including SystemError # This returns the result of the training function. output = None if self.training_started: output = self.training_thread.join(timeout=timeout) return output def get_next(self) -> Optional[_TrainingResult]: """Gets the next ``_TrainingResult`` from the result queue. If the result queue is empty, then this function returns ``None``. """ if not self.training_started: raise RuntimeError("Please call start before calling get_next.") if self.synchronous_result_reporting: # There's no need to release the lock on the first report # since `start` already started the training thread. if not self._first_report: # Release the lock to trigger training to continue, # until the next call to report. self.continue_lock.release() self._first_report = False result = None # While training is still ongoing, attempt to get the result. while result is None and self.training_thread.is_alive(): result = self._get_result_from_queues(block=True) # If no result was found, then the runner must no longer be alive. if result is None: # Try one last time to fetch results in case results were # reported in between the time of the last check and the # termination of the thread runner. result = self._get_result_from_queues(block=False) # check if error occurred inside the thread runner. if result is None: # only raise an error from the runner if all results are consumed self._report_thread_runner_error(block=True) else: if not self.error_queue.empty(): logger.debug( ( "Runner error waiting to be raised in main thread. " "Logging all available results first." ) ) if not self.synchronous_result_reporting: # At this point, the training thread has reached # the `train.report` and is blocked there. # If performing asynchronous result reporting, # release the lock to allow each worker to keep training # immediately after the coordinator fetches their result. self.continue_lock.release() # Return None if there are no more results to fetch. return result def _get_or_create_inter_actor_queue(self): """Get or create the inter-actor queue.""" if self._inter_actor_queue is None: self._inter_actor_queue = ray_queue.Queue(1, actor_options={"num_cpus": 0}) return self._inter_actor_queue def _get_result_from_queues(self, block: bool) -> Optional[_TrainingResult]: """Get result from result queue. Pass result from training actor result queue if needed.""" result = None if self._inter_actor_queue is not None: try: inter_actor_item = self._inter_actor_queue.get( block=block, timeout=_RESULT_FETCH_TIMEOUT ) if inter_actor_item: # Must release continue_lock to allow report to work. self.continue_lock.release() self.report(inter_actor_item) except ray_queue.Empty: pass try: result = self.result_queue.get(block=block, timeout=_RESULT_FETCH_TIMEOUT) except queue.Empty: pass return result def _auto_fill_metrics(self, result: dict) -> dict: """Add autofilled metrics and update attributes.""" current_time = time.time() current_datetime = datetime.now() if TIME_THIS_ITER_S in result: time_this_iter = result[TIME_THIS_ITER_S] else: time_this_iter = current_time - self.last_report_time self.iteration += 1 self.time_total += time_this_iter self.last_report_time = current_time auto_filled_metrics = { TIMESTAMP: int(time.mktime(current_datetime.timetuple())), TIME_TOTAL_S: self.time_total, WORKER_PID: os.getpid(), WORKER_HOSTNAME: platform.node(), WORKER_NODE_IP: self.local_ip, } if not self.detailed_autofilled_metrics: auto_filled_metrics = { k: v for k, v in auto_filled_metrics.items() if k not in DETAILED_AUTOFILLED_KEYS } result = result.copy() result.update(auto_filled_metrics) return result def _auto_fill_checkpoint_metrics(self, result: dict) -> dict: """Add autofilled metrics and update attributes.""" current_datetime = datetime.now() auto_filled_metrics = { TIMESTAMP: int(time.mktime(current_datetime.timetuple())) } result = result.copy() result.update(auto_filled_metrics) return result def _report_thread_runner_error(self, block=False): try: e = self.error_queue.get(block=block, timeout=_ERROR_FETCH_TIMEOUT) raise StartTraceback from e except queue.Empty: pass def _report_training_result(self, training_result: _TrainingResult) -> None: """Place a training result on the result queue for the main thread to process, then block until the main thread signals that training should continue. NOTE: This is used internally to report results from Train to Tune without persisting checkpoints to storage 2 times. `report` is the public API that directly persists to storage, which should only be called by user code. """ if training_result.checkpoint: # NOTE: This populates `train.get_checkpoint` self.loaded_checkpoint = training_result.checkpoint # Add result to a thread-safe queue. self.result_queue.put(training_result, block=True) # Acquire lock to stop the training thread until main thread # triggers resume. self.continue_lock.acquire() # If the trial should be terminated, exit gracefully. # NOTE: This is only really useful if `synchronous_result_reporting=True`. # Otherwise, the lock is immediately released on reporting, and this # check is skipped before the main thread decides to set the stop event. if self.stop_event.is_set(): self.stop_event.clear() sys.exit(0) def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None: # Special case: early fail for Torch tensors if "torch" in sys.modules: from ray.air._internal.torch_utils import contains_tensor if contains_tensor(metrics): raise ValueError( "Passing objects containg Torch tensors as metrics " "is not supported as it will throw an exception on " "deserialization. You can either convert the tensors " "to Python objects or report a `train.Checkpoint` " "with `ray.train.report` to store your Torch objects." ) if self.ignore_report: return metrics = self._auto_fill_metrics(metrics) persisted_checkpoint = None if checkpoint: self.storage._update_checkpoint_index(metrics) # Persist the reported checkpoint files to storage. persisted_checkpoint = self.storage.persist_current_checkpoint(checkpoint) metrics[CHECKPOINT_DIR_NAME] = self.storage.checkpoint_dir_name else: metrics[CHECKPOINT_DIR_NAME] = None # Persist trial artifacts to storage. force_artifact_sync = ( persisted_checkpoint and self.storage.sync_config.sync_artifacts_on_checkpoint ) self.storage.persist_artifacts(force=force_artifact_sync) # Set additional user metadata from the Trainer. if persisted_checkpoint and self.metadata: user_metadata = persisted_checkpoint.get_metadata() for k, v in self.metadata.items(): # Update keys not already set by the user. This gives user-set keys # precedence over keys set at the Trainer level. if k not in user_metadata: user_metadata[k] = v persisted_checkpoint.set_metadata(user_metadata) result = _TrainingResult(checkpoint=persisted_checkpoint, metrics=metrics) self._report_training_result(result) @property def experiment_name(self) -> str: return self.trial_info.experiment_name @property def trial_name(self) -> str: return self.trial_info.name @property def trial_id(self) -> str: return self.trial_info.id @property def run_id(self) -> str: return self.trial_info.run_id @property def trial_resources(self) -> "PlacementGroupFactory": return self.trial_info.resources @property def trial_dir(self) -> str: return self.trial_info.logdir def get_dataset_shard( self, dataset_name: Optional[str] = None, ) -> Optional["DataIterator"]: shard = self.dataset_shard if shard is None: warnings.warn( "No dataset passed in. Returning None. Make sure to " "pass in a Dataset to Trainer.run to use this " "function." ) elif isinstance(shard, dict): if not dataset_name: raise RuntimeError( "Multiple datasets were passed into ``Trainer``, " "but no ``dataset_name`` is passed into " "``get_dataset_shard``. Please specify which " "dataset shard to retrieve." ) return shard.get(dataset_name) return shard # Cache of resource dicts that have been checked by the launch hook already. _checked_resources: Set[frozenset] = set() # Global _TrainSession object initialized by Ray Tune function trainables # and Ray Train V1 workers. _session: Optional[_TrainSession] = None def _tune_task_and_actor_launch_hook( fn, resources: Dict[str, float], strategy: Optional[SchedulingStrategyT] ): """Launch hook to catch nested tasks that can't fit in the placement group. This gives users a nice warning in case they launch a nested task in a Tune trial without reserving resources in the trial placement group to fit it. """ # Already checked, skip for performance reasons. key = frozenset({(k, v) for k, v in resources.items() if v > 0}) if not key or key in _checked_resources: return # No need to check if placement group is None. if ( not isinstance(strategy, PlacementGroupSchedulingStrategy) or strategy.placement_group is None ): return # Check if the resource request is targeting the current placement group. cur_pg = ray.util.get_current_placement_group() if not cur_pg or strategy.placement_group.id != cur_pg.id: return _checked_resources.add(key) # Check if the request can be fulfilled by the current placement group. pgf = get_trial_resources() if pgf.head_bundle_is_empty: available_bundles = cur_pg.bundle_specs[0:] else: available_bundles = cur_pg.bundle_specs[1:] # Check if the request can be fulfilled by the current placement group. if _valid_resource_shape(resources, available_bundles): return if fn.class_name: submitted = "actor" name = fn.module_name + "." + fn.class_name + "." + fn.function_name else: submitted = "task" name = fn.module_name + "." + fn.function_name # Normalize the resource spec so it looks the same as the placement group bundle. main_resources = cur_pg.bundle_specs[0] resources = {k: float(v) for k, v in resources.items() if v > 0} raise RuntimeError( f"No trial resources are available for launching the {submitted} `{name}`. " "To resolve this, specify the Tune option:\n\n" "> resources_per_trial=tune.PlacementGroupFactory(\n" f"> [{main_resources}] + [{resources}] * N\n" "> )\n\n" f"Where `N` is the number of slots to reserve for trial {submitted}s. " "If you are using a Ray training library, there might be a utility function " "to set this automatically for you. For more information, refer to " "https://docs.ray.io/en/latest/tune/tutorials/tune-resources.html" ) def init_session(*args, **kwargs) -> None: global _session if _session: raise ValueError( "A Train session is already in use. Do not call " "`init_session()` manually." ) # Setup hooks for generating placement group resource deadlock warnings. from ray import actor, remote_function if "TUNE_DISABLE_RESOURCE_CHECKS" not in os.environ: actor._actor_launch_hook = _tune_task_and_actor_launch_hook remote_function._task_launch_hook = _tune_task_and_actor_launch_hook _session = _TrainSession(*args, **kwargs) def get_session() -> Optional[_TrainSession]: return _session def shutdown_session(): """Shuts down the initialized session.""" global _session _session = None def _raise_accelerator_session_misuse(): """Raises a SessionMisuseError because a utility function was used improperly.""" raise SessionMisuseError( "prepare/accelerate utility functions should be called inside a training " "function executed by `Trainer.run`" ) def get_accelerator(default_accelerator_cls: Type[Accelerator]) -> Accelerator: """The accelerator for this training session. If an accelerator has not been set, then this method will construct an accelerator using the provided accelerator class. Raises: SessionMisuseError: if the session is uninitialized. """ session = get_session() if session is None: _raise_accelerator_session_misuse() if session.accelerator is None: session.accelerator = default_accelerator_cls() return session.accelerator def set_accelerator(accelerator: Accelerator) -> None: """Sets the accelerator for this training session. Args: accelerator: The accelerator to use for training. Raises: SessionMisuseError: if the session is unitialized. RuntimeError: if the accelerator has already been set. """ session = get_session() if session is None: _raise_accelerator_session_misuse() if session.accelerator is not None: raise RuntimeError("Cannot change accelerator once set.") session.accelerator = accelerator def _warn_session_misuse(default_value: Any = None): """Warns if fn is being used outside of session and returns ``default_value``.""" def inner(fn: Callable): fn_name = fn.__name__ @functools.wraps(fn) def wrapper(*args, **kwargs): session = get_session() if not session: if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"): warnings.warn( f"`{fn_name}` is meant to only be " "called inside a function that is executed by a Tuner" f" or Trainer. Returning `{default_value}`." ) return default_value return fn(*args, **kwargs) return wrapper return inner @PublicAPI(stability="stable") @_warn_session_misuse() def report( metrics: Dict, *, checkpoint: Optional[Checkpoint] = None, checkpoint_dir_name: Optional[str] = None, ) -> None: """Report metrics and optionally save a checkpoint. If a checkpoint is provided, it will be :ref:`persisted to storage `. If this is called in multiple distributed training workers: - Only the metrics reported by the rank 0 worker will be tracked by Ray Train. See :ref:`the metrics logging guide `. - A checkpoint will be registered as long as one or more workers reports checkpoint that is not None. See the :ref:`checkpointing guide `. - Checkpoints from multiple workers will be merged into one directory in persistent storage. See :ref:`the distributed checkpointing guide `. .. note:: Each invocation of this method will automatically increment the underlying ``training_iteration`` number. The physical meaning of this "iteration" is defined by user depending on how often they call ``report``. It does not necessarily map to one epoch. .. warning:: All workers must call `ray.train.report` the same number of times so that Ray Train can properly synchronize the training state across workers. Otherwise, your training will hang. .. warning:: This method does NOT act as a barrier for distributed training workers. Workers will upload their checkpoint, then continue training immediately. If you need to synchronize workers, you can use a framework-native barrier such as `torch.distributed.barrier()`. Example: .. testcode:: import tempfile from ray import train from ray.train import Checkpoint from ray.train.torch import TorchTrainer def train_func(config): start_epoch = 0 checkpoint = train.get_checkpoint() if checkpoint: with checkpoint.as_directory() as checkpoint_dir: # Load back training state ... for epoch in range(start_epoch, config.get("num_epochs", 10)): # Do training... metrics = {"loss": ...} with tempfile.TemporaryDirectory() as temp_checkpoint_dir: # Save the checkpoint... # torch.save(...) checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) # Example: Only the rank 0 worker uploads the checkpoint. if ray.train.get_context().get_world_rank() == 0: train.report(metrics, checkpoint=checkpoint) else: train.report(metrics, checkpoint=None) trainer = TorchTrainer( train_func, scaling_config=train.ScalingConfig(num_workers=2) ) Args: metrics: The metrics you want to report. checkpoint: The optional checkpoint you want to report. """ if checkpoint_dir_name is not None: logger.warning( "`checkpoint_dir_name` is only supported in the new Ray Train " "implementation, which can be enabled with `RAY_TRAIN_V2_ENABLED=1`. " "This argument will be ignored." ) # If we are running in a Tune function, switch to `ray.tune.report`. from ray.tune.trainable.trainable_fn_utils import _in_tune_session if _in_tune_session(): import ray.tune if _v2_migration_warnings_enabled(): _log_deprecation_warning( "`ray.train.report` should be switched to " "`ray.tune.report` when running in a function " "passed to Ray Tune. This will be an error in the future. " "See this issue for more context: " "https://github.com/ray-project/ray/issues/49454" ) return ray.tune.report(metrics, checkpoint=checkpoint) get_session().report(metrics, checkpoint=checkpoint) @PublicAPI(stability="stable") @_warn_session_misuse() def get_checkpoint() -> Optional[Checkpoint]: """Access the latest reported checkpoint to resume from if one exists. Example: .. testcode:: import tempfile from ray import train from ray.train import Checkpoint from ray.train.torch import TorchTrainer def train_func(config): start_epoch = 0 checkpoint = train.get_checkpoint() if checkpoint: with checkpoint.as_directory() as checkpoint_dir: # Load back training state ... for epoch in range(start_epoch, config.get("num_epochs", 10)): # Do training... metrics = {"loss": ...} with tempfile.TemporaryDirectory() as temp_checkpoint_dir: # Save the checkpoint... checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) train.report(metrics, checkpoint=checkpoint) trainer = TorchTrainer( train_func, scaling_config=train.ScalingConfig(num_workers=2) ) Returns: Checkpoint object if the session is currently being resumed. Otherwise, return None. """ # If we are running in a Tune function, switch to `ray.tune.get_checkpoint`. from ray.tune.trainable.trainable_fn_utils import _in_tune_session if _in_tune_session(): import ray.tune if _v2_migration_warnings_enabled(): _log_deprecation_warning( "`ray.train.get_checkpoint` should be switched to " "`ray.tune.get_checkpoint` when running in a function " "passed to Ray Tune. This will be an error in the future. " "See this issue for more context: " "https://github.com/ray-project/ray/issues/49454" ) return ray.tune.get_checkpoint() return get_session().loaded_checkpoint @PublicAPI(stability="beta") @_warn_session_misuse() def get_metadata() -> Dict[str, Any]: """User metadata dict passed to the Trainer constructor.""" return get_session().metadata @PublicAPI(stability="beta") @_warn_session_misuse() def get_experiment_name() -> str: """Experiment name for the corresponding trial.""" return get_session().experiment_name @PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_name() -> str: """Trial name for the corresponding trial.""" return get_session().trial_name @PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_id() -> str: """Trial id for the corresponding trial.""" return get_session().trial_id @PublicAPI(stability="alpha") @_warn_session_misuse() def get_run_id() -> str: """Unique Train Run id for the corresponding trial.""" return get_session().run_id @PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_resources() -> "PlacementGroupFactory": """Trial resources for the corresponding trial.""" return get_session().trial_resources @PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_dir() -> str: """Log directory corresponding to the trial directory for a Tune session. If calling from a Train session, this will give the trial directory of its parent Tune session. .. testcode:: import ray.tune def train_func(config): print(ray.tune.get_context().get_trial_dir()) tuner = ray.tune.Tuner(train_func) tuner.fit() .. testoutput:: :options: +MOCK /Users/root/ray_results/train_func_2023-07-19_15-01-37/train_func_d620c_00000_0_2023-07-19_15-01-40 """ return get_session().trial_dir @PublicAPI(stability="beta") @_warn_session_misuse(default_value=1) def get_world_size() -> int: """Get the current world size (i.e. total number of workers) for this run. .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.tensorflow import TensorflowTrainer NUM_WORKERS = 2 def train_loop_per_worker(config): assert train.get_context().get_world_size() == NUM_WORKERS train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") trainer = TensorflowTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=NUM_WORKERS), datasets={"train": train_dataset} ) trainer.fit() .. testoutput:: :hide: ... """ session = get_session() if not hasattr(session, "world_size"): raise RuntimeError( "`get_world_size` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.world_size @PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_world_rank() -> int: """Get the world rank of this worker. .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.tensorflow import TensorflowTrainer def train_loop_per_worker(config): if train.get_context().get_world_rank() == 0: print("Worker 0") train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") trainer = TensorflowTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2), datasets={"train": train_dataset} ) trainer.fit() .. testoutput:: :hide: ... """ session = get_session() if not hasattr(session, "world_rank"): raise RuntimeError( "`get_world_rank` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.world_rank @PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_local_rank() -> int: """Get the local rank of this worker (rank of the worker on its node). .. testcode:: import torch import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(config): if torch.cuda.is_available(): torch.cuda.set_device(train.get_context().get_local_rank()) ... train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") trainer = TorchTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2, use_gpu=True), datasets={"train": train_dataset} ) trainer.fit() .. testoutput:: :hide: ... """ session = get_session() if not hasattr(session, "local_rank"): raise RuntimeError( "`get_local_rank` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.local_rank @PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_local_world_size() -> int: """Get the local world size of this node (i.e. number of workers on this node). Example: .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(): print(train.get_context().get_local_world_size()) train_dataset = ray.data.from_items( [{"x": x, "y": x + 1} for x in range(32)]) trainer = TorchTrainer(train_loop_per_worker, scaling_config=ScalingConfig(num_workers=1), datasets={"train": train_dataset}) trainer.fit() .. testoutput:: :hide: ... """ session = get_session() if not hasattr(session, "local_world_size"): raise RuntimeError( "`get_local_world_size` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.local_world_size @PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_node_rank() -> int: """Get the rank of this node. Example: .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(): print(train.get_context().get_node_rank()) train_dataset = ray.data.from_items( [{"x": x, "y": x + 1} for x in range(32)]) trainer = TorchTrainer(train_loop_per_worker, scaling_config=ScalingConfig(num_workers=1), datasets={"train": train_dataset}) trainer.fit() .. testoutput:: :hide: ... """ session = get_session() if not hasattr(session, "node_rank"): raise RuntimeError( "`get_node_rank` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.node_rank @PublicAPI(stability="stable") @_warn_session_misuse() def get_dataset_shard( dataset_name: Optional[str] = None, ) -> Optional["DataIterator"]: """Returns the :class:`ray.data.DataIterator` shard for this worker. Call :meth:`~ray.data.DataIterator.iter_torch_batches` or :meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the appropriate framework-specific data type. .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(config): ... for epoch in range(2): # Trainer will automatically handle sharding. data_shard = train.get_dataset_shard("train") for batch in data_shard.iter_torch_batches(): ... train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") trainer = TorchTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2), datasets={"train": train_dataset} ) trainer.fit() .. testoutput:: :hide: ... Args: dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then specifies which dataset shard to return. Returns: The ``DataIterator`` shard to use for this worker. If no dataset is passed into Trainer, then return None. """ session = get_session() if not hasattr(session, "get_dataset_shard"): raise RuntimeError( "`get_dataset_shard` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.get_dataset_shard(dataset_name) @DeveloperAPI @_warn_session_misuse() def get_storage() -> StorageContext: """Returns the :class:`~ray.train._internal.storage.StorageContext` storage context which gives advanced access to the filesystem and paths configured through `RunConfig`. NOTE: This is a developer API, and the `StorageContext` interface may change without notice between minor versions. """ return get_session().storage def _in_ray_train_worker() -> bool: """Check if the current process is a Ray Train V1 worker.""" return bool(get_session()) and get_session().world_rank is not None