dataloader.py 79 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709
  1. # mypy: allow-untyped-defs
  2. r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter.
  3. To support these two classes, in `./_utils` we define many utility methods and
  4. functions to be run in multiprocessing. E.g., the data loading worker loop is
  5. in `./_utils/worker.py`.
  6. """
  7. from __future__ import annotations
  8. import contextlib
  9. import functools
  10. import itertools
  11. import logging
  12. import multiprocessing as python_multiprocessing
  13. import os
  14. import queue
  15. import threading
  16. import warnings
  17. from collections.abc import Callable
  18. from typing import Any, Generic, NoReturn, TYPE_CHECKING, TypeVar
  19. from typing_extensions import Self
  20. import torch
  21. import torch.distributed as dist
  22. import torch.utils.data.graph_settings
  23. from torch._utils import ExceptionWrapper
  24. from torch.utils.data import _utils
  25. from torch.utils.data.datapipes.datapipe import (
  26. _IterDataPipeSerializationWrapper,
  27. _MapDataPipeSerializationWrapper,
  28. IterDataPipe,
  29. MapDataPipe,
  30. )
  31. from torch.utils.data.dataset import Dataset, IterableDataset
  32. from torch.utils.data.sampler import (
  33. BatchSampler,
  34. RandomSampler,
  35. Sampler,
  36. SequentialSampler,
  37. )
  38. if TYPE_CHECKING:
  39. from collections.abc import Iterable
  40. __all__ = [
  41. "DataLoader",
  42. "get_worker_info",
  43. "default_collate",
  44. "default_convert",
  45. ]
  46. _T = TypeVar("_T")
  47. _T_co = TypeVar("_T_co", covariant=True)
  48. _worker_init_fn_t = Callable[[int], None]
  49. # Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
  50. # type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
  51. # See https://github.com/python/mypy/issues/3737.
  52. _collate_fn_t = Callable[[list[_T]], Any]
  53. # These functions used to be defined in this file. However, it was moved to
  54. # _utils/collate.py. Although it is rather hard to access this from user land
  55. # (one has to explicitly directly `import torch.utils.data.dataloader`), there
  56. # probably is user code out there using it. This aliasing maintains BC in this
  57. # aspect.
  58. default_collate: _collate_fn_t = _utils.collate.default_collate
  59. default_convert = _utils.collate.default_convert
  60. get_worker_info = _utils.worker.get_worker_info
  61. logger = logging.getLogger(__name__)
  62. class _DatasetKind:
  63. Map = 0
  64. Iterable = 1
  65. @staticmethod
  66. def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
  67. if kind == _DatasetKind.Map:
  68. return _utils.fetch._MapDatasetFetcher(
  69. dataset, auto_collation, collate_fn, drop_last
  70. )
  71. else:
  72. return _utils.fetch._IterableDatasetFetcher(
  73. dataset, auto_collation, collate_fn, drop_last
  74. )
  75. class _InfiniteConstantSampler(Sampler):
  76. r"""Analogous to ``itertools.repeat(None, None)``.
  77. Used as sampler for :class:`~torch.utils.data.IterableDataset`.
  78. """
  79. def __iter__(self):
  80. while True:
  81. yield None
  82. def _get_distributed_settings():
  83. if dist.is_available() and dist.is_initialized():
  84. return dist.get_world_size(), dist.get_rank()
  85. else:
  86. return 1, 0
  87. def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id) -> None:
  88. global_worker_id = worker_id
  89. info = torch.utils.data.get_worker_info()
  90. if info is None:
  91. raise AssertionError("Worker info is None in sharding worker init function")
  92. total_workers = info.num_workers
  93. datapipe = info.dataset
  94. if not isinstance(datapipe, (IterDataPipe, MapDataPipe)):
  95. raise AssertionError(
  96. "datapipe must be an instance of IterDataPipe or MapDataPipe"
  97. )
  98. # To distribute elements across distributed process evenly, we should shard data on distributed
  99. # processes first then shard on worker processes
  100. total_workers *= world_size
  101. global_worker_id = global_worker_id * world_size + rank_id
  102. # For BC, use default SHARDING_PRIORITIES
  103. torch.utils.data.graph_settings.apply_sharding(
  104. datapipe, total_workers, global_worker_id
  105. )
  106. if worker_init_fn is not None:
  107. worker_init_fn(worker_id)
  108. def _share_dist_seed(generator, pg):
  109. _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
  110. if isinstance(pg, dist.ProcessGroup):
  111. dist.broadcast(_shared_seed, src=0, group=pg)
  112. return _shared_seed.item()
  113. class DataLoader(Generic[_T_co]):
  114. r"""
  115. Data loader combines a dataset and a sampler, and provides an iterable over the given dataset.
  116. The :class:`~torch.utils.data.DataLoader` supports both map-style and
  117. iterable-style datasets with single- or multi-process loading, customizing
  118. loading order and optional automatic batching (collation) and memory pinning.
  119. See :py:mod:`torch.utils.data` documentation page for more details.
  120. Args:
  121. dataset (Dataset): dataset from which to load the data.
  122. batch_size (int, optional): how many samples per batch to load
  123. (default: ``1``).
  124. shuffle (bool, optional): set to ``True`` to have the data reshuffled
  125. at every epoch (default: ``False``).
  126. sampler (Sampler or Iterable, optional): defines the strategy to draw
  127. samples from the dataset. Can be any ``Iterable`` with ``__len__``
  128. implemented. If specified, :attr:`shuffle` must not be specified.
  129. batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
  130. returns a batch of indices at a time. Mutually exclusive with
  131. :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
  132. and :attr:`drop_last`.
  133. num_workers (int, optional): how many subprocesses to use for data
  134. loading. ``0`` means that the data will be loaded in the main process.
  135. (default: ``0``)
  136. collate_fn (Callable, optional): merges a list of samples to form a
  137. mini-batch of Tensor(s). Used when using batched loading from a
  138. map-style dataset.
  139. pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
  140. into device/CUDA pinned memory before returning them. If your data elements
  141. are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
  142. see the example below.
  143. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
  144. if the dataset size is not divisible by the batch size. If ``False`` and
  145. the size of dataset is not divisible by the batch size, then the last batch
  146. will be smaller. (default: ``False``)
  147. timeout (numeric, optional): if positive, the timeout value for collecting a batch
  148. from workers. Should always be non-negative. (default: ``0``)
  149. worker_init_fn (Callable, optional): If not ``None``, this will be called on each
  150. worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
  151. input, after seeding and before data loading. (default: ``None``)
  152. multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If
  153. ``None``, the default
  154. `multiprocessing context <https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods>`_ # noqa: D401
  155. of your operating system will
  156. be used. (default: ``None``)
  157. generator (torch.Generator, optional): If not ``None``, this RNG will be used
  158. by RandomSampler to generate random indexes and multiprocessing to generate
  159. ``base_seed`` for workers. (default: ``None``)
  160. prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
  161. in advance by each worker. ``2`` means there will be a total of
  162. 2 * num_workers batches prefetched across all workers. (default value depends
  163. on the set value for num_workers. If value of num_workers=0 default is ``None``.
  164. Otherwise, if value of ``num_workers > 0`` default is ``2``).
  165. persistent_workers (bool, optional): If ``True``, the data loader will not shut down
  166. the worker processes after a dataset has been consumed once. This allows to
  167. maintain the workers `Dataset` instances alive. (default: ``False``)
  168. pin_memory_device (str, optional): Deprecated, the current :ref:`accelerator<accelerators>`
  169. will be used as the device if ``pin_memory=True``.
  170. in_order (bool, optional): If ``False``, the data loader will not enforce that batches
  171. are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``)
  172. .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
  173. cannot be an unpicklable object, e.g., a lambda function. See
  174. :ref:`multiprocessing-best-practices` on more details related
  175. to multiprocessing in PyTorch.
  176. .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
  177. When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
  178. it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
  179. rounding depending on :attr:`drop_last`, regardless of multi-process loading
  180. configurations. This represents the best guess PyTorch can make because PyTorch
  181. trusts user :attr:`dataset` code in correctly handling multi-process
  182. loading to avoid duplicate data.
  183. However, if sharding results in multiple workers having incomplete last batches,
  184. this estimate can still be inaccurate, because (1) an otherwise complete batch can
  185. be broken into multiple ones and (2) more than one batch worth of samples can be
  186. dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
  187. cases in general.
  188. See `Dataset Types`_ for more details on these two types of datasets and how
  189. :class:`~torch.utils.data.IterableDataset` interacts with
  190. `Multi-process data loading`_.
  191. .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
  192. :ref:`data-loading-randomness` notes for random seed related questions.
  193. .. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data
  194. distribution being fed to the trainer in cases with imbalanced data.
  195. """
  196. dataset: Dataset[_T_co]
  197. batch_size: int | None
  198. num_workers: int
  199. pin_memory: bool
  200. drop_last: bool
  201. timeout: float
  202. sampler: Sampler | Iterable
  203. pin_memory_device: str
  204. prefetch_factor: int | None
  205. _iterator: _BaseDataLoaderIter | None
  206. __initialized = False
  207. def __init__(
  208. self,
  209. dataset: Dataset[_T_co],
  210. batch_size: int | None = 1,
  211. shuffle: bool | None = None,
  212. sampler: Sampler | Iterable | None = None,
  213. batch_sampler: Sampler[list] | Iterable[list] | None = None,
  214. num_workers: int = 0,
  215. collate_fn: _collate_fn_t | None = None,
  216. pin_memory: bool = False,
  217. drop_last: bool = False,
  218. timeout: float = 0,
  219. worker_init_fn: _worker_init_fn_t | None = None,
  220. multiprocessing_context=None,
  221. generator=None,
  222. *,
  223. prefetch_factor: int | None = None,
  224. persistent_workers: bool = False,
  225. pin_memory_device: str = "",
  226. in_order: bool = True,
  227. ) -> None:
  228. torch._C._log_api_usage_once("python.data_loader")
  229. if num_workers < 0:
  230. raise ValueError(
  231. "num_workers option should be non-negative; "
  232. "use num_workers=0 to disable multiprocessing."
  233. )
  234. if timeout < 0:
  235. raise ValueError("timeout option should be non-negative")
  236. if num_workers == 0 and prefetch_factor is not None:
  237. raise ValueError(
  238. "prefetch_factor option could only be specified in multiprocessing."
  239. "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
  240. )
  241. elif num_workers > 0 and prefetch_factor is None:
  242. prefetch_factor = 2
  243. elif prefetch_factor is not None and prefetch_factor < 0:
  244. raise ValueError("prefetch_factor option should be non-negative")
  245. if persistent_workers and num_workers == 0:
  246. raise ValueError("persistent_workers option needs num_workers > 0")
  247. self.dataset = dataset
  248. self.num_workers = num_workers
  249. self.prefetch_factor = prefetch_factor
  250. self.pin_memory = pin_memory
  251. self.pin_memory_device = pin_memory_device
  252. self.timeout = timeout
  253. self.worker_init_fn = worker_init_fn
  254. self.multiprocessing_context = multiprocessing_context
  255. self.in_order = in_order
  256. # Adds forward compatibilities so classic DataLoader can work with DataPipes:
  257. # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
  258. if isinstance(self.dataset, IterDataPipe):
  259. self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
  260. elif isinstance(self.dataset, MapDataPipe):
  261. self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
  262. # Arg-check dataset related before checking samplers because we want to
  263. # tell users that iterable-style datasets are incompatible with custom
  264. # samplers first, so that they don't learn that this combo doesn't work
  265. # after spending time fixing the custom sampler errors.
  266. if isinstance(dataset, IterableDataset):
  267. self._dataset_kind = _DatasetKind.Iterable
  268. # NOTE [ Custom Samplers and IterableDataset ]
  269. #
  270. # `IterableDataset` does not support custom `batch_sampler` or
  271. # `sampler` since the key is irrelevant (unless we support
  272. # generator-style dataset one day...).
  273. #
  274. # For `sampler`, we always create a dummy sampler. This is an
  275. # infinite sampler even when the dataset may have an implemented
  276. # finite `__len__` because in multi-process data loading, naive
  277. # settings will return duplicated data (which may be desired), and
  278. # thus using a sampler with length matching that of dataset will
  279. # cause data lost (you may have duplicates of the first couple
  280. # batches, but never see anything afterwards). Therefore,
  281. # `Iterabledataset` always uses an infinite sampler, an instance of
  282. # `_InfiniteConstantSampler` defined above.
  283. #
  284. # A custom `batch_sampler` essentially only controls the batch size.
  285. # However, it is unclear how useful it would be since an iterable-style
  286. # dataset can handle that within itself. Moreover, it is pointless
  287. # in multi-process data loading as the assignment order of batches
  288. # to workers is an implementation detail so users can not control
  289. # how to batchify each worker's iterable. Thus, we disable this
  290. # option. If this turns out to be useful in future, we can re-enable
  291. # this, and support custom samplers that specify the assignments to
  292. # specific workers.
  293. if isinstance(dataset, IterDataPipe):
  294. if shuffle is not None:
  295. dataset = torch.utils.data.graph_settings.apply_shuffle_settings(
  296. dataset, shuffle=shuffle
  297. )
  298. # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
  299. elif shuffle not in {False, None}:
  300. raise ValueError(
  301. f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}"
  302. )
  303. if sampler is not None:
  304. # See NOTE [ Custom Samplers and IterableDataset ]
  305. raise ValueError(
  306. f"DataLoader with IterableDataset: expected unspecified sampler option, but got sampler={sampler}"
  307. )
  308. elif batch_sampler is not None:
  309. # See NOTE [ Custom Samplers and IterableDataset ]
  310. raise ValueError(
  311. "DataLoader with IterableDataset: expected unspecified "
  312. f"batch_sampler option, but got batch_sampler={batch_sampler}"
  313. )
  314. else:
  315. shuffle = bool(shuffle)
  316. self._dataset_kind = _DatasetKind.Map
  317. if sampler is not None and shuffle:
  318. raise ValueError("sampler option is mutually exclusive with shuffle")
  319. if batch_sampler is not None:
  320. # auto_collation with custom batch_sampler
  321. if batch_size != 1 or shuffle or sampler is not None or drop_last:
  322. raise ValueError(
  323. "batch_sampler option is mutually exclusive "
  324. "with batch_size, shuffle, sampler, and "
  325. "drop_last"
  326. )
  327. batch_size = None
  328. drop_last = False
  329. elif batch_size is None:
  330. # no auto_collation
  331. if drop_last:
  332. raise ValueError(
  333. "batch_size=None option disables auto-batching "
  334. "and is mutually exclusive with drop_last"
  335. )
  336. if sampler is None: # give default samplers
  337. if self._dataset_kind == _DatasetKind.Iterable:
  338. # See NOTE [ Custom Samplers and IterableDataset ]
  339. sampler = _InfiniteConstantSampler()
  340. else: # map-style
  341. if shuffle:
  342. sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
  343. else:
  344. sampler = SequentialSampler(dataset) # type: ignore[arg-type]
  345. if batch_size is not None and batch_sampler is None:
  346. # auto_collation without custom batch_sampler
  347. batch_sampler = BatchSampler(sampler, batch_size, drop_last)
  348. self.batch_size = batch_size
  349. self.drop_last = drop_last
  350. self.sampler = sampler
  351. self.batch_sampler = batch_sampler
  352. self.generator = generator
  353. if collate_fn is None:
  354. if self._auto_collation:
  355. collate_fn = _utils.collate.default_collate
  356. else:
  357. collate_fn = _utils.collate.default_convert
  358. self.collate_fn = collate_fn
  359. self.persistent_workers = persistent_workers
  360. self.__initialized = True
  361. self._IterableDataset_len_called = (
  362. None # See NOTE [ IterableDataset and __len__ ]
  363. )
  364. self._iterator = None
  365. self.check_worker_number_rationality()
  366. torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined]
  367. def _get_iterator(self) -> _BaseDataLoaderIter:
  368. if self.num_workers == 0:
  369. return _SingleProcessDataLoaderIter(self)
  370. else:
  371. self.check_worker_number_rationality()
  372. return _MultiProcessingDataLoaderIter(self)
  373. @property
  374. def multiprocessing_context(self):
  375. return self.__multiprocessing_context
  376. @multiprocessing_context.setter
  377. def multiprocessing_context(self, multiprocessing_context) -> None:
  378. if multiprocessing_context is not None:
  379. if self.num_workers > 0:
  380. if isinstance(multiprocessing_context, str):
  381. valid_start_methods = torch.multiprocessing.get_all_start_methods()
  382. if multiprocessing_context not in valid_start_methods:
  383. raise ValueError(
  384. "multiprocessing_context option "
  385. f"should specify a valid start method in {valid_start_methods!r}, but got "
  386. f"multiprocessing_context={multiprocessing_context!r}"
  387. )
  388. multiprocessing_context = torch.multiprocessing.get_context(
  389. multiprocessing_context
  390. )
  391. if not isinstance(
  392. multiprocessing_context, python_multiprocessing.context.BaseContext
  393. ):
  394. raise TypeError(
  395. "multiprocessing_context option should be a valid context "
  396. "object or a string specifying the start method, but got "
  397. f"multiprocessing_context={multiprocessing_context}"
  398. )
  399. else:
  400. raise ValueError(
  401. "multiprocessing_context can only be used with "
  402. "multi-process loading (num_workers > 0), but got "
  403. f"num_workers={self.num_workers}"
  404. )
  405. self.__multiprocessing_context = multiprocessing_context
  406. def __setattr__(self, attr, val) -> None:
  407. if self.__initialized and attr in (
  408. "batch_size",
  409. "batch_sampler",
  410. "sampler",
  411. "drop_last",
  412. "dataset",
  413. "persistent_workers",
  414. ):
  415. raise ValueError(
  416. f"{attr} attribute should not be set after {self.__class__.__name__} is initialized"
  417. )
  418. super().__setattr__(attr, val)
  419. def __iter__(self) -> _BaseDataLoaderIter:
  420. # When using a single worker the returned iterator should be
  421. # created every time to avoid resetting its state
  422. # However, in the case of a multiple workers iterator
  423. # the iterator is only created once in the lifetime of the
  424. # DataLoader object so that workers can be reused
  425. if self.persistent_workers and self.num_workers > 0:
  426. if self._iterator is None:
  427. self._iterator = self._get_iterator()
  428. else:
  429. self._iterator._reset(self)
  430. return self._iterator
  431. else:
  432. return self._get_iterator()
  433. @property
  434. def _auto_collation(self):
  435. return self.batch_sampler is not None
  436. @property
  437. def _index_sampler(self):
  438. # The actual sampler used for generating indices for `_DatasetFetcher`
  439. # (see _utils/fetch.py) to read data at each time. This would be
  440. # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
  441. # We can't change `.sampler` and `.batch_sampler` attributes for BC
  442. # reasons.
  443. if self._auto_collation:
  444. return self.batch_sampler
  445. else:
  446. return self.sampler
  447. def __len__(self) -> int:
  448. if self._dataset_kind == _DatasetKind.Iterable:
  449. # NOTE [ IterableDataset and __len__ ]
  450. #
  451. # For `IterableDataset`, `__len__` could be inaccurate when one naively
  452. # does multi-processing data loading, since the samples will be duplicated.
  453. # However, no real use case should be actually using that behavior, so
  454. # it should count as a user error. We should generally trust user
  455. # code to do the proper thing (e.g., configure each replica differently
  456. # in `__iter__`), and give us the correct `__len__` if they choose to
  457. # implement it (this will still throw if the dataset does not implement
  458. # a `__len__`).
  459. #
  460. # To provide a further warning, we track if `__len__` was called on the
  461. # `DataLoader`, save the returned value in `self._len_called`, and warn
  462. # if the iterator ends up yielding more than this number of samples.
  463. # Cannot statically verify that dataset is Sized
  464. length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type]
  465. if (
  466. self.batch_size is not None
  467. ): # IterableDataset doesn't allow custom sampler or batch_sampler
  468. from math import ceil
  469. if self.drop_last:
  470. length = length // self.batch_size
  471. else:
  472. length = ceil(length / self.batch_size)
  473. return length
  474. else:
  475. return len(self._index_sampler)
  476. def check_worker_number_rationality(self) -> None:
  477. # This function check whether the dataloader's worker number is rational based on
  478. # current system's resource. Current rule is that if the number of workers this
  479. # Dataloader will create is bigger than the number of logical cpus that is allowed to
  480. # use, than we will pop up a warning to let user pay attention.
  481. #
  482. # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
  483. # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
  484. # DataLoader process can use half of them which is 32, then the rational max number of
  485. # worker that initiated from this process is 32.
  486. # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
  487. # So the warning message is triggered to notify the user to lower the worker number if
  488. # necessary.
  489. #
  490. #
  491. # [Note] Please note that this function respects `cpuset` only when os.sched_getaffinity is
  492. # available (available in most of Linux system, but not OSX and Windows).
  493. # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
  494. # it doesn't respect cpuset.
  495. # We don't take threading into account since each worker process is single threaded
  496. # at this time.
  497. #
  498. # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
  499. # other than `torch.set_num_threads` to 1 in the worker process, if the passing
  500. # in functions use 3rd party modules that rely on those threading flags to determine
  501. # how many thread to create (eg. numpy, etc), then it is caller's responsibility to
  502. # set those flags correctly.
  503. def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
  504. suggested_max_worker_msg = (
  505. (
  506. (
  507. "Our suggested max number of worker in current system is {}{}, which is smaller "
  508. "than what this DataLoader is going to create."
  509. ).format(
  510. num_worker_suggest,
  511. (
  512. ""
  513. if cpuset_checked
  514. else " (`cpuset` is not taken into account)"
  515. ),
  516. )
  517. )
  518. if num_worker_suggest is not None
  519. else (
  520. "DataLoader is not able to compute a suggested max number of worker in current system."
  521. )
  522. )
  523. warn_msg = (
  524. f"This DataLoader will create {num_worker_created} worker processes in total. {suggested_max_worker_msg} "
  525. "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
  526. "lower the worker number to avoid potential slowness/freeze if necessary."
  527. )
  528. return warn_msg
  529. if not self.num_workers or self.num_workers == 0:
  530. return
  531. # try to compute a suggested max number of worker based on system's resource
  532. max_num_worker_suggest = None
  533. cpuset_checked = False
  534. if hasattr(os, "sched_getaffinity"):
  535. try:
  536. max_num_worker_suggest = len(os.sched_getaffinity(0))
  537. cpuset_checked = True
  538. except Exception:
  539. pass
  540. if max_num_worker_suggest is None:
  541. # os.cpu_count() could return Optional[int]
  542. # get cpu count first and check None in order to satisfy mypy check
  543. cpu_count = os.cpu_count()
  544. if cpu_count is not None:
  545. max_num_worker_suggest = cpu_count
  546. if max_num_worker_suggest is None:
  547. warnings.warn(
  548. _create_warning_msg(
  549. max_num_worker_suggest, self.num_workers, cpuset_checked
  550. ),
  551. stacklevel=2,
  552. )
  553. return
  554. if self.num_workers > max_num_worker_suggest:
  555. warnings.warn(
  556. _create_warning_msg(
  557. max_num_worker_suggest, self.num_workers, cpuset_checked
  558. ),
  559. stacklevel=2,
  560. )
  561. class _BaseDataLoaderIter:
  562. def __init__(self, loader: DataLoader) -> None:
  563. self._dataset = loader.dataset
  564. self._shared_seed = None
  565. self._pg = None
  566. if isinstance(self._dataset, IterDataPipe):
  567. if dist.is_available() and dist.is_initialized():
  568. self._pg = dist.new_group(backend="gloo")
  569. self._shared_seed = _share_dist_seed(loader.generator, self._pg)
  570. shared_rng = torch.Generator()
  571. shared_rng.manual_seed(self._shared_seed)
  572. self._dataset = torch.utils.data.graph_settings.apply_random_seed(
  573. self._dataset, shared_rng
  574. )
  575. self._dataset_kind = loader._dataset_kind
  576. self._IterableDataset_len_called = loader._IterableDataset_len_called
  577. self._auto_collation = loader._auto_collation
  578. self._drop_last = loader.drop_last
  579. self._index_sampler = loader._index_sampler
  580. self._num_workers = loader.num_workers
  581. ws, rank = _get_distributed_settings()
  582. self._world_size = ws
  583. self._rank = rank
  584. if loader.pin_memory and loader.pin_memory_device:
  585. warnings.warn(
  586. "pin_memory_device is deprecated, the current accelerator will be used as the device,"
  587. f"ignore pin_memory_device='{loader.pin_memory_device}'.",
  588. stacklevel=2,
  589. )
  590. if loader.pin_memory and not torch.accelerator.is_available():
  591. warn_msg = (
  592. "'pin_memory' argument is set as true but no accelerator is found, "
  593. "then device pinned memory won't be used."
  594. )
  595. warnings.warn(warn_msg, stacklevel=2)
  596. # Enabling pin_memory in _BaseDataLoaderIter to support identical
  597. # behavior in forked implementations using _BaseDataLoaderIter.
  598. self._pin_memory = loader.pin_memory and torch.accelerator.is_available()
  599. # Set pin memory device based on the current accelerator.
  600. self._pin_memory_device = (
  601. acc.type
  602. if self._pin_memory
  603. and (acc := torch.accelerator.current_accelerator()) is not None
  604. else None
  605. )
  606. # Currently, pin_memory would raise error on the MPS backend (see
  607. # https://github.com/pytorch/pytorch/issues/86060), so forcibly
  608. # disable pin_memory on MPS. Remove this restriction once pinned
  609. # memory allocation for MPS is fixed.
  610. if self._pin_memory_device == "mps":
  611. self._pin_memory = False
  612. warn_msg = (
  613. "'pin_memory' argument is set as true but not supported on MPS now, "
  614. "device pinned memory won't be used."
  615. )
  616. warnings.warn(warn_msg, stacklevel=2)
  617. self._timeout = loader.timeout
  618. self._collate_fn = loader.collate_fn
  619. self._sampler_iter = iter(self._index_sampler)
  620. self._base_seed = (
  621. torch.empty((), dtype=torch.int64)
  622. .random_(generator=loader.generator)
  623. .item()
  624. )
  625. self._persistent_workers = loader.persistent_workers
  626. self._num_yielded = 0
  627. self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__"
  628. def __iter__(self) -> Self:
  629. return self
  630. def _reset(self, loader, first_iter=False) -> None:
  631. self._sampler_iter = iter(self._index_sampler)
  632. self._num_yielded = 0
  633. self._IterableDataset_len_called = loader._IterableDataset_len_called
  634. if isinstance(self._dataset, IterDataPipe):
  635. self._shared_seed = _share_dist_seed(loader.generator, self._pg)
  636. shared_rng = torch.Generator()
  637. shared_rng.manual_seed(self._shared_seed)
  638. self._dataset = torch.utils.data.graph_settings.apply_random_seed(
  639. self._dataset, shared_rng
  640. )
  641. def _next_index(self):
  642. return next(self._sampler_iter) # may raise StopIteration
  643. def _next_data(self) -> NoReturn:
  644. raise NotImplementedError
  645. def __next__(self) -> Any:
  646. with torch.autograd.profiler.record_function(self._profile_name):
  647. if self._sampler_iter is None:
  648. # TODO(https://github.com/pytorch/pytorch/issues/76750)
  649. self._reset() # type: ignore[call-arg]
  650. data = self._next_data()
  651. self._num_yielded += 1
  652. if (
  653. self._dataset_kind == _DatasetKind.Iterable
  654. and self._IterableDataset_len_called is not None
  655. and self._num_yielded > self._IterableDataset_len_called
  656. ):
  657. warn_msg = (
  658. f"Length of IterableDataset {self._dataset} was reported to be {self._IterableDataset_len_called}"
  659. f"(when accessing len(dataloader)), but {self._num_yielded} samples have been fetched. "
  660. )
  661. if self._num_workers > 0:
  662. warn_msg += (
  663. "For multiprocessing data-loading, this could be caused by not properly configuring the "
  664. "IterableDataset replica at each worker. Please see "
  665. "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
  666. )
  667. warnings.warn(warn_msg, stacklevel=2)
  668. return data
  669. def __len__(self) -> int:
  670. return len(self._index_sampler)
  671. def __getstate__(self):
  672. # TODO: add limited pickling support for sharing an iterator
  673. # across multiple threads for HOGWILD.
  674. # Probably the best way to do this is by moving the sample pushing
  675. # to a separate thread and then just sharing the data queue
  676. # but signalling the end is tricky without a non-blocking API
  677. raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
  678. class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
  679. def __init__(self, loader) -> None:
  680. super().__init__(loader)
  681. if self._timeout != 0:
  682. raise AssertionError("_SingleProcessDataLoaderIter requires timeout == 0")
  683. if self._num_workers != 0:
  684. raise AssertionError(
  685. "_SingleProcessDataLoaderIter requires num_workers == 0"
  686. )
  687. # Adds forward compatibilities so classic DataLoader can work with DataPipes:
  688. # Taking care of distributed sharding
  689. if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
  690. # For BC, use default SHARDING_PRIORITIES
  691. torch.utils.data.graph_settings.apply_sharding(
  692. self._dataset, self._world_size, self._rank
  693. )
  694. self._dataset_fetcher = _DatasetKind.create_fetcher(
  695. self._dataset_kind,
  696. self._dataset,
  697. self._auto_collation,
  698. self._collate_fn,
  699. self._drop_last,
  700. )
  701. def _next_data(self):
  702. index = self._next_index() # may raise StopIteration
  703. data = self._dataset_fetcher.fetch(index) # may raise StopIteration
  704. if self._pin_memory:
  705. data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
  706. return data
  707. class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
  708. r"""Iterates once over the DataLoader's dataset, as specified by the sampler."""
  709. # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
  710. #
  711. # Preliminary:
  712. #
  713. # Our data model looks like this (queues are indicated with curly brackets):
  714. #
  715. # main process ||
  716. # | ||
  717. # {index_queue} ||
  718. # | ||
  719. # worker processes || DATA
  720. # | ||
  721. # {worker_result_queue} || FLOW
  722. # | ||
  723. # pin_memory_thread of main process || DIRECTION
  724. # | ||
  725. # {data_queue} ||
  726. # | ||
  727. # data output \/
  728. #
  729. # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
  730. # `pin_memory=False`.
  731. #
  732. #
  733. # Terminating multiprocessing logic requires very careful design. In
  734. # particular, we need to make sure that
  735. #
  736. # 1. The iterator gracefully exits the workers when its last reference is
  737. # gone or it is depleted.
  738. #
  739. # In this case, the workers should be gracefully exited because the
  740. # main process may still need to continue to run, and we want cleaning
  741. # up code in the workers to be executed (e.g., releasing GPU memory).
  742. # Naturally, we implement the shutdown logic in `__del__` of
  743. # DataLoaderIterator.
  744. #
  745. # We delay the discussion on the logic in this case until later.
  746. #
  747. # 2. The iterator exits the workers when the loader process and/or worker
  748. # processes exits normally or with error.
  749. #
  750. # We set all workers and `pin_memory_thread` to have `daemon=True`.
  751. #
  752. # You may ask, why can't we make the workers non-daemonic, and
  753. # gracefully exit using the same logic as we have in `__del__` when the
  754. # iterator gets deleted (see 1 above)?
  755. #
  756. # First of all, `__del__` is **not** guaranteed to be called when
  757. # interpreter exits. Even if it is called, by the time it executes,
  758. # many Python core library resources may already be freed, and even
  759. # simple things like acquiring an internal lock of a queue may hang.
  760. # Therefore, in this case, we actually need to prevent `__del__` from
  761. # being executed, and rely on the automatic termination of daemonic
  762. # children.
  763. #
  764. # Thus, we register an `atexit` hook that sets a global flag
  765. # `_utils.python_exit_status`. Since `atexit` hooks are executed in the
  766. # reverse order of registration, we are guaranteed that this flag is
  767. # set before library resources we use are freed (which, at least in
  768. # CPython, is done via an `atexit` handler defined in
  769. # `multiprocessing/util.py`
  770. # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
  771. # registered when an object requiring this mechanism is first
  772. # created, e.g., `mp.Queue`
  773. # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
  774. # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
  775. # )
  776. #
  777. # So in `__del__`, we check if `_utils.python_exit_status` is set or
  778. # `None` (freed), and perform no-op if so.
  779. #
  780. # However, simply letting library clean-up codes run can also be bad,
  781. # because such codes (i.e., `multiprocessing.util._exit_function()`)
  782. # include join putting threads for `mp.Queue`, which can be blocking.
  783. # Hence, the main process putting threads are called with
  784. # `cancel_join_thread` at creation. See later section
  785. # [ 3b. A process won't hang when putting into a queue; ]
  786. # for more details.
  787. #
  788. # Here are two example cases where library clean-up codes can run
  789. # before `__del__` is called:
  790. #
  791. # 1. If we hold onto a reference to the iterator, it more often
  792. # than not tries to do `multiprocessing` library cleaning before
  793. # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
  794. # and thus prevents our cleaning-up code to run first.
  795. #
  796. # 2. A similar issue araises when a `DataLoader` is used in a subprocess.
  797. # When a process ends, it shuts the all its daemonic children
  798. # down with a SIGTERM (instead of joining them without a timeout).
  799. # Similarly for threads, but by a different mechanism. This fact,
  800. # together with a few implementation details of multiprocessing, forces
  801. # us to make workers daemonic. All of our problems arise when a
  802. # DataLoader is used in a subprocess, and are caused by multiprocessing
  803. # code which looks more or less like this:
  804. #
  805. # try:
  806. # your_function_using_a_dataloader()
  807. # finally:
  808. # multiprocessing.util._exit_function()
  809. #
  810. # The joining/termination mentioned above happens inside
  811. # `_exit_function()`. Now, if `your_function_using_a_dataloader()`
  812. # throws, the stack trace stored in the exception will prevent the
  813. # frame which uses `DataLoaderIter` to be freed. If the frame has any
  814. # reference to the `DataLoaderIter` (e.g., in a method of the iter),
  815. # its `__del__`, which starts the shutdown procedure, will not be
  816. # called. That, in turn, means that workers aren't notified. Attempting
  817. # to join in `_exit_function` will then result in a hang.
  818. #
  819. # For context, `_exit_function` is also registered as an `atexit` call.
  820. # So it is unclear to me (@ssnl) why this is needed in a finally block.
  821. # The code dates back to 2008 and there is no comment on the original
  822. # PEP 371 or patch https://bugs.python.org/issue3050 (containing both
  823. # the finally block and the `atexit` registration) that explains this.
  824. #
  825. #
  826. # Finally, another choice is to just shutdown workers with logic in 1
  827. # above whenever we see an error in `next`. This isn't ideal because
  828. # a. It prevents users from using try-catch to resume data loading.
  829. # b. It doesn't prevent hanging if users have references to the
  830. # iterator.
  831. #
  832. # 3. All processes exit if any of them die unexpectedly by fatal signals.
  833. #
  834. # As shown above, the workers are set as daemonic children of the main
  835. # process. However, automatic cleaning-up of such child processes only
  836. # happens if the parent process exits gracefully (e.g., not via fatal
  837. # signals like SIGKILL). So we must ensure that each process will exit
  838. # even the process that should send/receive data to/from it were
  839. # killed, i.e.,
  840. #
  841. # a. A process won't hang when getting from a queue.
  842. #
  843. # Even with carefully designed data dependencies (i.e., a `put()`
  844. # always corresponding to a `get()`), hanging on `get()` can still
  845. # happen when data in queue is corrupted (e.g., due to
  846. # `cancel_join_thread` or unexpected exit).
  847. #
  848. # For child exit, we set a timeout whenever we try to get data
  849. # from `data_queue`, and check the workers' status on each timeout
  850. # and error.
  851. # See `_DataLoaderiter._get_batch()` and
  852. # `_DataLoaderiter._try_get_data()` for details.
  853. #
  854. # Additionally, for child exit on non-Windows platforms, we also
  855. # register a SIGCHLD handler (which is supported on Windows) on
  856. # the main process, which checks if any of the workers fail in the
  857. # (Python) handler. This is more efficient and faster in detecting
  858. # worker failures, compared to only using the above mechanism.
  859. # See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
  860. #
  861. # For `.get()` calls where the sender(s) is not the workers, we
  862. # guard them with timeouts, and check the status of the sender
  863. # when timeout happens:
  864. # + in the workers, the `_utils.worker.ManagerWatchdog` class
  865. # checks the status of the main process.
  866. # + if `pin_memory=True`, when getting from `pin_memory_thread`,
  867. # check `pin_memory_thread` status periodically until `.get()`
  868. # returns or see that `pin_memory_thread` died.
  869. #
  870. # b. A process won't hang when putting into a queue;
  871. #
  872. # We use `mp.Queue` which has a separate background thread to put
  873. # objects from an unbounded buffer array. The background thread is
  874. # daemonic and usually automatically joined when the process
  875. # *exits*.
  876. #
  877. # In case that the receiver has ended abruptly while
  878. # reading from the pipe, the join will hang forever. The usual
  879. # solution for this in Python is calling `q.cancel_join_thread`,
  880. # which prevents automatically joining it when finalizing
  881. # (exiting).
  882. #
  883. # Nonetheless, `cancel_join_thread` must only be called when the
  884. # queue is **not** going to be read from or write into by another
  885. # process, because it may hold onto a lock or leave corrupted data
  886. # in the queue, leading other readers/writers to hang.
  887. #
  888. # Hence,
  889. # + For worker processes, we only do so (for their output
  890. # queues, i.e., `worker_result_queue`) before exiting.
  891. # + For `pin_memory_thread`, its output queue `data_queue` is a
  892. # `queue.Queue` that does blocking `put` if the queue is full.
  893. # So there is no above problem, but as a result, in
  894. # `_pin_memory_loop`, we do need to wrap the `put` in a loop
  895. # that breaks not only upon success, but also when the main
  896. # process stops reading, i.e., is shutting down.
  897. # + For loader process, we `cancel_join_thread()` for all
  898. # `_index_queues` because the whole purpose of workers and
  899. # `pin_memory_thread` is to serve the loader process. If
  900. # loader process is already exiting, we don't really care if
  901. # the queues are corrupted.
  902. #
  903. #
  904. # Now let's get back to 1:
  905. # how we gracefully exit the workers when the last reference to the
  906. # iterator is gone.
  907. #
  908. # To achieve this, we implement the following logic along with the design
  909. # choices mentioned above:
  910. #
  911. # `workers_done_event`:
  912. # A `multiprocessing.Event` shared among the main process and all worker
  913. # processes. This is used to signal the workers that the iterator is
  914. # shutting down. After it is set, they will not send processed data to
  915. # queues anymore, and only wait for the final `None` before exiting.
  916. # `done_event` isn't strictly needed. I.e., we can just check for `None`
  917. # from the input queue, but it allows us to skip wasting resources
  918. # processing data if we are already shutting down.
  919. #
  920. # `pin_memory_thread_done_event`:
  921. # A `threading.Event` for a similar purpose to that of
  922. # `workers_done_event`, but is for the `pin_memory_thread`. The reason
  923. # that separate events are needed is that `pin_memory_thread` reads from
  924. # the output queue of the workers. But the workers, upon seeing that
  925. # `workers_done_event` is set, only wants to see the final `None`, and is
  926. # not required to flush all data in the output queue (e.g., it may call
  927. # `cancel_join_thread` on that queue if its `IterableDataset` iterator
  928. # happens to exhaust coincidentally, which is out of the control of the
  929. # main process). Thus, since we will exit `pin_memory_thread` before the
  930. # workers (see below), two separate events are used.
  931. #
  932. # NOTE: In short, the protocol is that the main process will set these
  933. # `done_event`s and then the corresponding processes/threads a `None`,
  934. # and that they may exit at any time after receiving the `None`.
  935. #
  936. # NOTE: Using `None` as the final signal is valid, since normal data will
  937. # always be a 2-tuple with the 1st element being the index of the data
  938. # transferred (different from dataset index/key), and the 2nd being
  939. # either the dataset key or the data sample (depending on which part
  940. # of the data model the queue is at).
  941. #
  942. # [ worker processes ]
  943. # While loader process is alive:
  944. # Get from `index_queue`.
  945. # If get anything else,
  946. # Check `workers_done_event`.
  947. # If set, continue to next iteration
  948. # i.e., keep getting until see the `None`, then exit.
  949. # Otherwise, process data:
  950. # If is fetching from an `IterableDataset` and the iterator
  951. # is exhausted, send an `_IterableDatasetStopIteration`
  952. # object to signal iteration end. The main process, upon
  953. # receiving such an object, will send `None` to this
  954. # worker and not use the corresponding `index_queue`
  955. # anymore.
  956. # If timed out,
  957. # No matter `workers_done_event` is set (still need to see `None`)
  958. # or not, must continue to next iteration.
  959. # (outside loop)
  960. # If `workers_done_event` is set, (this can be False with `IterableDataset`)
  961. # `data_queue.cancel_join_thread()`. (Everything is ending here:
  962. # main process won't read from it;
  963. # other workers will also call
  964. # `cancel_join_thread`.)
  965. #
  966. # [ pin_memory_thread ]
  967. # # No need to check main thread. If this thread is alive, the main loader
  968. # # thread must be alive, because this thread is set as daemonic.
  969. # While `pin_memory_thread_done_event` is not set:
  970. # Get from `worker_result_queue`.
  971. # If timed out, continue to get in the next iteration.
  972. # Otherwise, process data.
  973. # While `pin_memory_thread_done_event` is not set:
  974. # Put processed data to `data_queue` (a `queue.Queue` with blocking put)
  975. # If timed out, continue to put in the next iteration.
  976. # Otherwise, break, i.e., continuing to the out loop.
  977. #
  978. # NOTE: we don't check the status of the main thread because
  979. # 1. if the process is killed by fatal signal, `pin_memory_thread`
  980. # ends.
  981. # 2. in other cases, either the cleaning-up in __del__ or the
  982. # automatic exit of daemonic thread will take care of it.
  983. # This won't busy-wait either because `.get(timeout)` does not
  984. # busy-wait.
  985. #
  986. # [ main process ]
  987. # In the DataLoader Iter's `__del__`
  988. # b. Exit `pin_memory_thread`
  989. # i. Set `pin_memory_thread_done_event`.
  990. # ii Put `None` in `worker_result_queue`.
  991. # iii. Join the `pin_memory_thread`.
  992. # iv. `worker_result_queue.cancel_join_thread()`.
  993. #
  994. # c. Exit the workers.
  995. # i. Set `workers_done_event`.
  996. # ii. Put `None` in each worker's `index_queue`.
  997. # iii. Join the workers.
  998. # iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
  999. #
  1000. # NOTE: (c) is better placed after (b) because it may leave corrupted
  1001. # data in `worker_result_queue`, which `pin_memory_thread`
  1002. # reads from, in which case the `pin_memory_thread` can only
  1003. # happen at timing out, which is slow. Nonetheless, same thing
  1004. # happens if a worker is killed by signal at unfortunate times,
  1005. # but in other cases, we are better off having a non-corrupted
  1006. # `worker_result_queue` for `pin_memory_thread`.
  1007. #
  1008. # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
  1009. # can be omitted
  1010. #
  1011. # NB: `done_event`s isn't strictly needed. E.g., we can just check for
  1012. # `None` from `index_queue`, but it allows us to skip wasting resources
  1013. # processing indices already in `index_queue` if we are already shutting
  1014. # down.
  1015. def __init__(self, loader) -> None:
  1016. super().__init__(loader)
  1017. self._prefetch_factor = loader.prefetch_factor
  1018. self._in_order = loader.in_order
  1019. if self._num_workers <= 0:
  1020. raise AssertionError(
  1021. "num_workers must be greater than 0 for MultiProcessingDataLoaderIter"
  1022. )
  1023. if self._prefetch_factor <= 0:
  1024. raise AssertionError(
  1025. "prefetch_factor must be greater than 0 for MultiProcessingDataLoaderIter"
  1026. )
  1027. if loader.multiprocessing_context is None:
  1028. multiprocessing_context = torch.multiprocessing
  1029. else:
  1030. multiprocessing_context = loader.multiprocessing_context
  1031. self._worker_init_fn = loader.worker_init_fn
  1032. # Adds forward compatibilities so classic DataLoader can work with DataPipes:
  1033. # Additional worker init function will take care of sharding in MP and Distributed
  1034. if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
  1035. self._worker_init_fn = functools.partial(
  1036. _sharding_worker_init_fn,
  1037. self._worker_init_fn,
  1038. self._world_size,
  1039. self._rank,
  1040. )
  1041. # No certainty which module multiprocessing_context is
  1042. self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
  1043. self._worker_pids_set = False
  1044. self._shutdown = False
  1045. self._workers_done_event = multiprocessing_context.Event()
  1046. self._index_queues = []
  1047. self._workers = []
  1048. for i in range(self._num_workers):
  1049. # No certainty which module multiprocessing_context is
  1050. index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
  1051. # Need to `cancel_join_thread` here!
  1052. # See sections (2) and (3b) above.
  1053. index_queue.cancel_join_thread()
  1054. w = multiprocessing_context.Process(
  1055. target=_utils.worker._worker_loop,
  1056. args=(
  1057. self._dataset_kind,
  1058. self._dataset,
  1059. index_queue,
  1060. self._worker_result_queue,
  1061. self._workers_done_event,
  1062. self._auto_collation,
  1063. self._collate_fn,
  1064. self._drop_last,
  1065. self._base_seed,
  1066. self._worker_init_fn,
  1067. i,
  1068. self._num_workers,
  1069. self._persistent_workers,
  1070. self._shared_seed,
  1071. ),
  1072. )
  1073. w.daemon = True
  1074. # NB: Process.start() actually take some time as it needs to
  1075. # start a process and pass the arguments over via a pipe.
  1076. # Therefore, we only add a worker to self._workers list after
  1077. # it started, so that we do not call .join() if program dies
  1078. # before it starts, and __del__ tries to join but will get:
  1079. # AssertionError: can only join a started process.
  1080. from pickle import PicklingError
  1081. try:
  1082. w.start()
  1083. except (TypeError, AttributeError, PicklingError):
  1084. warnings.warn(
  1085. "Got pickle error when attempting to start a worker Process. "
  1086. "This might be because the worker Process arguments are not picklable. "
  1087. "Python 3.14+ changed the multiprocessing start method in non-Mac POSIX platforms "
  1088. "to 'forkserver', which requires the worker Process arguments to be picklable. "
  1089. "You can also try multiprocessing.set_start_method('fork').",
  1090. stacklevel=2,
  1091. )
  1092. raise
  1093. self._index_queues.append(index_queue)
  1094. self._workers.append(w)
  1095. if self._pin_memory:
  1096. self._pin_memory_thread_done_event = threading.Event()
  1097. # Queue is not type-annotated
  1098. self._data_queue = queue.Queue() # type: ignore[var-annotated]
  1099. current_device_id = torch.accelerator.current_device_index()
  1100. pin_memory_thread = threading.Thread(
  1101. target=_utils.pin_memory._pin_memory_loop,
  1102. args=(
  1103. self._worker_result_queue,
  1104. self._data_queue,
  1105. current_device_id,
  1106. self._pin_memory_thread_done_event,
  1107. self._pin_memory_device,
  1108. ),
  1109. )
  1110. pin_memory_thread.daemon = True
  1111. pin_memory_thread.start()
  1112. # Similar to workers (see comment above), we only register
  1113. # pin_memory_thread once it is started.
  1114. self._pin_memory_thread = pin_memory_thread
  1115. else:
  1116. self._data_queue = self._worker_result_queue # type: ignore[assignment]
  1117. # In some rare cases, persistent workers (daemonic processes)
  1118. # would be terminated before `__del__` of iterator is invoked
  1119. # when main process exits
  1120. # It would cause failure when pin_memory_thread tries to read
  1121. # corrupted data from worker_result_queue
  1122. # atexit is used to shutdown thread and child processes in the
  1123. # right sequence before main process exits
  1124. if self._persistent_workers and self._pin_memory:
  1125. import atexit
  1126. for w in self._workers:
  1127. atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
  1128. # .pid can be None only before process is spawned (not the case, so ignore)
  1129. _utils.signal_handling._set_worker_pids(
  1130. id(self),
  1131. tuple(w.pid for w in self._workers), # type: ignore[misc]
  1132. )
  1133. _utils.signal_handling._set_SIGCHLD_handler()
  1134. self._worker_pids_set = True
  1135. self._reset(loader, first_iter=True)
  1136. def _reset(self, loader, first_iter=False) -> None:
  1137. super()._reset(loader, first_iter)
  1138. self._send_idx = 0 # idx of the next task to be sent to workers
  1139. self._rcvd_idx = 0 # idx of the next task to be returned in __next__
  1140. # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
  1141. # map: task idx => - (worker_id,) if data isn't fetched (outstanding)
  1142. # \ (worker_id, data) if data is already fetched (out-of-order)
  1143. self._task_info = {}
  1144. self._tasks_outstanding = (
  1145. 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
  1146. )
  1147. # A list of booleans representing whether each worker still has work to
  1148. # do, i.e., not having exhausted its iterable dataset object. It always
  1149. # contains all `True`s if not using an iterable-style dataset
  1150. # (i.e., if kind != Iterable).
  1151. # Not that this indicates that a worker still has work to do *for this epoch*.
  1152. # It does not mean that a worker is dead. In case of `_persistent_workers`,
  1153. # the worker will be reset to available in the next epoch.
  1154. self._workers_status = [True for i in range(self._num_workers)]
  1155. # A list of integers representing how many tasks are outstanding for each worker
  1156. # Incremented when a task is dispatched to the worker
  1157. # Decremented when that data has been given to the main thread
  1158. # Each worker should have at most self._prefetch_factor tasks outstanding
  1159. self._workers_num_tasks = [0 for i in range(self._num_workers)]
  1160. # Reset the worker queue cycle so it resumes next epoch at worker 0
  1161. self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
  1162. # We resume the prefetching in case it was enabled
  1163. if not first_iter:
  1164. for idx in range(self._num_workers):
  1165. self._index_queues[idx].put(
  1166. _utils.worker._ResumeIteration(self._shared_seed)
  1167. )
  1168. resume_iteration_cnt = self._num_workers
  1169. while resume_iteration_cnt > 0:
  1170. return_idx, return_data = self._get_data()
  1171. if isinstance(return_idx, _utils.worker._ResumeIteration):
  1172. if return_data is not None:
  1173. raise AssertionError(
  1174. "Expected return_data to be None when resuming iteration"
  1175. )
  1176. resume_iteration_cnt -= 1
  1177. # prime the prefetch loop
  1178. for _ in range(self._prefetch_factor * self._num_workers):
  1179. self._try_put_index()
  1180. def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
  1181. # Tries to fetch data from `self._data_queue` once for a given timeout.
  1182. # This can also be used as inner loop of fetching without timeout, with
  1183. # the sender status as the loop condition.
  1184. #
  1185. # This raises a `RuntimeError` if any worker died expectedly. This error
  1186. # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
  1187. # (only for non-Windows platforms), or the manual check below on errors
  1188. # and timeouts.
  1189. #
  1190. # Returns a 2-tuple:
  1191. # (bool: whether successfully get data, any: data if successful else None)
  1192. try:
  1193. data = self._data_queue.get(timeout=timeout)
  1194. return (True, data)
  1195. except Exception as e:
  1196. # At timeout and error, we manually check whether any worker has
  1197. # failed. Note that this is the only mechanism for Windows to detect
  1198. # worker failures.
  1199. failed_workers = []
  1200. for worker_id, w in enumerate(self._workers):
  1201. if self._workers_status[worker_id] and not w.is_alive():
  1202. failed_workers.append(w)
  1203. self._mark_worker_as_unavailable(worker_id)
  1204. if len(failed_workers) > 0:
  1205. pids_str = ", ".join(str(w.pid) for w in failed_workers)
  1206. raise RuntimeError(
  1207. f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly"
  1208. ) from e
  1209. if isinstance(e, queue.Empty):
  1210. return (False, None)
  1211. import errno
  1212. import tempfile
  1213. try:
  1214. # Raise an exception if we are this close to the FDs limit.
  1215. # Apparently, trying to open only one file is not a sufficient
  1216. # test.
  1217. # See NOTE [ DataLoader on Linux and open files limit ]
  1218. fds_limit_margin = 10
  1219. with contextlib.ExitStack() as stack:
  1220. for _ in range(fds_limit_margin):
  1221. stack.enter_context(
  1222. tempfile.NamedTemporaryFile() # pyrefly: ignore [bad-argument-type]
  1223. )
  1224. except OSError as e:
  1225. if e.errno == errno.EMFILE:
  1226. raise RuntimeError(
  1227. "Too many open files. Communication with the"
  1228. " workers is no longer possible. Please increase the"
  1229. " limit using `ulimit -n` in the shell or change the"
  1230. " sharing strategy by calling"
  1231. " `torch.multiprocessing.set_sharing_strategy('file_system')`"
  1232. " at the beginning of your code"
  1233. ) from None
  1234. raise
  1235. # NOTE [ DataLoader on Linux and open files limit ]
  1236. #
  1237. # On Linux when DataLoader is used with multiprocessing we pass the data between
  1238. # the root process and the workers through SHM files. We remove those files from
  1239. # the filesystem as soon as they are created and keep them alive by
  1240. # passing around their file descriptors through AF_UNIX sockets. (See
  1241. # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in
  1242. # the wiki (https://github.com/pytorch/pytorch/wiki).)
  1243. #
  1244. # This sometimes leads us to exceeding the open files limit. When that happens,
  1245. # and the offending file descriptor is coming over a socket, the `socket` Python
  1246. # package silently strips the file descriptor from the message, setting only the
  1247. # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that
  1248. # it _indicates that some control data were discarded due to lack of space in
  1249. # the buffer for ancillary data_). This might reflect the C implementation of
  1250. # AF_UNIX sockets.
  1251. #
  1252. # This behaviour can be reproduced with the script and instructions at the
  1253. # bottom of this note.
  1254. #
  1255. # When that happens, the standard Python `multiprocessing` (and not
  1256. # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata`
  1257. #
  1258. # Sometimes, instead of the FD being stripped, you may get an `OSError:
  1259. # Too many open files`, both in the script below and in DataLoader. However,
  1260. # this is rare and seems to be nondeterministic.
  1261. #
  1262. #
  1263. # #!/usr/bin/env python3
  1264. # import sys
  1265. # import socket
  1266. # import os
  1267. # import array
  1268. # import shutil
  1269. # import socket
  1270. #
  1271. #
  1272. # if len(sys.argv) != 4:
  1273. # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)")
  1274. # sys.exit(1)
  1275. #
  1276. # if __name__ == '__main__':
  1277. # dirname = sys.argv[1]
  1278. # sock_path = dirname + "/sock"
  1279. # iterations = int(sys.argv[2])
  1280. # def dummy_path(i):
  1281. # return dirname + "/" + str(i) + ".dummy"
  1282. #
  1283. #
  1284. # if sys.argv[3] == 'send':
  1285. # while not os.path.exists(sock_path):
  1286. # pass
  1287. # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
  1288. # client.connect(sock_path)
  1289. # for i in range(iterations):
  1290. # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
  1291. # ancdata = array.array('i', [fd])
  1292. # msg = bytes([i % 256])
  1293. # print("Sending fd ", fd, " (iteration #", i, ")")
  1294. # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
  1295. #
  1296. #
  1297. # else:
  1298. # assert sys.argv[3] == 'recv'
  1299. #
  1300. # if os.path.exists(dirname):
  1301. # raise Exception("Directory exists")
  1302. #
  1303. # os.mkdir(dirname)
  1304. #
  1305. # print("Opening socket...")
  1306. # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
  1307. # server.bind(sock_path)
  1308. #
  1309. # print("Listening...")
  1310. # for i in range(iterations):
  1311. # a = array.array('i')
  1312. # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
  1313. # assert(len(ancdata) == 1)
  1314. # cmsg_level, cmsg_type, cmsg_data = ancdata[0]
  1315. # a.frombytes(cmsg_data)
  1316. # print("Received fd ", a[0], " (iteration #", i, ")")
  1317. #
  1318. # shutil.rmtree(dirname)
  1319. #
  1320. # Steps to reproduce:
  1321. #
  1322. # 1. Run two shells and set lower file descriptor limit in the receiving one:
  1323. # (shell1) ulimit -n 1020
  1324. # (shell2) ulimit -n 1022
  1325. #
  1326. # 2. Run the script above with the `recv` option in the first shell
  1327. # (shell1) ./test_socket.py sock_tmp 1017 recv
  1328. #
  1329. # 3. Run the script with the `send` option in the second shell:
  1330. # (shell2) ./test_socket.py sock_tmp 1017 send
  1331. def _get_data(self):
  1332. # Fetches data from `self._data_queue`.
  1333. #
  1334. # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
  1335. # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
  1336. # in a loop. This is the only mechanism to detect worker failures for
  1337. # Windows. For other platforms, a SIGCHLD handler is also used for
  1338. # worker failure detection.
  1339. #
  1340. # If `pin_memory=True`, we also need check if `pin_memory_thread` had
  1341. # died at timeouts.
  1342. if self._timeout > 0:
  1343. success, data = self._try_get_data(self._timeout)
  1344. if success:
  1345. return data
  1346. else:
  1347. raise RuntimeError(
  1348. f"DataLoader timed out after {self._timeout} seconds"
  1349. )
  1350. elif self._pin_memory:
  1351. while self._pin_memory_thread.is_alive():
  1352. success, data = self._try_get_data()
  1353. if success:
  1354. return data
  1355. else:
  1356. # while condition is false, i.e., pin_memory_thread died.
  1357. raise RuntimeError("Pin memory thread exited unexpectedly")
  1358. # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
  1359. # need to call `.task_done()` because we don't use `.join()`.
  1360. else:
  1361. while True:
  1362. success, data = self._try_get_data()
  1363. if success:
  1364. return data
  1365. def _next_data(self):
  1366. while True:
  1367. # If the worker responsible for `self._rcvd_idx` has already ended
  1368. # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
  1369. # we try to advance `self._rcvd_idx` to find the next valid index.
  1370. #
  1371. # This part needs to run in the loop because both the `self._get_data()`
  1372. # call and `_IterableDatasetStopIteration` check below can mark
  1373. # extra worker(s) as dead.
  1374. while self._rcvd_idx < self._send_idx:
  1375. info = self._task_info.get(self._rcvd_idx, None)
  1376. if info:
  1377. worker_id = info[0]
  1378. if (
  1379. len(info) == 2 or self._workers_status[worker_id]
  1380. ): # has data or is still active
  1381. break
  1382. del self._task_info[self._rcvd_idx]
  1383. self._rcvd_idx += 1
  1384. else:
  1385. # no valid `self._rcvd_idx` is found (i.e., didn't break)
  1386. if not self._persistent_workers:
  1387. self._shutdown_workers()
  1388. raise StopIteration
  1389. # Now `self._rcvd_idx` is the batch index we want to fetch
  1390. # Check if the next sample has already been generated
  1391. if len(self._task_info[self._rcvd_idx]) == 2:
  1392. worker_id, data = self._task_info.pop(self._rcvd_idx)
  1393. self._rcvd_idx += 1
  1394. return self._process_data(data, worker_id)
  1395. if self._shutdown or self._tasks_outstanding <= 0:
  1396. raise AssertionError(
  1397. "Invalid iterator state: shutdown or no outstanding tasks when fetching next data"
  1398. )
  1399. idx, data = self._get_data()
  1400. self._tasks_outstanding -= 1
  1401. if self._dataset_kind == _DatasetKind.Iterable:
  1402. # Check for _IterableDatasetStopIteration
  1403. if isinstance(data, _utils.worker._IterableDatasetStopIteration):
  1404. if self._persistent_workers:
  1405. self._workers_status[data.worker_id] = False
  1406. else:
  1407. self._mark_worker_as_unavailable(data.worker_id)
  1408. self._try_put_index()
  1409. continue
  1410. if idx != self._rcvd_idx:
  1411. if not self._in_order:
  1412. # don't store it for later, process now
  1413. # delete from self._task_info immediately
  1414. # this keeps the object size manageable
  1415. worker_id = self._task_info.pop(idx)[0]
  1416. return self._process_data(data, worker_id)
  1417. # store out-of-order samples
  1418. self._task_info[idx] += (data,)
  1419. else:
  1420. worker_id = self._task_info.pop(idx)[0]
  1421. self._rcvd_idx += 1
  1422. return self._process_data(data, worker_id)
  1423. def _try_put_index(self) -> None:
  1424. max_tasks = self._prefetch_factor * self._num_workers
  1425. if self._tasks_outstanding >= max_tasks:
  1426. raise AssertionError(
  1427. "Number of outstanding tasks exceeded maximum allowed tasks"
  1428. )
  1429. try:
  1430. index = self._next_index()
  1431. except StopIteration:
  1432. return
  1433. for _ in range(self._num_workers): # find the next active worker, if any
  1434. worker_queue_idx = next(self._worker_queue_idx_cycle)
  1435. if self._workers_status[worker_queue_idx]:
  1436. if self._in_order:
  1437. break
  1438. elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum(
  1439. self._workers_status
  1440. ):
  1441. # when self._in_order is False, distribute work to a worker if it has capacity
  1442. # _workers_status is updated only in this thread, so the sum is guaranteed > 0
  1443. break
  1444. else:
  1445. # not found (i.e., didn't break)
  1446. return
  1447. self._index_queues[worker_queue_idx].put((self._send_idx, index)) # type: ignore[possibly-undefined]
  1448. self._task_info[self._send_idx] = (worker_queue_idx,)
  1449. self._workers_num_tasks[worker_queue_idx] += 1
  1450. self._tasks_outstanding += 1
  1451. self._send_idx += 1
  1452. def _process_data(self, data, worker_idx):
  1453. self._workers_num_tasks[worker_idx] -= 1
  1454. self._try_put_index()
  1455. if isinstance(data, ExceptionWrapper):
  1456. data.reraise()
  1457. return data
  1458. def _mark_worker_as_unavailable(self, worker_id, shutdown=False) -> None:
  1459. # Mark a worker as having finished its work e.g., due to
  1460. # exhausting an `IterableDataset`. This should be used only when this
  1461. # `_MultiProcessingDataLoaderIter` is going to continue running.
  1462. if (
  1463. not self._workers_status[worker_id]
  1464. and not self._persistent_workers
  1465. and not shutdown
  1466. ):
  1467. raise AssertionError(
  1468. "Worker status inconsistent when marking worker as unavailable"
  1469. )
  1470. # Signal termination to that specific worker.
  1471. q = self._index_queues[worker_id]
  1472. # Indicate that no more data will be put on this queue by the current
  1473. # process.
  1474. q.put(None)
  1475. # Note that we don't actually join the worker here, nor do we remove the
  1476. # worker's pid from C side struct because (1) joining may be slow, and
  1477. # (2) since we don't join, the worker may still raise error, and we
  1478. # prefer capturing those, rather than ignoring them, even though they
  1479. # are raised after the worker has finished its job.
  1480. # Joining is deferred to `_shutdown_workers`, which it is called when
  1481. # all workers finish their jobs (e.g., `IterableDataset` replicas) or
  1482. # when this iterator is garbage collected.
  1483. self._workers_status[worker_id] = False
  1484. if self._workers_done_event.is_set() != shutdown:
  1485. raise AssertionError(
  1486. "_workers_done_event state does not match shutdown flag"
  1487. )
  1488. def _shutdown_workers(self) -> None:
  1489. # Called when shutting down this `_MultiProcessingDataLoaderIter`.
  1490. # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
  1491. # the logic of this function.
  1492. if (
  1493. _utils is None
  1494. # pyrefly: ignore [unnecessary-comparison]
  1495. or _utils.python_exit_status is True
  1496. # pyrefly: ignore [unnecessary-comparison]
  1497. or _utils.python_exit_status is None
  1498. ):
  1499. # See (2) of the note. If Python is shutting down, do no-op.
  1500. return
  1501. # Normal exit when last reference is gone / iterator is depleted.
  1502. # See (1) and the second half of the note.
  1503. if not self._shutdown:
  1504. self._shutdown = True
  1505. try:
  1506. # Normal exit when last reference is gone / iterator is depleted.
  1507. # See (1) and the second half of the note.
  1508. # Exit `pin_memory_thread` first because exiting workers may leave
  1509. # corrupted data in `worker_result_queue` which `pin_memory_thread`
  1510. # reads from.
  1511. if hasattr(self, "_pin_memory_thread"):
  1512. # Use hasattr in case error happens before we set the attribute.
  1513. self._pin_memory_thread_done_event.set()
  1514. # Send something to pin_memory_thread in case it is waiting
  1515. # so that it can wake up and check `pin_memory_thread_done_event`
  1516. self._worker_result_queue.put((None, None))
  1517. self._pin_memory_thread.join()
  1518. self._worker_result_queue.cancel_join_thread()
  1519. self._worker_result_queue.close()
  1520. # Exit workers now.
  1521. self._workers_done_event.set()
  1522. for worker_id in range(len(self._workers)):
  1523. # Get number of workers from `len(self._workers)` instead of
  1524. # `self._num_workers` in case we error before starting all
  1525. # workers.
  1526. # If we are using workers_status with persistent_workers
  1527. # we have to shut it down because the worker is paused
  1528. if self._persistent_workers or self._workers_status[worker_id]:
  1529. self._mark_worker_as_unavailable(worker_id, shutdown=True)
  1530. for w in self._workers:
  1531. # We should be able to join here, but in case anything went
  1532. # wrong, we set a timeout and if the workers fail to join,
  1533. # they are killed in the `finally` block.
  1534. w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  1535. for q in self._index_queues:
  1536. q.cancel_join_thread()
  1537. q.close()
  1538. finally:
  1539. # Even though all this function does is putting into queues that
  1540. # we have called `cancel_join_thread` on, weird things can
  1541. # happen when a worker is killed by a signal, e.g., hanging in
  1542. # `Event.set()`. So we need to guard this with SIGCHLD handler,
  1543. # and remove pids from the C side data structure only at the
  1544. # end.
  1545. #
  1546. # FIXME: Unfortunately, for Windows, we are missing a worker
  1547. # error detection mechanism here in this function, as it
  1548. # doesn't provide a SIGCHLD handler.
  1549. if self._worker_pids_set:
  1550. _utils.signal_handling._remove_worker_pids(id(self))
  1551. self._worker_pids_set = False
  1552. for w in self._workers:
  1553. if w.is_alive():
  1554. # Existing mechanisms try to make the workers exit
  1555. # peacefully, but in case that we unfortunately reach
  1556. # here, which we shouldn't, (e.g., pytorch/pytorch#39570),
  1557. # we kill the worker.
  1558. w.terminate()
  1559. # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
  1560. @staticmethod
  1561. def _clean_up_worker(w) -> None:
  1562. try:
  1563. w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  1564. finally:
  1565. if w.is_alive():
  1566. w.terminate()
  1567. def __del__(self) -> None:
  1568. self._shutdown_workers()