data_parallel.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. # mypy: allow-untyped-defs
  2. import operator
  3. import warnings
  4. from collections.abc import Sequence
  5. from itertools import chain
  6. from typing import Any, Generic, TypeVar
  7. import torch
  8. from torch._utils import (
  9. _get_all_device_indices,
  10. _get_available_device_type,
  11. _get_device_index,
  12. _get_devices_properties,
  13. )
  14. from torch.nn.modules import Module
  15. from torch.nn.parallel.parallel_apply import parallel_apply
  16. from torch.nn.parallel.replicate import replicate
  17. from torch.nn.parallel.scatter_gather import gather, scatter_kwargs
  18. __all__ = ["DataParallel", "data_parallel"]
  19. def _check_balance(device_ids: Sequence[int | torch.device]) -> None:
  20. imbalance_warn = """
  21. There is an imbalance between your GPUs. You may want to exclude GPU {} which
  22. has less than 75% of the memory or cores of GPU {}. You can do so by setting
  23. the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
  24. environment variable."""
  25. device_ids = [_get_device_index(x, True) for x in device_ids]
  26. dev_props = _get_devices_properties(device_ids)
  27. def warn_imbalance(get_prop) -> bool:
  28. values = [get_prop(props) for props in dev_props]
  29. min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
  30. max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
  31. if min_val / max_val < 0.75:
  32. warnings.warn(
  33. imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]),
  34. stacklevel=2,
  35. )
  36. return True
  37. return False
  38. if warn_imbalance(lambda props: props.total_memory):
  39. return
  40. if warn_imbalance(lambda props: props.multi_processor_count):
  41. return
  42. T = TypeVar("T", bound=Module)
  43. class DataParallel(Module, Generic[T]):
  44. r"""Implements data parallelism at the module level.
  45. This container parallelizes the application of the given :attr:`module` by
  46. splitting the input across the specified devices by chunking in the batch
  47. dimension (other objects will be copied once per device). In the forward
  48. pass, the module is replicated on each device, and each replica handles a
  49. portion of the input. During the backwards pass, gradients from each replica
  50. are summed into the original module.
  51. The batch size should be larger than the number of GPUs used.
  52. .. warning::
  53. It is recommended to use :class:`~torch.nn.parallel.DistributedDataParallel`,
  54. instead of this class, to do multi-GPU training, even if there is only a single
  55. node. See: :ref:`cuda-nn-ddp-instead` and :ref:`ddp`.
  56. Arbitrary positional and keyword inputs are allowed to be passed into
  57. DataParallel but some types are specially handled. tensors will be
  58. **scattered** on dim specified (default 0). tuple, list and dict types will
  59. be shallow copied. The other types will be shared among different threads
  60. and can be corrupted if written to in the model's forward pass.
  61. The parallelized :attr:`module` must have its parameters and buffers on
  62. ``device_ids[0]`` before running this :class:`~torch.nn.DataParallel`
  63. module.
  64. .. warning::
  65. In each forward, :attr:`module` is **replicated** on each device, so any
  66. updates to the running module in ``forward`` will be lost. For example,
  67. if :attr:`module` has a counter attribute that is incremented in each
  68. ``forward``, it will always stay at the initial value because the update
  69. is done on the replicas which are destroyed after ``forward``. However,
  70. :class:`~torch.nn.DataParallel` guarantees that the replica on
  71. ``device[0]`` will have its parameters and buffers sharing storage with
  72. the base parallelized :attr:`module`. So **in-place** updates to the
  73. parameters or buffers on ``device[0]`` will be recorded. E.g.,
  74. :class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm`
  75. rely on this behavior to update the buffers.
  76. .. warning::
  77. Forward and backward hooks defined on :attr:`module` and its submodules
  78. will be invoked ``len(device_ids)`` times, each with inputs located on
  79. a particular device. Particularly, the hooks are only guaranteed to be
  80. executed in correct order with respect to operations on corresponding
  81. devices. For example, it is not guaranteed that hooks set via
  82. :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before
  83. `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but
  84. that each such hook be executed before the corresponding
  85. :meth:`~torch.nn.Module.forward` call of that device.
  86. .. warning::
  87. When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
  88. :func:`forward`, this wrapper will return a vector of length equal to
  89. number of devices used in data parallelism, containing the result from
  90. each device.
  91. .. note::
  92. There is a subtlety in using the
  93. ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
  94. :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
  95. See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for
  96. details.
  97. Args:
  98. module (Module): module to be parallelized
  99. device_ids (list of int or torch.device): CUDA devices (default: all devices)
  100. output_device (int or torch.device): device location of output (default: device_ids[0])
  101. Attributes:
  102. module (Module): the module to be parallelized
  103. Example::
  104. >>> # xdoctest: +SKIP
  105. >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
  106. >>> output = net(input_var) # input_var can be on any device, including CPU
  107. """
  108. # TODO: update notes/cuda.rst when this class handles 8+ GPUs well
  109. def __init__(
  110. self,
  111. module: T,
  112. device_ids: Sequence[int | torch.device] | None = None,
  113. output_device: int | torch.device | None = None,
  114. dim: int = 0,
  115. ) -> None:
  116. super().__init__()
  117. torch._C._log_api_usage_once("torch.nn.parallel.DataParallel")
  118. device_type = _get_available_device_type()
  119. if device_type is None or device_type == "mps":
  120. self.module = module
  121. self.device_ids = []
  122. return
  123. if device_ids is None:
  124. device_ids = _get_all_device_indices()
  125. if device_ids is None:
  126. raise RuntimeError("no available devices were found")
  127. if output_device is None:
  128. output_device = device_ids[0]
  129. self.dim = dim
  130. self.module = module
  131. self.device_ids = [_get_device_index(x, True) for x in device_ids]
  132. self.output_device = _get_device_index(output_device, True)
  133. # pyrefly: ignore [read-only]
  134. self.src_device_obj = torch.device(device_type, self.device_ids[0])
  135. if device_type == "cuda":
  136. _check_balance(self.device_ids)
  137. if len(self.device_ids) == 1:
  138. self.module.to(self.src_device_obj)
  139. def forward(self, *inputs: Any, **kwargs: Any) -> Any:
  140. with torch.autograd.profiler.record_function("DataParallel.forward"):
  141. if not self.device_ids:
  142. return self.module(*inputs, **kwargs)
  143. # pyrefly: ignore [bad-argument-type]
  144. for t in chain(self.module.parameters(), self.module.buffers()):
  145. if t.device != self.src_device_obj:
  146. raise RuntimeError(
  147. "module must have its parameters and buffers "
  148. f"on device {self.src_device_obj} (device_ids[0]) but found one of "
  149. f"them on device: {t.device}"
  150. )
  151. inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids)
  152. # for forward function without any inputs, empty list and dict will be created
  153. # so the module can be executed on one device which is the first one in device_ids
  154. if not inputs and not module_kwargs:
  155. inputs = ((),)
  156. module_kwargs = ({},)
  157. if len(self.device_ids) == 1:
  158. return self.module(*inputs[0], **module_kwargs[0])
  159. replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
  160. outputs = self.parallel_apply(replicas, inputs, module_kwargs)
  161. return self.gather(outputs, self.output_device)
  162. def replicate(self, module: T, device_ids: Sequence[int | torch.device]) -> list[T]:
  163. return replicate(module, device_ids, not torch.is_grad_enabled())
  164. def scatter(
  165. self,
  166. inputs: tuple[Any, ...],
  167. kwargs: dict[str, Any] | None,
  168. device_ids: Sequence[int | torch.device],
  169. ) -> Any:
  170. return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
  171. def parallel_apply(
  172. self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any
  173. ) -> list[Any]:
  174. return parallel_apply(
  175. replicas, inputs, kwargs, self.device_ids[: len(replicas)]
  176. )
  177. def gather(self, outputs: Any, output_device: int | torch.device) -> Any:
  178. return gather(outputs, output_device, dim=self.dim)
  179. def data_parallel(
  180. module: Module,
  181. inputs: Any,
  182. device_ids: Sequence[int | torch.device] | None = None,
  183. output_device: int | torch.device | None = None,
  184. dim: int = 0,
  185. module_kwargs: Any | None = None,
  186. ) -> torch.Tensor:
  187. r"""Evaluate module(input) in parallel across the GPUs given in device_ids.
  188. This is the functional version of the DataParallel module.
  189. Args:
  190. module (Module): the module to evaluate in parallel
  191. inputs (Tensor): inputs to the module
  192. device_ids (list of int or torch.device): GPU ids on which to replicate module
  193. output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU.
  194. (default: device_ids[0])
  195. Returns:
  196. a Tensor containing the result of module(input) located on
  197. output_device
  198. """
  199. if not isinstance(inputs, tuple):
  200. inputs = (inputs,) if inputs is not None else ()
  201. device_type = _get_available_device_type()
  202. if device_type is None:
  203. raise RuntimeError("device type could not be determined")
  204. if device_ids is None:
  205. device_ids = _get_all_device_indices()
  206. if device_ids is None:
  207. raise RuntimeError("no available devices were found")
  208. if output_device is None:
  209. output_device = device_ids[0]
  210. device_ids = [_get_device_index(x, True) for x in device_ids]
  211. output_device = _get_device_index(output_device, True)
  212. # pyrefly: ignore [no-matching-overload]
  213. src_device_obj = torch.device(device_type, device_ids[0])
  214. # pyrefly: ignore [bad-argument-type]
  215. for t in chain(module.parameters(), module.buffers()):
  216. if t.device != src_device_obj:
  217. raise RuntimeError(
  218. "module must have its parameters and buffers "
  219. f"on device {src_device_obj} (device_ids[0]) but found one of "
  220. f"them on device: {t.device}"
  221. )
  222. inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
  223. # for module without any inputs, empty list and dict will be created
  224. # so the module can be executed on one device which is the first one in device_ids
  225. if not inputs and not module_kwargs:
  226. inputs = ((),)
  227. module_kwargs = ({},)
  228. if module_kwargs is None:
  229. raise AssertionError("module_kwargs should not be None after scatter_kwargs")
  230. if len(device_ids) == 1:
  231. return module(*inputs[0], **module_kwargs[0])
  232. used_device_ids = device_ids[: len(inputs)]
  233. replicas = replicate(module, used_device_ids)
  234. outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
  235. return gather(outputs, output_device, dim)