| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- """ Quick n Simple Image Folder, Tarfile based DataSet
- Hacked together by / Copyright 2019, Ross Wightman
- """
- import io
- import logging
- from typing import Optional
- import torch
- import torch.utils.data as data
- from PIL import Image
- from .readers import create_reader
- _logger = logging.getLogger(__name__)
- _ERROR_RETRY = 20
- class ImageDataset(data.Dataset):
- def __init__(
- self,
- root,
- reader=None,
- split='train',
- class_map=None,
- load_bytes=False,
- input_img_mode='RGB',
- transform=None,
- target_transform=None,
- additional_features=None,
- **kwargs,
- ):
- if reader is None or isinstance(reader, str):
- reader = create_reader(
- reader or '',
- root=root,
- split=split,
- class_map=class_map,
- additional_features=additional_features,
- **kwargs,
- )
- self.reader = reader
- self.load_bytes = load_bytes
- self.input_img_mode = input_img_mode
- self.transform = transform
- self.target_transform = target_transform
- self.additional_features = additional_features
- self._max_retries = _ERROR_RETRY
- def __getitem__(self, index):
- for attempt in range(self._max_retries):
- try:
- img, target, *features = self.reader[index]
- img = img.read() if self.load_bytes else Image.open(img)
- break
- except (IOError, OSError) as e: # be specific
- _logger.warning(f'Skipped sample (index {index}). {e}')
- index = (index + 1) % len(self.reader)
- else:
- raise RuntimeError(f"Failed to load {self._max_retries} consecutive samples")
- if self.input_img_mode and not self.load_bytes:
- img = img.convert(self.input_img_mode)
- if self.transform is not None:
- img = self.transform(img)
- if target is None:
- target = -1
- elif self.target_transform is not None:
- target = self.target_transform(target)
- if self.additional_features is None:
- return img, target
- else:
- return img, target, *features
- def __len__(self):
- return len(self.reader)
- def filename(self, index, basename=False, absolute=False):
- return self.reader.filename(index, basename, absolute)
- def filenames(self, basename=False, absolute=False):
- return self.reader.filenames(basename, absolute)
- class IterableImageDataset(data.IterableDataset):
- def __init__(
- self,
- root,
- reader=None,
- split='train',
- class_map=None,
- is_training=False,
- batch_size=1,
- num_samples=None,
- seed=42,
- repeats=0,
- download=False,
- input_img_mode='RGB',
- input_key=None,
- target_key=None,
- transform=None,
- target_transform=None,
- max_steps=None,
- **kwargs,
- ):
- assert reader is not None
- if isinstance(reader, str):
- self.reader = create_reader(
- reader,
- root=root,
- split=split,
- class_map=class_map,
- is_training=is_training,
- batch_size=batch_size,
- num_samples=num_samples,
- seed=seed,
- repeats=repeats,
- download=download,
- input_img_mode=input_img_mode,
- input_key=input_key,
- target_key=target_key,
- max_steps=max_steps,
- **kwargs,
- )
- else:
- self.reader = reader
- self.transform = transform
- self.target_transform = target_transform
- def __iter__(self):
- for img, target in self.reader:
- if self.transform is not None:
- img = self.transform(img)
- if self.target_transform is not None:
- target = self.target_transform(target)
- yield img, target
- def __len__(self):
- if hasattr(self.reader, '__len__'):
- return len(self.reader)
- else:
- return 0
- def set_epoch(self, count):
- # TFDS and WDS need external epoch count for deterministic cross process shuffle
- if hasattr(self.reader, 'set_epoch'):
- self.reader.set_epoch(count)
- def set_loader_cfg(
- self,
- num_workers: Optional[int] = None,
- ):
- # TFDS and WDS readers need # workers for correct # samples estimate before loader processes created
- if hasattr(self.reader, 'set_loader_cfg'):
- self.reader.set_loader_cfg(num_workers=num_workers)
- def filename(self, index, basename=False, absolute=False):
- assert False, 'Filename lookup by index not supported, use filenames().'
- def filenames(self, basename=False, absolute=False):
- return self.reader.filenames(basename, absolute)
- class AugMixDataset(torch.utils.data.Dataset):
- """Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
- def __init__(self, dataset, num_splits=2):
- self.augmentation = None
- self.normalize = None
- self.dataset = dataset
- if self.dataset.transform is not None:
- self._set_transforms(self.dataset.transform)
- self.num_splits = num_splits
- def _set_transforms(self, x):
- assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
- self.dataset.transform = x[0]
- self.augmentation = x[1]
- self.normalize = x[2]
- @property
- def transform(self):
- return self.dataset.transform
- @transform.setter
- def transform(self, x):
- self._set_transforms(x)
- def _normalize(self, x):
- return x if self.normalize is None else self.normalize(x)
- def __getitem__(self, i):
- x, y = self.dataset[i] # all splits share the same dataset base transform
- x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
- # run the full augmentation on the remaining splits
- for _ in range(self.num_splits - 1):
- x_list.append(self._normalize(self.augmentation(x)))
- return tuple(x_list), y
- def __len__(self):
- return len(self.dataset)
|