join.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from abc import ABC, abstractmethod
  4. from types import TracebackType
  5. from typing import Any, NamedTuple
  6. import torch
  7. import torch.distributed as dist
  8. __all__ = ["JoinHook", "Joinable", "Join"]
  9. class JoinHook:
  10. r"""
  11. This defines a join hook, which provides two entry points in the join context manager.
  12. Entry points : a main hook, which is called repeatedly while there exists a non-joined
  13. process, and a post-hook, which is called once all processes have joined.
  14. To implement a join hook for the generic join context manager, define a
  15. class that inherits from :class:`JoinHook` and override ``main_hook()`` and
  16. ``post_hook()`` as appropriate.
  17. """
  18. def main_hook(self) -> None:
  19. r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration.
  20. Training iteration i.e., in one forward pass, backward pass, and optimizer step.
  21. """
  22. def post_hook(self, is_last_joiner: bool) -> None:
  23. r"""
  24. Call hook after all processes have joined.
  25. It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join.
  26. Arguments:
  27. is_last_joiner (bool): ``True`` if the rank is one of the last to
  28. join; ``False`` otherwise.
  29. """
  30. class Joinable(ABC):
  31. r"""
  32. This defines an abstract base class for joinable classes.
  33. A joinable class
  34. (inheriting from :class:`Joinable`) should implement :meth:`join_hook`,
  35. which returns a :class:`JoinHook` instance, in addition to
  36. :meth:`join_device` and :meth:`join_process_group` that return device and
  37. process group information, respectively.
  38. """
  39. @abstractmethod
  40. def __init__(self) -> None:
  41. super().__init__()
  42. self._join_config = _JoinConfig.construct_disabled_join_config()
  43. @abstractmethod
  44. def join_hook(self, **kwargs) -> JoinHook:
  45. r"""
  46. Return a :class:`JoinHook` instance for the given :class:`Joinable`.
  47. Arguments:
  48. kwargs (dict): a :class:`dict` containing any keyword arguments
  49. to modify the behavior of the join hook at run time; all
  50. :class:`Joinable` instances sharing the same join context
  51. manager are forwarded the same value for ``kwargs``.
  52. """
  53. ...
  54. @property
  55. @abstractmethod
  56. def join_device(self) -> torch.device:
  57. r"""Return the device from which to perform collective communications needed by the join context manager."""
  58. ...
  59. @property
  60. @abstractmethod
  61. def join_process_group(self) -> Any:
  62. r"""Returns the process group for the collective communications needed by the join context manager itself."""
  63. ...
  64. class _JoinConfig(NamedTuple):
  65. r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side."""
  66. enable: bool
  67. throw_on_early_termination: bool
  68. is_first_joinable: bool
  69. @staticmethod
  70. def construct_disabled_join_config():
  71. r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled.
  72. e.g. if the caller is not in a join context manager.
  73. """
  74. return _JoinConfig(
  75. enable=False, throw_on_early_termination=False, is_first_joinable=False
  76. )
  77. class Join:
  78. r"""
  79. This class defines the generic join context manager, which allows custom hooks to be called after a process joins.
  80. These hooks should shadow the
  81. collective communications of non-joined processes to prevent hanging and
  82. erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook`
  83. for details about the hook definition.
  84. .. warning::
  85. The context manager requires each participating :class:`Joinable` to
  86. call the method :meth:`notify_join_context()` before its own per-
  87. iteration collective communications to ensure correctness.
  88. .. warning::
  89. The context manager requires that all ``process_group`` attributes in
  90. the :class:`JoinHook` objects are the same. If there are multiple
  91. :class:`JoinHook` objects, then the ``device`` of the first is used.
  92. The process group and device information is used for checking for non-
  93. joined processes and for notifying processes to throw an exception if
  94. ``throw_on_early_termination`` is enabled, both of which using an all-
  95. reduce.
  96. Arguments:
  97. joinables (List[Joinable]): a list of the participating
  98. :class:`Joinable` s; their hooks are iterated over in the given
  99. order.
  100. enable (bool): a flag enabling uneven input detection; setting to
  101. ``False`` disables the context manager's functionality and should
  102. only be set when the user knows the inputs will not be uneven
  103. (default: ``True``).
  104. throw_on_early_termination (bool): a flag controlling whether to throw an
  105. exception upon detecting uneven inputs (default: ``False``).
  106. Example::
  107. >>> import os
  108. >>> import torch
  109. >>> import torch.distributed as dist
  110. >>> import torch.multiprocessing as mp
  111. >>> # xdoctest: +SKIP
  112. >>> import torch.nn.parallel.DistributedDataParallel as DDP
  113. >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
  114. >>> from torch.distributed.algorithms.join import Join
  115. >>>
  116. >>> # On each spawned worker
  117. >>> def worker(rank):
  118. >>> dist.init_process_group("nccl", rank=rank, world_size=2)
  119. >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
  120. >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
  121. >>> # Rank 1 gets one more input than rank 0
  122. >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
  123. >>> with Join([model, optim]):
  124. >>> for input in inputs:
  125. >>> loss = model(input).sum()
  126. >>> loss.backward()
  127. >>> optim.step()
  128. >>> # All ranks reach here without hanging/erroring
  129. """
  130. def __init__(
  131. self,
  132. joinables: list[Joinable],
  133. enable: bool = True,
  134. throw_on_early_termination: bool = False,
  135. **kwargs,
  136. ):
  137. if len(joinables) == 0:
  138. raise ValueError("The join context manager requires at least one joinable")
  139. self._joinables = joinables
  140. self._join_hooks = [
  141. joinable.join_hook(**kwargs) for joinable in self._joinables
  142. ]
  143. self._enable = enable
  144. self._throw_on_early_termination = throw_on_early_termination
  145. self._set_joinable_configs()
  146. self._extract_dist_info()
  147. def _set_joinable_configs(self) -> None:
  148. r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`."""
  149. assert len(self._joinables) > 0
  150. is_first_joinable = True
  151. for joinable in self._joinables:
  152. joinable._join_config = _JoinConfig(
  153. enable=self._enable,
  154. throw_on_early_termination=self._throw_on_early_termination,
  155. is_first_joinable=is_first_joinable,
  156. )
  157. is_first_joinable = False
  158. def _extract_dist_info(self) -> None:
  159. r"""
  160. Extract the process group and device information from the joinables.
  161. If there are multiple joinables, then the context manager uses the
  162. first specified device.
  163. Preconditions:
  164. ``self._joinables`` is not ``None`` and is non-empty.
  165. Raises:
  166. ValueError
  167. If there are multiple conflicting ``process_group`` attributes
  168. among the ``Joinable`` objects.
  169. """
  170. process_group = None
  171. device = None
  172. # pyrefly: ignore [bad-assignment]
  173. for joinable in self._joinables:
  174. if process_group is None:
  175. process_group = joinable.join_process_group
  176. elif process_group != joinable.join_process_group:
  177. raise ValueError(
  178. "Using join context manager with multiple process groups"
  179. )
  180. if device is None:
  181. device = joinable.join_device
  182. self._process_group = process_group
  183. self._rank = dist.get_rank(self._process_group)
  184. self._device = device
  185. def __enter__(self): ...
  186. def __exit__(
  187. self,
  188. type: type[BaseException] | None,
  189. value: BaseException | None,
  190. traceback: TracebackType | None,
  191. ):
  192. r"""
  193. Repeatedly runs the main hooks until all processes join; then, runs the post-hooks.
  194. Raises:
  195. RuntimeError
  196. If ``throw_on_early_termination=True``.
  197. """
  198. if not self._enable or type:
  199. return # propagate the exception directly if one was raised
  200. all_procs_joined = False
  201. is_last_joiner = True
  202. i = 0
  203. WARN_THRESHOLD = 1000
  204. warnings.simplefilter("once")
  205. while not all_procs_joined:
  206. if i > WARN_THRESHOLD:
  207. warnings.warn(
  208. "Detected uneven input skew of greater than "
  209. f"{WARN_THRESHOLD}. This means that rank "
  210. f"{self._rank} has at least {WARN_THRESHOLD} "
  211. f"fewer inputs than other currently-active ranks. "
  212. "This level of skew could lead to performance "
  213. "degradation during training.",
  214. stacklevel=2,
  215. )
  216. # Shadow the all-reduce in non-joined processes
  217. num_nonjoined_procs = self._get_num_nonjoined_procs()
  218. if num_nonjoined_procs == 0:
  219. all_procs_joined = True
  220. else:
  221. if self._throw_on_early_termination:
  222. self._notify_procs_to_terminate()
  223. # Run main hooks
  224. for join_hook in self._join_hooks:
  225. join_hook.main_hook()
  226. is_last_joiner = False
  227. i += 1
  228. # Run post-hooks
  229. for join_hook in self._join_hooks:
  230. join_hook.post_hook(is_last_joiner)
  231. def _get_num_nonjoined_procs(self):
  232. r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes."""
  233. num_nonjoined_procs = torch.zeros(1, device=self._device)
  234. dist.all_reduce(num_nonjoined_procs, group=self._process_group)
  235. return num_nonjoined_procs.item()
  236. def _notify_procs_to_terminate(self):
  237. r"""Schedule an all-reduce to notify non-joined processes to terminate.
  238. Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs.
  239. """
  240. ones = torch.ones(1, device=self._device)
  241. dist.all_reduce(ones, group=self._process_group)
  242. raise RuntimeError(f"Rank {self._rank} exhausted all inputs.")
  243. @staticmethod
  244. def notify_join_context(joinable: Joinable):
  245. r"""
  246. Notifies the join context manager that the calling process has not yet joined.
  247. Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected
  248. (i.e. if one process has already joined) and throws an exception if so.
  249. This method should be called from a :class:`Joinable` object before
  250. its per-iteration collective communications. For example, this should
  251. be called at the beginning of the forward pass in
  252. :class:`DistributedDataParallel`.
  253. Only the first :class:`Joinable` object passed into the context
  254. manager performs the collective communications in this method, and
  255. for the others, this method is vacuous.
  256. Arguments:
  257. joinable (Joinable): the :class:`Joinable` object calling this
  258. method.
  259. Returns:
  260. An async work handle for the all-reduce meant to notify the context
  261. manager that the process has not yet joined if ``joinable`` is the
  262. first one passed into the context manager; ``None`` otherwise.
  263. """
  264. assert hasattr(joinable, "_join_config"), (
  265. f"Check that the {type(joinable)} constructor calls the "
  266. "``Joinable`` constructor"
  267. )
  268. join_config = joinable._join_config
  269. # First joinable is responsible for the collective communications
  270. if not join_config.is_first_joinable or not join_config.enable:
  271. return None
  272. device = joinable.join_device
  273. process_group = joinable.join_process_group
  274. # Schedule an all-reduce to indicate that the caller has not yet joined
  275. ones = torch.ones(1, device=device)
  276. work = dist.all_reduce(ones, group=process_group, async_op=True)
  277. if join_config.throw_on_early_termination:
  278. # Check if uneven inputs have been detected
  279. zeros = torch.zeros(1, device=device)
  280. dist.all_reduce(zeros, group=process_group)
  281. should_throw = zeros.item()
  282. if should_throw:
  283. raise RuntimeError(
  284. "Detected at least one rank that exhausted inputs. "
  285. "Throwing across all ranks."
  286. )
  287. return work