sampler.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # mypy: allow-untyped-defs
  2. import itertools
  3. from collections.abc import Iterable, Iterator, Sequence, Sized
  4. from typing import Generic, TypeVar
  5. import torch
  6. # Note: For benchmarking changes to samplers, see:
  7. # /benchmarks/data/samplers_bench.py
  8. # This benchmark compares the performance of different sampler implementations
  9. # and can be used to evaluate the impact of optimizations.
  10. __all__ = [
  11. "BatchSampler",
  12. "RandomSampler",
  13. "Sampler",
  14. "SequentialSampler",
  15. "SubsetRandomSampler",
  16. "WeightedRandomSampler",
  17. ]
  18. _T_co = TypeVar("_T_co", covariant=True)
  19. class Sampler(Generic[_T_co]):
  20. r"""Base class for all Samplers.
  21. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
  22. way to iterate over indices or lists of indices (batches) of dataset elements,
  23. and may provide a :meth:`__len__` method that returns the length of the returned iterators.
  24. Example:
  25. >>> # xdoctest: +SKIP
  26. >>> class AccedingSequenceLengthSampler(Sampler[int]):
  27. >>> def __init__(self, data: List[str]) -> None:
  28. >>> self.data = data
  29. >>>
  30. >>> def __len__(self) -> int:
  31. >>> return len(self.data)
  32. >>>
  33. >>> def __iter__(self) -> Iterator[int]:
  34. >>> sizes = torch.tensor([len(x) for x in self.data])
  35. >>> yield from torch.argsort(sizes).tolist()
  36. >>>
  37. >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
  38. >>> def __init__(self, data: List[str], batch_size: int) -> None:
  39. >>> self.data = data
  40. >>> self.batch_size = batch_size
  41. >>>
  42. >>> def __len__(self) -> int:
  43. >>> return (len(self.data) + self.batch_size - 1) // self.batch_size
  44. >>>
  45. >>> def __iter__(self) -> Iterator[List[int]]:
  46. >>> sizes = torch.tensor([len(x) for x in self.data])
  47. >>> for batch in torch.chunk(torch.argsort(sizes), len(self)):
  48. >>> yield batch.tolist()
  49. .. note:: The :meth:`__len__` method isn't strictly required by
  50. :class:`~torch.utils.data.DataLoader`, but is expected in any
  51. calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
  52. """
  53. def __iter__(self) -> Iterator[_T_co]:
  54. raise NotImplementedError
  55. # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  56. #
  57. # Many times we have an abstract class representing a collection/iterable of
  58. # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
  59. # implementing a `__len__` method. In such cases, we must make sure to not
  60. # provide a default implementation, because both straightforward default
  61. # implementations have their issues:
  62. #
  63. # + `return NotImplemented`:
  64. # Calling `len(subclass_instance)` raises:
  65. # TypeError: 'NotImplementedType' object cannot be interpreted as an integer
  66. #
  67. # + `raise NotImplementedError`:
  68. # This prevents triggering some fallback behavior. E.g., the built-in
  69. # `list(X)` tries to call `len(X)` first, and executes a different code
  70. # path if the method is not found or `NotImplemented` is returned, while
  71. # raising a `NotImplementedError` will propagate and make the call fail
  72. # where it could have used `__iter__` to complete the call.
  73. #
  74. # Thus, the only two sensible things to do are
  75. #
  76. # + **not** provide a default `__len__`.
  77. #
  78. # + raise a `TypeError` instead, which is what Python uses when users call
  79. # a method that is not defined on an object.
  80. # (@ssnl verifies that this works on at least Python 3.7.)
  81. class SequentialSampler(Sampler[int]):
  82. r"""Samples elements sequentially, always in the same order.
  83. Args:
  84. data_source (Sized): data source to sample from. Must implement __len__.
  85. """
  86. data_source: Sized
  87. def __init__(self, data_source: Sized) -> None:
  88. self.data_source = data_source
  89. def __iter__(self) -> Iterator[int]:
  90. return iter(range(len(self.data_source)))
  91. def __len__(self) -> int:
  92. return len(self.data_source)
  93. class RandomSampler(Sampler[int]):
  94. r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
  95. If with replacement, then user can specify :attr:`num_samples` to draw.
  96. Args:
  97. data_source (Sized): data source to sample from. Must implement __len__.
  98. replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
  99. num_samples (int): number of samples to draw, default=`len(dataset)`.
  100. generator (Generator): Generator used in sampling.
  101. """
  102. data_source: Sized
  103. replacement: bool
  104. def __init__(
  105. self,
  106. data_source: Sized,
  107. replacement: bool = False,
  108. num_samples: int | None = None,
  109. generator=None,
  110. ) -> None:
  111. self.data_source = data_source
  112. self.replacement = replacement
  113. self._num_samples = num_samples
  114. self.generator = generator
  115. if not isinstance(self.replacement, bool):
  116. raise TypeError(
  117. f"replacement should be a boolean value, but got replacement={self.replacement}"
  118. )
  119. if not isinstance(self.num_samples, int) or self.num_samples <= 0:
  120. raise ValueError(
  121. f"num_samples should be a positive integer value, but got num_samples={self.num_samples}"
  122. )
  123. @property
  124. def num_samples(self) -> int:
  125. # dataset size might change at runtime
  126. if self._num_samples is None:
  127. return len(self.data_source)
  128. return self._num_samples
  129. def __iter__(self) -> Iterator[int]:
  130. n = len(self.data_source)
  131. if self.generator is None:
  132. seed = int(torch.empty((), dtype=torch.int64).random_().item())
  133. generator = torch.Generator()
  134. generator.manual_seed(seed)
  135. else:
  136. generator = self.generator
  137. if self.replacement:
  138. for _ in range(self.num_samples // 32):
  139. yield from torch.randint(
  140. high=n, size=(32,), dtype=torch.int64, generator=generator
  141. ).tolist()
  142. yield from torch.randint(
  143. high=n,
  144. size=(self.num_samples % 32,),
  145. dtype=torch.int64,
  146. generator=generator,
  147. ).tolist()
  148. else:
  149. for _ in range(self.num_samples // n):
  150. yield from torch.randperm(n, generator=generator).tolist()
  151. yield from torch.randperm(n, generator=generator).tolist()[
  152. : self.num_samples % n
  153. ]
  154. def __len__(self) -> int:
  155. return self.num_samples
  156. class SubsetRandomSampler(Sampler[int]):
  157. r"""Samples elements randomly from a given list of indices, without replacement.
  158. Args:
  159. indices (sequence): a sequence of indices
  160. generator (Generator): Generator used in sampling.
  161. """
  162. indices: Sequence[int]
  163. def __init__(self, indices: Sequence[int], generator=None) -> None:
  164. self.indices = indices
  165. self.generator = generator
  166. def __iter__(self) -> Iterator[int]:
  167. for i in torch.randperm(len(self.indices), generator=self.generator).tolist():
  168. yield self.indices[i]
  169. def __len__(self) -> int:
  170. return len(self.indices)
  171. class WeightedRandomSampler(Sampler[int]):
  172. r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
  173. Args:
  174. weights (sequence) : a sequence of weights, not necessary summing up to one
  175. num_samples (int): number of samples to draw
  176. replacement (bool): if ``True``, samples are drawn with replacement.
  177. If not, they are drawn without replacement, which means that when a
  178. sample index is drawn for a row, it cannot be drawn again for that row.
  179. generator (Generator): Generator used in sampling.
  180. Example:
  181. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  182. >>> list(
  183. ... WeightedRandomSampler(
  184. ... [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True
  185. ... )
  186. ... )
  187. [4, 4, 1, 4, 5]
  188. >>> list(
  189. ... WeightedRandomSampler(
  190. ... [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False
  191. ... )
  192. ... )
  193. [0, 1, 4, 3, 2]
  194. """
  195. weights: torch.Tensor
  196. num_samples: int
  197. replacement: bool
  198. def __init__(
  199. self,
  200. weights: Sequence[float],
  201. num_samples: int,
  202. replacement: bool = True,
  203. generator=None,
  204. ) -> None:
  205. if (
  206. not isinstance(num_samples, int)
  207. or isinstance(num_samples, bool)
  208. or num_samples <= 0
  209. ):
  210. raise ValueError(
  211. f"num_samples should be a positive integer value, but got num_samples={num_samples}"
  212. )
  213. if not isinstance(replacement, bool):
  214. raise ValueError(
  215. f"replacement should be a boolean value, but got replacement={replacement}"
  216. )
  217. weights_tensor = torch.as_tensor(weights, dtype=torch.double)
  218. if len(weights_tensor.shape) != 1:
  219. raise ValueError(
  220. "weights should be a 1d sequence but given "
  221. f"weights have shape {tuple(weights_tensor.shape)}"
  222. )
  223. self.weights = weights_tensor
  224. self.num_samples = num_samples
  225. self.replacement = replacement
  226. self.generator = generator
  227. def __iter__(self) -> Iterator[int]:
  228. rand_tensor = torch.multinomial(
  229. self.weights, self.num_samples, self.replacement, generator=self.generator
  230. )
  231. yield from iter(rand_tensor.tolist())
  232. def __len__(self) -> int:
  233. return self.num_samples
  234. class BatchSampler(Sampler[list[int]]):
  235. r"""Wraps another sampler to yield a mini-batch of indices.
  236. Args:
  237. sampler (Sampler or Iterable): Base sampler. Can be any iterable object
  238. batch_size (int): Size of mini-batch.
  239. drop_last (bool): If ``True``, the sampler will drop the last batch if
  240. its size would be less than ``batch_size``
  241. Example:
  242. >>> list(
  243. ... BatchSampler(
  244. ... SequentialSampler(range(10)), batch_size=3, drop_last=False
  245. ... )
  246. ... )
  247. [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
  248. >>> list(
  249. ... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
  250. ... )
  251. [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
  252. """
  253. def __init__(
  254. self,
  255. sampler: Sampler[int] | Iterable[int],
  256. batch_size: int,
  257. drop_last: bool,
  258. ) -> None:
  259. # Since collections.abc.Iterable does not check for `__getitem__`, which
  260. # is one way for an object to be an iterable, we don't do an `isinstance`
  261. # check here.
  262. if (
  263. not isinstance(batch_size, int)
  264. or isinstance(batch_size, bool)
  265. or batch_size <= 0
  266. ):
  267. raise ValueError(
  268. f"batch_size should be a positive integer value, but got batch_size={batch_size}"
  269. )
  270. if not isinstance(drop_last, bool):
  271. raise ValueError(
  272. f"drop_last should be a boolean value, but got drop_last={drop_last}"
  273. )
  274. self.sampler = sampler
  275. self.batch_size = batch_size
  276. self.drop_last = drop_last
  277. def __iter__(self) -> Iterator[list[int]]:
  278. sampler_iter = iter(self.sampler)
  279. if self.drop_last:
  280. # Create multiple references to the same iterator
  281. args = [sampler_iter] * self.batch_size
  282. for batch_droplast in zip(*args, strict=False):
  283. yield [*batch_droplast]
  284. else:
  285. batch = [*itertools.islice(sampler_iter, self.batch_size)]
  286. while batch:
  287. yield batch
  288. batch = [*itertools.islice(sampler_iter, self.batch_size)]
  289. def __len__(self) -> int:
  290. # Can only be called if self.sampler has __len__ implemented
  291. # We cannot enforce this condition, so we turn off typechecking for the
  292. # implementation below.
  293. # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  294. if self.drop_last:
  295. return len(self.sampler) // self.batch_size # type: ignore[arg-type]
  296. else:
  297. return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]