loader.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. """ Loader Factory, Fast Collate, CUDA Prefetcher
  2. Prefetcher and Fast Collate inspired by NVIDIA APEX example at
  3. https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
  4. Hacked together by / Copyright 2019, Ross Wightman
  5. """
  6. import logging
  7. import random
  8. from contextlib import suppress
  9. from functools import partial
  10. from itertools import repeat
  11. from typing import Callable, Optional, Tuple, Union
  12. import torch
  13. import torch.utils.data
  14. import numpy as np
  15. from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  16. from .dataset import IterableImageDataset, ImageDataset
  17. from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
  18. from .random_erasing import RandomErasing
  19. from .mixup import FastCollateMixup
  20. from .transforms_factory import create_transform
  21. _logger = logging.getLogger(__name__)
  22. def fast_collate(batch):
  23. """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
  24. assert isinstance(batch[0], tuple)
  25. batch_size = len(batch)
  26. if isinstance(batch[0][0], tuple):
  27. # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
  28. # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
  29. is_np = isinstance(batch[0][0], np.ndarray)
  30. inner_tuple_size = len(batch[0][0])
  31. flattened_batch_size = batch_size * inner_tuple_size
  32. targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
  33. tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
  34. for i in range(batch_size):
  35. assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
  36. for j in range(inner_tuple_size):
  37. targets[i + j * batch_size] = batch[i][1]
  38. if is_np:
  39. tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
  40. else:
  41. tensor[i + j * batch_size] += batch[i][0][j]
  42. return tensor, targets
  43. elif isinstance(batch[0][0], np.ndarray):
  44. targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
  45. assert len(targets) == batch_size
  46. tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
  47. for i in range(batch_size):
  48. tensor[i] += torch.from_numpy(batch[i][0])
  49. return tensor, targets
  50. elif isinstance(batch[0][0], torch.Tensor):
  51. targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
  52. assert len(targets) == batch_size
  53. tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
  54. for i in range(batch_size):
  55. tensor[i].copy_(batch[i][0])
  56. return tensor, targets
  57. else:
  58. assert False
  59. def adapt_to_chs(x, n):
  60. if not isinstance(x, (tuple, list)):
  61. x = tuple(repeat(x, n))
  62. elif len(x) != n:
  63. x_mean = np.mean(x).item()
  64. x = (x_mean,) * n
  65. _logger.warning(f'Pretrained mean/std different shape than model, using avg value {x}.')
  66. else:
  67. assert len(x) == n, 'normalization stats must match image channels'
  68. return x
  69. class PrefetchLoader:
  70. def __init__(
  71. self,
  72. loader: torch.utils.data.DataLoader,
  73. mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
  74. std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
  75. channels: int = 3,
  76. device: torch.device = torch.device('cuda'),
  77. img_dtype: Optional[torch.dtype] = None,
  78. fp16: bool = False,
  79. re_prob: float = 0.,
  80. re_mode: str = 'const',
  81. re_count: int = 1,
  82. re_num_splits: int = 0,
  83. ):
  84. mean = adapt_to_chs(mean, channels)
  85. std = adapt_to_chs(std, channels)
  86. normalization_shape = (1, channels, 1, 1)
  87. self.loader = loader
  88. self.device = device
  89. if fp16:
  90. # fp16 arg is deprecated, but will override dtype arg if set for bwd compat
  91. img_dtype = torch.float16
  92. self.img_dtype = img_dtype or torch.float32
  93. self.mean = torch.tensor(
  94. [x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
  95. self.std = torch.tensor(
  96. [x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
  97. if re_prob > 0.:
  98. self.random_erasing = RandomErasing(
  99. probability=re_prob,
  100. mode=re_mode,
  101. max_count=re_count,
  102. num_splits=re_num_splits,
  103. device=device,
  104. )
  105. else:
  106. self.random_erasing = None
  107. self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
  108. self.is_npu = device.type == 'npu' and torch.npu.is_available()
  109. def __iter__(self):
  110. first = True
  111. if self.is_cuda:
  112. stream = torch.cuda.Stream(device=self.device)
  113. stream_context = partial(torch.cuda.stream, stream=stream)
  114. elif self.is_npu:
  115. stream = torch.npu.Stream(device=self.device)
  116. stream_context = partial(torch.npu.stream, stream=stream)
  117. else:
  118. stream = None
  119. stream_context = suppress
  120. for next_input, next_target in self.loader:
  121. with stream_context():
  122. next_input = next_input.to(device=self.device, non_blocking=True)
  123. next_target = next_target.to(device=self.device, non_blocking=True)
  124. next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
  125. if self.random_erasing is not None:
  126. next_input = self.random_erasing(next_input)
  127. if not first:
  128. yield input, target
  129. else:
  130. first = False
  131. if stream is not None:
  132. if self.is_cuda:
  133. torch.cuda.current_stream(device=self.device).wait_stream(stream)
  134. elif self.is_npu:
  135. torch.npu.current_stream(device=self.device).wait_stream(stream)
  136. input = next_input
  137. target = next_target
  138. yield input, target
  139. def __len__(self):
  140. return len(self.loader)
  141. @property
  142. def sampler(self):
  143. return self.loader.sampler
  144. @property
  145. def dataset(self):
  146. return self.loader.dataset
  147. @property
  148. def mixup_enabled(self):
  149. if isinstance(self.loader.collate_fn, FastCollateMixup):
  150. return self.loader.collate_fn.mixup_enabled
  151. else:
  152. return False
  153. @mixup_enabled.setter
  154. def mixup_enabled(self, x):
  155. if isinstance(self.loader.collate_fn, FastCollateMixup):
  156. self.loader.collate_fn.mixup_enabled = x
  157. def _worker_init(worker_id, worker_seeding='all'):
  158. worker_info = torch.utils.data.get_worker_info()
  159. assert worker_info.id == worker_id
  160. if isinstance(worker_seeding, Callable):
  161. seed = worker_seeding(worker_info)
  162. random.seed(seed)
  163. torch.manual_seed(seed)
  164. np.random.seed(seed % (2 ** 32 - 1))
  165. else:
  166. assert worker_seeding in ('all', 'part')
  167. # random / torch seed already called in dataloader iter class w/ worker_info.seed
  168. # to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
  169. if worker_seeding == 'all':
  170. np.random.seed(worker_info.seed % (2 ** 32 - 1))
  171. def create_loader(
  172. dataset: Union[ImageDataset, IterableImageDataset],
  173. input_size: Union[int, Tuple[int, int], Tuple[int, int, int]],
  174. batch_size: int,
  175. is_training: bool = False,
  176. no_aug: bool = False,
  177. re_prob: float = 0.,
  178. re_mode: str = 'const',
  179. re_count: int = 1,
  180. re_split: bool = False,
  181. train_crop_mode: Optional[str] = None,
  182. scale: Optional[Tuple[float, float]] = None,
  183. ratio: Optional[Tuple[float, float]] = None,
  184. hflip: float = 0.5,
  185. vflip: float = 0.,
  186. color_jitter: float = 0.4,
  187. color_jitter_prob: Optional[float] = None,
  188. grayscale_prob: float = 0.,
  189. gaussian_blur_prob: float = 0.,
  190. auto_augment: Optional[str] = None,
  191. num_aug_repeats: int = 0,
  192. num_aug_splits: int = 0,
  193. interpolation: str = 'bilinear',
  194. mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
  195. std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
  196. num_workers: int = 1,
  197. distributed: bool = False,
  198. crop_pct: Optional[float] = None,
  199. crop_mode: Optional[str] = None,
  200. crop_border_pixels: Optional[int] = None,
  201. collate_fn: Optional[Callable] = None,
  202. pin_memory: bool = False,
  203. fp16: bool = False, # deprecated, use img_dtype
  204. img_dtype: torch.dtype = torch.float32,
  205. device: torch.device = torch.device('cuda'),
  206. use_prefetcher: bool = True,
  207. use_multi_epochs_loader: bool = False,
  208. persistent_workers: bool = True,
  209. worker_seeding: str = 'all',
  210. tf_preprocessing: bool = False,
  211. ):
  212. """
  213. Args:
  214. dataset: The image dataset to load.
  215. input_size: Target input size (channels, height, width) tuple or size scalar.
  216. batch_size: Number of samples in a batch.
  217. is_training: Return training (random) transforms.
  218. no_aug: Disable augmentation for training (useful for debug).
  219. re_prob: Random erasing probability.
  220. re_mode: Random erasing fill mode.
  221. re_count: Number of random erasing regions.
  222. re_split: Control split of random erasing across batch size.
  223. scale: Random resize scale range (crop area, < 1.0 => zoom in).
  224. ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
  225. hflip: Horizontal flip probability.
  226. vflip: Vertical flip probability.
  227. color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
  228. Scalar is applied as (scalar,) * 3 (no hue).
  229. color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug
  230. grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
  231. gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
  232. auto_augment: Auto augment configuration string (see auto_augment.py).
  233. num_aug_repeats: Enable special sampler to repeat same augmentation across distributed GPUs.
  234. num_aug_splits: Enable mode where augmentations can be split across the batch.
  235. interpolation: Image interpolation mode.
  236. mean: Image normalization mean.
  237. std: Image normalization standard deviation.
  238. num_workers: Num worker processes per DataLoader.
  239. distributed: Enable dataloading for distributed training.
  240. crop_pct: Inference crop percentage (output size / resize size).
  241. crop_mode: Inference crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
  242. crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
  243. collate_fn: Override default collate_fn.
  244. pin_memory: Pin memory for device transfer.
  245. fp16: Deprecated argument for half-precision input dtype. Use img_dtype.
  246. img_dtype: Data type for input image.
  247. device: Device to transfer inputs and targets to.
  248. use_prefetcher: Use efficient pre-fetcher to load samples onto device.
  249. use_multi_epochs_loader:
  250. persistent_workers: Enable persistent worker processes.
  251. worker_seeding: Control worker random seeding at init.
  252. tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports.
  253. Returns:
  254. DataLoader
  255. """
  256. re_num_splits = 0
  257. if re_split:
  258. # apply RE to second half of batch if no aug split otherwise line up with aug split
  259. re_num_splits = num_aug_splits or 2
  260. dataset.transform = create_transform(
  261. input_size,
  262. is_training=is_training,
  263. no_aug=no_aug,
  264. train_crop_mode=train_crop_mode,
  265. scale=scale,
  266. ratio=ratio,
  267. hflip=hflip,
  268. vflip=vflip,
  269. color_jitter=color_jitter,
  270. color_jitter_prob=color_jitter_prob,
  271. grayscale_prob=grayscale_prob,
  272. gaussian_blur_prob=gaussian_blur_prob,
  273. auto_augment=auto_augment,
  274. interpolation=interpolation,
  275. mean=mean,
  276. std=std,
  277. crop_pct=crop_pct,
  278. crop_mode=crop_mode,
  279. crop_border_pixels=crop_border_pixels,
  280. re_prob=re_prob,
  281. re_mode=re_mode,
  282. re_count=re_count,
  283. re_num_splits=re_num_splits,
  284. tf_preprocessing=tf_preprocessing,
  285. use_prefetcher=use_prefetcher,
  286. separate=num_aug_splits > 0,
  287. )
  288. if isinstance(dataset, IterableImageDataset):
  289. # give Iterable datasets early knowledge of num_workers so that sample estimates
  290. # are correct before worker processes are launched
  291. dataset.set_loader_cfg(num_workers=num_workers)
  292. sampler = None
  293. if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
  294. if is_training:
  295. if num_aug_repeats:
  296. sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)
  297. else:
  298. sampler = torch.utils.data.distributed.DistributedSampler(dataset)
  299. else:
  300. # This will add extra duplicate entries to result in equal num
  301. # of samples per-process, will slightly alter validation results
  302. sampler = OrderedDistributedSampler(dataset)
  303. else:
  304. assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"
  305. if collate_fn is None:
  306. collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
  307. loader_class = torch.utils.data.DataLoader
  308. if use_multi_epochs_loader:
  309. loader_class = MultiEpochsDataLoader
  310. loader_args = dict(
  311. batch_size=batch_size,
  312. shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
  313. num_workers=num_workers,
  314. sampler=sampler,
  315. collate_fn=collate_fn,
  316. pin_memory=pin_memory,
  317. drop_last=is_training,
  318. worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
  319. persistent_workers=persistent_workers
  320. )
  321. try:
  322. loader = loader_class(dataset, **loader_args)
  323. except TypeError as e:
  324. loader_args.pop('persistent_workers') # only in Pytorch 1.7+
  325. loader = loader_class(dataset, **loader_args)
  326. if use_prefetcher:
  327. prefetch_re_prob = re_prob if is_training and not no_aug else 0.
  328. loader = PrefetchLoader(
  329. loader,
  330. mean=mean,
  331. std=std,
  332. channels=input_size[0],
  333. device=device,
  334. fp16=fp16, # deprecated, use img_dtype
  335. img_dtype=img_dtype,
  336. re_prob=prefetch_re_prob,
  337. re_mode=re_mode,
  338. re_count=re_count,
  339. re_num_splits=re_num_splits
  340. )
  341. return loader
  342. class MultiEpochsDataLoader(torch.utils.data.DataLoader):
  343. def __init__(self, *args, **kwargs):
  344. super().__init__(*args, **kwargs)
  345. self._DataLoader__initialized = False
  346. if self.batch_sampler is None:
  347. self.sampler = _RepeatSampler(self.sampler)
  348. else:
  349. self.batch_sampler = _RepeatSampler(self.batch_sampler)
  350. self._DataLoader__initialized = True
  351. self.iterator = super().__iter__()
  352. def __len__(self):
  353. return len(self.sampler) if self.batch_sampler is None else len(self.batch_sampler.sampler)
  354. def __iter__(self):
  355. for i in range(len(self)):
  356. yield next(self.iterator)
  357. class _RepeatSampler(object):
  358. """ Sampler that repeats forever.
  359. Args:
  360. sampler (Sampler)
  361. """
  362. def __init__(self, sampler):
  363. self.sampler = sampler
  364. def __iter__(self):
  365. while True:
  366. yield from iter(self.sampler)