replicate.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # mypy: allow-untyped-defs
  2. import weakref
  3. from collections.abc import Iterable
  4. from typing import Any, NoReturn
  5. import torch
  6. import torch.nn as nn
  7. from torch.distributed._composable_state import _State
  8. from torch.nn.parallel import DistributedDataParallel
  9. from .contract import _get_registry, contract
  10. _ROOT_MODULE_PREFIX = ""
  11. class _ReplicateState(_State):
  12. _ddp_weakref: weakref.ref
  13. def __init__(self) -> None:
  14. super().__init__()
  15. self.module: nn.Module = nn.ParameterList()
  16. self.has_initialized: bool = False
  17. self._param_list: nn.ParameterList = nn.ParameterList()
  18. # TODO(@fegin): this variable is originally create for testing, we
  19. # should remove this if possible.
  20. self._orig_module = self.module
  21. self._param_names: list[str] = []
  22. self._no_sync: bool = False
  23. self._init_args: tuple[Any, ...] | None = None
  24. self._init_kwargs: dict[str, Any] = {}
  25. self._comm_hook_args: list[Any] = []
  26. def _collect_params(
  27. self,
  28. module: nn.Module,
  29. ignored_modules: set[nn.Module],
  30. ignored_params: set[nn.Parameter],
  31. prefix: str = _ROOT_MODULE_PREFIX,
  32. ) -> None:
  33. # skip if managed by fully_sharded API
  34. if _is_fully_sharded(module):
  35. return
  36. # if a module is ignored, all descendants of the module are ignored.
  37. if module in ignored_modules:
  38. return
  39. recurse_prefix = (
  40. f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX
  41. )
  42. for n, p in module.named_parameters(recurse=False):
  43. if p not in ignored_params:
  44. self._param_list.append(p)
  45. self._param_names.append(f"{recurse_prefix}{n}")
  46. for name, child_module in module.named_children():
  47. self._collect_params(
  48. child_module,
  49. ignored_modules,
  50. ignored_params,
  51. prefix=f"{recurse_prefix}{name}",
  52. )
  53. def lazy_init(self) -> None:
  54. @torch._disable_dynamo(recursive=True)
  55. def _lazy_init():
  56. assert self._init_args is not None
  57. self.init(*self._init_args, **self._init_kwargs)
  58. self.register_comm_hook()
  59. self._init_args = ()
  60. self._init_kwargs = {}
  61. _lazy_init()
  62. def init(
  63. self,
  64. module: nn.Module,
  65. ignored_modules: set[nn.Module],
  66. **kwargs,
  67. ) -> None:
  68. if self.has_initialized:
  69. return
  70. self.has_initialized = True
  71. self.module = module
  72. ignored_params = {p for m in ignored_modules for p in m.parameters()}
  73. for submodule in module.modules():
  74. if _is_fully_sharded(submodule):
  75. ignored_params.update(submodule.parameters())
  76. from torch.distributed.tensor.parallel.ddp import _localize_dtensor
  77. _localize_dtensor(module, ignored_params=ignored_params)
  78. self._collect_params(module, ignored_modules, ignored_params)
  79. if "device_id" in kwargs:
  80. # replicate() supports a small usability enhancement where
  81. # user can pass in device_id as a Union[int, torch.device] even for
  82. # CPU devices so users don't have to change code for CPU/GPU runs.
  83. # We derive the right device_ids to feed into DDP to support this.
  84. if kwargs["device_id"] is not None:
  85. device_id = kwargs["device_id"]
  86. # Convert to device_ids that DDP expects.
  87. if isinstance(device_id, torch.device) and device_id.type == "cpu":
  88. # CPU modules receive device_ids None
  89. kwargs["device_ids"] = None
  90. else:
  91. # GPU modules expect device_ids=[cuda_device]
  92. kwargs["device_ids"] = [device_id]
  93. else:
  94. kwargs["device_ids"] = None
  95. kwargs.pop("device_id")
  96. self._ddp = DistributedDataParallel(self._param_list, **kwargs)
  97. # Weakref to the DDP instance is currently only used for testing.
  98. replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp)
  99. def register_comm_hook(self) -> None:
  100. for comm_args, comm_kwargs in self._comm_hook_args:
  101. self._ddp.register_comm_hook(*comm_args, **comm_kwargs)
  102. self._comm_hook_args.clear()
  103. def record_init_args(self, *args, **kwargs) -> None:
  104. self._init_args = args
  105. self._init_kwargs = kwargs
  106. def forward_pre_hook(
  107. self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
  108. ) -> Any:
  109. if self._init_args or self._init_kwargs:
  110. self.lazy_init()
  111. self._ddp.require_backward_grad_sync = not self._no_sync
  112. DistributedDataParallel._active_ddp_module = self._ddp
  113. return self._ddp._pre_forward(*args, **kwargs)
  114. def forward_post_hook(
  115. self,
  116. module: nn.Module,
  117. input: tuple[torch.Tensor],
  118. output: torch.Tensor,
  119. ) -> torch.Tensor:
  120. DistributedDataParallel._active_ddp_module = None
  121. return self._ddp._post_forward(output)
  122. def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
  123. raise AssertionError(
  124. "DDP does not support deepcopy. Please use state dict for serialization."
  125. )
  126. # Follow the same pattern as FSDP/fully_shard
  127. class DDP:
  128. def __new__(cls, *args, **kwargs):
  129. """
  130. Override ``__new__`` to remove the DDP class and directly construct
  131. the original class for cases like indexing into a container module.
  132. """
  133. # Use index 2 since 0 is the dynamically constructed `DDP<...>` class
  134. # and index 1 is the `DDP` class itself
  135. orig_cls = cls.__mro__[2]
  136. return orig_cls.__new__(orig_cls, *args, **kwargs)
  137. def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None:
  138. """
  139. Sets if the module should sync gradients. This can be used to implement
  140. gradient accumulation without communication.
  141. Args:
  142. requires_gradient_sync (bool): Whether to reduce gradients for the
  143. module's parameters.
  144. """
  145. replicate.state(self)._no_sync = not requires_gradient_sync # type: ignore[arg-type]
  146. def register_comm_hook(self, *args, **kwargs) -> None:
  147. replicate.state(self)._comm_hook_args.append((args, kwargs)) # type: ignore[arg-type]
  148. @contract(state_cls=_ReplicateState)
  149. def replicate(
  150. module: nn.Module,
  151. ignored_modules: Iterable[torch.nn.Module] | None = None,
  152. **kwargs,
  153. ) -> nn.Module:
  154. r"""Replicates a module
  155. Args:
  156. module (torch.nn.Module): module to replicate
  157. Example::
  158. >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
  159. >>> module = nn.Linear(3, 3)
  160. >>> replicate(module)
  161. """
  162. torch._C._log_api_usage_once("torch.distributed.replicate")
  163. # TODO(fegin): using kwargs is not a good idea if we would like to make
  164. # replicate a formal API to replace DDP.
  165. if "device_id" in kwargs:
  166. if not isinstance(kwargs["device_id"], (int, torch.device)):
  167. raise RuntimeError(
  168. "Expected device_id to be int or torch.device, "
  169. f"but got {type(kwargs['device_id'])}"
  170. )
  171. if _is_fully_sharded(module):
  172. raise RuntimeError(
  173. "Cannot apply `replicate()` on a Module already managed by `fully_shard`"
  174. )
  175. if ignored_modules is None:
  176. ignored_modules = {}
  177. else:
  178. ignored_modules = set(ignored_modules)
  179. state = replicate.state(module)
  180. module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
  181. device_mesh = kwargs.get("device_mesh")
  182. if device_mesh is not None:
  183. root_mesh = device_mesh._get_root_mesh()
  184. # if a root mesh is not the same as device_mesh,
  185. # meaning the device_mesh is sliced out from the root mesh.
  186. if root_mesh != device_mesh:
  187. # TODO: This is a temporary work around to enable DDP + TP.
  188. # We should do the logic in DDP so that the 2D implementation is
  189. # sound and the state_dict works out of the box.
  190. #
  191. # This won't conflict with what is done in DDP class as the module
  192. # replicate is going to pass is NOT the original module.
  193. from torch.distributed.tensor.parallel.ddp import (
  194. _localize_dtensor,
  195. _reconstruct_dtensor,
  196. )
  197. module.register_forward_pre_hook(_reconstruct_dtensor)
  198. module.register_forward_hook(_localize_dtensor)
  199. module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type]
  200. state.record_init_args(module, ignored_modules, **kwargs)
  201. # Place DDP leftmost for highest priority in the method resolution order
  202. cls = module.__class__
  203. dct = {"__deepcopy__": unimplemented_deepcopy}
  204. new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct)
  205. module.__class__ = new_cls
  206. return module
  207. def _is_fully_sharded(module: nn.Module) -> bool:
  208. r"""Check if module is marked with fully_shard."""
  209. registry = _get_registry(module)
  210. if registry is None:
  211. return False
  212. return "fully_shard" in registry