session.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203
  1. import functools
  2. import logging
  3. import os
  4. import platform
  5. import queue
  6. import sys
  7. import threading
  8. import time
  9. import warnings
  10. from dataclasses import dataclass
  11. from datetime import datetime
  12. from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type
  13. import ray
  14. from ray.air._internal.util import RunnerThread, StartTraceback
  15. from ray.air.constants import (
  16. _ERROR_FETCH_TIMEOUT,
  17. _RESULT_FETCH_TIMEOUT,
  18. SESSION_MISUSE_LOG_ONCE_KEY,
  19. TIME_THIS_ITER_S,
  20. TIMESTAMP,
  21. )
  22. from ray.train import Checkpoint
  23. from ray.train._internal.accelerator import Accelerator
  24. from ray.train._internal.storage import StorageContext
  25. from ray.train.constants import (
  26. CHECKPOINT_DIR_NAME,
  27. DETAILED_AUTOFILLED_KEYS,
  28. RAY_CHDIR_TO_TRIAL_DIR,
  29. TIME_TOTAL_S,
  30. WORKER_HOSTNAME,
  31. WORKER_NODE_IP,
  32. WORKER_PID,
  33. _v2_migration_warnings_enabled,
  34. )
  35. from ray.train.error import SessionMisuseError
  36. from ray.train.utils import _log_deprecation_warning
  37. from ray.util import queue as ray_queue
  38. from ray.util.annotations import DeveloperAPI, PublicAPI
  39. from ray.util.debug import log_once
  40. from ray.util.placement_group import _valid_resource_shape
  41. from ray.util.scheduling_strategies import (
  42. PlacementGroupSchedulingStrategy,
  43. SchedulingStrategyT,
  44. )
  45. if TYPE_CHECKING:
  46. from ray.data import DataIterator, Dataset
  47. from ray.tune.execution.placement_groups import PlacementGroupFactory
  48. logger = logging.getLogger(__name__)
  49. @dataclass
  50. class TrialInfo:
  51. """The trial information to propagate to TrainSession."""
  52. name: str
  53. id: str
  54. resources: Dict[str, float]
  55. logdir: str
  56. driver_ip: str
  57. driver_node_id: str
  58. experiment_name: Optional[str] = None
  59. run_id: Optional[str] = None
  60. class _FutureTrainingResult:
  61. """A future that will be resolved to a `_TrainingResult`.
  62. This is needed for specific schedulers such as PBT that schedule saves.
  63. This wrapper should be removed after refactoring PBT to not schedule saves anymore.
  64. """
  65. def __init__(self, future: ray.ObjectRef):
  66. self.future = future
  67. def resolve(self, block: bool = True) -> Optional["_TrainingResult"]:
  68. """Resolve into ``_TrainingResult``.
  69. This will return None for function trainables if no checkpoint has been
  70. saved before.
  71. """
  72. if block:
  73. timeout = None
  74. else:
  75. timeout = 1e-9
  76. try:
  77. return ray.get(self.future, timeout=timeout)
  78. except TimeoutError:
  79. # Not ready, yet
  80. pass
  81. except Exception as exc:
  82. logger.error(f"Error resolving result: {exc}")
  83. class _TrainingResult:
  84. """A (checkpoint, metrics) result reported by the user."""
  85. def __init__(self, checkpoint: Optional[Checkpoint], metrics: Dict[str, Any]):
  86. self.checkpoint = checkpoint
  87. self.metrics = metrics
  88. def __repr__(self) -> str:
  89. return f"TrainingResult(checkpoint={self.checkpoint}, metrics={self.metrics})"
  90. # TODO(xwjiang): This needs a better name.
  91. @DeveloperAPI
  92. class _TrainSession:
  93. """Holds information for training on each worker."""
  94. def __init__(
  95. self,
  96. training_func: Callable,
  97. world_rank: Optional[int],
  98. local_rank: Optional[int],
  99. node_rank: Optional[int],
  100. local_world_size: Optional[int],
  101. world_size: Optional[int],
  102. trial_info: Optional[TrialInfo] = None,
  103. dataset_shard: Optional[Dict[str, "Dataset"]] = None,
  104. metadata: Dict[str, Any] = None,
  105. checkpoint: Optional[Checkpoint] = None,
  106. detailed_autofilled_metrics: bool = False,
  107. storage: Optional[StorageContext] = None,
  108. synchronous_result_reporting: bool = False,
  109. ):
  110. # `synchronous_result_reporting` refers to whether or not the
  111. # training function is immediately unblocked to continue running
  112. # after the main thread receives its result.
  113. # Ex 1: For 2 Ray Train workers with synchronous_result_reporting=True,
  114. # the worker that produces a result first will immediately will continue
  115. # onto the next iteration.
  116. # Ex 2: For a Tune function Trainable with `synchronous_result_reporting=False`,
  117. # training will only continue with an explicit call to `session.get_next`.
  118. # Synchronous reporting in example 2 is needed for Tune schedulers to
  119. # be able to stop the execution of the training function at will,
  120. # for advanced pausing schedulers (PBT, BOHB) and actor reuse.
  121. self.synchronous_result_reporting = synchronous_result_reporting
  122. # Ray Train worker properties
  123. # Note: These are set to None for Tune function Trainables.
  124. self.dataset_shard = dataset_shard
  125. self.metadata = metadata
  126. self.world_rank = world_rank
  127. self.local_rank = local_rank
  128. self.node_rank = node_rank
  129. self.local_world_size = local_world_size
  130. self.world_size = world_size
  131. assert storage
  132. logger.debug(f"StorageContext on SESSION (rank={world_rank}):\n{storage}")
  133. # NOTE: `reset` will initialize many properties needed to start running the
  134. # training_func as a thread.
  135. self.reset(
  136. training_func=training_func,
  137. trial_info=trial_info,
  138. storage=storage,
  139. loaded_checkpoint=checkpoint,
  140. )
  141. # Autofilled metrics attributes.
  142. self.detailed_autofilled_metrics = detailed_autofilled_metrics
  143. self.last_report_time = time.time()
  144. self.iteration = 0
  145. self.time_total = 0.0
  146. self.local_ip = self.get_current_ip()
  147. self.accelerator = None
  148. self._state = {}
  149. def get_state(self, key: str) -> Any:
  150. return self._state.get(key)
  151. def set_state(self, key: str, value: Any):
  152. self._state[key] = value
  153. def get_current_ip(self):
  154. self.local_ip = ray.util.get_node_ip_address()
  155. return self.local_ip
  156. def start(self):
  157. """Starts the training thread."""
  158. self.training_started = True
  159. self.training_thread.start()
  160. def reset(
  161. self,
  162. training_func: Callable,
  163. trial_info: TrialInfo,
  164. storage: StorageContext,
  165. loaded_checkpoint=None,
  166. ):
  167. # This lock is used to control the execution of the training thread.
  168. self.continue_lock = threading.Semaphore(0)
  169. # This event is used to signal the training thread to stop.
  170. self.stop_event = threading.Event()
  171. # Queue for sending results across threads.
  172. self.result_queue = queue.Queue(1)
  173. # Queue for sending results from training actor to main thread.
  174. self._inter_actor_queue: Optional[ray_queue.Queue[Dict]] = None
  175. # Queue for raising exceptions from runner thread to main thread.
  176. # The error queue has a max size of one to prevent stacking error and force
  177. # error reporting to block until finished.
  178. self.error_queue = queue.Queue(1)
  179. # The Thread object that is running the training function.
  180. self.training_thread = RunnerThread(
  181. target=training_func, daemon=True, error_queue=self.error_queue
  182. )
  183. # Possibly override with new state
  184. self.trial_info = trial_info
  185. self.storage = storage
  186. self.loaded_checkpoint = loaded_checkpoint
  187. # Reset state
  188. self._state = {}
  189. self.ignore_report = False
  190. self.training_started = False
  191. self._first_report = True
  192. # Change the working directory to a special trial folder.
  193. # This is to ensure that all Ray Train workers have a common working directory.
  194. os.makedirs(storage.trial_working_directory, exist_ok=True)
  195. if bool(int(os.environ.get(RAY_CHDIR_TO_TRIAL_DIR, "1"))):
  196. logger.debug(
  197. f"Changing the working directory to: {storage.trial_working_directory}"
  198. )
  199. os.chdir(storage.trial_working_directory)
  200. def pause_reporting(self):
  201. """Ignore all future ``session.report()`` calls."""
  202. self.ignore_report = True
  203. def finish(self, timeout: Optional[float] = None) -> Optional[Any]:
  204. """Finishes the training thread.
  205. Raises any Exception from training.
  206. """
  207. # Set the stop event for the training thread to gracefully exit.
  208. self.stop_event.set()
  209. # Release the lock so that training thread can process this event.
  210. self.continue_lock.release()
  211. # Force a final (blocking) sync of artifacts in the trial path to storage.
  212. self.storage.persist_artifacts(force=True)
  213. # Wait for training to finish.
  214. # This will raise any errors that occur during training, including SystemError
  215. # This returns the result of the training function.
  216. output = None
  217. if self.training_started:
  218. output = self.training_thread.join(timeout=timeout)
  219. return output
  220. def get_next(self) -> Optional[_TrainingResult]:
  221. """Gets the next ``_TrainingResult`` from the result queue.
  222. If the result queue is empty, then this function returns ``None``.
  223. """
  224. if not self.training_started:
  225. raise RuntimeError("Please call start before calling get_next.")
  226. if self.synchronous_result_reporting:
  227. # There's no need to release the lock on the first report
  228. # since `start` already started the training thread.
  229. if not self._first_report:
  230. # Release the lock to trigger training to continue,
  231. # until the next call to report.
  232. self.continue_lock.release()
  233. self._first_report = False
  234. result = None
  235. # While training is still ongoing, attempt to get the result.
  236. while result is None and self.training_thread.is_alive():
  237. result = self._get_result_from_queues(block=True)
  238. # If no result was found, then the runner must no longer be alive.
  239. if result is None:
  240. # Try one last time to fetch results in case results were
  241. # reported in between the time of the last check and the
  242. # termination of the thread runner.
  243. result = self._get_result_from_queues(block=False)
  244. # check if error occurred inside the thread runner.
  245. if result is None:
  246. # only raise an error from the runner if all results are consumed
  247. self._report_thread_runner_error(block=True)
  248. else:
  249. if not self.error_queue.empty():
  250. logger.debug(
  251. (
  252. "Runner error waiting to be raised in main thread. "
  253. "Logging all available results first."
  254. )
  255. )
  256. if not self.synchronous_result_reporting:
  257. # At this point, the training thread has reached
  258. # the `train.report` and is blocked there.
  259. # If performing asynchronous result reporting,
  260. # release the lock to allow each worker to keep training
  261. # immediately after the coordinator fetches their result.
  262. self.continue_lock.release()
  263. # Return None if there are no more results to fetch.
  264. return result
  265. def _get_or_create_inter_actor_queue(self):
  266. """Get or create the inter-actor queue."""
  267. if self._inter_actor_queue is None:
  268. self._inter_actor_queue = ray_queue.Queue(1, actor_options={"num_cpus": 0})
  269. return self._inter_actor_queue
  270. def _get_result_from_queues(self, block: bool) -> Optional[_TrainingResult]:
  271. """Get result from result queue. Pass result from training actor result queue if needed."""
  272. result = None
  273. if self._inter_actor_queue is not None:
  274. try:
  275. inter_actor_item = self._inter_actor_queue.get(
  276. block=block, timeout=_RESULT_FETCH_TIMEOUT
  277. )
  278. if inter_actor_item:
  279. # Must release continue_lock to allow report to work.
  280. self.continue_lock.release()
  281. self.report(inter_actor_item)
  282. except ray_queue.Empty:
  283. pass
  284. try:
  285. result = self.result_queue.get(block=block, timeout=_RESULT_FETCH_TIMEOUT)
  286. except queue.Empty:
  287. pass
  288. return result
  289. def _auto_fill_metrics(self, result: dict) -> dict:
  290. """Add autofilled metrics and update attributes."""
  291. current_time = time.time()
  292. current_datetime = datetime.now()
  293. if TIME_THIS_ITER_S in result:
  294. time_this_iter = result[TIME_THIS_ITER_S]
  295. else:
  296. time_this_iter = current_time - self.last_report_time
  297. self.iteration += 1
  298. self.time_total += time_this_iter
  299. self.last_report_time = current_time
  300. auto_filled_metrics = {
  301. TIMESTAMP: int(time.mktime(current_datetime.timetuple())),
  302. TIME_TOTAL_S: self.time_total,
  303. WORKER_PID: os.getpid(),
  304. WORKER_HOSTNAME: platform.node(),
  305. WORKER_NODE_IP: self.local_ip,
  306. }
  307. if not self.detailed_autofilled_metrics:
  308. auto_filled_metrics = {
  309. k: v
  310. for k, v in auto_filled_metrics.items()
  311. if k not in DETAILED_AUTOFILLED_KEYS
  312. }
  313. result = result.copy()
  314. result.update(auto_filled_metrics)
  315. return result
  316. def _auto_fill_checkpoint_metrics(self, result: dict) -> dict:
  317. """Add autofilled metrics and update attributes."""
  318. current_datetime = datetime.now()
  319. auto_filled_metrics = {
  320. TIMESTAMP: int(time.mktime(current_datetime.timetuple()))
  321. }
  322. result = result.copy()
  323. result.update(auto_filled_metrics)
  324. return result
  325. def _report_thread_runner_error(self, block=False):
  326. try:
  327. e = self.error_queue.get(block=block, timeout=_ERROR_FETCH_TIMEOUT)
  328. raise StartTraceback from e
  329. except queue.Empty:
  330. pass
  331. def _report_training_result(self, training_result: _TrainingResult) -> None:
  332. """Place a training result on the result queue for the main thread to process,
  333. then block until the main thread signals that training should continue.
  334. NOTE: This is used internally to report results from Train to Tune
  335. without persisting checkpoints to storage 2 times.
  336. `report` is the public API that directly persists to storage, which
  337. should only be called by user code.
  338. """
  339. if training_result.checkpoint:
  340. # NOTE: This populates `train.get_checkpoint`
  341. self.loaded_checkpoint = training_result.checkpoint
  342. # Add result to a thread-safe queue.
  343. self.result_queue.put(training_result, block=True)
  344. # Acquire lock to stop the training thread until main thread
  345. # triggers resume.
  346. self.continue_lock.acquire()
  347. # If the trial should be terminated, exit gracefully.
  348. # NOTE: This is only really useful if `synchronous_result_reporting=True`.
  349. # Otherwise, the lock is immediately released on reporting, and this
  350. # check is skipped before the main thread decides to set the stop event.
  351. if self.stop_event.is_set():
  352. self.stop_event.clear()
  353. sys.exit(0)
  354. def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
  355. # Special case: early fail for Torch tensors
  356. if "torch" in sys.modules:
  357. from ray.air._internal.torch_utils import contains_tensor
  358. if contains_tensor(metrics):
  359. raise ValueError(
  360. "Passing objects containg Torch tensors as metrics "
  361. "is not supported as it will throw an exception on "
  362. "deserialization. You can either convert the tensors "
  363. "to Python objects or report a `train.Checkpoint` "
  364. "with `ray.train.report` to store your Torch objects."
  365. )
  366. if self.ignore_report:
  367. return
  368. metrics = self._auto_fill_metrics(metrics)
  369. persisted_checkpoint = None
  370. if checkpoint:
  371. self.storage._update_checkpoint_index(metrics)
  372. # Persist the reported checkpoint files to storage.
  373. persisted_checkpoint = self.storage.persist_current_checkpoint(checkpoint)
  374. metrics[CHECKPOINT_DIR_NAME] = self.storage.checkpoint_dir_name
  375. else:
  376. metrics[CHECKPOINT_DIR_NAME] = None
  377. # Persist trial artifacts to storage.
  378. force_artifact_sync = (
  379. persisted_checkpoint
  380. and self.storage.sync_config.sync_artifacts_on_checkpoint
  381. )
  382. self.storage.persist_artifacts(force=force_artifact_sync)
  383. # Set additional user metadata from the Trainer.
  384. if persisted_checkpoint and self.metadata:
  385. user_metadata = persisted_checkpoint.get_metadata()
  386. for k, v in self.metadata.items():
  387. # Update keys not already set by the user. This gives user-set keys
  388. # precedence over keys set at the Trainer level.
  389. if k not in user_metadata:
  390. user_metadata[k] = v
  391. persisted_checkpoint.set_metadata(user_metadata)
  392. result = _TrainingResult(checkpoint=persisted_checkpoint, metrics=metrics)
  393. self._report_training_result(result)
  394. @property
  395. def experiment_name(self) -> str:
  396. return self.trial_info.experiment_name
  397. @property
  398. def trial_name(self) -> str:
  399. return self.trial_info.name
  400. @property
  401. def trial_id(self) -> str:
  402. return self.trial_info.id
  403. @property
  404. def run_id(self) -> str:
  405. return self.trial_info.run_id
  406. @property
  407. def trial_resources(self) -> "PlacementGroupFactory":
  408. return self.trial_info.resources
  409. @property
  410. def trial_dir(self) -> str:
  411. return self.trial_info.logdir
  412. def get_dataset_shard(
  413. self,
  414. dataset_name: Optional[str] = None,
  415. ) -> Optional["DataIterator"]:
  416. shard = self.dataset_shard
  417. if shard is None:
  418. warnings.warn(
  419. "No dataset passed in. Returning None. Make sure to "
  420. "pass in a Dataset to Trainer.run to use this "
  421. "function."
  422. )
  423. elif isinstance(shard, dict):
  424. if not dataset_name:
  425. raise RuntimeError(
  426. "Multiple datasets were passed into ``Trainer``, "
  427. "but no ``dataset_name`` is passed into "
  428. "``get_dataset_shard``. Please specify which "
  429. "dataset shard to retrieve."
  430. )
  431. return shard.get(dataset_name)
  432. return shard
  433. # Cache of resource dicts that have been checked by the launch hook already.
  434. _checked_resources: Set[frozenset] = set()
  435. # Global _TrainSession object initialized by Ray Tune function trainables
  436. # and Ray Train V1 workers.
  437. _session: Optional[_TrainSession] = None
  438. def _tune_task_and_actor_launch_hook(
  439. fn, resources: Dict[str, float], strategy: Optional[SchedulingStrategyT]
  440. ):
  441. """Launch hook to catch nested tasks that can't fit in the placement group.
  442. This gives users a nice warning in case they launch a nested task in a Tune trial
  443. without reserving resources in the trial placement group to fit it.
  444. """
  445. # Already checked, skip for performance reasons.
  446. key = frozenset({(k, v) for k, v in resources.items() if v > 0})
  447. if not key or key in _checked_resources:
  448. return
  449. # No need to check if placement group is None.
  450. if (
  451. not isinstance(strategy, PlacementGroupSchedulingStrategy)
  452. or strategy.placement_group is None
  453. ):
  454. return
  455. # Check if the resource request is targeting the current placement group.
  456. cur_pg = ray.util.get_current_placement_group()
  457. if not cur_pg or strategy.placement_group.id != cur_pg.id:
  458. return
  459. _checked_resources.add(key)
  460. # Check if the request can be fulfilled by the current placement group.
  461. pgf = get_trial_resources()
  462. if pgf.head_bundle_is_empty:
  463. available_bundles = cur_pg.bundle_specs[0:]
  464. else:
  465. available_bundles = cur_pg.bundle_specs[1:]
  466. # Check if the request can be fulfilled by the current placement group.
  467. if _valid_resource_shape(resources, available_bundles):
  468. return
  469. if fn.class_name:
  470. submitted = "actor"
  471. name = fn.module_name + "." + fn.class_name + "." + fn.function_name
  472. else:
  473. submitted = "task"
  474. name = fn.module_name + "." + fn.function_name
  475. # Normalize the resource spec so it looks the same as the placement group bundle.
  476. main_resources = cur_pg.bundle_specs[0]
  477. resources = {k: float(v) for k, v in resources.items() if v > 0}
  478. raise RuntimeError(
  479. f"No trial resources are available for launching the {submitted} `{name}`. "
  480. "To resolve this, specify the Tune option:\n\n"
  481. "> resources_per_trial=tune.PlacementGroupFactory(\n"
  482. f"> [{main_resources}] + [{resources}] * N\n"
  483. "> )\n\n"
  484. f"Where `N` is the number of slots to reserve for trial {submitted}s. "
  485. "If you are using a Ray training library, there might be a utility function "
  486. "to set this automatically for you. For more information, refer to "
  487. "https://docs.ray.io/en/latest/tune/tutorials/tune-resources.html"
  488. )
  489. def init_session(*args, **kwargs) -> None:
  490. global _session
  491. if _session:
  492. raise ValueError(
  493. "A Train session is already in use. Do not call "
  494. "`init_session()` manually."
  495. )
  496. # Setup hooks for generating placement group resource deadlock warnings.
  497. from ray import actor, remote_function
  498. if "TUNE_DISABLE_RESOURCE_CHECKS" not in os.environ:
  499. actor._actor_launch_hook = _tune_task_and_actor_launch_hook
  500. remote_function._task_launch_hook = _tune_task_and_actor_launch_hook
  501. _session = _TrainSession(*args, **kwargs)
  502. def get_session() -> Optional[_TrainSession]:
  503. return _session
  504. def shutdown_session():
  505. """Shuts down the initialized session."""
  506. global _session
  507. _session = None
  508. def _raise_accelerator_session_misuse():
  509. """Raises a SessionMisuseError because a utility function was used improperly."""
  510. raise SessionMisuseError(
  511. "prepare/accelerate utility functions should be called inside a training "
  512. "function executed by `Trainer.run`"
  513. )
  514. def get_accelerator(default_accelerator_cls: Type[Accelerator]) -> Accelerator:
  515. """The accelerator for this training session.
  516. If an accelerator has not been set, then this method will construct an
  517. accelerator using the provided accelerator class.
  518. Raises:
  519. SessionMisuseError: if the session is uninitialized.
  520. """
  521. session = get_session()
  522. if session is None:
  523. _raise_accelerator_session_misuse()
  524. if session.accelerator is None:
  525. session.accelerator = default_accelerator_cls()
  526. return session.accelerator
  527. def set_accelerator(accelerator: Accelerator) -> None:
  528. """Sets the accelerator for this training session.
  529. Args:
  530. accelerator: The accelerator to use for training.
  531. Raises:
  532. SessionMisuseError: if the session is unitialized.
  533. RuntimeError: if the accelerator has already been set.
  534. """
  535. session = get_session()
  536. if session is None:
  537. _raise_accelerator_session_misuse()
  538. if session.accelerator is not None:
  539. raise RuntimeError("Cannot change accelerator once set.")
  540. session.accelerator = accelerator
  541. def _warn_session_misuse(default_value: Any = None):
  542. """Warns if fn is being used outside of session and returns ``default_value``."""
  543. def inner(fn: Callable):
  544. fn_name = fn.__name__
  545. @functools.wraps(fn)
  546. def wrapper(*args, **kwargs):
  547. session = get_session()
  548. if not session:
  549. if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"):
  550. warnings.warn(
  551. f"`{fn_name}` is meant to only be "
  552. "called inside a function that is executed by a Tuner"
  553. f" or Trainer. Returning `{default_value}`."
  554. )
  555. return default_value
  556. return fn(*args, **kwargs)
  557. return wrapper
  558. return inner
  559. @PublicAPI(stability="stable")
  560. @_warn_session_misuse()
  561. def report(
  562. metrics: Dict,
  563. *,
  564. checkpoint: Optional[Checkpoint] = None,
  565. checkpoint_dir_name: Optional[str] = None,
  566. ) -> None:
  567. """Report metrics and optionally save a checkpoint.
  568. If a checkpoint is provided, it will be
  569. :ref:`persisted to storage <persistent-storage-guide>`.
  570. If this is called in multiple distributed training workers:
  571. - Only the metrics reported by the rank 0 worker will be tracked by Ray Train.
  572. See :ref:`the metrics logging guide <train-monitoring-and-logging>`.
  573. - A checkpoint will be registered as long as one or more workers reports
  574. checkpoint that is not None.
  575. See the :ref:`checkpointing guide <train-dl-saving-checkpoints>`.
  576. - Checkpoints from multiple workers will be merged into one directory
  577. in persistent storage.
  578. See :ref:`the distributed checkpointing guide <train-distributed-checkpointing>`.
  579. .. note::
  580. Each invocation of this method will automatically increment the underlying
  581. ``training_iteration`` number. The physical meaning of this "iteration" is
  582. defined by user depending on how often they call ``report``.
  583. It does not necessarily map to one epoch.
  584. .. warning::
  585. All workers must call `ray.train.report` the same number of times
  586. so that Ray Train can properly synchronize the training state across
  587. workers. Otherwise, your training will hang.
  588. .. warning::
  589. This method does NOT act as a barrier for distributed training workers.
  590. Workers will upload their checkpoint, then continue training immediately.
  591. If you need to synchronize workers, you can use a framework-native barrier
  592. such as `torch.distributed.barrier()`.
  593. Example:
  594. .. testcode::
  595. import tempfile
  596. from ray import train
  597. from ray.train import Checkpoint
  598. from ray.train.torch import TorchTrainer
  599. def train_func(config):
  600. start_epoch = 0
  601. checkpoint = train.get_checkpoint()
  602. if checkpoint:
  603. with checkpoint.as_directory() as checkpoint_dir:
  604. # Load back training state
  605. ...
  606. for epoch in range(start_epoch, config.get("num_epochs", 10)):
  607. # Do training...
  608. metrics = {"loss": ...}
  609. with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
  610. # Save the checkpoint...
  611. # torch.save(...)
  612. checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
  613. # Example: Only the rank 0 worker uploads the checkpoint.
  614. if ray.train.get_context().get_world_rank() == 0:
  615. train.report(metrics, checkpoint=checkpoint)
  616. else:
  617. train.report(metrics, checkpoint=None)
  618. trainer = TorchTrainer(
  619. train_func, scaling_config=train.ScalingConfig(num_workers=2)
  620. )
  621. Args:
  622. metrics: The metrics you want to report.
  623. checkpoint: The optional checkpoint you want to report.
  624. """
  625. if checkpoint_dir_name is not None:
  626. logger.warning(
  627. "`checkpoint_dir_name` is only supported in the new Ray Train "
  628. "implementation, which can be enabled with `RAY_TRAIN_V2_ENABLED=1`. "
  629. "This argument will be ignored."
  630. )
  631. # If we are running in a Tune function, switch to `ray.tune.report`.
  632. from ray.tune.trainable.trainable_fn_utils import _in_tune_session
  633. if _in_tune_session():
  634. import ray.tune
  635. if _v2_migration_warnings_enabled():
  636. _log_deprecation_warning(
  637. "`ray.train.report` should be switched to "
  638. "`ray.tune.report` when running in a function "
  639. "passed to Ray Tune. This will be an error in the future. "
  640. "See this issue for more context: "
  641. "https://github.com/ray-project/ray/issues/49454"
  642. )
  643. return ray.tune.report(metrics, checkpoint=checkpoint)
  644. get_session().report(metrics, checkpoint=checkpoint)
  645. @PublicAPI(stability="stable")
  646. @_warn_session_misuse()
  647. def get_checkpoint() -> Optional[Checkpoint]:
  648. """Access the latest reported checkpoint to resume from if one exists.
  649. Example:
  650. .. testcode::
  651. import tempfile
  652. from ray import train
  653. from ray.train import Checkpoint
  654. from ray.train.torch import TorchTrainer
  655. def train_func(config):
  656. start_epoch = 0
  657. checkpoint = train.get_checkpoint()
  658. if checkpoint:
  659. with checkpoint.as_directory() as checkpoint_dir:
  660. # Load back training state
  661. ...
  662. for epoch in range(start_epoch, config.get("num_epochs", 10)):
  663. # Do training...
  664. metrics = {"loss": ...}
  665. with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
  666. # Save the checkpoint...
  667. checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
  668. train.report(metrics, checkpoint=checkpoint)
  669. trainer = TorchTrainer(
  670. train_func, scaling_config=train.ScalingConfig(num_workers=2)
  671. )
  672. Returns:
  673. Checkpoint object if the session is currently being resumed.
  674. Otherwise, return None.
  675. """
  676. # If we are running in a Tune function, switch to `ray.tune.get_checkpoint`.
  677. from ray.tune.trainable.trainable_fn_utils import _in_tune_session
  678. if _in_tune_session():
  679. import ray.tune
  680. if _v2_migration_warnings_enabled():
  681. _log_deprecation_warning(
  682. "`ray.train.get_checkpoint` should be switched to "
  683. "`ray.tune.get_checkpoint` when running in a function "
  684. "passed to Ray Tune. This will be an error in the future. "
  685. "See this issue for more context: "
  686. "https://github.com/ray-project/ray/issues/49454"
  687. )
  688. return ray.tune.get_checkpoint()
  689. return get_session().loaded_checkpoint
  690. @PublicAPI(stability="beta")
  691. @_warn_session_misuse()
  692. def get_metadata() -> Dict[str, Any]:
  693. """User metadata dict passed to the Trainer constructor."""
  694. return get_session().metadata
  695. @PublicAPI(stability="beta")
  696. @_warn_session_misuse()
  697. def get_experiment_name() -> str:
  698. """Experiment name for the corresponding trial."""
  699. return get_session().experiment_name
  700. @PublicAPI(stability="beta")
  701. @_warn_session_misuse()
  702. def get_trial_name() -> str:
  703. """Trial name for the corresponding trial."""
  704. return get_session().trial_name
  705. @PublicAPI(stability="beta")
  706. @_warn_session_misuse()
  707. def get_trial_id() -> str:
  708. """Trial id for the corresponding trial."""
  709. return get_session().trial_id
  710. @PublicAPI(stability="alpha")
  711. @_warn_session_misuse()
  712. def get_run_id() -> str:
  713. """Unique Train Run id for the corresponding trial."""
  714. return get_session().run_id
  715. @PublicAPI(stability="beta")
  716. @_warn_session_misuse()
  717. def get_trial_resources() -> "PlacementGroupFactory":
  718. """Trial resources for the corresponding trial."""
  719. return get_session().trial_resources
  720. @PublicAPI(stability="beta")
  721. @_warn_session_misuse()
  722. def get_trial_dir() -> str:
  723. """Log directory corresponding to the trial directory for a Tune session.
  724. If calling from a Train session, this will give the trial directory of its parent
  725. Tune session.
  726. .. testcode::
  727. import ray.tune
  728. def train_func(config):
  729. print(ray.tune.get_context().get_trial_dir())
  730. tuner = ray.tune.Tuner(train_func)
  731. tuner.fit()
  732. .. testoutput::
  733. :options: +MOCK
  734. /Users/root/ray_results/train_func_2023-07-19_15-01-37/train_func_d620c_00000_0_2023-07-19_15-01-40
  735. """
  736. return get_session().trial_dir
  737. @PublicAPI(stability="beta")
  738. @_warn_session_misuse(default_value=1)
  739. def get_world_size() -> int:
  740. """Get the current world size (i.e. total number of workers) for this run.
  741. .. testcode::
  742. import ray
  743. from ray import train
  744. from ray.train import ScalingConfig
  745. from ray.train.tensorflow import TensorflowTrainer
  746. NUM_WORKERS = 2
  747. def train_loop_per_worker(config):
  748. assert train.get_context().get_world_size() == NUM_WORKERS
  749. train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
  750. trainer = TensorflowTrainer(
  751. train_loop_per_worker,
  752. scaling_config=ScalingConfig(num_workers=NUM_WORKERS),
  753. datasets={"train": train_dataset}
  754. )
  755. trainer.fit()
  756. .. testoutput::
  757. :hide:
  758. ...
  759. """
  760. session = get_session()
  761. if not hasattr(session, "world_size"):
  762. raise RuntimeError(
  763. "`get_world_size` can only be called for TrainSession! "
  764. "Make sure you only use that in `train_loop_per_worker` function"
  765. "that is passed into `DataParallelTrainer`."
  766. )
  767. return session.world_size
  768. @PublicAPI(stability="beta")
  769. @_warn_session_misuse(default_value=0)
  770. def get_world_rank() -> int:
  771. """Get the world rank of this worker.
  772. .. testcode::
  773. import ray
  774. from ray import train
  775. from ray.train import ScalingConfig
  776. from ray.train.tensorflow import TensorflowTrainer
  777. def train_loop_per_worker(config):
  778. if train.get_context().get_world_rank() == 0:
  779. print("Worker 0")
  780. train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
  781. trainer = TensorflowTrainer(
  782. train_loop_per_worker,
  783. scaling_config=ScalingConfig(num_workers=2),
  784. datasets={"train": train_dataset}
  785. )
  786. trainer.fit()
  787. .. testoutput::
  788. :hide:
  789. ...
  790. """
  791. session = get_session()
  792. if not hasattr(session, "world_rank"):
  793. raise RuntimeError(
  794. "`get_world_rank` can only be called for TrainSession! "
  795. "Make sure you only use that in `train_loop_per_worker` function"
  796. "that is passed into `DataParallelTrainer`."
  797. )
  798. return session.world_rank
  799. @PublicAPI(stability="beta")
  800. @_warn_session_misuse(default_value=0)
  801. def get_local_rank() -> int:
  802. """Get the local rank of this worker (rank of the worker on its node).
  803. .. testcode::
  804. import torch
  805. import ray
  806. from ray import train
  807. from ray.train import ScalingConfig
  808. from ray.train.torch import TorchTrainer
  809. def train_loop_per_worker(config):
  810. if torch.cuda.is_available():
  811. torch.cuda.set_device(train.get_context().get_local_rank())
  812. ...
  813. train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
  814. trainer = TorchTrainer(
  815. train_loop_per_worker,
  816. scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
  817. datasets={"train": train_dataset}
  818. )
  819. trainer.fit()
  820. .. testoutput::
  821. :hide:
  822. ...
  823. """
  824. session = get_session()
  825. if not hasattr(session, "local_rank"):
  826. raise RuntimeError(
  827. "`get_local_rank` can only be called for TrainSession! "
  828. "Make sure you only use that in `train_loop_per_worker` function"
  829. "that is passed into `DataParallelTrainer`."
  830. )
  831. return session.local_rank
  832. @PublicAPI(stability="beta")
  833. @_warn_session_misuse(default_value=0)
  834. def get_local_world_size() -> int:
  835. """Get the local world size of this node (i.e. number of workers on this node).
  836. Example:
  837. .. testcode::
  838. import ray
  839. from ray import train
  840. from ray.train import ScalingConfig
  841. from ray.train.torch import TorchTrainer
  842. def train_loop_per_worker():
  843. print(train.get_context().get_local_world_size())
  844. train_dataset = ray.data.from_items(
  845. [{"x": x, "y": x + 1} for x in range(32)])
  846. trainer = TorchTrainer(train_loop_per_worker,
  847. scaling_config=ScalingConfig(num_workers=1),
  848. datasets={"train": train_dataset})
  849. trainer.fit()
  850. .. testoutput::
  851. :hide:
  852. ...
  853. """
  854. session = get_session()
  855. if not hasattr(session, "local_world_size"):
  856. raise RuntimeError(
  857. "`get_local_world_size` can only be called for TrainSession! "
  858. "Make sure you only use that in `train_loop_per_worker` function"
  859. "that is passed into `DataParallelTrainer`."
  860. )
  861. return session.local_world_size
  862. @PublicAPI(stability="beta")
  863. @_warn_session_misuse(default_value=0)
  864. def get_node_rank() -> int:
  865. """Get the rank of this node.
  866. Example:
  867. .. testcode::
  868. import ray
  869. from ray import train
  870. from ray.train import ScalingConfig
  871. from ray.train.torch import TorchTrainer
  872. def train_loop_per_worker():
  873. print(train.get_context().get_node_rank())
  874. train_dataset = ray.data.from_items(
  875. [{"x": x, "y": x + 1} for x in range(32)])
  876. trainer = TorchTrainer(train_loop_per_worker,
  877. scaling_config=ScalingConfig(num_workers=1),
  878. datasets={"train": train_dataset})
  879. trainer.fit()
  880. .. testoutput::
  881. :hide:
  882. ...
  883. """
  884. session = get_session()
  885. if not hasattr(session, "node_rank"):
  886. raise RuntimeError(
  887. "`get_node_rank` can only be called for TrainSession! "
  888. "Make sure you only use that in `train_loop_per_worker` function"
  889. "that is passed into `DataParallelTrainer`."
  890. )
  891. return session.node_rank
  892. @PublicAPI(stability="stable")
  893. @_warn_session_misuse()
  894. def get_dataset_shard(
  895. dataset_name: Optional[str] = None,
  896. ) -> Optional["DataIterator"]:
  897. """Returns the :class:`ray.data.DataIterator` shard for this worker.
  898. Call :meth:`~ray.data.DataIterator.iter_torch_batches` or
  899. :meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the
  900. appropriate framework-specific data type.
  901. .. testcode::
  902. import ray
  903. from ray import train
  904. from ray.train import ScalingConfig
  905. from ray.train.torch import TorchTrainer
  906. def train_loop_per_worker(config):
  907. ...
  908. for epoch in range(2):
  909. # Trainer will automatically handle sharding.
  910. data_shard = train.get_dataset_shard("train")
  911. for batch in data_shard.iter_torch_batches():
  912. ...
  913. train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
  914. trainer = TorchTrainer(
  915. train_loop_per_worker,
  916. scaling_config=ScalingConfig(num_workers=2),
  917. datasets={"train": train_dataset}
  918. )
  919. trainer.fit()
  920. .. testoutput::
  921. :hide:
  922. ...
  923. Args:
  924. dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then
  925. specifies which dataset shard to return.
  926. Returns:
  927. The ``DataIterator`` shard to use for this worker.
  928. If no dataset is passed into Trainer, then return None.
  929. """
  930. session = get_session()
  931. if not hasattr(session, "get_dataset_shard"):
  932. raise RuntimeError(
  933. "`get_dataset_shard` can only be called for TrainSession! "
  934. "Make sure you only use that in `train_loop_per_worker` function"
  935. "that is passed into `DataParallelTrainer`."
  936. )
  937. return session.get_dataset_shard(dataset_name)
  938. @DeveloperAPI
  939. @_warn_session_misuse()
  940. def get_storage() -> StorageContext:
  941. """Returns the :class:`~ray.train._internal.storage.StorageContext` storage
  942. context which gives advanced access to the filesystem and paths
  943. configured through `RunConfig`.
  944. NOTE: This is a developer API, and the `StorageContext` interface may change
  945. without notice between minor versions.
  946. """
  947. return get_session().storage
  948. def _in_ray_train_worker() -> bool:
  949. """Check if the current process is a Ray Train V1 worker."""
  950. return bool(get_session()) and get_session().world_rank is not None