distributed.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import math
  2. from collections.abc import Iterator
  3. from typing import TypeVar
  4. import torch
  5. import torch.distributed as dist
  6. from torch.utils.data.dataset import Dataset
  7. from torch.utils.data.sampler import Sampler
  8. __all__ = ["DistributedSampler"]
  9. _T_co = TypeVar("_T_co", covariant=True)
  10. class DistributedSampler(Sampler[_T_co]):
  11. r"""Sampler that restricts data loading to a subset of the dataset.
  12. It is especially useful in conjunction with
  13. :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
  14. process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a
  15. :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
  16. original dataset that is exclusive to it.
  17. .. note::
  18. Dataset is assumed to be of constant size and that any instance of it always
  19. returns the same elements in the same order.
  20. Args:
  21. dataset: Dataset used for sampling.
  22. num_replicas (int, optional): Number of processes participating in
  23. distributed training. By default, :attr:`world_size` is retrieved from the
  24. current distributed group.
  25. rank (int, optional): Rank of the current process within :attr:`num_replicas`.
  26. By default, :attr:`rank` is retrieved from the current distributed
  27. group.
  28. shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
  29. indices.
  30. seed (int, optional): random seed used to shuffle the sampler if
  31. :attr:`shuffle=True`. This number should be identical across all
  32. processes in the distributed group. Default: ``0``.
  33. drop_last (bool, optional): if ``True``, then the sampler will drop the
  34. tail of the data to make it evenly divisible across the number of
  35. replicas. If ``False``, the sampler will add extra indices to make
  36. the data evenly divisible across the replicas. Default: ``False``.
  37. .. warning::
  38. In distributed mode, calling the :meth:`set_epoch` method at
  39. the beginning of each epoch **before** creating the :class:`DataLoader` iterator
  40. is necessary to make shuffling work properly across multiple epochs. Otherwise,
  41. the same ordering will be always used.
  42. Example::
  43. >>> # xdoctest: +SKIP
  44. >>> sampler = DistributedSampler(dataset) if is_distributed else None
  45. >>> loader = DataLoader(dataset, shuffle=(sampler is None),
  46. ... sampler=sampler)
  47. >>> for epoch in range(start_epoch, n_epochs):
  48. ... if is_distributed:
  49. ... sampler.set_epoch(epoch)
  50. ... train(loader)
  51. """
  52. def __init__(
  53. self,
  54. dataset: Dataset,
  55. num_replicas: int | None = None,
  56. rank: int | None = None,
  57. shuffle: bool = True,
  58. seed: int = 0,
  59. drop_last: bool = False,
  60. ) -> None:
  61. if num_replicas is None:
  62. if not dist.is_available():
  63. raise RuntimeError("Requires distributed package to be available")
  64. num_replicas = dist.get_world_size()
  65. if rank is None:
  66. if not dist.is_available():
  67. raise RuntimeError("Requires distributed package to be available")
  68. rank = dist.get_rank()
  69. if rank >= num_replicas or rank < 0:
  70. raise ValueError(
  71. f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
  72. )
  73. self.dataset = dataset
  74. self.num_replicas = num_replicas
  75. self.rank = rank
  76. self.epoch = 0
  77. self.drop_last = drop_last
  78. # If the dataset length is evenly divisible by # of replicas, then there
  79. # is no need to drop any data, since the dataset will be split equally.
  80. if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
  81. # Split to nearest available length that is evenly divisible.
  82. # This is to ensure each rank receives the same amount of data when
  83. # using this Sampler.
  84. self.num_samples = math.ceil(
  85. (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
  86. )
  87. else:
  88. self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
  89. self.total_size = self.num_samples * self.num_replicas
  90. self.shuffle = shuffle
  91. self.seed = seed
  92. def __iter__(self) -> Iterator[_T_co]:
  93. if self.shuffle:
  94. # deterministically shuffle based on epoch and seed
  95. g = torch.Generator()
  96. g.manual_seed(self.seed + self.epoch)
  97. indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
  98. else:
  99. indices = list(range(len(self.dataset))) # type: ignore[arg-type]
  100. if not self.drop_last:
  101. # add extra samples to make it evenly divisible
  102. padding_size = self.total_size - len(indices)
  103. if padding_size <= len(indices):
  104. indices += indices[:padding_size]
  105. else:
  106. indices += (indices * math.ceil(padding_size / len(indices)))[
  107. :padding_size
  108. ]
  109. else:
  110. # remove tail of data to make it evenly divisible.
  111. indices = indices[: self.total_size]
  112. if len(indices) != self.total_size:
  113. raise AssertionError(
  114. f"Number of indices ({len(indices)}) does not match total_size ({self.total_size})"
  115. )
  116. # subsample
  117. indices = indices[self.rank : self.total_size : self.num_replicas]
  118. if len(indices) != self.num_samples:
  119. raise AssertionError(
  120. f"Number of subsampled indices ({len(indices)}) does not match num_samples ({self.num_samples})"
  121. )
  122. # pyrefly: ignore [bad-return]
  123. return iter(indices)
  124. def __len__(self) -> int:
  125. return self.num_samples
  126. def set_epoch(self, epoch: int) -> None:
  127. r"""
  128. Set the epoch for this sampler.
  129. When :attr:`shuffle=True`, this ensures all replicas
  130. use a different random ordering for each epoch. Otherwise, the next iteration of this
  131. sampler will yield the same ordering.
  132. Args:
  133. epoch (int): Epoch number.
  134. """
  135. self.epoch = epoch