dataset.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. """ Quick n Simple Image Folder, Tarfile based DataSet
  2. Hacked together by / Copyright 2019, Ross Wightman
  3. """
  4. import io
  5. import logging
  6. from typing import Optional
  7. import torch
  8. import torch.utils.data as data
  9. from PIL import Image
  10. from .readers import create_reader
  11. _logger = logging.getLogger(__name__)
  12. _ERROR_RETRY = 20
  13. class ImageDataset(data.Dataset):
  14. def __init__(
  15. self,
  16. root,
  17. reader=None,
  18. split='train',
  19. class_map=None,
  20. load_bytes=False,
  21. input_img_mode='RGB',
  22. transform=None,
  23. target_transform=None,
  24. additional_features=None,
  25. **kwargs,
  26. ):
  27. if reader is None or isinstance(reader, str):
  28. reader = create_reader(
  29. reader or '',
  30. root=root,
  31. split=split,
  32. class_map=class_map,
  33. additional_features=additional_features,
  34. **kwargs,
  35. )
  36. self.reader = reader
  37. self.load_bytes = load_bytes
  38. self.input_img_mode = input_img_mode
  39. self.transform = transform
  40. self.target_transform = target_transform
  41. self.additional_features = additional_features
  42. self._max_retries = _ERROR_RETRY
  43. def __getitem__(self, index):
  44. for attempt in range(self._max_retries):
  45. try:
  46. img, target, *features = self.reader[index]
  47. img = img.read() if self.load_bytes else Image.open(img)
  48. break
  49. except (IOError, OSError) as e: # be specific
  50. _logger.warning(f'Skipped sample (index {index}). {e}')
  51. index = (index + 1) % len(self.reader)
  52. else:
  53. raise RuntimeError(f"Failed to load {self._max_retries} consecutive samples")
  54. if self.input_img_mode and not self.load_bytes:
  55. img = img.convert(self.input_img_mode)
  56. if self.transform is not None:
  57. img = self.transform(img)
  58. if target is None:
  59. target = -1
  60. elif self.target_transform is not None:
  61. target = self.target_transform(target)
  62. if self.additional_features is None:
  63. return img, target
  64. else:
  65. return img, target, *features
  66. def __len__(self):
  67. return len(self.reader)
  68. def filename(self, index, basename=False, absolute=False):
  69. return self.reader.filename(index, basename, absolute)
  70. def filenames(self, basename=False, absolute=False):
  71. return self.reader.filenames(basename, absolute)
  72. class IterableImageDataset(data.IterableDataset):
  73. def __init__(
  74. self,
  75. root,
  76. reader=None,
  77. split='train',
  78. class_map=None,
  79. is_training=False,
  80. batch_size=1,
  81. num_samples=None,
  82. seed=42,
  83. repeats=0,
  84. download=False,
  85. input_img_mode='RGB',
  86. input_key=None,
  87. target_key=None,
  88. transform=None,
  89. target_transform=None,
  90. max_steps=None,
  91. **kwargs,
  92. ):
  93. assert reader is not None
  94. if isinstance(reader, str):
  95. self.reader = create_reader(
  96. reader,
  97. root=root,
  98. split=split,
  99. class_map=class_map,
  100. is_training=is_training,
  101. batch_size=batch_size,
  102. num_samples=num_samples,
  103. seed=seed,
  104. repeats=repeats,
  105. download=download,
  106. input_img_mode=input_img_mode,
  107. input_key=input_key,
  108. target_key=target_key,
  109. max_steps=max_steps,
  110. **kwargs,
  111. )
  112. else:
  113. self.reader = reader
  114. self.transform = transform
  115. self.target_transform = target_transform
  116. def __iter__(self):
  117. for img, target in self.reader:
  118. if self.transform is not None:
  119. img = self.transform(img)
  120. if self.target_transform is not None:
  121. target = self.target_transform(target)
  122. yield img, target
  123. def __len__(self):
  124. if hasattr(self.reader, '__len__'):
  125. return len(self.reader)
  126. else:
  127. return 0
  128. def set_epoch(self, count):
  129. # TFDS and WDS need external epoch count for deterministic cross process shuffle
  130. if hasattr(self.reader, 'set_epoch'):
  131. self.reader.set_epoch(count)
  132. def set_loader_cfg(
  133. self,
  134. num_workers: Optional[int] = None,
  135. ):
  136. # TFDS and WDS readers need # workers for correct # samples estimate before loader processes created
  137. if hasattr(self.reader, 'set_loader_cfg'):
  138. self.reader.set_loader_cfg(num_workers=num_workers)
  139. def filename(self, index, basename=False, absolute=False):
  140. assert False, 'Filename lookup by index not supported, use filenames().'
  141. def filenames(self, basename=False, absolute=False):
  142. return self.reader.filenames(basename, absolute)
  143. class AugMixDataset(torch.utils.data.Dataset):
  144. """Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
  145. def __init__(self, dataset, num_splits=2):
  146. self.augmentation = None
  147. self.normalize = None
  148. self.dataset = dataset
  149. if self.dataset.transform is not None:
  150. self._set_transforms(self.dataset.transform)
  151. self.num_splits = num_splits
  152. def _set_transforms(self, x):
  153. assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
  154. self.dataset.transform = x[0]
  155. self.augmentation = x[1]
  156. self.normalize = x[2]
  157. @property
  158. def transform(self):
  159. return self.dataset.transform
  160. @transform.setter
  161. def transform(self, x):
  162. self._set_transforms(x)
  163. def _normalize(self, x):
  164. return x if self.normalize is None else self.normalize(x)
  165. def __getitem__(self, i):
  166. x, y = self.dataset[i] # all splits share the same dataset base transform
  167. x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
  168. # run the full augmentation on the remaining splits
  169. for _ in range(self.num_splits - 1):
  170. x_list.append(self._normalize(self.augmentation(x)))
  171. return tuple(x_list), y
  172. def __len__(self):
  173. return len(self.dataset)