_lightning_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import logging
  2. import os
  3. import shutil
  4. import tempfile
  5. from pathlib import Path
  6. from typing import Any, Dict
  7. import torch
  8. from packaging.version import Version
  9. import ray
  10. import ray.train
  11. from ray._common.usage.usage_lib import TagKey, record_extra_usage_tag
  12. from ray.train import Checkpoint
  13. from ray.util import PublicAPI
  14. def import_lightning(): # noqa: F402
  15. try:
  16. import lightning.pytorch as pl
  17. except ModuleNotFoundError:
  18. import pytorch_lightning as pl
  19. return pl
  20. pl = import_lightning()
  21. _LIGHTNING_GREATER_EQUAL_2_0 = Version(pl.__version__) >= Version("2.0.0")
  22. _LIGHTNING_LESS_THAN_2_1 = Version(pl.__version__) < Version("2.1.0")
  23. _TORCH_GREATER_EQUAL_1_12 = Version(torch.__version__) >= Version("1.12.0")
  24. _TORCH_FSDP_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available()
  25. try:
  26. from lightning.pytorch.plugins.environments import LightningEnvironment
  27. except ModuleNotFoundError:
  28. from pytorch_lightning.plugins.environments import LightningEnvironment
  29. if _LIGHTNING_GREATER_EQUAL_2_0:
  30. FSDPStrategy = pl.strategies.FSDPStrategy
  31. else:
  32. FSDPStrategy = pl.strategies.DDPFullyShardedStrategy
  33. if _TORCH_FSDP_AVAILABLE:
  34. from torch.distributed.fsdp import (
  35. FullStateDictConfig,
  36. FullyShardedDataParallel,
  37. StateDictType,
  38. )
  39. logger = logging.getLogger(__name__)
  40. LIGHTNING_REPORT_STAGE_KEY = "_report_on"
  41. @PublicAPI(stability="beta")
  42. class RayDDPStrategy(pl.strategies.DDPStrategy):
  43. """Subclass of DDPStrategy to ensure compatibility with Ray orchestration.
  44. For a full list of initialization arguments, please refer to:
  45. https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.DDPStrategy.html
  46. Note that `process_group_backend`, `timeout`, and `start_method` are disabled here,
  47. please specify these arguments in :class:`~ray.train.torch.TorchConfig` instead.
  48. """
  49. def __init__(self, *args, **kwargs):
  50. super().__init__(*args, **kwargs)
  51. record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYDDPSTRATEGY, "1")
  52. @property
  53. def root_device(self) -> torch.device:
  54. return ray.train.torch.get_device()
  55. @property
  56. def distributed_sampler_kwargs(self) -> Dict[str, Any]:
  57. return dict(
  58. num_replicas=self.world_size,
  59. rank=self.global_rank,
  60. )
  61. @PublicAPI(stability="beta")
  62. class RayFSDPStrategy(FSDPStrategy): # noqa: F821
  63. """Subclass of FSDPStrategy to ensure compatibility with Ray orchestration.
  64. For a full list of initialization arguments, please refer to:
  65. https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.FSDPStrategy.html
  66. .. note::
  67. It is recommended to upgrade `lightning>=2.1` or above when using FSDP
  68. with Lightning, since Lightning starts to natively support `state_dict_type`,
  69. `sharding_strategy`, `auto_wrap_policy` and other FSDP configurations from 2.1.
  70. """
  71. def __init__(self, *args, **kwargs):
  72. super().__init__(*args, **kwargs)
  73. record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYFSDPSTRATEGY, "1")
  74. @property
  75. def root_device(self) -> torch.device:
  76. return ray.train.torch.get_device()
  77. @property
  78. def distributed_sampler_kwargs(self) -> Dict[str, Any]:
  79. return dict(
  80. num_replicas=self.world_size,
  81. rank=self.global_rank,
  82. )
  83. def lightning_module_state_dict(self) -> Dict[str, Any]:
  84. """Gathers the full state dict to rank 0 on CPU.
  85. FSDP checkpointing is broken in Lightning 2.0.x. This subclass patches the
  86. behavior to perform a full state dict checkpointing, gathering the checkpoint
  87. shards on rank 0 CPU. Upgrade to `lightning>=2.1` to do sharded state dict
  88. checkpointing.
  89. See the note in the class docstring for more details.
  90. """
  91. assert self.model is not None, "Failed to get the state dict for a None model!"
  92. if (
  93. _TORCH_FSDP_AVAILABLE
  94. and _LIGHTNING_GREATER_EQUAL_2_0
  95. and _LIGHTNING_LESS_THAN_2_1
  96. ):
  97. with FullyShardedDataParallel.state_dict_type(
  98. module=self.model,
  99. state_dict_type=StateDictType.FULL_STATE_DICT,
  100. state_dict_config=FullStateDictConfig(
  101. offload_to_cpu=True, rank0_only=True
  102. ),
  103. ):
  104. state_dict = self.model.state_dict()
  105. ckpt_state_dict = {}
  106. prefix_len = len("_forward_module.")
  107. for k, v in state_dict.items():
  108. if k.startswith("_forward_module."):
  109. non_prefixed_key = k[prefix_len:]
  110. ckpt_state_dict[non_prefixed_key] = v
  111. else:
  112. ckpt_state_dict[k] = v
  113. return ckpt_state_dict
  114. else:
  115. # Otherwise Lightning uses Fairscale FSDP, no need to unshard by ourself.
  116. return super().lightning_module_state_dict()
  117. @PublicAPI(stability="beta")
  118. class RayDeepSpeedStrategy(pl.strategies.DeepSpeedStrategy):
  119. """Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration.
  120. For a full list of initialization arguments, please refer to:
  121. https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.DeepSpeedStrategy.html
  122. """
  123. def __init__(self, *args, **kwargs):
  124. super().__init__(*args, **kwargs)
  125. record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYDEEPSPEEDSTRATEGY, "1")
  126. @property
  127. def root_device(self) -> torch.device:
  128. return ray.train.torch.get_device()
  129. @property
  130. def distributed_sampler_kwargs(self) -> Dict[str, Any]:
  131. return dict(
  132. num_replicas=self.world_size,
  133. rank=self.global_rank,
  134. )
  135. @PublicAPI(stability="beta")
  136. class RayLightningEnvironment(LightningEnvironment): # noqa: F821
  137. """Setup Lightning DDP training environment for Ray cluster."""
  138. def __init__(self, *args, **kwargs):
  139. super().__init__(*args, **kwargs)
  140. record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYLIGHTNINGENVIRONMENT, "1")
  141. def world_size(self) -> int:
  142. return ray.train.get_context().get_world_size()
  143. def global_rank(self) -> int:
  144. return ray.train.get_context().get_world_rank()
  145. def local_rank(self) -> int:
  146. return ray.train.get_context().get_local_rank()
  147. def node_rank(self) -> int:
  148. return ray.train.get_context().get_node_rank()
  149. def set_world_size(self, size: int) -> None:
  150. # Disable it since `world_size()` directly returns data from Train context.
  151. pass
  152. def set_global_rank(self, rank: int) -> None:
  153. # Disable it since `global_rank()` directly returns data from Train.
  154. pass
  155. def teardown(self):
  156. pass
  157. @PublicAPI(stability="beta")
  158. def prepare_trainer(trainer: pl.Trainer) -> pl.Trainer:
  159. """Prepare the PyTorch Lightning Trainer for distributed execution."""
  160. # Check strategy class
  161. valid_strategy_class = [RayDDPStrategy, RayFSDPStrategy, RayDeepSpeedStrategy]
  162. if not any(isinstance(trainer.strategy, cls) for cls in valid_strategy_class):
  163. raise RuntimeError(
  164. f"Invalid strategy class: {type(trainer.strategy)}. To use "
  165. "PyTorch Lightning with Ray, the strategy object should be one of "
  166. f"{[cls.__name__ for cls in valid_strategy_class]} class "
  167. "or its subclass."
  168. )
  169. # Check cluster environment
  170. cluster_environment = getattr(trainer.strategy, "cluster_environment", None)
  171. if cluster_environment and not isinstance(
  172. cluster_environment, RayLightningEnvironment
  173. ):
  174. raise RuntimeError(
  175. "Invalid cluster environment plugin. The expected class is"
  176. "`ray.train.lightning.RayLightningEnvironment` "
  177. f"but got {type(cluster_environment)}!"
  178. )
  179. record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_PREPARE_TRAINER, "1")
  180. return trainer
  181. @PublicAPI(stability="beta")
  182. class RayTrainReportCallback(pl.callbacks.Callback):
  183. """A simple callback that reports checkpoints to Ray on train epoch end.
  184. This callback is a subclass of `lightning.pytorch.callbacks.Callback
  185. <https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.Callback.html#lightning.pytorch.callbacks.Callback>`_.
  186. It fetches the latest `trainer.callback_metrics` and reports together with
  187. the checkpoint on each training epoch end.
  188. Checkpoints will be saved in the following structure::
  189. checkpoint_00000*/ Ray Train Checkpoint
  190. └─ checkpoint.ckpt PyTorch Lightning Checkpoint
  191. For customized reporting and checkpointing logic, implement your own
  192. `lightning.pytorch.callbacks.Callback` following this user
  193. guide: :ref:`Saving and Loading Checkpoints <train-dl-saving-checkpoints>`.
  194. """
  195. CHECKPOINT_NAME = "checkpoint.ckpt"
  196. def __init__(self) -> None:
  197. super().__init__()
  198. job_id = ray.get_runtime_context().get_job_id()
  199. experiment_name = ray.train.get_context().get_experiment_name()
  200. self.local_rank = ray.train.get_context().get_local_rank()
  201. self.tmpdir_prefix = Path(
  202. tempfile.gettempdir(),
  203. f"lightning_checkpoints-job_id={job_id}-name={experiment_name}",
  204. ).as_posix()
  205. if os.path.isdir(self.tmpdir_prefix) and self.local_rank == 0:
  206. shutil.rmtree(self.tmpdir_prefix)
  207. record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYTRAINREPORTCALLBACK, "1")
  208. def on_train_epoch_end(self, trainer, pl_module) -> None:
  209. # Creates a checkpoint dir with fixed name
  210. tmpdir = Path(self.tmpdir_prefix, str(trainer.current_epoch)).as_posix()
  211. os.makedirs(tmpdir, exist_ok=True)
  212. # Fetch metrics
  213. metrics = trainer.callback_metrics
  214. metrics = {k: v.item() for k, v in metrics.items()}
  215. # (Optional) Add customized metrics
  216. metrics["epoch"] = trainer.current_epoch
  217. metrics["step"] = trainer.global_step
  218. # Save checkpoint to local
  219. ckpt_path = Path(tmpdir, self.CHECKPOINT_NAME).as_posix()
  220. trainer.save_checkpoint(ckpt_path, weights_only=False)
  221. # Report to train session
  222. checkpoint = Checkpoint.from_directory(tmpdir)
  223. ray.train.report(metrics=metrics, checkpoint=checkpoint)
  224. # Add a barrier to ensure all workers finished reporting here
  225. trainer.strategy.barrier()
  226. if self.local_rank == 0:
  227. shutil.rmtree(tmpdir)