| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- # mypy: allow-untyped-defs
- import collections
- import warnings
- from collections.abc import Sequence
- import torch.cuda
- __all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"]
- SUM = 0 # ncclRedOp_t
- def is_available(tensors):
- if not hasattr(torch._C, "_nccl_all_reduce"):
- warnings.warn("PyTorch is not compiled with NCCL support", stacklevel=2)
- return False
- devices = set()
- for tensor in tensors:
- if tensor.is_sparse:
- return False
- if not tensor.is_contiguous():
- return False
- if not tensor.is_cuda:
- return False
- device = tensor.get_device()
- if device in devices:
- return False
- devices.add(device)
- return True
- def version():
- """
- Returns the version of the NCCL.
- This function returns a tuple containing the major, minor, and patch version numbers of the NCCL.
- The suffix is also included in the tuple if a version suffix exists.
- Returns:
- tuple: The version information of the NCCL.
- """
- ver = torch._C._nccl_version()
- major = ver >> 32
- minor = (ver >> 16) & 65535
- patch = ver & 65535
- suffix = torch._C._nccl_version_suffix().decode("utf-8")
- if suffix == "":
- return (major, minor, patch)
- else:
- return (major, minor, patch, suffix)
- def unique_id():
- return torch._C._nccl_unique_id()
- def init_rank(num_ranks, uid, rank):
- return torch._C._nccl_init_rank(num_ranks, uid, rank)
- def _check_sequence_type(inputs: torch.Tensor | Sequence[torch.Tensor]) -> None:
- if not isinstance(inputs, collections.abc.Container) or isinstance(
- inputs, torch.Tensor
- ):
- raise TypeError("Inputs should be a collection of tensors")
- def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
- _check_sequence_type(inputs)
- if outputs is None:
- outputs = inputs
- _check_sequence_type(outputs)
- torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
- # `output` used to be `outputs`, taking in a list of tensors. So we have two
- # arguments for BC reasons.
- def reduce(
- inputs: Sequence[torch.Tensor],
- output: torch.Tensor | Sequence[torch.Tensor] | None = None,
- root: int = 0,
- op: int = SUM,
- streams: Sequence[torch.cuda.Stream] | None = None,
- comms=None,
- *,
- outputs: Sequence[torch.Tensor] | None = None,
- ) -> None:
- _check_sequence_type(inputs)
- _output: torch.Tensor
- if outputs is not None:
- if output is not None:
- raise ValueError(
- "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
- "favor of 'output', taking in a single output tensor. The signature of reduce is: "
- "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)."
- )
- else:
- warnings.warn(
- "`nccl.reduce` with an output tensor list is deprecated. "
- "Please specify a single output tensor with argument 'output' instead instead.",
- FutureWarning,
- stacklevel=2,
- )
- _output = outputs[root]
- elif not isinstance(output, torch.Tensor) and isinstance(
- output, collections.abc.Sequence
- ):
- # User called old API with positional arguments of list of output tensors.
- warnings.warn(
- "nccl.reduce with an output tensor list is deprecated. "
- "Please specify a single output tensor.",
- FutureWarning,
- stacklevel=2,
- )
- _output = output[root]
- else:
- _output = inputs[root] if output is None else output
- torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
- def broadcast(
- inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None
- ) -> None:
- _check_sequence_type(inputs)
- torch._C._nccl_broadcast(inputs, root, streams, comms)
- def all_gather(
- inputs: Sequence[torch.Tensor],
- outputs: Sequence[torch.Tensor],
- streams=None,
- comms=None,
- ) -> None:
- _check_sequence_type(inputs)
- _check_sequence_type(outputs)
- torch._C._nccl_all_gather(inputs, outputs, streams, comms)
- def reduce_scatter(
- inputs: Sequence[torch.Tensor],
- outputs: Sequence[torch.Tensor],
- op: int = SUM,
- streams=None,
- comms=None,
- ) -> None:
- _check_sequence_type(inputs)
- _check_sequence_type(outputs)
- torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)
|