data_parallel_trainer.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. import logging
  2. import uuid
  3. from typing import Any, Callable, Dict, List, Optional, Type, Union
  4. import ray
  5. from ray._private.ray_constants import env_integer
  6. from ray._private.thirdparty.tabulate.tabulate import tabulate
  7. from ray.air.config import RunConfig, ScalingConfig
  8. from ray.train import BackendConfig, Checkpoint
  9. from ray.train._internal import session
  10. from ray.train._internal.backend_executor import BackendExecutor, TrialInfo
  11. from ray.train._internal.data_config import DataConfig
  12. from ray.train._internal.session import _TrainingResult, get_session
  13. from ray.train._internal.utils import construct_train_func, count_required_parameters
  14. from ray.train.base_trainer import _TRAINER_RESTORE_DEPRECATION_WARNING
  15. from ray.train.constants import RAY_TRAIN_ENABLE_STATE_TRACKING
  16. from ray.train.trainer import BaseTrainer, GenDataset, TrainingIterator
  17. from ray.util.annotations import Deprecated, DeveloperAPI
  18. from ray.widgets import Template
  19. from ray.widgets.util import repr_with_fallback
  20. logger = logging.getLogger(__name__)
  21. @DeveloperAPI
  22. class DataParallelTrainer(BaseTrainer):
  23. """A Trainer for data parallel training.
  24. You should subclass this Trainer if your Trainer follows SPMD (single program,
  25. multiple data) programming paradigm - you want multiple processes to run the same
  26. function, but on different data.
  27. This Trainer runs the function ``train_loop_per_worker`` on multiple Ray
  28. Actors.
  29. The ``train_loop_per_worker`` function is expected to take in either 0 or 1
  30. arguments:
  31. .. testcode::
  32. def train_loop_per_worker():
  33. ...
  34. .. testcode::
  35. def train_loop_per_worker(config: Dict):
  36. ...
  37. If ``train_loop_per_worker`` accepts an argument, then
  38. ``train_loop_config`` will be passed in as the argument. This is useful if you
  39. want to tune the values in ``train_loop_config`` as hyperparameters.
  40. If the ``datasets`` dict contains a training dataset (denoted by
  41. the "train" key), then it will be split into multiple dataset
  42. shards that can then be accessed by ``train.get_dataset_shard("train")`` inside
  43. ``train_loop_per_worker``. All the other datasets will not be split and
  44. ``train.get_dataset_shard(...)`` will return the entire Dataset.
  45. Inside the ``train_loop_per_worker`` function, you can use any of the
  46. :ref:`Ray Train loop methods <train-loop-api>`.
  47. .. testcode::
  48. from ray import train
  49. def train_loop_per_worker():
  50. # Report intermediate results for callbacks or logging and
  51. # checkpoint data.
  52. train.report(...)
  53. # Returns dict of last saved checkpoint.
  54. train.get_checkpoint()
  55. # Returns the Dataset shard for the given key.
  56. train.get_dataset_shard("my_dataset")
  57. # Returns the total number of workers executing training.
  58. train.get_context().get_world_size()
  59. # Returns the rank of this worker.
  60. train.get_context().get_world_rank()
  61. # Returns the rank of the worker on the current node.
  62. train.get_context().get_local_rank()
  63. Any returns from the ``train_loop_per_worker`` will be discarded and not
  64. used or persisted anywhere.
  65. **How do I use DataParallelTrainer or any of its subclasses?**
  66. Example:
  67. .. testcode::
  68. :skipif: True
  69. import ray
  70. from ray import train
  71. from ray.train import ScalingConfig
  72. from ray.train.data_parallel_trainer import DataParallelTrainer
  73. def train_loop_for_worker():
  74. dataset_shard_for_this_worker = train.get_dataset_shard("train")
  75. # 3 items for 3 workers, each worker gets 1 item
  76. batches = list(dataset_shard_for_this_worker.iter_batches(batch_size=1))
  77. assert len(batches) == 1
  78. train_dataset = ray.data.from_items([1, 2, 3])
  79. assert train_dataset.count() == 3
  80. trainer = DataParallelTrainer(
  81. train_loop_for_worker,
  82. scaling_config=ScalingConfig(num_workers=3),
  83. datasets={"train": train_dataset},
  84. )
  85. result = trainer.fit()
  86. **How do I develop on top of DataParallelTrainer?**
  87. In many cases, using DataParallelTrainer directly is sufficient to execute
  88. functions on multiple actors.
  89. However, you may want to subclass ``DataParallelTrainer`` and create a custom
  90. Trainer for the following 2 use cases:
  91. - **Use Case 1:** You want to do data parallel training, but want to have
  92. a predefined ``training_loop_per_worker``.
  93. - **Use Case 2:** You want to implement a custom
  94. :py:class:`~ray.train.backend.Backend` that automatically handles
  95. additional setup or teardown logic on each actor, so that the users of this
  96. new trainer do not have to implement this logic. For example, a
  97. ``TensorflowTrainer`` can be built on top of ``DataParallelTrainer``
  98. that automatically handles setting the proper environment variables for
  99. distributed Tensorflow on each actor.
  100. For 1, you can set a predefined training loop in __init__
  101. .. testcode::
  102. from ray.train.data_parallel_trainer import DataParallelTrainer
  103. class MyDataParallelTrainer(DataParallelTrainer):
  104. def __init__(self, *args, **kwargs):
  105. predefined_train_loop_per_worker = lambda: 1
  106. super().__init__(predefined_train_loop_per_worker, *args, **kwargs)
  107. For 2, you can implement the ``ray.train.Backend`` and ``ray.train.BackendConfig``
  108. interfaces.
  109. .. testcode::
  110. from dataclasses import dataclass
  111. from ray.train.backend import Backend, BackendConfig
  112. class MyBackend(Backend):
  113. def on_start(self, worker_group, backend_config):
  114. def set_env_var(env_var_value):
  115. import os
  116. os.environ["MY_ENV_VAR"] = env_var_value
  117. worker_group.execute(set_env_var, backend_config.env_var)
  118. @dataclass
  119. class MyBackendConfig(BackendConfig):
  120. env_var: str = "default_value"
  121. def backend_cls(self):
  122. return MyBackend
  123. class MyTrainer(DataParallelTrainer):
  124. def __init__(self, train_loop_per_worker, my_backend_config:
  125. MyBackendConfig, **kwargs):
  126. super().__init__(
  127. train_loop_per_worker,
  128. backend_config=my_backend_config, **kwargs)
  129. Args:
  130. train_loop_per_worker: The training function to execute.
  131. This can either take in no arguments or a ``config`` dict.
  132. train_loop_config: Configurations to pass into
  133. ``train_loop_per_worker`` if it accepts an argument.
  134. backend_config: Configuration for setting up a Backend (e.g. Torch,
  135. Tensorflow, Horovod) on each worker to enable distributed
  136. communication. If no Backend should be set up, then set this to None.
  137. scaling_config: Configuration for how to scale data parallel training.
  138. dataset_config: Configuration for dataset ingest. This is merged with the
  139. default dataset config for the given trainer (`cls._dataset_config`).
  140. run_config: Configuration for the execution of the training run.
  141. datasets: Ray Datasets to use for training and evaluation.
  142. This is a dict where the key is the name of the dataset, which
  143. can be accessed from within the ``train_loop_per_worker`` by calling
  144. ``train.get_dataset_shard(dataset_key)``.
  145. By default, all datasets are sharded equally across workers.
  146. This can be configured via ``dataset_config``.
  147. metadata: Dict that should be made available via
  148. `train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
  149. for checkpoints saved from this Trainer. Must be JSON-serializable.
  150. resume_from_checkpoint: A checkpoint to resume training from.
  151. """
  152. # Exposed here for testing purposes. Should never need
  153. # to be overridden.
  154. _backend_executor_cls: Type[BackendExecutor] = BackendExecutor
  155. _training_iterator_cls: Type[TrainingIterator] = TrainingIterator
  156. _scaling_config_allowed_keys = BaseTrainer._scaling_config_allowed_keys + [
  157. "num_workers",
  158. "resources_per_worker",
  159. "use_gpu",
  160. "placement_strategy",
  161. "accelerator_type",
  162. ]
  163. # For backwards compatibility with the legacy dataset config API.
  164. _dataset_config = None
  165. _fields_for_tuner_param_space = BaseTrainer._fields_for_tuner_param_space + [
  166. "train_loop_config"
  167. ]
  168. def __init__(
  169. self,
  170. train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
  171. *,
  172. train_loop_config: Optional[Dict] = None,
  173. backend_config: Optional[BackendConfig] = None,
  174. scaling_config: Optional[ScalingConfig] = None,
  175. dataset_config: Optional[DataConfig] = None,
  176. run_config: Optional[RunConfig] = None,
  177. datasets: Optional[Dict[str, GenDataset]] = None,
  178. metadata: Optional[Dict[str, Any]] = None,
  179. resume_from_checkpoint: Optional[Checkpoint] = None,
  180. ):
  181. self._train_loop_per_worker = train_loop_per_worker
  182. self._train_loop_config = train_loop_config
  183. if dataset_config is None:
  184. dataset_config = DataConfig()
  185. if not isinstance(dataset_config, DataConfig):
  186. raise ValueError(
  187. "`dataset_config` must be an instance of ray.train.DataConfig, "
  188. f"was: {dataset_config}"
  189. )
  190. self._data_config = dataset_config
  191. backend_config = (
  192. backend_config if backend_config is not None else BackendConfig()
  193. )
  194. self._backend_config = backend_config
  195. super(DataParallelTrainer, self).__init__(
  196. scaling_config=scaling_config,
  197. run_config=run_config,
  198. datasets=datasets,
  199. metadata=metadata,
  200. resume_from_checkpoint=resume_from_checkpoint,
  201. )
  202. train_total_resources = self.scaling_config.total_resources
  203. self._data_config.set_train_total_resources(
  204. train_total_resources.get("CPU", 0),
  205. train_total_resources.get("GPU", 0),
  206. )
  207. if env_integer(RAY_TRAIN_ENABLE_STATE_TRACKING, 0):
  208. from ray.train._internal.state.state_actor import get_or_create_state_actor
  209. get_or_create_state_actor()
  210. @classmethod
  211. @Deprecated(message=_TRAINER_RESTORE_DEPRECATION_WARNING)
  212. def restore(
  213. cls,
  214. path: str,
  215. train_loop_per_worker: Optional[
  216. Union[Callable[[], None], Callable[[Dict], None]]
  217. ] = None,
  218. train_loop_config: Optional[Dict] = None,
  219. **kwargs,
  220. ):
  221. """Restores a DataParallelTrainer from a previously interrupted/failed run.
  222. Args:
  223. train_loop_per_worker: Optionally re-specified train loop function.
  224. This should be used to re-specify a function that is not
  225. restorable in a new Ray cluster (e.g., it holds onto outdated
  226. object references). This should be the same training loop
  227. that was passed to the original trainer constructor.
  228. train_loop_config: Optionally re-specified train config.
  229. This should similarly be used if the original `train_loop_config`
  230. contained outdated object references, and it should not be modified
  231. from what was originally passed in.
  232. See :meth:`BaseTrainer.restore() <ray.train.trainer.BaseTrainer.restore>`
  233. for descriptions of the other arguments.
  234. Returns a restored instance of the `DataParallelTrainer`.
  235. """
  236. return super(DataParallelTrainer, cls).restore(
  237. path=path,
  238. train_loop_per_worker=train_loop_per_worker,
  239. train_loop_config=train_loop_config,
  240. **kwargs,
  241. )
  242. def _validate_attributes(self):
  243. super()._validate_attributes()
  244. self._validate_train_loop_per_worker(
  245. self._train_loop_per_worker, "train_loop_per_worker"
  246. )
  247. def _validate_train_loop_per_worker(
  248. self, train_loop_per_worker: Callable, fn_name: str
  249. ) -> None:
  250. num_required_params = count_required_parameters(train_loop_per_worker)
  251. if num_required_params > 1:
  252. raise ValueError(
  253. f"{fn_name} should take in 0 or 1 arguments, "
  254. f"but it accepts {num_required_params} arguments instead."
  255. )
  256. @classmethod
  257. def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfig:
  258. scaling_config = super(DataParallelTrainer, cls)._validate_scaling_config(
  259. scaling_config
  260. )
  261. # This validation happens after the scaling config is updated from
  262. # its specification in the Tuner `param_space`
  263. if not scaling_config.use_gpu and "GPU" in ray.available_resources():
  264. logger.info(
  265. "GPUs are detected in your Ray cluster, but GPU "
  266. "training is not enabled for this trainer. To enable "
  267. "GPU training, make sure to set `use_gpu` to True "
  268. "in your scaling config."
  269. )
  270. if scaling_config.num_workers is None:
  271. raise ValueError(
  272. "You must specify the 'num_workers' in `scaling_config` as either an "
  273. f"argument of `{cls.__name__}` or through the `param_space` of a "
  274. "`Tuner` (if performing hyperparameter tuning)."
  275. )
  276. if scaling_config.num_workers <= 0:
  277. raise ValueError(
  278. "'num_workers' in `scaling_config` must be a positive "
  279. f"integer. Received {scaling_config.num_workers}"
  280. )
  281. return scaling_config
  282. def _run_training(self, training_iterator: TrainingIterator) -> None:
  283. """This method loops over the `TrainingIterator`:
  284. The actual iteration (for ... in ...) waits for the training function
  285. on each worker to report a result and supplies it as a list of results.
  286. Afterwards (in the body of the loop), it will report the result
  287. to the Tune session.
  288. The iterator ends after the training function on each worker has finished.
  289. """
  290. for training_results in training_iterator:
  291. # TODO(ml-team): add ability to report results from multiple workers.
  292. self._propagate_results(training_results)
  293. def _propagate_results(self, training_results: List[_TrainingResult]):
  294. first_worker_result = training_results[0]
  295. assert all(isinstance(result, _TrainingResult) for result in training_results)
  296. tune_session = get_session()
  297. # Check if any workers reported a checkpoint.
  298. # If so, report a checkpoint pointing to the persisted location
  299. # to Tune for book-keeping.
  300. # NOTE: This removes the restriction for any individual worker
  301. # (ex: global rank 0 worker) from needing to report a checkpoint.
  302. # All workers reported a checkpoint to the same fs path, so there's
  303. # no need to report multiple checkpoints to Tune.
  304. worker_checkpoints = [
  305. result.checkpoint
  306. for result in training_results
  307. if result.checkpoint is not None
  308. ]
  309. at_least_one_reported_checkpoint = len(worker_checkpoints) > 0
  310. if at_least_one_reported_checkpoint:
  311. # Update the coordinator's checkpoint index to the latest.
  312. # This is what keeps the checkpoint index in line with the workers.
  313. tune_session.storage._update_checkpoint_index(first_worker_result.metrics)
  314. # Make sure that all workers uploaded to the same location.
  315. assert all(
  316. checkpoint.path == tune_session.storage.checkpoint_fs_path
  317. for checkpoint in worker_checkpoints
  318. )
  319. checkpoint = (
  320. Checkpoint(
  321. filesystem=tune_session.storage.storage_filesystem,
  322. path=tune_session.storage.checkpoint_fs_path,
  323. )
  324. if at_least_one_reported_checkpoint
  325. else None
  326. )
  327. tracked_training_result = _TrainingResult(
  328. checkpoint=checkpoint,
  329. metrics=first_worker_result.metrics,
  330. )
  331. logger.debug(
  332. "Report (metrics, checkpoint) to the Tune session:\n"
  333. f" metrics={tracked_training_result.metrics}\n"
  334. f" checkpoint={tracked_training_result.checkpoint}"
  335. )
  336. # Report the metrics and checkpoint to Tune.
  337. tune_session._report_training_result(tracked_training_result)
  338. def training_loop(self) -> None:
  339. scaling_config = self._validate_scaling_config(self.scaling_config)
  340. train_loop_per_worker = construct_train_func(
  341. self._train_loop_per_worker,
  342. self._train_loop_config,
  343. train_func_context=self._backend_config.train_func_context,
  344. fn_arg_name="train_loop_per_worker",
  345. discard_returns=True,
  346. )
  347. trial_info = TrialInfo(
  348. name=session.get_trial_name(),
  349. id=session.get_trial_id(),
  350. resources=session.get_trial_resources(),
  351. logdir=session.get_trial_dir(),
  352. driver_ip=ray.util.get_node_ip_address(),
  353. driver_node_id=ray.get_runtime_context().get_node_id(),
  354. experiment_name=session.get_experiment_name(),
  355. run_id=uuid.uuid4().hex,
  356. )
  357. backend_executor = self._backend_executor_cls(
  358. backend_config=self._backend_config,
  359. trial_info=trial_info,
  360. num_workers=scaling_config.num_workers,
  361. resources_per_worker=scaling_config._resources_per_worker_not_none,
  362. max_retries=0,
  363. )
  364. # Start the remote actors.
  365. backend_executor.start()
  366. training_iterator = self._training_iterator_cls(
  367. backend_executor=backend_executor,
  368. backend_config=self._backend_config,
  369. train_func=train_loop_per_worker,
  370. datasets=self.datasets,
  371. metadata=self.metadata,
  372. data_config=self._data_config,
  373. checkpoint=self.starting_checkpoint,
  374. )
  375. self._run_training(training_iterator)
  376. # Shutdown workers.
  377. backend_executor.shutdown()
  378. def get_dataset_config(self) -> DataConfig:
  379. """Returns a copy of this Trainer's final dataset configs.
  380. Returns:
  381. The merged default + user-supplied dataset config.
  382. """
  383. return self._data_config
  384. @repr_with_fallback(["ipywidgets", "8"])
  385. def _repr_mimebundle_(self, **kwargs):
  386. """Returns a mimebundle with an ipywidget repr and a simple text repr.
  387. Depending on the frontend where the data is being displayed,
  388. different mimetypes will be used from this bundle.
  389. See https://ipython.readthedocs.io/en/stable/config/integrating.html
  390. for information about this method, and
  391. https://ipywidgets.readthedocs.io/en/latest/embedding.html
  392. for more information about the jupyter widget mimetype.
  393. Returns:
  394. A mimebundle containing an ipywidget repr and a simple text repr.
  395. """
  396. from ipywidgets import HTML, Layout, Tab, VBox
  397. title = HTML(f"<h2>{self.__class__.__name__}</h2>")
  398. children = []
  399. titles = []
  400. if self.datasets:
  401. children.append(self._datasets_repr_())
  402. titles.append("Datasets")
  403. children.append(HTML(self._data_config_repr_html_()))
  404. titles.append("Data Config")
  405. if self._train_loop_config:
  406. children.append(HTML(self._train_loop_config_repr_html_()))
  407. titles.append("Train Loop Config")
  408. if self.scaling_config:
  409. children.append(HTML(self.scaling_config._repr_html_()))
  410. titles.append("Scaling Config")
  411. if self.run_config:
  412. children.append(HTML(self.run_config._repr_html_()))
  413. titles.append("Run Config")
  414. if self._backend_config:
  415. children.append(HTML(self._backend_config._repr_html_()))
  416. titles.append("Backend Config")
  417. tab = Tab(children, titles=titles)
  418. widget = VBox([title, tab], layout=Layout(width="100%"))
  419. bundle = widget._repr_mimebundle_(**kwargs)
  420. bundle.update(
  421. {
  422. "text/plain": repr(self),
  423. }
  424. )
  425. return bundle
  426. def _train_loop_config_repr_html_(self) -> str:
  427. if self._train_loop_config:
  428. table_data = {}
  429. for k, v in self._train_loop_config.items():
  430. if isinstance(v, str) or str(v).isnumeric():
  431. table_data[k] = v
  432. elif hasattr(v, "_repr_html_"):
  433. table_data[k] = v._repr_html_()
  434. else:
  435. table_data[k] = str(v)
  436. return Template("title_data.html.j2").render(
  437. title="Train Loop Config",
  438. data=Template("scrollableTable.html.j2").render(
  439. table=tabulate(
  440. table_data.items(),
  441. headers=["Setting", "Value"],
  442. showindex=False,
  443. tablefmt="unsafehtml",
  444. ),
  445. max_height="none",
  446. ),
  447. )
  448. else:
  449. return ""
  450. def _data_config_repr_html_(self) -> str:
  451. # TODO make this rendering nicer.
  452. content = [str(self._data_config)]
  453. return Template("rendered_html_common.html.j2").render(content=content)
  454. def _datasets_repr_(self) -> str:
  455. from ipywidgets import HTML, Layout, VBox
  456. content = []
  457. if self.datasets:
  458. for name, config in self.datasets.items():
  459. tab = config._tab_repr_()
  460. if tab:
  461. content.append(
  462. HTML(
  463. Template("title_data.html.j2").render(
  464. title=f"Dataset - <code>{name}</code>", data=None
  465. )
  466. )
  467. )
  468. content.append(config._tab_repr_())
  469. return VBox(content, layout=Layout(width="100%"))