scatter_gather.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # mypy: allow-untyped-defs
  2. from collections.abc import Sequence
  3. from typing import Any, overload, TypeVar
  4. from typing_extensions import deprecated
  5. import torch
  6. from torch.nn.parallel._functions import Gather, Scatter
  7. __all__ = ["scatter", "scatter_kwargs", "gather"]
  8. @deprecated(
  9. "`is_namedtuple` is deprecated, please use the python checks instead",
  10. category=FutureWarning,
  11. )
  12. def is_namedtuple(obj: Any) -> bool:
  13. # Check if type was created from collections.namedtuple or a typing.NamedTuple.
  14. return _is_namedtuple(obj)
  15. def _is_namedtuple(obj: Any) -> bool:
  16. # Check if type was created from collections.namedtuple or a typing.NamedTuple.
  17. return (
  18. isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
  19. )
  20. T = TypeVar("T", dict, list, tuple)
  21. # For some reason, 'scatter' returns a tuple when given a single Tensor input but a list otherwise.
  22. @overload
  23. def scatter(
  24. inputs: torch.Tensor,
  25. target_gpus: Sequence[int | torch.device],
  26. dim: int = ...,
  27. ) -> tuple[torch.Tensor, ...]: ...
  28. @overload
  29. def scatter(
  30. inputs: T,
  31. target_gpus: Sequence[int | torch.device],
  32. dim: int = ...,
  33. ) -> list[T]: ...
  34. def scatter(inputs, target_gpus, dim=0):
  35. r"""Slice tensors into approximately equal chunks and distributes them across given GPUs.
  36. Duplicates references to objects that are not tensors.
  37. """
  38. def scatter_map(obj):
  39. if isinstance(obj, torch.Tensor):
  40. return Scatter.apply(target_gpus, None, dim, obj)
  41. if _is_namedtuple(obj):
  42. return [
  43. type(obj)(*args)
  44. # pyrefly: ignore [no-matching-overload]
  45. for args in zip(*map(scatter_map, obj), strict=False)
  46. ]
  47. if isinstance(obj, tuple) and len(obj) > 0:
  48. # pyrefly: ignore [no-matching-overload]
  49. return list(zip(*map(scatter_map, obj), strict=False))
  50. if isinstance(obj, list) and len(obj) > 0:
  51. # pyrefly: ignore [no-matching-overload]
  52. return [list(i) for i in zip(*map(scatter_map, obj), strict=False)]
  53. if isinstance(obj, dict) and len(obj) > 0:
  54. return [
  55. type(obj)(i)
  56. # pyrefly: ignore [no-matching-overload]
  57. for i in zip(*map(scatter_map, obj.items()), strict=False)
  58. ]
  59. return [obj for _ in target_gpus]
  60. # After scatter_map is called, a scatter_map cell will exist. This cell
  61. # has a reference to the actual function scatter_map, which has references
  62. # to a closure that has a reference to the scatter_map cell (because the
  63. # fn is recursive). To avoid this reference cycle, we set the function to
  64. # None, clearing the cell
  65. try:
  66. res = scatter_map(inputs)
  67. finally:
  68. scatter_map = None # type: ignore[assignment]
  69. return res
  70. def scatter_kwargs(
  71. inputs: tuple[Any, ...],
  72. kwargs: dict[str, Any] | None,
  73. target_gpus: Sequence[int | torch.device],
  74. dim: int = 0,
  75. ) -> tuple[tuple[Any, ...], tuple[dict[str, Any], ...]]:
  76. r"""Scatter with support for kwargs dictionary."""
  77. scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else []
  78. scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
  79. if len(scattered_inputs) < len(scattered_kwargs):
  80. scattered_inputs.extend(
  81. () for _ in range(len(scattered_kwargs) - len(scattered_inputs))
  82. )
  83. elif len(scattered_kwargs) < len(inputs):
  84. scattered_kwargs.extend(
  85. {} for _ in range(len(scattered_inputs) - len(scattered_kwargs))
  86. )
  87. return tuple(scattered_inputs), tuple(scattered_kwargs)
  88. def gather(outputs: Any, target_device: int | torch.device, dim: int = 0) -> Any:
  89. r"""Gather tensors from different GPUs on a specified device.
  90. This function is useful for gathering the results of a distributed computation.
  91. It takes a sequence of objects, one for each GPU, and returns a single object
  92. on the specified device.
  93. Args:
  94. outputs (Any): A sequence of objects (potentially tensors) to gather.
  95. target_device (Union[int, torch.device]): The device to gather the tensors to.
  96. Use 'cpu' for CPU to avoid a deprecation warning.
  97. dim (int, optional): The dimension along which to gather. Default: 0.
  98. Returns:
  99. Any: A gathered object (potentially tensor) on the specified device.
  100. """
  101. def gather_map(outputs):
  102. out = outputs[0]
  103. if isinstance(out, torch.Tensor):
  104. return Gather.apply(target_device, dim, *outputs)
  105. if out is None:
  106. return None
  107. if isinstance(out, dict):
  108. if not all(len(out) == len(d) for d in outputs):
  109. raise ValueError("All dicts must have the same number of keys")
  110. # pyrefly: ignore [not-callable]
  111. return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
  112. if _is_namedtuple(out):
  113. # pyrefly: ignore [no-matching-overload]
  114. return type(out)._make(map(gather_map, zip(*outputs, strict=True)))
  115. # pyrefly: ignore [no-matching-overload]
  116. return type(out)(map(gather_map, zip(*outputs, strict=True)))
  117. # Recursive function calls like this create reference cycles.
  118. # Setting the function to None clears the refcycle.
  119. try:
  120. res = gather_map(outputs)
  121. finally:
  122. gather_map = None # type: ignore[assignment]
  123. return res