dataset_factory.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. """ Dataset Factory
  2. Hacked together by / Copyright 2021, Ross Wightman
  3. """
  4. import os
  5. from typing import Optional
  6. from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder
  7. try:
  8. from torchvision.datasets import Places365
  9. has_places365 = True
  10. except ImportError:
  11. has_places365 = False
  12. try:
  13. from torchvision.datasets import INaturalist
  14. has_inaturalist = True
  15. except ImportError:
  16. has_inaturalist = False
  17. try:
  18. from torchvision.datasets import QMNIST
  19. has_qmnist = True
  20. except ImportError:
  21. has_qmnist = False
  22. try:
  23. from torchvision.datasets import ImageNet
  24. has_imagenet = True
  25. except ImportError:
  26. has_imagenet = False
  27. from .dataset import IterableImageDataset, ImageDataset
  28. _TORCH_BASIC_DS = dict(
  29. cifar10=CIFAR10,
  30. cifar100=CIFAR100,
  31. mnist=MNIST,
  32. kmnist=KMNIST,
  33. fashion_mnist=FashionMNIST,
  34. )
  35. _TRAIN_SYNONYM = dict(train=None, training=None)
  36. _EVAL_SYNONYM = dict(val=None, valid=None, validation=None, eval=None, evaluation=None)
  37. def _search_split(root, split):
  38. # look for sub-folder with name of split in root and use that if it exists
  39. split_name = split.split('[')[0]
  40. try_root = os.path.join(root, split_name)
  41. if os.path.exists(try_root):
  42. return try_root
  43. def _try(syn):
  44. for s in syn:
  45. try_root = os.path.join(root, s)
  46. if os.path.exists(try_root):
  47. return try_root
  48. return root
  49. if split_name in _TRAIN_SYNONYM:
  50. root = _try(_TRAIN_SYNONYM)
  51. elif split_name in _EVAL_SYNONYM:
  52. root = _try(_EVAL_SYNONYM)
  53. return root
  54. def create_dataset(
  55. name: str,
  56. root: Optional[str] = None,
  57. split: str = 'validation',
  58. search_split: bool = True,
  59. class_map: dict = None,
  60. load_bytes: bool = False,
  61. is_training: bool = False,
  62. download: bool = False,
  63. batch_size: int = 1,
  64. num_samples: Optional[int] = None,
  65. seed: int = 42,
  66. repeats: int = 0,
  67. input_img_mode: str = 'RGB',
  68. trust_remote_code: bool = False,
  69. **kwargs,
  70. ):
  71. """ Dataset factory method
  72. In parentheses after each arg are the type of dataset supported for each arg, one of:
  73. * Folder - default, timm folder (or tar) based ImageDataset
  74. * Torch - torchvision based datasets
  75. * HFDS - Hugging Face Datasets
  76. * HFIDS - Hugging Face Datasets Iterable (streaming mode, with IterableDataset)
  77. * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
  78. * WDS - Webdataset
  79. * All - any of the above
  80. Args:
  81. name: Dataset name, empty is okay for folder based datasets
  82. root: Root folder of dataset (All)
  83. split: Dataset split (All)
  84. search_split: Search for split specific child fold from root so one can specify
  85. `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (Folder, Torch)
  86. class_map: Specify class -> index mapping via text file or dict (Folder)
  87. load_bytes: Load data, return images as undecoded bytes (Folder)
  88. download: Download dataset if not present and supported (HFIDS, TFDS, Torch)
  89. is_training: Create dataset in train mode, this is different from the split.
  90. For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS, HFIDS)
  91. batch_size: Batch size hint for iterable datasets (TFDS, WDS, HFIDS)
  92. seed: Seed for iterable datasets (TFDS, WDS, HFIDS)
  93. repeats: Dataset repeats per iteration i.e. epoch (TFDS, WDS, HFIDS)
  94. input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS, HFIDS)
  95. trust_remote_code: Trust remote code in Hugging Face Datasets if True (HFDS, HFIDS)
  96. **kwargs: Other args to pass through to underlying Dataset and/or Reader classes
  97. Returns:
  98. Dataset object
  99. """
  100. kwargs = {k: v for k, v in kwargs.items() if v is not None}
  101. name = name.lower()
  102. if name.startswith('torch/'):
  103. name = name.split('/', 2)[-1]
  104. torch_kwargs = dict(root=root, download=download, **kwargs)
  105. if name in _TORCH_BASIC_DS:
  106. ds_class = _TORCH_BASIC_DS[name]
  107. use_train = split in _TRAIN_SYNONYM
  108. ds = ds_class(train=use_train, **torch_kwargs)
  109. elif name == 'inaturalist' or name == 'inat':
  110. assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
  111. target_type = 'full'
  112. split_split = split.split('/')
  113. if len(split_split) > 1:
  114. target_type = split_split[0].split('_')
  115. if len(target_type) == 1:
  116. target_type = target_type[0]
  117. split = split_split[-1]
  118. if split in _TRAIN_SYNONYM:
  119. split = '2021_train'
  120. elif split in _EVAL_SYNONYM:
  121. split = '2021_valid'
  122. ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
  123. elif name == 'places365':
  124. assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.'
  125. if split in _TRAIN_SYNONYM:
  126. split = 'train-standard'
  127. elif split in _EVAL_SYNONYM:
  128. split = 'val'
  129. ds = Places365(split=split, **torch_kwargs)
  130. elif name == 'qmnist':
  131. assert has_qmnist, 'Please update to a newer PyTorch and torchvision for QMNIST dataset.'
  132. use_train = split in _TRAIN_SYNONYM
  133. ds = QMNIST(train=use_train, **torch_kwargs)
  134. elif name == 'imagenet':
  135. torch_kwargs.pop('download')
  136. assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.'
  137. if split in _EVAL_SYNONYM:
  138. split = 'val'
  139. ds = ImageNet(split=split, **torch_kwargs)
  140. elif name == 'image_folder' or name == 'folder':
  141. # in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
  142. if search_split and os.path.isdir(root):
  143. # look for split specific sub-folder in root
  144. root = _search_split(root, split)
  145. ds = ImageFolder(root, **kwargs)
  146. else:
  147. assert False, f"Unknown torchvision dataset {name}"
  148. elif name.startswith('hfds/'):
  149. # NOTE right now, HF datasets default arrow format is a random-access Dataset,
  150. # There will be a IterableDataset variant too, TBD
  151. ds = ImageDataset(
  152. root,
  153. reader=name,
  154. split=split,
  155. class_map=class_map,
  156. input_img_mode=input_img_mode,
  157. trust_remote_code=trust_remote_code,
  158. **kwargs,
  159. )
  160. elif name.startswith('hfids/'):
  161. ds = IterableImageDataset(
  162. root,
  163. reader=name,
  164. split=split,
  165. class_map=class_map,
  166. is_training=is_training,
  167. download=download,
  168. batch_size=batch_size,
  169. num_samples=num_samples,
  170. repeats=repeats,
  171. seed=seed,
  172. input_img_mode=input_img_mode,
  173. trust_remote_code=trust_remote_code,
  174. **kwargs,
  175. )
  176. elif name.startswith('tfds/'):
  177. ds = IterableImageDataset(
  178. root,
  179. reader=name,
  180. split=split,
  181. class_map=class_map,
  182. is_training=is_training,
  183. download=download,
  184. batch_size=batch_size,
  185. num_samples=num_samples,
  186. repeats=repeats,
  187. seed=seed,
  188. input_img_mode=input_img_mode,
  189. **kwargs
  190. )
  191. elif name.startswith('wds/'):
  192. ds = IterableImageDataset(
  193. root,
  194. reader=name,
  195. split=split,
  196. class_map=class_map,
  197. is_training=is_training,
  198. batch_size=batch_size,
  199. num_samples=num_samples,
  200. repeats=repeats,
  201. seed=seed,
  202. input_img_mode=input_img_mode,
  203. **kwargs
  204. )
  205. else:
  206. # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
  207. if search_split and os.path.isdir(root):
  208. # look for split specific sub-folder in root
  209. root = _search_split(root, split)
  210. ds = ImageDataset(
  211. root,
  212. reader=name,
  213. class_map=class_map,
  214. load_bytes=load_bytes,
  215. input_img_mode=input_img_mode,
  216. **kwargs,
  217. )
  218. return ds