datapipe.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. import functools
  2. import pickle
  3. from collections.abc import Callable, Iterable, Iterator
  4. from typing import TypeVar
  5. from torch.utils._import_utils import import_dill
  6. from torch.utils.data.datapipes._hook_iterator import _SnapshotState
  7. from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
  8. from torch.utils.data.datapipes.utils.common import (
  9. _deprecation_warning,
  10. _iter_deprecated_functional_names,
  11. _map_deprecated_functional_names,
  12. )
  13. from torch.utils.data.dataset import Dataset, IterableDataset
  14. dill = import_dill()
  15. HAS_DILL = dill is not None
  16. __all__ = [
  17. "DataChunk",
  18. "DFIterDataPipe",
  19. "IterDataPipe",
  20. "MapDataPipe",
  21. ]
  22. _T = TypeVar("_T")
  23. _T_co = TypeVar("_T_co", covariant=True)
  24. UNTRACABLE_DATAFRAME_PIPES = [
  25. "batch", # As it returns DataChunks
  26. "groupby", # As it returns DataChunks
  27. "_dataframes_as_tuples", # As it unpacks DF
  28. "trace_as_dataframe", # As it used to mark DF for tracing
  29. ]
  30. class DataChunk(list[_T]):
  31. def __init__(self, items: Iterable[_T]) -> None:
  32. items = list(items)
  33. super().__init__(items)
  34. self.items = items
  35. def as_str(self, indent: str = "") -> str:
  36. return indent + "[" + ", ".join(str(i) for i in iter(self)) + "]"
  37. def __iter__(self) -> Iterator[_T]:
  38. yield from super().__iter__()
  39. def raw_iterator(self) -> Iterator[_T]:
  40. yield from self.items
  41. class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
  42. r"""
  43. Iterable-style DataPipe.
  44. All DataPipes that represent an iterable of data samples should subclass this.
  45. This style of DataPipes is particularly useful when data come from a stream, or
  46. when the number of samples is too large to fit them all in memory. ``IterDataPipe`` is lazily initialized and its
  47. elements are computed only when ``next()`` is called on the iterator of an ``IterDataPipe``.
  48. All subclasses should overwrite :meth:`__iter__`, which would return an
  49. iterator of samples in this DataPipe. Calling ``__iter__`` of an ``IterDataPipe`` automatically invokes its
  50. method ``reset()``, which by default performs no operation. When writing a custom ``IterDataPipe``, users should
  51. override ``reset()`` if necessary. The common usages include resetting buffers, pointers,
  52. and various state variables within the custom ``IterDataPipe``.
  53. Note:
  54. Only `one` iterator can be valid for each ``IterDataPipe`` at a time,
  55. and the creation a second iterator will invalidate the first one. This constraint is necessary because
  56. some ``IterDataPipe`` have internal buffers, whose states can become invalid if there are multiple iterators.
  57. The code example below presents details on how this constraint looks in practice.
  58. If you have any feedback related to this constraint, please see `GitHub IterDataPipe Single Iterator Issue`_.
  59. These DataPipes can be invoked in two ways, using the class constructor or applying their
  60. functional form onto an existing ``IterDataPipe`` (recommended, available to most but not all DataPipes).
  61. You can chain multiple `IterDataPipe` together to form a pipeline that will perform multiple
  62. operations in succession.
  63. .. _GitHub IterDataPipe Single Iterator Issue:
  64. https://github.com/pytorch/data/issues/45
  65. Note:
  66. When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
  67. item in the DataPipe will be yielded from the :class:`~torch.utils.data.DataLoader`
  68. iterator. When :attr:`num_workers > 0`, each worker process will have a
  69. different copy of the DataPipe object, so it is often desired to configure
  70. each copy independently to avoid having duplicate data returned from the
  71. workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
  72. process, returns information about the worker. It can be used in either the
  73. dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
  74. :attr:`worker_init_fn` option to modify each copy's behavior.
  75. Examples:
  76. General Usage:
  77. >>> # xdoctest: +SKIP
  78. >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
  79. >>> dp = IterableWrapper(range(10))
  80. >>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor
  81. >>> map_dp_2 = dp.map(
  82. ... lambda x: x + 1
  83. ... ) # Using functional form (recommended)
  84. >>> list(map_dp_1)
  85. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  86. >>> list(map_dp_2)
  87. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  88. >>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0)
  89. >>> list(filter_dp)
  90. [2, 4, 6, 8, 10]
  91. Single Iterator Constraint Example:
  92. >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
  93. >>> source_dp = IterableWrapper(range(10))
  94. >>> it1 = iter(source_dp)
  95. >>> list(it1)
  96. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  97. >>> it1 = iter(source_dp)
  98. >>> it2 = iter(
  99. ... source_dp
  100. ... ) # The creation of a new iterator invalidates `it1`
  101. >>> next(it2)
  102. 0
  103. >>> next(it1) # Further usage of `it1` will raise a `RunTimeError`
  104. """
  105. functions: dict[str, Callable] = {}
  106. reduce_ex_hook: Callable | None = None
  107. getstate_hook: Callable | None = None
  108. str_hook: Callable | None = None
  109. repr_hook: Callable | None = None
  110. _valid_iterator_id: int | None = None
  111. _number_of_samples_yielded: int = 0
  112. _snapshot_state: _SnapshotState = _SnapshotState.NotStarted
  113. _fast_forward_iterator: Iterator | None = None
  114. def __iter__(self) -> Iterator[_T_co]:
  115. # pyrefly: ignore [bad-return]
  116. return self
  117. def __getattr__(self, attribute_name):
  118. if attribute_name in IterDataPipe.functions:
  119. if attribute_name in _iter_deprecated_functional_names:
  120. kwargs = _iter_deprecated_functional_names[attribute_name]
  121. _deprecation_warning(**kwargs)
  122. f = IterDataPipe.functions[attribute_name]
  123. function = functools.partial(f, self)
  124. functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
  125. return function
  126. else:
  127. raise AttributeError(
  128. f"'{self.__class__.__name__}' object has no attribute '{attribute_name}"
  129. )
  130. @classmethod
  131. def register_function(cls, function_name, function) -> None:
  132. cls.functions[function_name] = function
  133. @classmethod
  134. def register_datapipe_as_function(
  135. cls, function_name, cls_to_register, enable_df_api_tracing=False
  136. ) -> None:
  137. if function_name in cls.functions:
  138. raise Exception( # noqa: TRY002
  139. f"Unable to add DataPipe function name {function_name} as it is already taken"
  140. )
  141. def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
  142. result_pipe = cls(source_dp, *args, **kwargs)
  143. if isinstance(result_pipe, IterDataPipe):
  144. if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
  145. if function_name not in UNTRACABLE_DATAFRAME_PIPES:
  146. result_pipe = result_pipe.trace_as_dataframe()
  147. return result_pipe
  148. function = functools.partial(
  149. class_function, cls_to_register, enable_df_api_tracing
  150. )
  151. functools.update_wrapper(
  152. wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
  153. )
  154. cls.functions[function_name] = function
  155. def __getstate__(self):
  156. """
  157. Serialize `lambda` functions when `dill` is available.
  158. If this doesn't cover your custom DataPipe's use case, consider writing custom methods for
  159. `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization.
  160. """
  161. state = self.__dict__
  162. if IterDataPipe.getstate_hook is not None:
  163. return IterDataPipe.getstate_hook(state)
  164. return state
  165. def __reduce_ex__(self, *args, **kwargs):
  166. if IterDataPipe.reduce_ex_hook is not None:
  167. try:
  168. return IterDataPipe.reduce_ex_hook(self)
  169. except NotImplementedError:
  170. pass
  171. return super().__reduce_ex__(*args, **kwargs)
  172. @classmethod
  173. def set_getstate_hook(cls, hook_fn) -> None:
  174. if IterDataPipe.getstate_hook is not None and hook_fn is not None:
  175. raise RuntimeError("Attempt to override existing getstate_hook")
  176. IterDataPipe.getstate_hook = hook_fn
  177. @classmethod
  178. def set_reduce_ex_hook(cls, hook_fn) -> None:
  179. if IterDataPipe.reduce_ex_hook is not None and hook_fn is not None:
  180. raise RuntimeError("Attempt to override existing reduce_ex_hook")
  181. IterDataPipe.reduce_ex_hook = hook_fn
  182. def __repr__(self) -> str:
  183. if self.repr_hook is not None:
  184. return self.repr_hook(self)
  185. # Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name
  186. return str(self.__class__.__qualname__)
  187. def __str__(self) -> str:
  188. if self.str_hook is not None:
  189. return self.str_hook(self)
  190. # Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name
  191. return str(self.__class__.__qualname__)
  192. def __dir__(self):
  193. # for auto-completion in a REPL (e.g. Jupyter notebook)
  194. return list(super().__dir__()) + list(self.functions.keys())
  195. def reset(self) -> None:
  196. r"""
  197. Reset the `IterDataPipe` to the initial state.
  198. By default, no-op. For subclasses of `IterDataPipe`, depending on their functionalities,
  199. they may want to override this method with implementations that
  200. may clear the buffers and reset pointers of the DataPipe.
  201. The `reset` method is always called when `__iter__` is called as part of `hook_iterator`.
  202. """
  203. class DFIterDataPipe(IterDataPipe):
  204. def _is_dfpipe(self) -> bool:
  205. return True
  206. class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
  207. r"""
  208. Map-style DataPipe.
  209. All datasets that represent a map from keys to data samples should subclass this.
  210. Subclasses should overwrite :meth:`__getitem__`, supporting fetching a
  211. data sample for a given, unique key. Subclasses can also optionally overwrite
  212. :meth:`__len__`, which is expected to return the size of the dataset by many
  213. :class:`~torch.utils.data.Sampler` implementations and the default options
  214. of :class:`~torch.utils.data.DataLoader`.
  215. These DataPipes can be invoked in two ways, using the class constructor or applying their
  216. functional form onto an existing `MapDataPipe` (recommend, available to most but not all DataPipes).
  217. Note:
  218. :class:`~torch.utils.data.DataLoader` by default constructs an index
  219. sampler that yields integral indices. To make it work with a map-style
  220. DataPipe with non-integral indices/keys, a custom sampler must be provided.
  221. Example:
  222. >>> # xdoctest: +SKIP
  223. >>> from torchdata.datapipes.map import SequenceWrapper, Mapper
  224. >>> dp = SequenceWrapper(range(10))
  225. >>> map_dp_1 = dp.map(lambda x: x + 1) # Using functional form (recommended)
  226. >>> list(map_dp_1)
  227. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  228. >>> map_dp_2 = Mapper(dp, lambda x: x + 1) # Using class constructor
  229. >>> list(map_dp_2)
  230. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  231. >>> batch_dp = map_dp_1.batch(batch_size=2)
  232. >>> list(batch_dp)
  233. [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
  234. """
  235. functions: dict[str, Callable] = {}
  236. reduce_ex_hook: Callable | None = None
  237. getstate_hook: Callable | None = None
  238. str_hook: Callable | None = None
  239. repr_hook: Callable | None = None
  240. def __getattr__(self, attribute_name):
  241. if attribute_name in MapDataPipe.functions:
  242. if attribute_name in _map_deprecated_functional_names:
  243. kwargs = _map_deprecated_functional_names[attribute_name]
  244. _deprecation_warning(**kwargs)
  245. f = MapDataPipe.functions[attribute_name]
  246. function = functools.partial(f, self)
  247. functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
  248. return function
  249. else:
  250. raise AttributeError(
  251. f"'{self.__class__.__name__}' object has no attribute '{attribute_name}"
  252. )
  253. @classmethod
  254. def register_function(cls, function_name, function) -> None:
  255. cls.functions[function_name] = function
  256. @classmethod
  257. def register_datapipe_as_function(cls, function_name, cls_to_register) -> None:
  258. if function_name in cls.functions:
  259. raise Exception( # noqa: TRY002
  260. f"Unable to add DataPipe function name {function_name} as it is already taken"
  261. )
  262. def class_function(cls, source_dp, *args, **kwargs):
  263. result_pipe = cls(source_dp, *args, **kwargs)
  264. return result_pipe
  265. function = functools.partial(class_function, cls_to_register)
  266. functools.update_wrapper(
  267. wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
  268. )
  269. cls.functions[function_name] = function
  270. def __getstate__(self):
  271. """
  272. Serialize `lambda` functions when `dill` is available.
  273. If this doesn't cover your custom DataPipe's use case, consider writing custom methods for
  274. `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization.
  275. """
  276. state = self.__dict__
  277. if MapDataPipe.getstate_hook is not None:
  278. return MapDataPipe.getstate_hook(state)
  279. return state
  280. def __reduce_ex__(self, *args, **kwargs):
  281. if MapDataPipe.reduce_ex_hook is not None:
  282. try:
  283. return MapDataPipe.reduce_ex_hook(self)
  284. except NotImplementedError:
  285. pass
  286. return super().__reduce_ex__(*args, **kwargs)
  287. @classmethod
  288. def set_getstate_hook(cls, hook_fn) -> None:
  289. if MapDataPipe.getstate_hook is not None and hook_fn is not None:
  290. raise RuntimeError("Attempt to override existing getstate_hook")
  291. MapDataPipe.getstate_hook = hook_fn
  292. @classmethod
  293. def set_reduce_ex_hook(cls, hook_fn) -> None:
  294. if MapDataPipe.reduce_ex_hook is not None and hook_fn is not None:
  295. raise RuntimeError("Attempt to override existing reduce_ex_hook")
  296. MapDataPipe.reduce_ex_hook = hook_fn
  297. def __repr__(self) -> str:
  298. if self.repr_hook is not None:
  299. return self.repr_hook(self)
  300. # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
  301. return str(self.__class__.__qualname__)
  302. def __str__(self) -> str:
  303. if self.str_hook is not None:
  304. return self.str_hook(self)
  305. # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
  306. return str(self.__class__.__qualname__)
  307. def __dir__(self):
  308. # for auto-completion in a REPL (e.g. Jupyter notebook)
  309. return list(super().__dir__()) + list(self.functions.keys())
  310. class _DataPipeSerializationWrapper:
  311. def __init__(self, datapipe) -> None:
  312. self._datapipe = datapipe
  313. def __getstate__(self):
  314. use_dill = False
  315. try:
  316. value = pickle.dumps(self._datapipe)
  317. except Exception:
  318. if HAS_DILL:
  319. # pyrefly: ignore [missing-attribute]
  320. value = dill.dumps(self._datapipe)
  321. use_dill = True
  322. else:
  323. raise
  324. return (value, use_dill)
  325. def __setstate__(self, state):
  326. value, use_dill = state
  327. if use_dill:
  328. # pyrefly: ignore [missing-attribute]
  329. self._datapipe = dill.loads(value)
  330. else:
  331. self._datapipe = pickle.loads(value)
  332. def __len__(self) -> int:
  333. try:
  334. return len(self._datapipe)
  335. except Exception as e:
  336. raise TypeError(
  337. f"{type(self).__name__} instance doesn't have valid length"
  338. ) from e
  339. class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
  340. def __init__(self, datapipe: IterDataPipe[_T_co]) -> None:
  341. super().__init__(datapipe)
  342. # pyrefly: ignore [invalid-type-var]
  343. self._datapipe_iter: Iterator[_T_co] | None = None
  344. def __iter__(self) -> "_IterDataPipeSerializationWrapper":
  345. self._datapipe_iter = iter(self._datapipe)
  346. return self
  347. def __next__(self) -> _T_co: # type: ignore[type-var]
  348. if self._datapipe_iter is None:
  349. raise AssertionError(
  350. "Iterator has not been initialized; call __iter__() before __next__()"
  351. )
  352. return next(self._datapipe_iter)
  353. class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe):
  354. def __getitem__(self, idx):
  355. return self._datapipe[idx]