dataset.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. # mypy: allow-untyped-defs
  2. import bisect
  3. import itertools
  4. import math
  5. import warnings
  6. from collections.abc import Sequence
  7. # UP006 wants 'Iterable' to be imported from collections.abc but it needs to
  8. # stay from typing for now due to BC concerns. In particular several internal
  9. # targets fail to typecheck with:
  10. # TypeError: Cannot create a consistent method resolution order (MRO) for
  11. # bases Iterable, Generic
  12. from typing import cast, Generic, Iterable, TypeVar # noqa: UP035
  13. from typing_extensions import deprecated
  14. # No 'default_generator' in torch/__init__.pyi
  15. from torch import default_generator, Generator, randperm, Tensor
  16. __all__ = [
  17. "Dataset",
  18. "IterableDataset",
  19. "TensorDataset",
  20. "StackDataset",
  21. "ConcatDataset",
  22. "ChainDataset",
  23. "Subset",
  24. "random_split",
  25. ]
  26. _T = TypeVar("_T")
  27. _T_co = TypeVar("_T_co", covariant=True)
  28. _T_dict = dict[str, _T_co]
  29. _T_tuple = tuple[_T_co, ...]
  30. _T_stack = TypeVar("_T_stack", _T_tuple, _T_dict)
  31. class Dataset(Generic[_T_co]):
  32. r"""An abstract class representing a :class:`Dataset`.
  33. All datasets that represent a map from keys to data samples should subclass
  34. it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
  35. data sample for a given key. Subclasses could also optionally overwrite
  36. :meth:`__len__`, which is expected to return the size of the dataset by many
  37. :class:`~torch.utils.data.Sampler` implementations and the default options
  38. of :class:`~torch.utils.data.DataLoader`. Subclasses could also
  39. optionally implement :meth:`__getitems__`, for speedup batched samples
  40. loading. This method accepts list of indices of samples of batch and returns
  41. list of samples.
  42. .. note::
  43. :class:`~torch.utils.data.DataLoader` by default constructs an index
  44. sampler that yields integral indices. To make it work with a map-style
  45. dataset with non-integral indices/keys, a custom sampler must be provided.
  46. """
  47. def __getitem__(self, index) -> _T_co:
  48. raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
  49. # def __getitems__(self, indices: List) -> List[_T_co]:
  50. # Not implemented to prevent false-positives in fetcher check in
  51. # torch.utils.data._utils.fetch._MapDatasetFetcher
  52. def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]":
  53. return ConcatDataset([self, other])
  54. # No `def __len__(self)` default?
  55. # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  56. # in pytorch/torch/utils/data/sampler.py
  57. class IterableDataset(Dataset[_T_co], Iterable[_T_co]):
  58. r"""An iterable Dataset.
  59. All datasets that represent an iterable of data samples should subclass it.
  60. Such form of datasets is particularly useful when data come from a stream.
  61. All subclasses should overwrite :meth:`__iter__`, which would return an
  62. iterator of samples in this dataset.
  63. When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
  64. item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
  65. iterator. When :attr:`num_workers > 0`, each worker process will have a
  66. different copy of the dataset object, so it is often desired to configure
  67. each copy independently to avoid having duplicate data returned from the
  68. workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
  69. process, returns information about the worker. It can be used in either the
  70. dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
  71. :attr:`worker_init_fn` option to modify each copy's behavior.
  72. Example 1: splitting workload across all workers in :meth:`__iter__`::
  73. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
  74. >>> # xdoctest: +SKIP("Fails on MacOS12")
  75. >>> class MyIterableDataset(torch.utils.data.IterableDataset):
  76. ... def __init__(self, start, end):
  77. ... super(MyIterableDataset).__init__()
  78. ... assert end > start, "this example only works with end >= start"
  79. ... self.start = start
  80. ... self.end = end
  81. ...
  82. ... def __iter__(self):
  83. ... worker_info = torch.utils.data.get_worker_info()
  84. ... if worker_info is None: # single-process data loading, return the full iterator
  85. ... iter_start = self.start
  86. ... iter_end = self.end
  87. ... else: # in a worker process
  88. ... # split workload
  89. ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
  90. ... worker_id = worker_info.id
  91. ... iter_start = self.start + worker_id * per_worker
  92. ... iter_end = min(iter_start + per_worker, self.end)
  93. ... return iter(range(iter_start, iter_end))
  94. ...
  95. >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
  96. >>> ds = MyIterableDataset(start=3, end=7)
  97. >>> # Single-process loading
  98. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
  99. [tensor([3]), tensor([4]), tensor([5]), tensor([6])]
  100. >>> # xdoctest: +REQUIRES(POSIX)
  101. >>> # Multi-process loading with two worker processes
  102. >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
  103. >>> # xdoctest: +IGNORE_WANT("non deterministic")
  104. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
  105. [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
  106. >>> # With even more workers
  107. >>> # xdoctest: +IGNORE_WANT("non deterministic")
  108. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
  109. [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
  110. Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
  111. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
  112. >>> class MyIterableDataset(torch.utils.data.IterableDataset):
  113. ... def __init__(self, start, end):
  114. ... super(MyIterableDataset).__init__()
  115. ... assert end > start, "this example only works with end >= start"
  116. ... self.start = start
  117. ... self.end = end
  118. ...
  119. ... def __iter__(self):
  120. ... return iter(range(self.start, self.end))
  121. ...
  122. >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
  123. >>> ds = MyIterableDataset(start=3, end=7)
  124. >>> # Single-process loading
  125. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
  126. [3, 4, 5, 6]
  127. >>>
  128. >>> # Directly doing multi-process loading yields duplicate data
  129. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
  130. [3, 3, 4, 4, 5, 5, 6, 6]
  131. >>> # Define a `worker_init_fn` that configures each dataset copy differently
  132. >>> def worker_init_fn(worker_id):
  133. ... worker_info = torch.utils.data.get_worker_info()
  134. ... dataset = worker_info.dataset # the dataset copy in this worker process
  135. ... overall_start = dataset.start
  136. ... overall_end = dataset.end
  137. ... # configure the dataset to only process the split workload
  138. ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
  139. ... worker_id = worker_info.id
  140. ... dataset.start = overall_start + worker_id * per_worker
  141. ... dataset.end = min(dataset.start + per_worker, overall_end)
  142. ...
  143. >>> # Mult-process loading with the custom `worker_init_fn`
  144. >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
  145. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
  146. [3, 5, 4, 6]
  147. >>> # With even more workers
  148. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
  149. [3, 4, 5, 6]
  150. """
  151. def __add__(self, other: Dataset[_T_co]):
  152. return ChainDataset([self, other])
  153. # No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
  154. # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  155. class TensorDataset(Dataset[tuple[Tensor, ...]]):
  156. r"""Dataset wrapping tensors.
  157. Each sample will be retrieved by indexing tensors along the first dimension.
  158. Args:
  159. *tensors (Tensor): tensors that have the same size of the first dimension.
  160. """
  161. tensors: tuple[Tensor, ...]
  162. def __init__(self, *tensors: Tensor) -> None:
  163. if any(tensors[0].size(0) != tensor.size(0) for tensor in tensors):
  164. raise AssertionError("Size mismatch between tensors")
  165. self.tensors = tensors
  166. def __getitem__(self, index):
  167. return tuple(tensor[index] for tensor in self.tensors)
  168. def __len__(self) -> int:
  169. return self.tensors[0].size(0)
  170. class StackDataset(Dataset[_T_stack]):
  171. r"""Dataset as a stacking of multiple datasets.
  172. This class is useful to assemble different parts of complex input data, given as datasets.
  173. Example:
  174. >>> # xdoctest: +SKIP
  175. >>> images = ImageDataset()
  176. >>> texts = TextDataset()
  177. >>> tuple_stack = StackDataset(images, texts)
  178. >>> tuple_stack[0] == (images[0], texts[0])
  179. >>> dict_stack = StackDataset(image=images, text=texts)
  180. >>> dict_stack[0] == {"image": images[0], "text": texts[0]}
  181. Args:
  182. *args (Dataset): Datasets for stacking returned as tuple.
  183. **kwargs (Dataset): Datasets for stacking returned as dict.
  184. """
  185. datasets: tuple | dict
  186. def __init__(self, *args: Dataset[_T_co], **kwargs: Dataset[_T_co]) -> None:
  187. if args:
  188. if kwargs:
  189. raise ValueError(
  190. "Supported either ``tuple``- (via ``args``) or"
  191. "``dict``- (via ``kwargs``) like input/output, but both types are given."
  192. )
  193. self._length = len(args[0]) # type: ignore[arg-type]
  194. if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type]
  195. raise ValueError("Size mismatch between datasets")
  196. self.datasets = args
  197. elif kwargs:
  198. tmp = list(kwargs.values())
  199. self._length = len(tmp[0]) # type: ignore[arg-type]
  200. if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type]
  201. raise ValueError("Size mismatch between datasets")
  202. self.datasets = kwargs
  203. else:
  204. raise ValueError("At least one dataset should be passed")
  205. def __getitem__(self, index):
  206. if isinstance(self.datasets, dict):
  207. return {k: dataset[index] for k, dataset in self.datasets.items()}
  208. return tuple(dataset[index] for dataset in self.datasets)
  209. def __getitems__(self, indices: list):
  210. # add batched sampling support when parent datasets supports it.
  211. if isinstance(self.datasets, dict):
  212. dict_batch: list[_T_dict] = [{} for _ in indices]
  213. for k, dataset in self.datasets.items():
  214. if callable(getattr(dataset, "__getitems__", None)):
  215. items = dataset.__getitems__(indices) # type: ignore[attr-defined]
  216. if len(items) != len(indices):
  217. raise ValueError(
  218. "Nested dataset's output size mismatch."
  219. f" Expected {len(indices)}, got {len(items)}"
  220. )
  221. for data, d_sample in zip(items, dict_batch, strict=True):
  222. d_sample[k] = data
  223. else:
  224. for idx, d_sample in zip(indices, dict_batch, strict=True):
  225. d_sample[k] = dataset[idx]
  226. return dict_batch
  227. # tuple data
  228. list_batch: list[list] = [[] for _ in indices]
  229. for dataset in self.datasets:
  230. if callable(getattr(dataset, "__getitems__", None)):
  231. items = dataset.__getitems__(indices) # type: ignore[attr-defined]
  232. if len(items) != len(indices):
  233. raise ValueError(
  234. "Nested dataset's output size mismatch."
  235. f" Expected {len(indices)}, got {len(items)}"
  236. )
  237. for data, t_sample in zip(items, list_batch, strict=True):
  238. t_sample.append(data)
  239. else:
  240. for idx, t_sample in zip(indices, list_batch, strict=True):
  241. t_sample.append(dataset[idx])
  242. tuple_batch: list[_T_tuple] = [tuple(sample) for sample in list_batch]
  243. return tuple_batch
  244. def __len__(self) -> int:
  245. return self._length
  246. class ConcatDataset(Dataset[_T_co]):
  247. r"""Dataset as a concatenation of multiple datasets.
  248. This class is useful to assemble different existing datasets.
  249. Args:
  250. datasets (sequence): List of datasets to be concatenated
  251. """
  252. datasets: list[Dataset[_T_co]]
  253. cumulative_sizes: list[int]
  254. @staticmethod
  255. def cumsum(sequence):
  256. r, s = [], 0
  257. for e in sequence:
  258. l = len(e)
  259. r.append(l + s)
  260. s += l
  261. return r
  262. def __init__(self, datasets: Iterable[Dataset]) -> None:
  263. super().__init__()
  264. self.datasets = list(datasets)
  265. if len(self.datasets) == 0:
  266. raise AssertionError("datasets should not be an empty iterable")
  267. for d in self.datasets:
  268. if isinstance(d, IterableDataset):
  269. raise AssertionError("ConcatDataset does not support IterableDataset")
  270. self.cumulative_sizes = self.cumsum(self.datasets)
  271. def __len__(self) -> int:
  272. return self.cumulative_sizes[-1]
  273. def __getitem__(self, idx):
  274. if idx < 0:
  275. if -idx > len(self):
  276. raise ValueError(
  277. "absolute value of index should not exceed dataset length"
  278. )
  279. idx = len(self) + idx
  280. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  281. if dataset_idx == 0:
  282. sample_idx = idx
  283. else:
  284. sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
  285. return self.datasets[dataset_idx][sample_idx]
  286. @property
  287. @deprecated(
  288. "`cummulative_sizes` attribute is renamed to `cumulative_sizes`",
  289. category=FutureWarning,
  290. )
  291. def cummulative_sizes(self):
  292. return self.cumulative_sizes
  293. class ChainDataset(IterableDataset):
  294. r"""Dataset for chaining multiple :class:`IterableDataset` s.
  295. This class is useful to assemble different existing dataset streams. The
  296. chaining operation is done on-the-fly, so concatenating large-scale
  297. datasets with this class will be efficient.
  298. Args:
  299. datasets (iterable of IterableDataset): datasets to be chained together
  300. """
  301. def __init__(self, datasets: Iterable[Dataset]) -> None:
  302. super().__init__()
  303. self.datasets = datasets
  304. def __iter__(self):
  305. for d in self.datasets:
  306. if not isinstance(d, IterableDataset):
  307. raise AssertionError("ChainDataset only supports IterableDataset")
  308. yield from d
  309. def __len__(self) -> int:
  310. total = 0
  311. for d in self.datasets:
  312. if not isinstance(d, IterableDataset):
  313. raise AssertionError("ChainDataset only supports IterableDataset")
  314. total += len(d) # type: ignore[arg-type]
  315. return total
  316. class Subset(Dataset[_T_co]):
  317. r"""
  318. Subset of a dataset at specified indices.
  319. .. note::
  320. When subclassing `Subset` and overriding `__getitem__`, you **must** also
  321. override `__getitems__` to ensure `DataLoader` works correctly with your
  322. custom logic. If you override only `__getitem__`, a `NotImplementedError`
  323. will be raised when using `DataLoader`.
  324. A simple implementation of `__getitems__` can delegate to `__getitem__`:
  325. .. code-block:: python
  326. def __getitems__(self, indices):
  327. return [self.__getitem__(idx) for idx in indices]
  328. For better performance, consider implementing batch-aware logic in
  329. `__getitems__` instead of calling `__getitem__` multiple times.
  330. Args:
  331. dataset (Dataset): The whole Dataset
  332. indices (sequence): Indices in the whole set selected for subset
  333. """
  334. dataset: Dataset[_T_co]
  335. indices: Sequence[int]
  336. def __init__(self, dataset: Dataset[_T_co], indices: Sequence[int]) -> None:
  337. self.dataset = dataset
  338. self.indices = indices
  339. # Check if __getitem__ is overridden but __getitems__ is not
  340. if (
  341. type(self).__getitem__ is not Subset.__getitem__
  342. and type(self).__getitems__ is Subset.__getitems__
  343. ):
  344. raise NotImplementedError(
  345. f"{type(self).__name__} overrides __getitem__ but not __getitems__. "
  346. "When subclassing Subset and overriding __getitem__, you must also override "
  347. "__getitems__ to ensure DataLoader works correctly with your custom logic. "
  348. "A simple implementation:\n\n"
  349. "def __getitems__(self, indices):\n"
  350. " return [self.__getitem__(idx) for idx in indices]"
  351. )
  352. def __getitem__(self, idx):
  353. if isinstance(idx, list):
  354. return self.dataset[[self.indices[i] for i in idx]]
  355. return self.dataset[self.indices[idx]]
  356. def __getitems__(self, indices: list[int]) -> list[_T_co]:
  357. # add batched sampling support when parent dataset supports it.
  358. # see torch.utils.data._utils.fetch._MapDatasetFetcher
  359. if callable(getattr(self.dataset, "__getitems__", None)):
  360. return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined]
  361. else:
  362. return [self.dataset[self.indices[idx]] for idx in indices]
  363. def __len__(self) -> int:
  364. return len(self.indices)
  365. def random_split(
  366. dataset: Dataset[_T],
  367. lengths: Sequence[int | float],
  368. generator: Generator | None = default_generator,
  369. ) -> list[Subset[_T]]:
  370. r"""
  371. Randomly split a dataset into non-overlapping new datasets of given lengths.
  372. If a list of fractions that sum up to 1 is given,
  373. the lengths will be computed automatically as
  374. floor(frac * len(dataset)) for each fraction provided.
  375. After computing the lengths, if there are any remainders, 1 count will be
  376. distributed in round-robin fashion to the lengths
  377. until there are no remainders left.
  378. Optionally fix the generator for reproducible results, e.g.:
  379. Example:
  380. >>> # xdoctest: +SKIP
  381. >>> generator1 = torch.Generator().manual_seed(42)
  382. >>> generator2 = torch.Generator().manual_seed(42)
  383. >>> random_split(range(10), [3, 7], generator=generator1)
  384. >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
  385. Args:
  386. dataset (Dataset): Dataset to be split
  387. lengths (sequence): lengths or fractions of splits to be produced
  388. generator (Generator): Generator used for the random permutation.
  389. """
  390. if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
  391. subset_lengths: list[int] = []
  392. for i, frac in enumerate(lengths):
  393. if frac < 0 or frac > 1:
  394. raise ValueError(f"Fraction at index {i} is not between 0 and 1")
  395. n_items_in_split = math.floor(len(dataset) * frac) # type: ignore[arg-type]
  396. subset_lengths.append(n_items_in_split)
  397. remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type]
  398. # add 1 to all the lengths in round-robin fashion until the remainder is 0
  399. for i in range(remainder):
  400. idx_to_add_at = i % len(subset_lengths)
  401. subset_lengths[idx_to_add_at] += 1
  402. lengths = subset_lengths
  403. for i, length in enumerate(lengths):
  404. if length == 0:
  405. warnings.warn(
  406. f"Length of split at index {i} is 0. "
  407. f"This might result in an empty dataset.",
  408. stacklevel=2,
  409. )
  410. # Cannot verify that dataset is Sized
  411. if sum(lengths) != len(dataset): # type: ignore[arg-type]
  412. raise ValueError(
  413. "Sum of input lengths does not equal the length of the input dataset!"
  414. )
  415. indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload]
  416. lengths = cast(Sequence[int], lengths)
  417. return [
  418. Subset(dataset, indices[offset - length : offset])
  419. for offset, length in zip(itertools.accumulate(lengths), lengths, strict=True)
  420. ]