config.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. import logging
  2. from dataclasses import dataclass
  3. from functools import cached_property
  4. from pathlib import Path
  5. from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
  6. import pyarrow.fs
  7. from ray.air.config import (
  8. FailureConfig as FailureConfigV1,
  9. ScalingConfig as ScalingConfigV1,
  10. )
  11. from ray.runtime_env import RuntimeEnv
  12. from ray.train.v2._internal.constants import _DEPRECATED
  13. from ray.train.v2._internal.execution.storage import StorageContext
  14. from ray.train.v2._internal.migration_utils import (
  15. FAIL_FAST_DEPRECATION_MESSAGE,
  16. TRAINER_RESOURCES_DEPRECATION_MESSAGE,
  17. )
  18. from ray.train.v2._internal.util import date_str
  19. from ray.util.annotations import PublicAPI
  20. from ray.util.tpu import get_tpu_worker_resources
  21. if TYPE_CHECKING:
  22. from ray.train import UserCallback
  23. logger = logging.getLogger(__name__)
  24. @dataclass
  25. class ScalingConfig(ScalingConfigV1):
  26. """Configuration for scaling training.
  27. Args:
  28. num_workers: The number of workers (Ray actors) to launch.
  29. Each worker will reserve 1 CPU by default. The number of CPUs
  30. reserved by each worker can be overridden with the
  31. ``resources_per_worker`` argument. If the number of workers is 0,
  32. the training function will run in local mode, meaning the training
  33. function runs in the same process.
  34. use_gpu: If True, training will be done on GPUs (1 per worker).
  35. Defaults to False. The number of GPUs reserved by each
  36. worker can be overridden with the ``resources_per_worker``
  37. argument.
  38. resources_per_worker: If specified, the resources
  39. defined in this Dict is reserved for each worker.
  40. Define the ``"CPU"`` and ``"GPU"`` keys (case-sensitive) to
  41. override the number of CPU or GPUs used by each worker.
  42. placement_strategy: The placement strategy to use for the
  43. placement group of the Ray actors. See :ref:`Placement Group
  44. Strategies <pgroup-strategy>` for the possible options.
  45. label_selector: A list of label selectors for Ray Train worker placement.
  46. If a single label selector is provided, it will be applied to all Ray Train workers.
  47. If a list is provided, it must be the same length as the max number of Ray Train workers.
  48. accelerator_type: [Experimental] If specified, Ray Train will launch the
  49. training coordinator and workers on the nodes with the specified type
  50. of accelerators.
  51. See :ref:`the available accelerator types <accelerator_types>`.
  52. Ensure that your cluster has instances with the specified accelerator type
  53. or is able to autoscale to fulfill the request. This field is required
  54. when `use_tpu` is True and `num_workers` is greater than 1.
  55. use_tpu: [Experimental] If True, training will be done on TPUs (1 TPU VM
  56. per worker). Defaults to False. The number of TPUs reserved by each
  57. worker can be overridden with the ``resources_per_worker``
  58. argument. This arg enables SPMD execution of the training workload.
  59. topology: [Experimental] If specified, Ray Train will launch the training
  60. coordinator and workers on nodes with the specified topology. Topology is
  61. auto-detected for TPUs and added as Ray node labels. This arg enables
  62. SPMD execution of the training workload. This field is required
  63. when `use_tpu` is True and `num_workers` is greater than 1.
  64. """
  65. trainer_resources: Optional[dict] = None
  66. label_selector: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None
  67. # Accelerator specific fields.
  68. use_tpu: Union[bool] = False
  69. topology: Optional[str] = None
  70. def __post_init__(self):
  71. if self.trainer_resources is not None:
  72. raise DeprecationWarning(TRAINER_RESOURCES_DEPRECATION_MESSAGE)
  73. self._validate_tpu_config()
  74. if (
  75. isinstance(self.label_selector, list)
  76. and isinstance(self.num_workers, int)
  77. and len(self.label_selector) != self.num_workers
  78. ):
  79. raise ValueError(
  80. "If `label_selector` is a list, it must be the same length as `num_workers`."
  81. )
  82. if self.num_workers == 0:
  83. logger.info(
  84. "Running in local mode. The training function will run in the same process. "
  85. "If you are using it and running into issues please file a report at "
  86. "https://github.com/ray-project/ray/issues."
  87. )
  88. super().__post_init__()
  89. def _validate_tpu_config(self):
  90. """Validates configuration specifically for TPU usage."""
  91. if self.use_gpu and self.use_tpu:
  92. raise ValueError("Cannot specify both `use_gpu=True` and `use_tpu=True`.")
  93. if not self.use_tpu:
  94. if self.num_tpus_per_worker > 0:
  95. raise ValueError(
  96. "`use_tpu` is False but `TPU` was found in "
  97. "`resources_per_worker`. Either set `use_tpu` to True or "
  98. "remove `TPU` from `resources_per_worker."
  99. )
  100. # If not using TPU, we are done validating TPU-specific logic.
  101. return
  102. if self.num_tpus_per_worker == 0:
  103. raise ValueError(
  104. "`use_tpu` is True but `TPU` is set to 0 in "
  105. "`resources_per_worker`. Either set `use_tpu` to False or "
  106. "request a positive number of `TPU` in "
  107. "`resources_per_worker."
  108. )
  109. if self.num_workers > 1:
  110. if not self.topology:
  111. raise ValueError(
  112. "`topology` must be specified in ScalingConfig when `use_tpu=True` "
  113. " and `num_workers` > 1."
  114. )
  115. if not self.accelerator_type:
  116. raise ValueError(
  117. "`accelerator_type` must be specified in ScalingConfig when "
  118. "`use_tpu=True` and `num_workers` > 1."
  119. )
  120. if self.label_selector:
  121. raise ValueError(
  122. "Cannot set `label_selector` when `use_tpu=True` because "
  123. "Ray Train automatically reserves a TPU slice with a predefined label."
  124. )
  125. # Validate TPU resources when both topology and accelerator type are specified.
  126. if self.topology and self.accelerator_type:
  127. try:
  128. workers_per_slice, tpu_resources = get_tpu_worker_resources(
  129. topology=self.topology,
  130. accelerator_type=self.accelerator_type,
  131. resources_per_unit=self.resources_per_worker,
  132. num_slices=1,
  133. )
  134. except Exception as e:
  135. raise ValueError(
  136. f"Could not parse TPU topology details for "
  137. f"type={self.accelerator_type}, "
  138. f"topology={self.topology}. Error: {e}"
  139. )
  140. if workers_per_slice > 0 and self.num_workers % workers_per_slice != 0:
  141. raise ValueError(
  142. f"The configured `num_workers` ({self.num_workers}) must be a "
  143. f"multiple of {workers_per_slice} for the specified topology ({self.topology}). "
  144. "TPU workloads typically require symmetric resource distribution "
  145. "across all slices to function correctly."
  146. )
  147. if self.resources_per_worker is None:
  148. self.resources_per_worker = tpu_resources
  149. @property
  150. def _resources_per_worker_not_none(self):
  151. if self.resources_per_worker is None:
  152. if self.use_tpu:
  153. return {"TPU": 1}
  154. return super()._resources_per_worker_not_none
  155. @property
  156. def _trainer_resources_not_none(self):
  157. return {}
  158. @property
  159. def num_tpus_per_worker(self):
  160. """The number of TPUs to set per worker."""
  161. return self._resources_per_worker_not_none.get("TPU", 0)
  162. @dataclass
  163. @PublicAPI(stability="stable")
  164. class CheckpointConfig:
  165. """Configuration for checkpointing.
  166. Default behavior is to persist all checkpoints reported with
  167. :meth:`ray.train.report` to disk. If ``num_to_keep`` is set,
  168. the default retention policy is to keep the most recent checkpoints.
  169. Args:
  170. num_to_keep: The maximum number of checkpoints to keep.
  171. If you report more checkpoints than this, the oldest
  172. (or lowest-scoring, if ``checkpoint_score_attribute`` is set)
  173. checkpoint will be deleted.
  174. If this is ``None`` then all checkpoints will be kept. Must be >= 1.
  175. checkpoint_score_attribute: The attribute that will be used to
  176. score checkpoints to determine which checkpoints should be kept.
  177. This attribute must be a key from the metrics dictionary
  178. attached to the checkpoint. This attribute must have a numerical value.
  179. checkpoint_score_order: Either "max" or "min".
  180. If "max"/"min", then checkpoints with highest/lowest values of
  181. the ``checkpoint_score_attribute`` will be kept. Defaults to "max".
  182. checkpoint_frequency: [Deprecated]
  183. checkpoint_at_end: [Deprecated]
  184. """
  185. num_to_keep: Optional[int] = None
  186. checkpoint_score_attribute: Optional[str] = None
  187. checkpoint_score_order: Literal["max", "min"] = "max"
  188. checkpoint_frequency: Union[Optional[int], Literal[_DEPRECATED]] = _DEPRECATED
  189. checkpoint_at_end: Union[Optional[bool], Literal[_DEPRECATED]] = _DEPRECATED
  190. def __post_init__(self):
  191. if self.checkpoint_frequency != _DEPRECATED:
  192. raise DeprecationWarning(
  193. "`checkpoint_frequency` is deprecated since it does not "
  194. "apply to user-defined training functions. "
  195. "Please remove this argument from your CheckpointConfig."
  196. )
  197. if self.checkpoint_at_end != _DEPRECATED:
  198. raise DeprecationWarning(
  199. "`checkpoint_at_end` is deprecated since it does not "
  200. "apply to user-defined training functions. "
  201. "Please remove this argument from your CheckpointConfig."
  202. )
  203. if self.num_to_keep is not None and self.num_to_keep <= 0:
  204. raise ValueError(
  205. f"Received invalid num_to_keep: {self.num_to_keep}. "
  206. "Must be None or an integer >= 1."
  207. )
  208. if self.checkpoint_score_order not in ("max", "min"):
  209. raise ValueError(
  210. f"Received invalid checkpoint_score_order: {self.checkpoint_score_order}. "
  211. "Must be 'max' or 'min'."
  212. )
  213. @dataclass
  214. class FailureConfig(FailureConfigV1):
  215. """Configuration related to failure handling of each training run.
  216. Args:
  217. max_failures: Tries to recover a run from training worker errors at least this many times.
  218. Will recover from the latest checkpoint if present.
  219. Setting to -1 will lead to infinite recovery retries.
  220. Setting to 0 will disable retries. Defaults to 0.
  221. controller_failure_limit: [DeveloperAPI] The maximum number of controller failures to tolerate.
  222. Setting to -1 will lead to infinite controller retries.
  223. Setting to 0 will disable controller retries. Defaults to -1.
  224. """
  225. fail_fast: Union[bool, str] = _DEPRECATED
  226. controller_failure_limit: int = -1
  227. def __post_init__(self):
  228. if self.fail_fast != _DEPRECATED:
  229. raise DeprecationWarning(FAIL_FAST_DEPRECATION_MESSAGE)
  230. @dataclass
  231. @PublicAPI(stability="stable")
  232. class RunConfig:
  233. """Runtime configuration for training runs.
  234. Args:
  235. name: Name of the trial or experiment. If not provided, will be deduced
  236. from the Trainable.
  237. storage_path: Path where all results and checkpoints are persisted.
  238. Can be a local directory or a destination on cloud storage.
  239. For multi-node training/tuning runs, this must be set to a
  240. shared storage location (e.g., S3, NFS).
  241. This defaults to the local ``~/ray_results`` directory.
  242. storage_filesystem: A custom filesystem to use for storage.
  243. If this is provided, `storage_path` should be a path with its
  244. prefix stripped (e.g., `s3://bucket/path` -> `bucket/path`).
  245. failure_config: Failure mode configuration.
  246. checkpoint_config: Checkpointing configuration.
  247. callbacks: [DeveloperAPI] A list of callbacks that the Ray Train controller
  248. will invoke during training.
  249. worker_runtime_env: [DeveloperAPI] Runtime environment configuration
  250. for all Ray Train worker actors.
  251. """
  252. name: Optional[str] = None
  253. storage_path: Optional[str] = None
  254. storage_filesystem: Optional[pyarrow.fs.FileSystem] = None
  255. failure_config: Optional[FailureConfig] = None
  256. checkpoint_config: Optional[CheckpointConfig] = None
  257. callbacks: Optional[List["UserCallback"]] = None
  258. worker_runtime_env: Optional[Union[dict, RuntimeEnv]] = None
  259. sync_config: str = _DEPRECATED
  260. verbose: str = _DEPRECATED
  261. stop: str = _DEPRECATED
  262. progress_reporter: str = _DEPRECATED
  263. log_to_file: str = _DEPRECATED
  264. def __post_init__(self):
  265. from ray.train.constants import DEFAULT_STORAGE_PATH
  266. if self.storage_path is None:
  267. self.storage_path = DEFAULT_STORAGE_PATH
  268. if not self.failure_config:
  269. self.failure_config = FailureConfig()
  270. if not self.checkpoint_config:
  271. self.checkpoint_config = CheckpointConfig()
  272. if isinstance(self.storage_path, Path):
  273. self.storage_path = self.storage_path.as_posix()
  274. run_config_deprecation_message = (
  275. "`RunConfig({})` is deprecated. This configuration was a "
  276. "Ray Tune API that did not support Ray Train usage well, "
  277. "so we are dropping support going forward. "
  278. "If you heavily rely on these configurations, "
  279. "you can run Ray Train as a single Ray Tune trial. "
  280. "See this issue for more context: "
  281. "https://github.com/ray-project/ray/issues/49454"
  282. )
  283. unsupported_params = [
  284. "sync_config",
  285. "verbose",
  286. "stop",
  287. "progress_reporter",
  288. "log_to_file",
  289. ]
  290. for param in unsupported_params:
  291. if getattr(self, param) != _DEPRECATED:
  292. raise DeprecationWarning(run_config_deprecation_message.format(param))
  293. if not self.name:
  294. self.name = f"ray_train_run-{date_str()}"
  295. self.callbacks = self.callbacks or []
  296. self.worker_runtime_env = self.worker_runtime_env or {}
  297. from ray.train.v2.api.callback import RayTrainCallback
  298. if not all(isinstance(cb, RayTrainCallback) for cb in self.callbacks):
  299. raise ValueError(
  300. "All callbacks must be instances of `ray.train.UserCallback`. "
  301. "Passing in a Ray Tune callback is no longer supported. "
  302. "See this issue for more context: "
  303. "https://github.com/ray-project/ray/issues/49454"
  304. )
  305. if not isinstance(self.checkpoint_config, CheckpointConfig):
  306. raise ValueError(
  307. f"Invalid `CheckpointConfig` type: {self.checkpoint_config.__class__}. "
  308. "Use `ray.train.CheckpointConfig` instead. "
  309. "See this issue for more context: "
  310. "https://github.com/ray-project/ray/issues/49454"
  311. )
  312. if not isinstance(self.failure_config, FailureConfig):
  313. raise ValueError(
  314. f"Invalid `FailureConfig` type: {self.failure_config.__class__}. "
  315. "Use `ray.train.FailureConfig` instead. "
  316. "See this issue for more context: "
  317. "https://github.com/ray-project/ray/issues/49454"
  318. )
  319. @cached_property
  320. def storage_context(self) -> StorageContext:
  321. return StorageContext(
  322. storage_path=self.storage_path,
  323. experiment_dir_name=self.name,
  324. storage_filesystem=self.storage_filesystem,
  325. )