| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413 |
- """ Loader Factory, Fast Collate, CUDA Prefetcher
- Prefetcher and Fast Collate inspired by NVIDIA APEX example at
- https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
- Hacked together by / Copyright 2019, Ross Wightman
- """
- import logging
- import random
- from contextlib import suppress
- from functools import partial
- from itertools import repeat
- from typing import Callable, Optional, Tuple, Union
- import torch
- import torch.utils.data
- import numpy as np
- from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from .dataset import IterableImageDataset, ImageDataset
- from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
- from .random_erasing import RandomErasing
- from .mixup import FastCollateMixup
- from .transforms_factory import create_transform
- _logger = logging.getLogger(__name__)
- def fast_collate(batch):
- """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
- assert isinstance(batch[0], tuple)
- batch_size = len(batch)
- if isinstance(batch[0][0], tuple):
- # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
- # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
- is_np = isinstance(batch[0][0], np.ndarray)
- inner_tuple_size = len(batch[0][0])
- flattened_batch_size = batch_size * inner_tuple_size
- targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
- tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
- for i in range(batch_size):
- assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
- for j in range(inner_tuple_size):
- targets[i + j * batch_size] = batch[i][1]
- if is_np:
- tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
- else:
- tensor[i + j * batch_size] += batch[i][0][j]
- return tensor, targets
- elif isinstance(batch[0][0], np.ndarray):
- targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
- assert len(targets) == batch_size
- tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
- for i in range(batch_size):
- tensor[i] += torch.from_numpy(batch[i][0])
- return tensor, targets
- elif isinstance(batch[0][0], torch.Tensor):
- targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
- assert len(targets) == batch_size
- tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
- for i in range(batch_size):
- tensor[i].copy_(batch[i][0])
- return tensor, targets
- else:
- assert False
- def adapt_to_chs(x, n):
- if not isinstance(x, (tuple, list)):
- x = tuple(repeat(x, n))
- elif len(x) != n:
- x_mean = np.mean(x).item()
- x = (x_mean,) * n
- _logger.warning(f'Pretrained mean/std different shape than model, using avg value {x}.')
- else:
- assert len(x) == n, 'normalization stats must match image channels'
- return x
- class PrefetchLoader:
- def __init__(
- self,
- loader: torch.utils.data.DataLoader,
- mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
- std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
- channels: int = 3,
- device: torch.device = torch.device('cuda'),
- img_dtype: Optional[torch.dtype] = None,
- fp16: bool = False,
- re_prob: float = 0.,
- re_mode: str = 'const',
- re_count: int = 1,
- re_num_splits: int = 0,
- ):
- mean = adapt_to_chs(mean, channels)
- std = adapt_to_chs(std, channels)
- normalization_shape = (1, channels, 1, 1)
- self.loader = loader
- self.device = device
- if fp16:
- # fp16 arg is deprecated, but will override dtype arg if set for bwd compat
- img_dtype = torch.float16
- self.img_dtype = img_dtype or torch.float32
- self.mean = torch.tensor(
- [x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
- self.std = torch.tensor(
- [x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
- if re_prob > 0.:
- self.random_erasing = RandomErasing(
- probability=re_prob,
- mode=re_mode,
- max_count=re_count,
- num_splits=re_num_splits,
- device=device,
- )
- else:
- self.random_erasing = None
- self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
- self.is_npu = device.type == 'npu' and torch.npu.is_available()
- def __iter__(self):
- first = True
- if self.is_cuda:
- stream = torch.cuda.Stream(device=self.device)
- stream_context = partial(torch.cuda.stream, stream=stream)
- elif self.is_npu:
- stream = torch.npu.Stream(device=self.device)
- stream_context = partial(torch.npu.stream, stream=stream)
- else:
- stream = None
- stream_context = suppress
- for next_input, next_target in self.loader:
- with stream_context():
- next_input = next_input.to(device=self.device, non_blocking=True)
- next_target = next_target.to(device=self.device, non_blocking=True)
- next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
- if self.random_erasing is not None:
- next_input = self.random_erasing(next_input)
- if not first:
- yield input, target
- else:
- first = False
- if stream is not None:
- if self.is_cuda:
- torch.cuda.current_stream(device=self.device).wait_stream(stream)
- elif self.is_npu:
- torch.npu.current_stream(device=self.device).wait_stream(stream)
- input = next_input
- target = next_target
- yield input, target
- def __len__(self):
- return len(self.loader)
- @property
- def sampler(self):
- return self.loader.sampler
- @property
- def dataset(self):
- return self.loader.dataset
- @property
- def mixup_enabled(self):
- if isinstance(self.loader.collate_fn, FastCollateMixup):
- return self.loader.collate_fn.mixup_enabled
- else:
- return False
- @mixup_enabled.setter
- def mixup_enabled(self, x):
- if isinstance(self.loader.collate_fn, FastCollateMixup):
- self.loader.collate_fn.mixup_enabled = x
- def _worker_init(worker_id, worker_seeding='all'):
- worker_info = torch.utils.data.get_worker_info()
- assert worker_info.id == worker_id
- if isinstance(worker_seeding, Callable):
- seed = worker_seeding(worker_info)
- random.seed(seed)
- torch.manual_seed(seed)
- np.random.seed(seed % (2 ** 32 - 1))
- else:
- assert worker_seeding in ('all', 'part')
- # random / torch seed already called in dataloader iter class w/ worker_info.seed
- # to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
- if worker_seeding == 'all':
- np.random.seed(worker_info.seed % (2 ** 32 - 1))
- def create_loader(
- dataset: Union[ImageDataset, IterableImageDataset],
- input_size: Union[int, Tuple[int, int], Tuple[int, int, int]],
- batch_size: int,
- is_training: bool = False,
- no_aug: bool = False,
- re_prob: float = 0.,
- re_mode: str = 'const',
- re_count: int = 1,
- re_split: bool = False,
- train_crop_mode: Optional[str] = None,
- scale: Optional[Tuple[float, float]] = None,
- ratio: Optional[Tuple[float, float]] = None,
- hflip: float = 0.5,
- vflip: float = 0.,
- color_jitter: float = 0.4,
- color_jitter_prob: Optional[float] = None,
- grayscale_prob: float = 0.,
- gaussian_blur_prob: float = 0.,
- auto_augment: Optional[str] = None,
- num_aug_repeats: int = 0,
- num_aug_splits: int = 0,
- interpolation: str = 'bilinear',
- mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
- std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
- num_workers: int = 1,
- distributed: bool = False,
- crop_pct: Optional[float] = None,
- crop_mode: Optional[str] = None,
- crop_border_pixels: Optional[int] = None,
- collate_fn: Optional[Callable] = None,
- pin_memory: bool = False,
- fp16: bool = False, # deprecated, use img_dtype
- img_dtype: torch.dtype = torch.float32,
- device: torch.device = torch.device('cuda'),
- use_prefetcher: bool = True,
- use_multi_epochs_loader: bool = False,
- persistent_workers: bool = True,
- worker_seeding: str = 'all',
- tf_preprocessing: bool = False,
- ):
- """
- Args:
- dataset: The image dataset to load.
- input_size: Target input size (channels, height, width) tuple or size scalar.
- batch_size: Number of samples in a batch.
- is_training: Return training (random) transforms.
- no_aug: Disable augmentation for training (useful for debug).
- re_prob: Random erasing probability.
- re_mode: Random erasing fill mode.
- re_count: Number of random erasing regions.
- re_split: Control split of random erasing across batch size.
- scale: Random resize scale range (crop area, < 1.0 => zoom in).
- ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
- hflip: Horizontal flip probability.
- vflip: Vertical flip probability.
- color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
- Scalar is applied as (scalar,) * 3 (no hue).
- color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug
- grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
- gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
- auto_augment: Auto augment configuration string (see auto_augment.py).
- num_aug_repeats: Enable special sampler to repeat same augmentation across distributed GPUs.
- num_aug_splits: Enable mode where augmentations can be split across the batch.
- interpolation: Image interpolation mode.
- mean: Image normalization mean.
- std: Image normalization standard deviation.
- num_workers: Num worker processes per DataLoader.
- distributed: Enable dataloading for distributed training.
- crop_pct: Inference crop percentage (output size / resize size).
- crop_mode: Inference crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
- crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
- collate_fn: Override default collate_fn.
- pin_memory: Pin memory for device transfer.
- fp16: Deprecated argument for half-precision input dtype. Use img_dtype.
- img_dtype: Data type for input image.
- device: Device to transfer inputs and targets to.
- use_prefetcher: Use efficient pre-fetcher to load samples onto device.
- use_multi_epochs_loader:
- persistent_workers: Enable persistent worker processes.
- worker_seeding: Control worker random seeding at init.
- tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports.
- Returns:
- DataLoader
- """
- re_num_splits = 0
- if re_split:
- # apply RE to second half of batch if no aug split otherwise line up with aug split
- re_num_splits = num_aug_splits or 2
- dataset.transform = create_transform(
- input_size,
- is_training=is_training,
- no_aug=no_aug,
- train_crop_mode=train_crop_mode,
- scale=scale,
- ratio=ratio,
- hflip=hflip,
- vflip=vflip,
- color_jitter=color_jitter,
- color_jitter_prob=color_jitter_prob,
- grayscale_prob=grayscale_prob,
- gaussian_blur_prob=gaussian_blur_prob,
- auto_augment=auto_augment,
- interpolation=interpolation,
- mean=mean,
- std=std,
- crop_pct=crop_pct,
- crop_mode=crop_mode,
- crop_border_pixels=crop_border_pixels,
- re_prob=re_prob,
- re_mode=re_mode,
- re_count=re_count,
- re_num_splits=re_num_splits,
- tf_preprocessing=tf_preprocessing,
- use_prefetcher=use_prefetcher,
- separate=num_aug_splits > 0,
- )
- if isinstance(dataset, IterableImageDataset):
- # give Iterable datasets early knowledge of num_workers so that sample estimates
- # are correct before worker processes are launched
- dataset.set_loader_cfg(num_workers=num_workers)
- sampler = None
- if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
- if is_training:
- if num_aug_repeats:
- sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)
- else:
- sampler = torch.utils.data.distributed.DistributedSampler(dataset)
- else:
- # This will add extra duplicate entries to result in equal num
- # of samples per-process, will slightly alter validation results
- sampler = OrderedDistributedSampler(dataset)
- else:
- assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"
- if collate_fn is None:
- collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
- loader_class = torch.utils.data.DataLoader
- if use_multi_epochs_loader:
- loader_class = MultiEpochsDataLoader
- loader_args = dict(
- batch_size=batch_size,
- shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
- num_workers=num_workers,
- sampler=sampler,
- collate_fn=collate_fn,
- pin_memory=pin_memory,
- drop_last=is_training,
- worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
- persistent_workers=persistent_workers
- )
- try:
- loader = loader_class(dataset, **loader_args)
- except TypeError as e:
- loader_args.pop('persistent_workers') # only in Pytorch 1.7+
- loader = loader_class(dataset, **loader_args)
- if use_prefetcher:
- prefetch_re_prob = re_prob if is_training and not no_aug else 0.
- loader = PrefetchLoader(
- loader,
- mean=mean,
- std=std,
- channels=input_size[0],
- device=device,
- fp16=fp16, # deprecated, use img_dtype
- img_dtype=img_dtype,
- re_prob=prefetch_re_prob,
- re_mode=re_mode,
- re_count=re_count,
- re_num_splits=re_num_splits
- )
- return loader
- class MultiEpochsDataLoader(torch.utils.data.DataLoader):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._DataLoader__initialized = False
- if self.batch_sampler is None:
- self.sampler = _RepeatSampler(self.sampler)
- else:
- self.batch_sampler = _RepeatSampler(self.batch_sampler)
- self._DataLoader__initialized = True
- self.iterator = super().__iter__()
- def __len__(self):
- return len(self.sampler) if self.batch_sampler is None else len(self.batch_sampler.sampler)
- def __iter__(self):
- for i in range(len(self)):
- yield next(self.iterator)
- class _RepeatSampler(object):
- """ Sampler that repeats forever.
- Args:
- sampler (Sampler)
- """
- def __init__(self, sampler):
- self.sampler = sampler
- def __iter__(self):
- while True:
- yield from iter(self.sampler)
|