base_trainer.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944
  1. import abc
  2. import copy
  3. import inspect
  4. import json
  5. import logging
  6. import os
  7. import warnings
  8. from functools import partial
  9. from pathlib import Path
  10. from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
  11. import pyarrow.fs
  12. import ray
  13. import ray.cloudpickle as pickle
  14. from ray._common.usage import usage_lib
  15. from ray._private.dict import deep_update
  16. from ray.air._internal import usage as air_usage
  17. from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
  18. from ray.air._internal.usage import AirEntrypoint
  19. from ray.air.config import RunConfig, ScalingConfig
  20. from ray.air.result import Result
  21. from ray.train import Checkpoint
  22. from ray.train._internal.session import get_session
  23. from ray.train._internal.storage import (
  24. StorageContext,
  25. _exists_at_fs_path,
  26. get_fs_and_path,
  27. )
  28. from ray.train.constants import (
  29. V2_MIGRATION_GUIDE_MESSAGE,
  30. _v2_migration_warnings_enabled,
  31. )
  32. from ray.train.context import _GET_METADATA_DEPRECATION_MESSAGE
  33. from ray.train.utils import _log_deprecation_warning
  34. from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI
  35. if TYPE_CHECKING:
  36. from ray.data import Dataset
  37. from ray.tune import Trainable
  38. _TRAINER_PKL = "trainer.pkl"
  39. # A type representing either a ray.data.Dataset or a function that returns a
  40. # ray.data.Dataset and accepts no arguments.
  41. GenDataset = Union["Dataset", Callable[[], "Dataset"]]
  42. logger = logging.getLogger(__name__)
  43. PREPROCESSOR_DEPRECATION_MESSAGE = (
  44. "The `preprocessor` argument to Trainers is deprecated as of Ray 2.7. "
  45. "Instead, use the Preprocessor `fit` and `transform` APIs directly on the Ray "
  46. "Dataset. For any state that needs to be saved to the trained checkpoint, pass it "
  47. "in using the `metadata` argument of the `Trainer`. "
  48. "For a full example, see "
  49. "https://docs.ray.io/en/master/train/user-guides/data-loading-preprocessing.html#preprocessing-structured-data " # noqa:E501
  50. )
  51. _TRAINER_RESTORE_DEPRECATION_WARNING = (
  52. "The `restore` and `can_restore` APIs are deprecated and "
  53. f"will be removed in a future release. {V2_MIGRATION_GUIDE_MESSAGE}"
  54. )
  55. _RESUME_FROM_CHECKPOINT_DEPRECATION_WARNING = (
  56. "`resume_from_checkpoint` is deprecated and will be removed in an upcoming "
  57. f"release. {V2_MIGRATION_GUIDE_MESSAGE}"
  58. )
  59. @PublicAPI(stability="beta")
  60. class TrainingFailedError(RuntimeError):
  61. """An error indicating that training has failed."""
  62. _RESTORE_MSG = (
  63. "The Ray Train run failed. Please inspect the previous error messages for a "
  64. "cause. After fixing the issue (assuming that the error is not caused by "
  65. "your own application logic, but rather an error such as OOM), you can restart "
  66. "the run from scratch or continue this run.\n"
  67. "To continue this run, you can use: "
  68. '`trainer = {trainer_cls_name}.restore("{path}")`.'
  69. )
  70. _FAILURE_CONFIG_MSG = (
  71. "To start a new run that will retry on training failures, set "
  72. "`train.RunConfig(failure_config=train.FailureConfig(max_failures))` "
  73. "in the Trainer's `run_config` with `max_failures > 0`, or `max_failures = -1` "
  74. "for unlimited retries."
  75. )
  76. def _train_coordinator_fn(
  77. config: dict, trainer_cls: Type["BaseTrainer"], metadata: dict
  78. ):
  79. """This is the function that defines the logic of the Ray Train coordinator.
  80. This is responsible for setting up a remote instance of the `trainer_cls`
  81. (a different instance than the one calling `trainer.fit` on the driver!)
  82. and running the training loop.
  83. """
  84. assert metadata is not None, metadata
  85. # Propagate user metadata from the Trainer constructor.
  86. get_session().metadata = metadata
  87. # config already contains merged values.
  88. # Instantiate new Trainer in Trainable.
  89. trainer = trainer_cls(**config)
  90. # Get the checkpoint from Tune and pass it to workers later on.
  91. checkpoint = ray.tune.get_checkpoint()
  92. if checkpoint:
  93. # Set `starting_checkpoint` for auto-recovery fault-tolerance
  94. # as well as manual restoration.
  95. trainer.starting_checkpoint = checkpoint
  96. # else: Train will restore from the user-provided
  97. # `resume_from_checkpoint` == `starting_checkpoint`.
  98. # Evaluate datasets if they are wrapped in a factory.
  99. trainer.datasets = {
  100. k: d() if callable(d) else d for k, d in trainer.datasets.items()
  101. }
  102. trainer.setup()
  103. trainer.training_loop()
  104. @DeveloperAPI
  105. class BaseTrainer(abc.ABC):
  106. """Defines interface for distributed training on Ray.
  107. Note: The base ``BaseTrainer`` class cannot be instantiated directly. Only
  108. one of its subclasses can be used.
  109. Note to developers: If a new trainer is added, please update
  110. `air/_internal/usage.py`.
  111. **How does a trainer work?**
  112. - First, initialize the Trainer. The initialization runs locally,
  113. so heavyweight setup should not be done in ``__init__``.
  114. - Then, when you call ``trainer.fit()``, the Trainer is serialized
  115. and copied to a remote Ray actor. The following methods are then
  116. called in sequence on the remote actor.
  117. - ``trainer.setup()``: Any heavyweight Trainer setup should be
  118. specified here.
  119. - ``trainer.training_loop()``: Executes the main training logic.
  120. - Calling ``trainer.fit()`` will return a ``ray.result.Result``
  121. object where you can access metrics from your training run, as well
  122. as any checkpoints that may have been saved.
  123. **How do I create a new Trainer?**
  124. Subclass ``ray.train.trainer.BaseTrainer``, and override the ``training_loop``
  125. method, and optionally ``setup``.
  126. .. testcode::
  127. :skipif: True
  128. import torch
  129. from ray.train.trainer import BaseTrainer
  130. from ray import train, tune
  131. class MyPytorchTrainer(BaseTrainer):
  132. def setup(self):
  133. self.model = torch.nn.Linear(1, 1)
  134. self.optimizer = torch.optim.SGD(
  135. self.model.parameters(), lr=0.1)
  136. def training_loop(self):
  137. # You can access any Trainer attributes directly in this method.
  138. # self.datasets["train"] has already been
  139. dataset = self.datasets["train"]
  140. torch_ds = dataset.iter_torch_batches(dtypes=torch.float)
  141. loss_fn = torch.nn.MSELoss()
  142. for epoch_idx in range(10):
  143. loss = 0
  144. num_batches = 0
  145. torch_ds = dataset.iter_torch_batches(
  146. dtypes=torch.float, batch_size=2
  147. )
  148. for batch in torch_ds:
  149. X = torch.unsqueeze(batch["x"], 1)
  150. y = torch.unsqueeze(batch["y"], 1)
  151. # Compute prediction error
  152. pred = self.model(X)
  153. batch_loss = loss_fn(pred, y)
  154. # Backpropagation
  155. self.optimizer.zero_grad()
  156. batch_loss.backward()
  157. self.optimizer.step()
  158. loss += batch_loss.item()
  159. num_batches += 1
  160. loss /= num_batches
  161. # Use Tune functions to report intermediate
  162. # results.
  163. train.report({"loss": loss, "epoch": epoch_idx})
  164. # Initialize the Trainer, and call Trainer.fit()
  165. import ray
  166. train_dataset = ray.data.from_items(
  167. [{"x": i, "y": i} for i in range(10)])
  168. my_trainer = MyPytorchTrainer(datasets={"train": train_dataset})
  169. result = my_trainer.fit()
  170. Args:
  171. scaling_config: Configuration for how to scale training.
  172. run_config: Configuration for the execution of the training run.
  173. datasets: Any Datasets to use for training. Use the key "train"
  174. to denote which dataset is the training dataset.
  175. metadata: Dict that should be made available via
  176. `train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
  177. for checkpoints saved from this Trainer. Must be JSON-serializable.
  178. resume_from_checkpoint: A checkpoint to resume training from.
  179. """
  180. _scaling_config_allowed_keys: List[str] = [
  181. "trainer_resources",
  182. ]
  183. _handles_checkpoint_freq: bool = False
  184. _handles_checkpoint_at_end: bool = False
  185. # fields to propagate to Tuner param_space.
  186. # See `BaseTrainer._extract_fields_for_tuner_param_space` for more details.
  187. _fields_for_tuner_param_space = []
  188. def __init__(
  189. self,
  190. *,
  191. scaling_config: Optional[ScalingConfig] = None,
  192. run_config: Optional[RunConfig] = None,
  193. datasets: Optional[Dict[str, GenDataset]] = None,
  194. metadata: Optional[Dict[str, Any]] = None,
  195. resume_from_checkpoint: Optional[Checkpoint] = None,
  196. ):
  197. self.scaling_config = (
  198. scaling_config if scaling_config is not None else ScalingConfig()
  199. )
  200. self.run_config = (
  201. copy.copy(run_config) if run_config is not None else RunConfig()
  202. )
  203. self.metadata = metadata
  204. self.datasets = datasets if datasets is not None else {}
  205. self.starting_checkpoint = resume_from_checkpoint
  206. if _v2_migration_warnings_enabled():
  207. if metadata is not None:
  208. _log_deprecation_warning(_GET_METADATA_DEPRECATION_MESSAGE)
  209. if resume_from_checkpoint is not None:
  210. _log_deprecation_warning(_RESUME_FROM_CHECKPOINT_DEPRECATION_WARNING)
  211. # These attributes should only be set through `BaseTrainer.restore`
  212. self._restore_path = None
  213. self._restore_storage_filesystem = None
  214. self._validate_attributes()
  215. usage_lib.record_library_usage("train")
  216. air_usage.tag_air_trainer(self)
  217. @classmethod
  218. @Deprecated(message=_TRAINER_RESTORE_DEPRECATION_WARNING)
  219. def restore(
  220. cls: Type["BaseTrainer"],
  221. path: Union[str, os.PathLike],
  222. storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
  223. datasets: Optional[Dict[str, GenDataset]] = None,
  224. scaling_config: Optional[ScalingConfig] = None,
  225. **kwargs,
  226. ) -> "BaseTrainer":
  227. """Restores a Train experiment from a previously interrupted/failed run.
  228. Restore should be used for experiment-level fault tolerance in the event
  229. that the head node crashes (e.g., OOM or some other runtime error) or the
  230. entire cluster goes down (e.g., network error affecting all nodes).
  231. A run that has already completed successfully will not be resumed from this API.
  232. To continue training from a successful run, launch a new run with the
  233. ``<Framework>Trainer(resume_from_checkpoint)`` API instead, passing in a
  234. checkpoint from the previous run to start with.
  235. .. note::
  236. Restoring an experiment from a path that's pointing to a *different*
  237. location than the original experiment path is supported. However, Ray Train
  238. assumes that the full experiment directory is available
  239. (including checkpoints) so that it's possible to resume trials from their
  240. latest state.
  241. For example, if the original experiment path was run locally, then the
  242. results are uploaded to cloud storage, Ray Train expects the full contents
  243. to be available in cloud storage if attempting to resume
  244. via ``<Framework>Trainer.restore("s3://...")``. The restored run will
  245. continue writing results to the same cloud storage location.
  246. The following example can be paired with implementing job retry using
  247. :ref:`Ray Jobs <jobs-overview>` to produce a Train experiment that will
  248. attempt to resume on both experiment-level and trial-level failures:
  249. .. testcode::
  250. :skipif: True
  251. import os
  252. import ray
  253. from ray import train
  254. from ray.train.trainer import BaseTrainer
  255. experiment_name = "unique_experiment_name"
  256. storage_path = os.path.expanduser("~/ray_results")
  257. experiment_dir = os.path.join(storage_path, experiment_name)
  258. # Define some dummy inputs for demonstration purposes
  259. datasets = {"train": ray.data.from_items([{"a": i} for i in range(10)])}
  260. class CustomTrainer(BaseTrainer):
  261. def training_loop(self):
  262. pass
  263. if CustomTrainer.can_restore(experiment_dir):
  264. trainer = CustomTrainer.restore(
  265. experiment_dir, datasets=datasets
  266. )
  267. else:
  268. trainer = CustomTrainer(
  269. datasets=datasets,
  270. run_config=train.RunConfig(
  271. name=experiment_name,
  272. storage_path=storage_path,
  273. # Tip: You can also enable retries on failure for
  274. # worker-level fault tolerance
  275. failure_config=train.FailureConfig(max_failures=3),
  276. ),
  277. )
  278. result = trainer.fit()
  279. Args:
  280. path: The path to the experiment directory of the training run to restore.
  281. This can be a local path or a remote URI if the experiment was
  282. uploaded to the cloud.
  283. storage_filesystem: Custom ``pyarrow.fs.FileSystem``
  284. corresponding to the ``path``. This may be necessary if the original
  285. experiment passed in a custom filesystem.
  286. datasets: Re-specified datasets used in the original training run.
  287. This must include all the datasets that were passed in the
  288. original trainer constructor.
  289. scaling_config: Optionally re-specified scaling config. This can be
  290. modified to be different from the original spec.
  291. **kwargs: Other optionally re-specified arguments, passed in by subclasses.
  292. Raises:
  293. ValueError: If all datasets were not re-supplied on restore.
  294. Returns:
  295. BaseTrainer: A restored instance of the class that is calling this method.
  296. """
  297. if _v2_migration_warnings_enabled():
  298. _log_deprecation_warning(_TRAINER_RESTORE_DEPRECATION_WARNING)
  299. if not cls.can_restore(path, storage_filesystem):
  300. raise ValueError(
  301. f"Invalid restore path: {path}. Make sure that this path exists and "
  302. "is the experiment directory that results from a call to "
  303. "`trainer.fit()`."
  304. )
  305. fs, fs_path = get_fs_and_path(path, storage_filesystem)
  306. trainer_pkl_path = Path(fs_path, _TRAINER_PKL).as_posix()
  307. with fs.open_input_file(trainer_pkl_path) as f:
  308. trainer_cls, param_dict = pickle.loads(f.readall())
  309. if trainer_cls is not cls:
  310. warnings.warn(
  311. f"Invalid trainer type. You are attempting to restore a trainer of type"
  312. f" {trainer_cls} with `{cls.__name__}.restore`, "
  313. "which will most likely fail. "
  314. f"Use `{trainer_cls.__name__}.restore` instead."
  315. )
  316. original_datasets = param_dict.pop("datasets", {})
  317. if original_datasets and not datasets:
  318. raise ValueError(
  319. "The following datasets need to be provided again on restore: "
  320. f"{list(original_datasets.keys())}\n"
  321. f"Use {cls.__name__}.restore(..., datasets=datasets) "
  322. "with the datasets that were provided to the original trainer."
  323. )
  324. datasets = datasets or {}
  325. if set(original_datasets) != set(datasets):
  326. raise ValueError(
  327. "The provided datasets don't match the original dataset keys.\n"
  328. f" Expected datasets for the keys: {list(original_datasets.keys())}\n"
  329. f" Actual datasets provided: {list(datasets.keys())}"
  330. )
  331. param_dict["datasets"] = datasets
  332. if scaling_config:
  333. param_dict["scaling_config"] = scaling_config
  334. for param_name, val in kwargs.items():
  335. # Overwrite the old value if something is passed into restore
  336. if val is not None:
  337. param_dict[param_name] = val
  338. try:
  339. trainer = cls(**param_dict)
  340. except Exception as e:
  341. raise ValueError(
  342. "Trainer restoration failed (see above for the stack trace). "
  343. "Make sure that you use the right trainer class to restore: "
  344. f"`{cls.__name__}.restore`\n"
  345. ) from e
  346. trainer._restore_path = path
  347. trainer._restore_storage_filesystem = storage_filesystem
  348. return trainer
  349. @classmethod
  350. @Deprecated(
  351. message=_TRAINER_RESTORE_DEPRECATION_WARNING,
  352. warning=_v2_migration_warnings_enabled(),
  353. )
  354. def can_restore(
  355. cls: Type["BaseTrainer"],
  356. path: Union[str, os.PathLike],
  357. storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
  358. ) -> bool:
  359. """Checks whether a given directory contains a restorable Train experiment.
  360. Args:
  361. path: The path to the experiment directory of the Train experiment.
  362. This can be either a local directory (e.g., ~/ray_results/exp_name)
  363. or a remote URI (e.g., s3://bucket/exp_name).
  364. Returns:
  365. bool: Whether this path exists and contains the trainer state to resume from
  366. """
  367. if _v2_migration_warnings_enabled():
  368. _log_deprecation_warning(_TRAINER_RESTORE_DEPRECATION_WARNING)
  369. fs, fs_path = get_fs_and_path(path, storage_filesystem)
  370. trainer_pkl_path = Path(fs_path, _TRAINER_PKL).as_posix()
  371. return _exists_at_fs_path(fs, trainer_pkl_path)
  372. def __repr__(self):
  373. # A dictionary that maps parameters to their default values.
  374. default_values: Dict[str, Any] = {
  375. "scaling_config": ScalingConfig(),
  376. "run_config": RunConfig(),
  377. "datasets": {},
  378. "starting_checkpoint": None,
  379. }
  380. non_default_arguments = []
  381. for parameter, default_value in default_values.items():
  382. value = getattr(self, parameter)
  383. if value != default_value:
  384. # 'Dataset.__repr__' returns a table rather than a regular Python object
  385. # representation. So, we need to special case the 'datasets' parameter.
  386. if parameter == "datasets":
  387. value_repr = format_datasets_for_repr(value)
  388. else:
  389. value_repr = repr(value)
  390. non_default_arguments.append(f"{parameter}={value_repr}")
  391. if non_default_arguments:
  392. return f"<{self.__class__.__name__} {' '.join(non_default_arguments)}>"
  393. return f"<{self.__class__.__name__}>"
  394. def __new__(cls, *args, **kwargs):
  395. # Store the init args as attributes so this can be merged with Tune hparams.
  396. trainer = super(BaseTrainer, cls).__new__(cls)
  397. parameters = inspect.signature(cls.__init__).parameters
  398. parameters = list(parameters.keys())
  399. # Remove self.
  400. parameters = parameters[1:]
  401. arg_dict = dict(zip(parameters, args))
  402. trainer._param_dict = {**arg_dict, **kwargs}
  403. return trainer
  404. def _validate_attributes(self):
  405. """Called on __init()__ to validate trainer attributes."""
  406. # Run config
  407. if not isinstance(self.run_config, RunConfig):
  408. raise ValueError(
  409. f"`run_config` should be an instance of `ray.train.RunConfig`, "
  410. f"found {type(self.run_config)} with value `{self.run_config}`."
  411. )
  412. # Scaling config
  413. if not isinstance(self.scaling_config, ScalingConfig):
  414. raise ValueError(
  415. "`scaling_config` should be an instance of `ScalingConfig`, "
  416. f"found {type(self.scaling_config)} with value `{self.scaling_config}`."
  417. )
  418. # Datasets
  419. if not isinstance(self.datasets, dict):
  420. raise ValueError(
  421. f"`datasets` should be a dict mapping from a string to "
  422. f"`ray.data.Dataset` objects, "
  423. f"found {type(self.datasets)} with value `{self.datasets}`."
  424. )
  425. else:
  426. for key, dataset in self.datasets.items():
  427. if not isinstance(dataset, ray.data.Dataset) and not callable(dataset):
  428. raise ValueError(
  429. f"The Dataset under '{key}' key is not a "
  430. "`ray.data.Dataset`. "
  431. f"Received {dataset} instead."
  432. )
  433. # Metadata.
  434. self.metadata = self.metadata or {}
  435. if not isinstance(self.metadata, dict):
  436. raise TypeError(
  437. f"The provided metadata must be a dict, was {type(self.metadata)}."
  438. )
  439. try:
  440. self.metadata = json.loads(json.dumps(self.metadata))
  441. except Exception as e:
  442. raise ValueError(
  443. "The provided metadata must be JSON-serializable: "
  444. f"{self.metadata}: {e}"
  445. )
  446. if self.starting_checkpoint is not None and not isinstance(
  447. self.starting_checkpoint, Checkpoint
  448. ):
  449. raise ValueError(
  450. f"`resume_from_checkpoint` should be an instance of "
  451. f"`ray.train.Checkpoint`, found {type(self.starting_checkpoint)} "
  452. f"with value `{self.starting_checkpoint}`."
  453. )
  454. self._log_v2_deprecation_warnings()
  455. def _log_v2_deprecation_warnings(self):
  456. """Logs deprecation warnings for v2 migration.
  457. Log them here in the Ray Train case rather than in the configuration
  458. constructors to avoid logging incorrect deprecation warnings when
  459. `ray.train.RunConfig` is passed to Ray Tune.
  460. """
  461. from ray.train.v2._internal.constants import V2_ENABLED_ENV_VAR, is_v2_enabled
  462. if is_v2_enabled():
  463. raise DeprecationWarning(
  464. f"Detected use of a deprecated Trainer import from `{self.__class__.__module__}`. "
  465. "This Trainer class is not compatible with Ray Train V2.\n"
  466. "To fix this:\n"
  467. " - Update to use the new import path. For example, "
  468. "`from ray.train.torch.torch_trainer import TorchTrainer` -> "
  469. "`from ray.train.torch import TorchTrainer`\n"
  470. f" - Or, explicitly disable V2 by setting: {V2_ENABLED_ENV_VAR}=0\n"
  471. "See this issue for more context: "
  472. "https://github.com/ray-project/ray/issues/49454"
  473. )
  474. if not _v2_migration_warnings_enabled():
  475. return
  476. from ray.train.v2._internal.migration_utils import (
  477. CALLBACKS_DEPRECATION_MESSAGE,
  478. FAIL_FAST_DEPRECATION_MESSAGE,
  479. LOG_TO_FILE_DEPRECATION_MESSAGE,
  480. PROGRESS_REPORTER_DEPRECATION_MESSAGE,
  481. STOP_DEPRECATION_MESSAGE,
  482. SYNC_CONFIG_DEPRECATION_MESSAGE,
  483. TRAINER_RESOURCES_DEPRECATION_MESSAGE,
  484. VERBOSE_DEPRECATION_MESSAGE,
  485. )
  486. # ScalingConfig deprecations
  487. if self.scaling_config.trainer_resources is not None:
  488. _log_deprecation_warning(TRAINER_RESOURCES_DEPRECATION_MESSAGE)
  489. # FailureConfig deprecations
  490. if self.run_config.failure_config.fail_fast:
  491. _log_deprecation_warning(FAIL_FAST_DEPRECATION_MESSAGE)
  492. # RunConfig deprecations
  493. # NOTE: _verbose is the original verbose value passed by the user
  494. if self.run_config._verbose is not None:
  495. _log_deprecation_warning(VERBOSE_DEPRECATION_MESSAGE)
  496. if self.run_config.log_to_file:
  497. _log_deprecation_warning(LOG_TO_FILE_DEPRECATION_MESSAGE)
  498. if self.run_config.stop is not None:
  499. _log_deprecation_warning(STOP_DEPRECATION_MESSAGE)
  500. if self.run_config.callbacks is not None:
  501. _log_deprecation_warning(CALLBACKS_DEPRECATION_MESSAGE)
  502. if self.run_config.progress_reporter is not None:
  503. _log_deprecation_warning(PROGRESS_REPORTER_DEPRECATION_MESSAGE)
  504. if self.run_config.sync_config != ray.train.SyncConfig():
  505. _log_deprecation_warning(SYNC_CONFIG_DEPRECATION_MESSAGE)
  506. @classmethod
  507. def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfig:
  508. """Returns scaling config dataclass after validating updated keys."""
  509. ensure_only_allowed_dataclass_keys_updated(
  510. dataclass=scaling_config,
  511. allowed_keys=cls._scaling_config_allowed_keys,
  512. )
  513. return scaling_config
  514. def setup(self) -> None:
  515. """Called during fit() to perform initial setup on the Trainer.
  516. .. note:: This method is run on a remote process.
  517. This method will not be called on the driver, so any expensive setup
  518. operations should be placed here and not in ``__init__``.
  519. This method is called prior to ``preprocess_datasets`` and
  520. ``training_loop``.
  521. """
  522. pass
  523. def preprocess_datasets(self) -> None:
  524. """Deprecated."""
  525. raise DeprecationWarning(
  526. "`preprocess_datasets` is no longer used, since preprocessors "
  527. f"are no longer accepted by Trainers.\n{PREPROCESSOR_DEPRECATION_MESSAGE}"
  528. )
  529. @abc.abstractmethod
  530. def training_loop(self) -> None:
  531. """Loop called by fit() to run training and report results to Tune.
  532. .. note:: This method runs on a remote process.
  533. ``self.datasets`` have already been evaluated if they were wrapped in a factory.
  534. You can use the :ref:`Ray Train utilities <train-loop-api>`
  535. (:func:`train.report() <ray.train.report>` and
  536. :func:`train.get_checkpoint() <ray.train.get_checkpoint>`) inside
  537. this training loop.
  538. Example:
  539. .. testcode::
  540. from ray.train.trainer import BaseTrainer
  541. from ray import train
  542. class MyTrainer(BaseTrainer):
  543. def training_loop(self):
  544. for epoch_idx in range(5):
  545. ...
  546. train.report({"epoch": epoch_idx})
  547. """
  548. raise NotImplementedError
  549. @PublicAPI(stability="beta")
  550. def fit(self) -> Result:
  551. """Runs training.
  552. Returns:
  553. A Result object containing the training result.
  554. Raises:
  555. ray.train.base_trainer.TrainingFailedError: If any failures during the execution
  556. of ``self.as_trainable()``, or during the Tune execution loop.
  557. """
  558. from ray.tune import ResumeConfig, TuneError
  559. from ray.tune.tuner import Tuner
  560. trainable = self.as_trainable()
  561. param_space = self._extract_fields_for_tuner_param_space()
  562. self.run_config.name = (
  563. self.run_config.name or StorageContext.get_experiment_dir_name(trainable)
  564. )
  565. # The storage context here is only used to access the resolved
  566. # storage fs and experiment path, in order to avoid duplicating that logic.
  567. # This is NOT the storage context object that gets passed to remote workers.
  568. storage = StorageContext(
  569. storage_path=self.run_config.storage_path,
  570. experiment_dir_name=self.run_config.name,
  571. storage_filesystem=self.run_config.storage_filesystem,
  572. )
  573. if self._restore_path:
  574. tuner = Tuner.restore(
  575. path=self._restore_path,
  576. trainable=trainable,
  577. param_space=param_space,
  578. _resume_config=ResumeConfig(
  579. finished=ResumeConfig.ResumeType.RESUME,
  580. unfinished=ResumeConfig.ResumeType.RESUME,
  581. errored=ResumeConfig.ResumeType.RESUME,
  582. ),
  583. storage_filesystem=self._restore_storage_filesystem,
  584. )
  585. else:
  586. tuner = Tuner(
  587. trainable=trainable,
  588. param_space=param_space,
  589. run_config=self.run_config,
  590. _entrypoint=AirEntrypoint.TRAINER,
  591. )
  592. self._save(storage.storage_filesystem, storage.experiment_fs_path)
  593. restore_msg = TrainingFailedError._RESTORE_MSG.format(
  594. trainer_cls_name=self.__class__.__name__,
  595. path=str(storage.experiment_fs_path),
  596. )
  597. try:
  598. result_grid = tuner.fit()
  599. except TuneError as e:
  600. # Catch any `TuneError`s raised by the `Tuner.fit` call.
  601. # Unwrap the `TuneError` if needed.
  602. parent_error = e.__cause__ or e
  603. # Raise it to the user as a `TrainingFailedError` with a message to restore.
  604. raise TrainingFailedError(restore_msg) from parent_error
  605. # Other exceptions get passed through directly (ex: on `fail_fast='raise'`)
  606. assert len(result_grid) == 1
  607. result = result_grid[0]
  608. if result.error:
  609. # Raise trainable errors to the user with a message to restore
  610. # or configure `FailureConfig` in a new run.
  611. raise TrainingFailedError(
  612. "\n".join([restore_msg, TrainingFailedError._FAILURE_CONFIG_MSG])
  613. ) from result.error
  614. return result
  615. def _save(self, fs: pyarrow.fs.FileSystem, experiment_path: str):
  616. """Saves the current trainer's class along with the `param_dict` of
  617. parameters passed to this trainer's constructor.
  618. This is used to recreate the trainer on restore.
  619. Unless a parameter is re-specified during restoration (only a subset
  620. of parameters can be passed in again), that parameter will be loaded
  621. from the saved copy.
  622. Datasets should not be saved as part of the state. Instead, we save the
  623. keys and replace the dataset values with dummy functions that will
  624. raise an error if invoked. The error only serves as a guardrail for
  625. misuse (e.g., manually unpickling and constructing the Trainer again)
  626. and is not typically surfaced, since datasets must be re-specified
  627. upon restoration.
  628. """
  629. param_dict = self._param_dict.copy()
  630. datasets = param_dict.pop("datasets", {})
  631. def raise_fn():
  632. raise RuntimeError
  633. if datasets:
  634. param_dict["datasets"] = dict.fromkeys(datasets, raise_fn)
  635. cls_and_param_dict = (self.__class__, param_dict)
  636. fs.create_dir(experiment_path)
  637. with fs.open_output_stream(Path(experiment_path, _TRAINER_PKL).as_posix()) as f:
  638. f.write(pickle.dumps(cls_and_param_dict))
  639. def _extract_fields_for_tuner_param_space(self) -> Dict:
  640. """Extracts fields to be included in `Tuner.param_space`.
  641. This is needed to leverage the full logging/integration offerings from Tune.
  642. For example, `param_space` is logged automatically to wandb integration.
  643. Currently only done for `train_loop_config`.
  644. Returns:
  645. A dictionary that should be passed to Tuner.param_space.
  646. """
  647. result = {}
  648. for key in self._fields_for_tuner_param_space:
  649. if key in self._param_dict.keys():
  650. result[key] = copy.deepcopy(self._param_dict[key])
  651. return result
  652. def _generate_trainable_cls(self) -> Type["Trainable"]:
  653. """Generates the base Trainable class.
  654. Returns:
  655. A Trainable class to use for training.
  656. """
  657. from ray.tune.execution.placement_groups import PlacementGroupFactory
  658. from ray.tune.trainable import wrap_function
  659. trainer_cls = self.__class__
  660. scaling_config = self.scaling_config
  661. metadata = self.metadata
  662. train_coordinator_fn = partial(
  663. _train_coordinator_fn, trainer_cls=trainer_cls, metadata=metadata
  664. )
  665. # Change the name of the training function to match the name of the Trainer
  666. # class. This will mean the Tune trial name will match the name of Trainer on
  667. # stdout messages and the results directory.
  668. train_coordinator_fn.__name__ = trainer_cls.__name__
  669. trainable_cls = wrap_function(train_coordinator_fn)
  670. has_base_dataset = bool(self.datasets)
  671. if has_base_dataset:
  672. from ray.data.context import DataContext
  673. dataset_context = DataContext.get_current()
  674. else:
  675. dataset_context = None
  676. class TrainTrainable(trainable_cls):
  677. """Adds default resources to the Trainable."""
  678. _handles_checkpoint_freq = trainer_cls._handles_checkpoint_freq
  679. _handles_checkpoint_at_end = trainer_cls._handles_checkpoint_at_end
  680. @classmethod
  681. def has_base_dataset(cls) -> bool:
  682. """Whether a dataset is provided through the Trainer."""
  683. return has_base_dataset
  684. @classmethod
  685. def base_scaling_config(cls) -> ScalingConfig:
  686. """Returns the unchanged scaling config provided through the Trainer."""
  687. return scaling_config
  688. def setup(self, config, **kwargs):
  689. base_config = dict(kwargs)
  690. # Merge Tuner param space hyperparameters in `config` into the
  691. # base config passed to the Trainer constructor, which is `base_config`.
  692. # `base_config` is pulled from the object store from the usage of
  693. # tune.with_parameters in `BaseTrainer.as_trainable`.
  694. # run_config is not a tunable hyperparameter so it does not need to be
  695. # merged.
  696. run_config = base_config.pop("run_config", None)
  697. self._merged_config = deep_update(
  698. base_config, self.config, new_keys_allowed=True
  699. )
  700. self._merged_config["run_config"] = run_config
  701. merged_scaling_config = self._merged_config.get(
  702. "scaling_config", ScalingConfig()
  703. )
  704. if isinstance(merged_scaling_config, dict):
  705. merged_scaling_config = ScalingConfig(**merged_scaling_config)
  706. self._merged_config[
  707. "scaling_config"
  708. ] = self._reconcile_scaling_config_with_trial_resources(
  709. merged_scaling_config
  710. )
  711. if self.has_base_dataset():
  712. # Set the DataContext on the Trainer actor to the DataContext
  713. # specified on the driver.
  714. DataContext._set_current(dataset_context)
  715. super(TrainTrainable, self).setup(config)
  716. def _reconcile_scaling_config_with_trial_resources(
  717. self, scaling_config: ScalingConfig
  718. ) -> ScalingConfig:
  719. """
  720. ResourceChangingScheduler workaround.
  721. Ensures that the scaling config matches trial resources.
  722. This should be replaced with RCS returning a ScalingConfig
  723. in the future.
  724. """
  725. trial_resources = self.trial_resources
  726. # This will be false if the resources are default
  727. if not isinstance(trial_resources, PlacementGroupFactory):
  728. return scaling_config
  729. # Ignore ResourceChangingScheduler workaround when resource bundles
  730. # are unchanged
  731. if self.trial_resources == scaling_config.as_placement_group_factory():
  732. return scaling_config
  733. trainer_cls._validate_scaling_config(scaling_config)
  734. return ScalingConfig.from_placement_group_factory(trial_resources)
  735. def _trainable_func(self, config):
  736. # We ignore the config passed by Tune and instead use the merged
  737. # config which includes the initial Trainer args.
  738. super()._trainable_func(self._merged_config)
  739. @classmethod
  740. def default_resource_request(cls, config):
  741. # `config["scaling_config"] is a dataclass when passed via the
  742. # `scaling_config` argument in `Trainer` and is a dict when passed
  743. # via the `scaling_config` key of `param_spec`.
  744. # Conversion logic must be duplicated in `TrainTrainable.__init__`
  745. # because this is a class method.
  746. updated_scaling_config = config.get("scaling_config", scaling_config)
  747. if isinstance(updated_scaling_config, dict):
  748. updated_scaling_config = ScalingConfig(**updated_scaling_config)
  749. validated_scaling_config = trainer_cls._validate_scaling_config(
  750. updated_scaling_config
  751. )
  752. return validated_scaling_config.as_placement_group_factory()
  753. return TrainTrainable
  754. def as_trainable(self) -> Type["Trainable"]:
  755. """Converts self to a ``tune.Trainable`` class."""
  756. from ray import tune
  757. base_config = self._param_dict
  758. trainable_cls = self._generate_trainable_cls()
  759. # Wrap with `tune.with_parameters` to handle very large values in base_config
  760. return tune.with_parameters(trainable_cls, **base_config)
  761. @DeveloperAPI
  762. def format_datasets_for_repr(datasets: Optional[Dict[str, GenDataset]]) -> str:
  763. """Format datasets for BaseTrainer repr using plan strings.
  764. The Dataset.__repr__ returns a table rather than a conventional Python object
  765. reprentation. To ensure the BaseTrainer representation still looks reasonable, we
  766. need to special-case datasets.
  767. """
  768. from ray.data import Dataset
  769. assert datasets is not None, "Expected caller to pass in non-None argument"
  770. formatted = {}
  771. for key, dataset in datasets.items():
  772. if isinstance(dataset, Dataset):
  773. formatted[key] = dataset._plan.get_plan_as_string(type(dataset))
  774. else:
  775. formatted[key] = dataset
  776. return "{" + ", ".join(f"'{key}': {formatted[key]}" for key in datasets) + "}"