nccl.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import warnings
  4. from collections.abc import Sequence
  5. import torch.cuda
  6. __all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"]
  7. SUM = 0 # ncclRedOp_t
  8. def is_available(tensors):
  9. if not hasattr(torch._C, "_nccl_all_reduce"):
  10. warnings.warn("PyTorch is not compiled with NCCL support", stacklevel=2)
  11. return False
  12. devices = set()
  13. for tensor in tensors:
  14. if tensor.is_sparse:
  15. return False
  16. if not tensor.is_contiguous():
  17. return False
  18. if not tensor.is_cuda:
  19. return False
  20. device = tensor.get_device()
  21. if device in devices:
  22. return False
  23. devices.add(device)
  24. return True
  25. def version():
  26. """
  27. Returns the version of the NCCL.
  28. This function returns a tuple containing the major, minor, and patch version numbers of the NCCL.
  29. The suffix is also included in the tuple if a version suffix exists.
  30. Returns:
  31. tuple: The version information of the NCCL.
  32. """
  33. ver = torch._C._nccl_version()
  34. major = ver >> 32
  35. minor = (ver >> 16) & 65535
  36. patch = ver & 65535
  37. suffix = torch._C._nccl_version_suffix().decode("utf-8")
  38. if suffix == "":
  39. return (major, minor, patch)
  40. else:
  41. return (major, minor, patch, suffix)
  42. def unique_id():
  43. return torch._C._nccl_unique_id()
  44. def init_rank(num_ranks, uid, rank):
  45. return torch._C._nccl_init_rank(num_ranks, uid, rank)
  46. def _check_sequence_type(inputs: torch.Tensor | Sequence[torch.Tensor]) -> None:
  47. if not isinstance(inputs, collections.abc.Container) or isinstance(
  48. inputs, torch.Tensor
  49. ):
  50. raise TypeError("Inputs should be a collection of tensors")
  51. def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
  52. _check_sequence_type(inputs)
  53. if outputs is None:
  54. outputs = inputs
  55. _check_sequence_type(outputs)
  56. torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
  57. # `output` used to be `outputs`, taking in a list of tensors. So we have two
  58. # arguments for BC reasons.
  59. def reduce(
  60. inputs: Sequence[torch.Tensor],
  61. output: torch.Tensor | Sequence[torch.Tensor] | None = None,
  62. root: int = 0,
  63. op: int = SUM,
  64. streams: Sequence[torch.cuda.Stream] | None = None,
  65. comms=None,
  66. *,
  67. outputs: Sequence[torch.Tensor] | None = None,
  68. ) -> None:
  69. _check_sequence_type(inputs)
  70. _output: torch.Tensor
  71. if outputs is not None:
  72. if output is not None:
  73. raise ValueError(
  74. "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
  75. "favor of 'output', taking in a single output tensor. The signature of reduce is: "
  76. "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)."
  77. )
  78. else:
  79. warnings.warn(
  80. "`nccl.reduce` with an output tensor list is deprecated. "
  81. "Please specify a single output tensor with argument 'output' instead instead.",
  82. FutureWarning,
  83. stacklevel=2,
  84. )
  85. _output = outputs[root]
  86. elif not isinstance(output, torch.Tensor) and isinstance(
  87. output, collections.abc.Sequence
  88. ):
  89. # User called old API with positional arguments of list of output tensors.
  90. warnings.warn(
  91. "nccl.reduce with an output tensor list is deprecated. "
  92. "Please specify a single output tensor.",
  93. FutureWarning,
  94. stacklevel=2,
  95. )
  96. _output = output[root]
  97. else:
  98. _output = inputs[root] if output is None else output
  99. torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
  100. def broadcast(
  101. inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None
  102. ) -> None:
  103. _check_sequence_type(inputs)
  104. torch._C._nccl_broadcast(inputs, root, streams, comms)
  105. def all_gather(
  106. inputs: Sequence[torch.Tensor],
  107. outputs: Sequence[torch.Tensor],
  108. streams=None,
  109. comms=None,
  110. ) -> None:
  111. _check_sequence_type(inputs)
  112. _check_sequence_type(outputs)
  113. torch._C._nccl_all_gather(inputs, outputs, streams, comms)
  114. def reduce_scatter(
  115. inputs: Sequence[torch.Tensor],
  116. outputs: Sequence[torch.Tensor],
  117. op: int = SUM,
  118. streams=None,
  119. comms=None,
  120. ) -> None:
  121. _check_sequence_type(inputs)
  122. _check_sequence_type(outputs)
  123. torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)