config.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695
  1. import logging
  2. import os
  3. import warnings
  4. from collections import Counter, defaultdict
  5. from dataclasses import _MISSING_TYPE, dataclass, fields
  6. from pathlib import Path
  7. from typing import (
  8. TYPE_CHECKING,
  9. Any,
  10. Callable,
  11. Dict,
  12. List,
  13. Mapping,
  14. Optional,
  15. Tuple,
  16. Union,
  17. )
  18. import pyarrow.fs
  19. import ray
  20. from ray._common.utils import RESOURCE_CONSTRAINT_PREFIX
  21. from ray._private.thirdparty.tabulate.tabulate import tabulate
  22. from ray.util.annotations import PublicAPI, RayDeprecationWarning
  23. from ray.widgets import Template, make_table_html_repr
  24. if TYPE_CHECKING:
  25. import ray.tune.progress_reporter
  26. from ray.tune.callback import Callback
  27. from ray.tune.execution.placement_groups import PlacementGroupFactory
  28. from ray.tune.experimental.output import AirVerbosity
  29. from ray.tune.search.sample import Domain
  30. from ray.tune.stopper import Stopper
  31. from ray.tune.utils.log import Verbosity
  32. # Dict[str, List] is to support `tune.grid_search`:
  33. # TODO(sumanthratna/matt): Upstream this to Tune.
  34. SampleRange = Union["Domain", Dict[str, List]]
  35. MAX = "max"
  36. MIN = "min"
  37. _DEPRECATED_VALUE = "DEPRECATED"
  38. logger = logging.getLogger(__name__)
  39. def _repr_dataclass(obj, *, default_values: Optional[Dict[str, Any]] = None) -> str:
  40. """A utility function to elegantly represent dataclasses.
  41. In contrast to the default dataclass `__repr__`, which shows all parameters, this
  42. function only shows parameters with non-default values.
  43. Args:
  44. obj: The dataclass to represent.
  45. default_values: An optional dictionary that maps field names to default values.
  46. Use this parameter to specify default values that are generated dynamically
  47. (e.g., in `__post_init__` or by a `default_factory`). If a default value
  48. isn't specified in `default_values`, then the default value is inferred from
  49. the `dataclass`.
  50. Returns:
  51. A representation of the dataclass.
  52. """
  53. if default_values is None:
  54. default_values = {}
  55. non_default_values = {} # Maps field name to value.
  56. def equals(value, default_value):
  57. # We need to special case None because of a bug in pyarrow:
  58. # https://github.com/apache/arrow/issues/38535
  59. if value is None and default_value is None:
  60. return True
  61. if value is None or default_value is None:
  62. return False
  63. return value == default_value
  64. for field in fields(obj):
  65. value = getattr(obj, field.name)
  66. default_value = default_values.get(field.name, field.default)
  67. is_required = isinstance(field.default, _MISSING_TYPE)
  68. if is_required or not equals(value, default_value):
  69. non_default_values[field.name] = value
  70. string = f"{obj.__class__.__name__}("
  71. string += ", ".join(
  72. f"{name}={value!r}" for name, value in non_default_values.items()
  73. )
  74. string += ")"
  75. return string
  76. @dataclass
  77. @PublicAPI(stability="stable")
  78. class ScalingConfig:
  79. """Configuration for scaling training.
  80. For more details, see :ref:`train_scaling_config`.
  81. Args:
  82. trainer_resources: Resources to allocate for the training coordinator.
  83. The training coordinator launches the worker group and executes
  84. the training function per worker, and this process does NOT require
  85. GPUs. The coordinator is always scheduled on the same node as the
  86. rank 0 worker, so one example use case is to set a minimum amount
  87. of resources (e.g. CPU memory) required by the rank 0 node.
  88. By default, this assigns 1 CPU to the training coordinator.
  89. num_workers: The number of workers (Ray actors) to launch.
  90. Each worker will reserve 1 CPU by default. The number of CPUs
  91. reserved by each worker can be overridden with the
  92. ``resources_per_worker`` argument.
  93. use_gpu: If True, training will be done on GPUs (1 per worker).
  94. Defaults to False. The number of GPUs reserved by each
  95. worker can be overridden with the ``resources_per_worker``
  96. argument.
  97. resources_per_worker: If specified, the resources
  98. defined in this Dict is reserved for each worker.
  99. Define the ``"CPU"`` key (case-sensitive) to
  100. override the number of CPUs used by each worker.
  101. This can also be used to request :ref:`custom resources <custom-resources>`.
  102. placement_strategy: The placement strategy to use for the
  103. placement group of the Ray actors. See :ref:`Placement Group
  104. Strategies <pgroup-strategy>` for the possible options.
  105. accelerator_type: [Experimental] If specified, Ray Train will launch the
  106. training coordinator and workers on the nodes with the specified type
  107. of accelerators.
  108. See :ref:`the available accelerator types <accelerator_types>`.
  109. Ensure that your cluster has instances with the specified accelerator type
  110. or is able to autoscale to fulfill the request.
  111. Example:
  112. .. code-block:: python
  113. from ray.train import ScalingConfig
  114. scaling_config = ScalingConfig(
  115. # Number of distributed workers.
  116. num_workers=2,
  117. # Turn on/off GPU.
  118. use_gpu=True,
  119. # Assign extra CPU/GPU/custom resources per worker.
  120. resources_per_worker={"GPU": 1, "CPU": 1, "memory": 1e9, "custom": 1.0},
  121. # Try to schedule workers on different nodes.
  122. placement_strategy="SPREAD",
  123. )
  124. """
  125. trainer_resources: Optional[Union[Dict, SampleRange]] = None
  126. num_workers: Union[int, SampleRange] = 1
  127. use_gpu: Union[bool, SampleRange] = False
  128. resources_per_worker: Optional[Union[Dict, SampleRange]] = None
  129. placement_strategy: Union[str, SampleRange] = "PACK"
  130. accelerator_type: Optional[str] = None
  131. def __post_init__(self):
  132. if self.resources_per_worker:
  133. if not self.use_gpu and self.num_gpus_per_worker > 0:
  134. raise ValueError(
  135. "`use_gpu` is False but `GPU` was found in "
  136. "`resources_per_worker`. Either set `use_gpu` to True or "
  137. "remove `GPU` from `resources_per_worker."
  138. )
  139. if self.use_gpu and self.num_gpus_per_worker == 0:
  140. raise ValueError(
  141. "`use_gpu` is True but `GPU` is set to 0 in "
  142. "`resources_per_worker`. Either set `use_gpu` to False or "
  143. "request a positive number of `GPU` in "
  144. "`resources_per_worker."
  145. )
  146. def __repr__(self):
  147. return _repr_dataclass(self)
  148. def _repr_html_(self) -> str:
  149. return make_table_html_repr(obj=self, title=type(self).__name__)
  150. def __eq__(self, o: "ScalingConfig") -> bool:
  151. if not isinstance(o, type(self)):
  152. return False
  153. return self.as_placement_group_factory() == o.as_placement_group_factory()
  154. @property
  155. def _resources_per_worker_not_none(self):
  156. if self.resources_per_worker is None:
  157. if self.use_gpu:
  158. # Note that we don't request any CPUs, which avoids possible
  159. # scheduling contention. Generally nodes have many more CPUs than
  160. # GPUs, so not requesting a CPU does not lead to oversubscription.
  161. resources_per_worker = {"GPU": 1}
  162. else:
  163. resources_per_worker = {"CPU": 1}
  164. else:
  165. resources_per_worker = {
  166. k: v for k, v in self.resources_per_worker.items() if v != 0
  167. }
  168. if self.use_gpu:
  169. resources_per_worker.setdefault("GPU", 1)
  170. if self.accelerator_type:
  171. accelerator = f"{RESOURCE_CONSTRAINT_PREFIX}{self.accelerator_type}"
  172. resources_per_worker.setdefault(accelerator, 0.001)
  173. return resources_per_worker
  174. @property
  175. def _trainer_resources_not_none(self):
  176. if self.trainer_resources is None:
  177. if self.num_workers:
  178. # For Google Colab, don't allocate resources to the base Trainer.
  179. # Colab only has 2 CPUs, and because of this resource scarcity,
  180. # we have to be careful on where we allocate resources. Since Colab
  181. # is not distributed, the concern about many parallel Ray Tune trials
  182. # leading to all Trainers being scheduled on the head node if we set
  183. # `trainer_resources` to 0 is no longer applicable.
  184. try:
  185. import google.colab # noqa: F401
  186. trainer_num_cpus = 0
  187. except ImportError:
  188. trainer_num_cpus = 1
  189. else:
  190. # If there are no additional workers, then always reserve 1 CPU for
  191. # the Trainer.
  192. trainer_num_cpus = 1
  193. trainer_resources = {"CPU": trainer_num_cpus}
  194. else:
  195. trainer_resources = {
  196. k: v for k, v in self.trainer_resources.items() if v != 0
  197. }
  198. return trainer_resources
  199. @property
  200. def total_resources(self):
  201. """Map of total resources required for the trainer."""
  202. total_resource_map = defaultdict(float, self._trainer_resources_not_none)
  203. for k, value in self._resources_per_worker_not_none.items():
  204. total_resource_map[k] += value * self.num_workers
  205. return dict(total_resource_map)
  206. @property
  207. def num_cpus_per_worker(self):
  208. """The number of CPUs to set per worker."""
  209. return self._resources_per_worker_not_none.get("CPU", 0)
  210. @property
  211. def num_gpus_per_worker(self):
  212. """The number of GPUs to set per worker."""
  213. return self._resources_per_worker_not_none.get("GPU", 0)
  214. @property
  215. def additional_resources_per_worker(self):
  216. """Resources per worker, not including CPU or GPU resources."""
  217. return {
  218. k: v
  219. for k, v in self._resources_per_worker_not_none.items()
  220. if k not in ["CPU", "GPU"]
  221. }
  222. def as_placement_group_factory(self) -> "PlacementGroupFactory":
  223. """Returns a PlacementGroupFactory to specify resources for Tune."""
  224. from ray.tune.execution.placement_groups import PlacementGroupFactory
  225. trainer_bundle = self._trainer_resources_not_none
  226. worker_bundle = self._resources_per_worker_not_none
  227. # Colocate Trainer and rank0 worker by merging their bundles
  228. # Note: This empty bundle is required so that the Tune actor manager schedules
  229. # the Trainable onto the combined bundle while taking none of its resources,
  230. # rather than a non-empty head bundle.
  231. combined_bundle = dict(Counter(trainer_bundle) + Counter(worker_bundle))
  232. bundles = [{}, combined_bundle] + [worker_bundle] * (self.num_workers - 1)
  233. return PlacementGroupFactory(bundles, strategy=self.placement_strategy)
  234. @classmethod
  235. def from_placement_group_factory(
  236. cls, pgf: "PlacementGroupFactory"
  237. ) -> "ScalingConfig":
  238. """Create a ScalingConfig from a Tune's PlacementGroupFactory
  239. Note that this is only needed for ResourceChangingScheduler, which
  240. modifies a trial's PlacementGroupFactory but doesn't propagate
  241. the changes to ScalingConfig. TrainTrainable needs to reconstruct
  242. a ScalingConfig from on the trial's PlacementGroupFactory.
  243. """
  244. # pgf.bundles = [{trainer + worker}, {worker}, ..., {worker}]
  245. num_workers = len(pgf.bundles)
  246. combined_resources = pgf.bundles[0]
  247. resources_per_worker = pgf.bundles[-1]
  248. use_gpu = bool(resources_per_worker.get("GPU", False))
  249. placement_strategy = pgf.strategy
  250. # In `as_placement_group_factory`, we merged the trainer resource into the
  251. # first worker resources bundle. We need to calculate the resources diff to
  252. # get the trainer resources.
  253. # Note: If there's only one worker, we won't be able to calculate the diff.
  254. # We'll have empty trainer bundle and assign all resources to the worker.
  255. trainer_resources = dict(
  256. Counter(combined_resources) - Counter(resources_per_worker)
  257. )
  258. return ScalingConfig(
  259. trainer_resources=trainer_resources,
  260. num_workers=num_workers,
  261. use_gpu=use_gpu,
  262. resources_per_worker=resources_per_worker,
  263. placement_strategy=placement_strategy,
  264. )
  265. @dataclass
  266. @PublicAPI(stability="stable")
  267. class FailureConfig:
  268. """Configuration related to failure handling of each training/tuning run.
  269. Args:
  270. max_failures: Tries to recover a run at least this many times.
  271. Will recover from the latest checkpoint if present.
  272. Setting to -1 will lead to infinite recovery retries.
  273. Setting to 0 will disable retries. Defaults to 0.
  274. fail_fast: Whether to fail upon the first error.
  275. If fail_fast='raise' provided, the original error during training will be
  276. immediately raised. fail_fast='raise' can easily leak resources and
  277. should be used with caution.
  278. """
  279. max_failures: int = 0
  280. fail_fast: Union[bool, str] = False
  281. def __post_init__(self):
  282. # Same check as in TuneController
  283. if not (isinstance(self.fail_fast, bool) or self.fail_fast.upper() == "RAISE"):
  284. raise ValueError(
  285. "fail_fast must be one of {bool, 'raise'}. " f"Got {self.fail_fast}."
  286. )
  287. # Same check as in tune.run
  288. if self.fail_fast and self.max_failures != 0:
  289. raise ValueError(
  290. f"max_failures must be 0 if fail_fast={repr(self.fail_fast)}."
  291. )
  292. def __repr__(self):
  293. return _repr_dataclass(self)
  294. def _repr_html_(self):
  295. return Template("scrollableTable.html.j2").render(
  296. table=tabulate(
  297. {
  298. "Setting": ["Max failures", "Fail fast"],
  299. "Value": [self.max_failures, self.fail_fast],
  300. },
  301. tablefmt="html",
  302. showindex=False,
  303. headers="keys",
  304. ),
  305. max_height="none",
  306. )
  307. @dataclass
  308. @PublicAPI(stability="stable")
  309. class CheckpointConfig:
  310. """Configurable parameters for defining the checkpointing strategy.
  311. Default behavior is to persist all checkpoints to disk. If
  312. ``num_to_keep`` is set, the default retention policy is to keep the
  313. checkpoints with maximum timestamp, i.e. the most recent checkpoints.
  314. Args:
  315. num_to_keep: The number of checkpoints to keep
  316. on disk for this run. If a checkpoint is persisted to disk after
  317. there are already this many checkpoints, then an existing
  318. checkpoint will be deleted. If this is ``None`` then checkpoints
  319. will not be deleted. Must be >= 1.
  320. checkpoint_score_attribute: The attribute that will be used to
  321. score checkpoints to determine which checkpoints should be kept
  322. on disk when there are greater than ``num_to_keep`` checkpoints.
  323. This attribute must be a key from the checkpoint
  324. dictionary which has a numerical value. Per default, the last
  325. checkpoints will be kept.
  326. checkpoint_score_order: Either "max" or "min".
  327. If "max", then checkpoints with highest values of
  328. ``checkpoint_score_attribute`` will be kept.
  329. If "min", then checkpoints with lowest values of
  330. ``checkpoint_score_attribute`` will be kept.
  331. checkpoint_frequency: Number of iterations between checkpoints. If 0
  332. this will disable checkpointing.
  333. Please note that most trainers will still save one checkpoint at
  334. the end of training.
  335. This attribute is only supported
  336. by trainers that don't take in custom training loops.
  337. checkpoint_at_end: If True, will save a checkpoint at the end of training.
  338. This attribute is only supported by trainers that don't take in
  339. custom training loops. Defaults to True for trainers that support it
  340. and False for generic function trainables.
  341. _checkpoint_keep_all_ranks: This experimental config is deprecated.
  342. This behavior is now controlled by reporting `checkpoint=None`
  343. in the workers that shouldn't persist a checkpoint.
  344. For example, if you only want the rank 0 worker to persist a checkpoint
  345. (e.g., in standard data parallel training), then you should save and
  346. report a checkpoint if `ray.train.get_context().get_world_rank() == 0`
  347. and `None` otherwise.
  348. _checkpoint_upload_from_workers: This experimental config is deprecated.
  349. Uploading checkpoint directly from the worker is now the default behavior.
  350. """
  351. num_to_keep: Optional[int] = None
  352. checkpoint_score_attribute: Optional[str] = None
  353. checkpoint_score_order: Optional[str] = MAX
  354. checkpoint_frequency: Optional[int] = 0
  355. checkpoint_at_end: Optional[bool] = None
  356. _checkpoint_keep_all_ranks: Optional[bool] = _DEPRECATED_VALUE
  357. _checkpoint_upload_from_workers: Optional[bool] = _DEPRECATED_VALUE
  358. def __post_init__(self):
  359. if self._checkpoint_keep_all_ranks != _DEPRECATED_VALUE:
  360. raise DeprecationWarning(
  361. "The experimental `_checkpoint_keep_all_ranks` config is deprecated. "
  362. "This behavior is now controlled by reporting `checkpoint=None` "
  363. "in the workers that shouldn't persist a checkpoint. "
  364. "For example, if you only want the rank 0 worker to persist a "
  365. "checkpoint (e.g., in standard data parallel training), "
  366. "then you should save and report a checkpoint if "
  367. "`ray.train.get_context().get_world_rank() == 0` "
  368. "and `None` otherwise."
  369. )
  370. if self._checkpoint_upload_from_workers != _DEPRECATED_VALUE:
  371. raise DeprecationWarning(
  372. "The experimental `_checkpoint_upload_from_workers` config is "
  373. "deprecated. Uploading checkpoint directly from the worker is "
  374. "now the default behavior."
  375. )
  376. if self.num_to_keep is not None and self.num_to_keep <= 0:
  377. raise ValueError(
  378. f"Received invalid num_to_keep: "
  379. f"{self.num_to_keep}. "
  380. f"Must be None or an integer >= 1."
  381. )
  382. if self.checkpoint_score_order not in (MAX, MIN):
  383. raise ValueError(
  384. f"checkpoint_score_order must be either " f'"{MAX}" or "{MIN}".'
  385. )
  386. if self.checkpoint_frequency < 0:
  387. raise ValueError(
  388. f"checkpoint_frequency must be >=0, got {self.checkpoint_frequency}"
  389. )
  390. def __repr__(self):
  391. return _repr_dataclass(self)
  392. def _repr_html_(self) -> str:
  393. if self.num_to_keep is None:
  394. num_to_keep_repr = "All"
  395. else:
  396. num_to_keep_repr = self.num_to_keep
  397. if self.checkpoint_score_attribute is None:
  398. checkpoint_score_attribute_repr = "Most recent"
  399. else:
  400. checkpoint_score_attribute_repr = self.checkpoint_score_attribute
  401. if self.checkpoint_at_end is None:
  402. checkpoint_at_end_repr = ""
  403. else:
  404. checkpoint_at_end_repr = self.checkpoint_at_end
  405. return Template("scrollableTable.html.j2").render(
  406. table=tabulate(
  407. {
  408. "Setting": [
  409. "Number of checkpoints to keep",
  410. "Checkpoint score attribute",
  411. "Checkpoint score order",
  412. "Checkpoint frequency",
  413. "Checkpoint at end",
  414. ],
  415. "Value": [
  416. num_to_keep_repr,
  417. checkpoint_score_attribute_repr,
  418. self.checkpoint_score_order,
  419. self.checkpoint_frequency,
  420. checkpoint_at_end_repr,
  421. ],
  422. },
  423. tablefmt="html",
  424. showindex=False,
  425. headers="keys",
  426. ),
  427. max_height="none",
  428. )
  429. @property
  430. def _tune_legacy_checkpoint_score_attr(self) -> Optional[str]:
  431. """Same as ``checkpoint_score_attr`` in ``tune.run``.
  432. Only used for Legacy API compatibility.
  433. """
  434. if self.checkpoint_score_attribute is None:
  435. return self.checkpoint_score_attribute
  436. prefix = ""
  437. if self.checkpoint_score_order == MIN:
  438. prefix = "min-"
  439. return f"{prefix}{self.checkpoint_score_attribute}"
  440. @dataclass
  441. @PublicAPI(stability="stable")
  442. class RunConfig:
  443. """Runtime configuration for training and tuning runs.
  444. Upon resuming from a training or tuning run checkpoint,
  445. Ray Train/Tune will automatically apply the RunConfig from
  446. the previously checkpointed run.
  447. Args:
  448. name: Name of the trial or experiment. If not provided, will be deduced
  449. from the Trainable.
  450. storage_path: [Beta] Path where all results and checkpoints are persisted.
  451. Can be a local directory or a destination on cloud storage.
  452. For multi-node training/tuning runs, this must be set to a
  453. shared storage location (e.g., S3, NFS).
  454. This defaults to the local ``~/ray_results`` directory.
  455. storage_filesystem: [Beta] A custom filesystem to use for storage.
  456. If this is provided, `storage_path` should be a path with its
  457. prefix stripped (e.g., `s3://bucket/path` -> `bucket/path`).
  458. failure_config: Failure mode configuration.
  459. checkpoint_config: Checkpointing configuration.
  460. sync_config: Configuration object for syncing. See train.SyncConfig.
  461. verbose: 0, 1, or 2. Verbosity mode.
  462. 0 = silent, 1 = default, 2 = verbose. Defaults to 1.
  463. If the ``RAY_AIR_NEW_OUTPUT=1`` environment variable is set,
  464. uses the old verbosity settings:
  465. 0 = silent, 1 = only status updates, 2 = status and brief
  466. results, 3 = status and detailed results.
  467. stop: Stop conditions to consider. Refer to ray.tune.stopper.Stopper
  468. for more info. Stoppers should be serializable.
  469. callbacks: [DeveloperAPI] Callbacks to invoke.
  470. Refer to ray.tune.callback.Callback for more info.
  471. Callbacks should be serializable.
  472. Currently only stateless callbacks are supported for resumed runs.
  473. (any state of the callback will not be checkpointed by Tune
  474. and thus will not take effect in resumed runs).
  475. progress_reporter: [DeveloperAPI] Progress reporter for reporting
  476. intermediate experiment progress. Defaults to CLIReporter if
  477. running in command-line, or JupyterNotebookReporter if running in
  478. a Jupyter notebook.
  479. log_to_file: [DeveloperAPI] Log stdout and stderr to files in
  480. trial directories. If this is `False` (default), no files
  481. are written. If `true`, outputs are written to `trialdir/stdout`
  482. and `trialdir/stderr`, respectively. If this is a single string,
  483. this is interpreted as a file relative to the trialdir, to which
  484. both streams are written. If this is a Sequence (e.g. a Tuple),
  485. it has to have length 2 and the elements indicate the files to
  486. which stdout and stderr are written, respectively.
  487. """
  488. name: Optional[str] = None
  489. storage_path: Optional[str] = None
  490. storage_filesystem: Optional[pyarrow.fs.FileSystem] = None
  491. failure_config: Optional[FailureConfig] = None
  492. checkpoint_config: Optional[CheckpointConfig] = None
  493. sync_config: Optional["ray.train.SyncConfig"] = None
  494. verbose: Optional[Union[int, "AirVerbosity", "Verbosity"]] = None
  495. stop: Optional[Union[Mapping, "Stopper", Callable[[str, Mapping], bool]]] = None
  496. callbacks: Optional[List["Callback"]] = None
  497. progress_reporter: Optional["ray.tune.progress_reporter.ProgressReporter"] = None
  498. log_to_file: Union[bool, str, Tuple[str, str]] = False
  499. # Deprecated
  500. local_dir: Optional[str] = None
  501. def __post_init__(self):
  502. from ray.train import SyncConfig
  503. from ray.train.constants import DEFAULT_STORAGE_PATH
  504. from ray.tune.experimental.output import AirVerbosity, get_air_verbosity
  505. if self.local_dir is not None:
  506. raise DeprecationWarning(
  507. "The `RunConfig(local_dir)` argument is deprecated. "
  508. "You should set the `RunConfig(storage_path)` instead."
  509. "See the docs: https://docs.ray.io/en/latest/train/user-guides/"
  510. "persistent-storage.html#setting-the-local-staging-directory"
  511. )
  512. if self.storage_path is None:
  513. self.storage_path = DEFAULT_STORAGE_PATH
  514. # TODO(justinvyu): [Deprecated]
  515. ray_storage_uri: Optional[str] = os.environ.get("RAY_STORAGE")
  516. if ray_storage_uri is not None:
  517. logger.info(
  518. "Using configured Ray Storage URI as the `storage_path`: "
  519. f"{ray_storage_uri}"
  520. )
  521. warnings.warn(
  522. "The `RAY_STORAGE` environment variable is deprecated. "
  523. "Please use `RunConfig(storage_path)` instead.",
  524. RayDeprecationWarning,
  525. stacklevel=2,
  526. )
  527. self.storage_path = ray_storage_uri
  528. if not self.failure_config:
  529. self.failure_config = FailureConfig()
  530. if not self.sync_config:
  531. self.sync_config = SyncConfig()
  532. if not self.checkpoint_config:
  533. self.checkpoint_config = CheckpointConfig()
  534. # Save the original verbose value to check for deprecations
  535. self._verbose = self.verbose
  536. if self.verbose is None:
  537. # Default `verbose` value. For new output engine,
  538. # this is AirVerbosity.DEFAULT.
  539. # For old output engine, this is Verbosity.V3_TRIAL_DETAILS
  540. # Todo (krfricke): Currently uses number to pass test_configs::test_repr
  541. self.verbose = get_air_verbosity(AirVerbosity.DEFAULT) or 3
  542. if isinstance(self.storage_path, Path):
  543. self.storage_path = self.storage_path.as_posix()
  544. def __repr__(self):
  545. from ray.train import SyncConfig
  546. return _repr_dataclass(
  547. self,
  548. default_values={
  549. "failure_config": FailureConfig(),
  550. "sync_config": SyncConfig(),
  551. "checkpoint_config": CheckpointConfig(),
  552. },
  553. )
  554. def _repr_html_(self) -> str:
  555. reprs = []
  556. if self.failure_config is not None:
  557. reprs.append(
  558. Template("title_data_mini.html.j2").render(
  559. title="Failure Config", data=self.failure_config._repr_html_()
  560. )
  561. )
  562. if self.sync_config is not None:
  563. reprs.append(
  564. Template("title_data_mini.html.j2").render(
  565. title="Sync Config", data=self.sync_config._repr_html_()
  566. )
  567. )
  568. if self.checkpoint_config is not None:
  569. reprs.append(
  570. Template("title_data_mini.html.j2").render(
  571. title="Checkpoint Config", data=self.checkpoint_config._repr_html_()
  572. )
  573. )
  574. # Create a divider between each displayed repr
  575. subconfigs = [Template("divider.html.j2").render()] * (2 * len(reprs) - 1)
  576. subconfigs[::2] = reprs
  577. settings = Template("scrollableTable.html.j2").render(
  578. table=tabulate(
  579. {
  580. "Name": self.name,
  581. "Local results directory": self.local_dir,
  582. "Verbosity": self.verbose,
  583. "Log to file": self.log_to_file,
  584. }.items(),
  585. tablefmt="html",
  586. headers=["Setting", "Value"],
  587. showindex=False,
  588. ),
  589. max_height="300px",
  590. )
  591. return Template("title_data.html.j2").render(
  592. title="RunConfig",
  593. data=Template("run_config.html.j2").render(
  594. subconfigs=subconfigs,
  595. settings=settings,
  596. ),
  597. )