jax_trainer.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import logging
  2. from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
  3. from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
  4. from ray.train import DataConfig
  5. from ray.train.trainer import GenDataset
  6. from ray.train.v2.api.config import RunConfig, ScalingConfig
  7. from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
  8. from ray.train.v2.api.validation_config import ValidationConfig
  9. from ray.train.v2.jax.config import JaxConfig
  10. from ray.util import PublicAPI
  11. if TYPE_CHECKING:
  12. pass
  13. logger = logging.getLogger(__name__)
  14. @PublicAPI(stability="alpha")
  15. class JaxTrainer(DataParallelTrainer):
  16. """A Trainer for Single-Program Multi-Data (SPMD) JAX training.
  17. At a high level, this Trainer does the following:
  18. 1. Launches multiple workers as defined by the ``scaling_config``.
  19. 2. Sets up a distributed JAX environment for TPUs or GPUs
  20. on these workers as defined by the ``jax_config``.
  21. 3. Ingests the input ``datasets`` based on the ``dataset_config``.
  22. 4. Runs the input ``train_loop_per_worker(train_loop_config)``
  23. on all workers.
  24. For more details, see:
  25. * :ref:`Jax Guide <train-jax>`
  26. .. testcode::
  27. :skipif: True
  28. import os
  29. from absl import app
  30. import logging
  31. from typing import Sequence
  32. import ray
  33. from ray.train import ScalingConfig, RunConfig
  34. from ray.train.v2.jax import JaxTrainer
  35. from MaxText.train import main as maxtext_main
  36. def train_loop_per_worker(config):
  37. argv = config["argv"]
  38. maxtext_main(argv)
  39. def main(argv: Sequence[str]):
  40. ray.init()
  41. # If you want to use TPUs, specify the TPU topology and accelerator type.
  42. tpu_scaling_config = ScalingConfig(
  43. use_tpu=True,
  44. num_workers=4,
  45. topology="4x4",
  46. accelerator_type="TPU-V6E",
  47. placement_strategy="SPREAD",
  48. resources_per_worker={"TPU": 4},
  49. )
  50. # If you want to use GPUs, specify the GPU scaling config like below.
  51. # gpu_scaling_config = ScalingConfig(
  52. # use_gpu=True,
  53. # num_workers=4,
  54. # resources_per_worker={"GPU": 1},
  55. # )
  56. trainer = JaxTrainer(
  57. train_loop_per_worker=train_loop_per_worker,
  58. train_loop_config={"argv": absolute_argv},
  59. scaling_config=tpu_scaling_config,
  60. run_config=RunConfig(
  61. name="maxtext_jaxtrainer",
  62. worker_runtime_env={
  63. "env_vars": {
  64. "JAX_PLATFORMS": "tpu",
  65. # If you want to use GPUs, set the JAX_PLATFORMS to "cuda".
  66. # "JAX_PLATFORMS": "cuda",
  67. }
  68. },
  69. ),
  70. )
  71. result = trainer.fit()
  72. If the ``datasets`` dict contains datasets (e.g. "train" and "val"), then it will be split into multiple dataset
  73. shards that can then be accessed by ``ray.train.get_dataset_shard("train")`` and ``ray.train.get_dataset_shard("val")``.
  74. Note:
  75. * If you are using TPUs, importing `jax` should occur within `train_loop_per_worker` to
  76. avoid driver-side TPU lock issues.
  77. Args:
  78. train_loop_per_worker: The training function to execute on each worker.
  79. This function can either take in zero arguments or a single ``Dict``
  80. argument which is set by defining ``train_loop_config``.
  81. Within this function you can use any of the
  82. :ref:`Ray Train Loop utilities <train-loop-api>`.
  83. train_loop_config: A configuration ``Dict`` to pass in as an argument to
  84. ``train_loop_per_worker``.
  85. This is typically used for specifying hyperparameters. Passing large
  86. datasets via `train_loop_config` is not recommended and may introduce
  87. large overhead and unknown issues with serialization and deserialization.
  88. jax_config: The configuration for setting up the JAX backend.
  89. If set to None, a default configuration will be used based on the ``scaling_config`` and ``JAX_PLATFORMS`` environment variable.
  90. scaling_config: Configuration for how to scale data parallel training
  91. with SPMD. ``num_workers`` should be set to the number of TPU hosts or GPU workers.
  92. If using TPUs, ``topology`` should be set to the TPU topology.
  93. See :class:`~ray.train.ScalingConfig` for more info.
  94. dataset_config: The configuration for ingesting the input ``datasets``.
  95. By default, all the Ray Dataset are split equally across workers.
  96. See :class:`~ray.train.DataConfig` for more details.
  97. run_config: The configuration for the execution of the training run.
  98. See :class:`~ray.train.RunConfig` for more info.
  99. datasets: The Ray Datasets to ingest for training.
  100. Datasets are keyed by name (``{name: dataset}``).
  101. Each dataset can be accessed from within the ``train_loop_per_worker``
  102. by calling ``ray.train.get_dataset_shard(name)``.
  103. Sharding and additional configuration can be done by
  104. passing in a ``dataset_config``.
  105. validation_config: [Alpha] Configuration for checkpoint validation.
  106. If provided and ``ray.train.report`` is called with the ``validation``
  107. argument, Ray Train will validate the reported checkpoint using
  108. the validation function specified in this config.
  109. """
  110. def __init__(
  111. self,
  112. train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
  113. *,
  114. train_loop_config: Optional[Dict] = None,
  115. jax_config: Optional[JaxConfig] = None,
  116. scaling_config: Optional[ScalingConfig] = None,
  117. dataset_config: Optional[Dict[str, DataConfig]] = None,
  118. run_config: Optional[RunConfig] = None,
  119. datasets: Optional[Dict[str, GenDataset]] = None,
  120. validation_config: Optional[ValidationConfig] = None,
  121. ):
  122. if not jax_config:
  123. jax_config = JaxConfig(
  124. use_tpu=scaling_config.use_tpu,
  125. use_gpu=scaling_config.use_gpu,
  126. )
  127. super(JaxTrainer, self).__init__(
  128. train_loop_per_worker=train_loop_per_worker,
  129. train_loop_config=train_loop_config,
  130. backend_config=jax_config,
  131. scaling_config=scaling_config,
  132. dataset_config=dataset_config,
  133. run_config=run_config,
  134. datasets=datasets,
  135. validation_config=validation_config,
  136. )
  137. @classmethod
  138. def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfig:
  139. """Return scaling config dataclass after validating updated keys."""
  140. ensure_only_allowed_dataclass_keys_updated(
  141. dataclass=scaling_config,
  142. allowed_keys=cls._scaling_config_allowed_keys,
  143. )
  144. return scaling_config