reader_tfds.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. """ Dataset reader that wraps TFDS datasets
  2. Wraps many (most?) TFDS image-classification datasets
  3. from https://github.com/tensorflow/datasets
  4. https://www.tensorflow.org/datasets/catalog/overview#image_classification
  5. Hacked together by / Copyright 2020 Ross Wightman
  6. """
  7. import math
  8. import os
  9. import sys
  10. from typing import Optional
  11. import torch
  12. import torch.distributed as dist
  13. from PIL import Image
  14. try:
  15. import tensorflow as tf
  16. tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
  17. import tensorflow_datasets as tfds
  18. try:
  19. tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
  20. has_buggy_even_splits = False
  21. except TypeError:
  22. print("Warning: This version of tfds doesn't have the latest even_splits impl. "
  23. "Please update or use tfds-nightly for better fine-grained split behaviour.")
  24. has_buggy_even_splits = True
  25. # NOTE uncomment below if having file limit issues on dataset build (or alter your OS defaults)
  26. # import resource
  27. # low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
  28. # resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
  29. except ImportError as e:
  30. print(e)
  31. print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
  32. raise e
  33. from .class_map import load_class_map
  34. from .reader import Reader
  35. from .shared_count import SharedCount
  36. MAX_TP_SIZE = int(os.environ.get('TFDS_TP_SIZE', 8)) # maximum TF threadpool size, for jpeg decodes and queuing activities
  37. SHUFFLE_SIZE = int(os.environ.get('TFDS_SHUFFLE_SIZE', 8192)) # samples to shuffle in DS queue
  38. PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to prefetch
  39. @tfds.decode.make_decoder()
  40. def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE', channels=3):
  41. return tf.image.decode_jpeg(
  42. serialized_image,
  43. channels=channels,
  44. dct_method=dct_method,
  45. )
  46. def even_split_indices(split, n, num_samples):
  47. partitions = [round(i * num_samples / n) for i in range(n + 1)]
  48. return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)]
  49. def get_class_labels(info):
  50. if 'label' not in info.features:
  51. return {}
  52. class_label = info.features['label']
  53. class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
  54. return class_to_idx
  55. class ReaderTfds(Reader):
  56. """ Wrap Tensorflow Datasets for use in PyTorch
  57. There several things to be aware of:
  58. * To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of
  59. dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
  60. https://github.com/pytorch/pytorch/issues/33413
  61. * With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
  62. from each worker could be a different size. For training this is worked around by option above, for
  63. validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced
  64. across replicas are of same size. This will slightly alter the results, distributed validation will not be
  65. 100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
  66. since there are up to N * J extra samples with IterableDatasets.
  67. * The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
  68. replicas and dataloader workers you can use. For really small datasets that only contain a few shards
  69. you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
  70. benefit of distributed training or fast dataloading should be much less for small datasets.
  71. * This wrapper is currently configured to return individual, decompressed image samples from the TFDS
  72. dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
  73. to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
  74. components.
  75. """
  76. def __init__(
  77. self,
  78. name,
  79. root=None,
  80. split='train',
  81. class_map=None,
  82. is_training=False,
  83. batch_size=1,
  84. download=False,
  85. repeats=0,
  86. seed=42,
  87. input_key='image',
  88. input_img_mode='RGB',
  89. target_key='label',
  90. target_img_mode='',
  91. prefetch_size=None,
  92. shuffle_size=None,
  93. max_threadpool_size=None
  94. ):
  95. """ Tensorflow-datasets Wrapper
  96. Args:
  97. root: root data dir (ie your TFDS_DATA_DIR. not dataset specific sub-dir)
  98. name: tfds dataset name (eg `imagenet2012`)
  99. split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`)
  100. is_training: training mode, shuffle enabled, dataset len rounded by batch_size
  101. batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes
  102. download: download and build TFDS dataset if set, otherwise must use tfds CLI
  103. repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
  104. seed: common seed for shard shuffle across all distributed/worker instances
  105. input_key: name of Feature to return as data (input)
  106. input_img_mode: image mode if input is an image (currently PIL mode string)
  107. target_key: name of Feature to return as target (label)
  108. target_img_mode: image mode if target is an image (currently PIL mode string)
  109. prefetch_size: override default tf.data prefetch buffer size
  110. shuffle_size: override default tf.data shuffle buffer size
  111. max_threadpool_size: override default threadpool size for tf.data
  112. """
  113. super().__init__()
  114. self.root = root
  115. self.split = split
  116. self.is_training = is_training
  117. self.batch_size = batch_size
  118. self.repeats = repeats
  119. self.common_seed = seed # a seed that's fixed across all worker / distributed instances
  120. # performance settings
  121. self.prefetch_size = prefetch_size or PREFETCH_SIZE
  122. self.shuffle_size = shuffle_size or SHUFFLE_SIZE
  123. self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE
  124. # TFDS builder and split information
  125. self.input_key = input_key # FIXME support tuples / lists of inputs and targets and full range of Feature
  126. self.input_img_mode = input_img_mode
  127. self.target_key = target_key
  128. self.target_img_mode = target_img_mode # for dense pixel targets
  129. self.builder = tfds.builder(name, data_dir=root)
  130. # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
  131. if download:
  132. self.builder.download_and_prepare()
  133. self.remap_class = False
  134. if class_map:
  135. self.class_to_idx = load_class_map(class_map)
  136. self.remap_class = True
  137. else:
  138. self.class_to_idx = get_class_labels(self.builder.info) if self.target_key == 'label' else {}
  139. self.split_info = self.builder.info.splits[split]
  140. self.num_samples = self.split_info.num_examples
  141. # Distributed world state
  142. self.dist_rank = 0
  143. self.dist_num_replicas = 1
  144. if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
  145. self.dist_rank = dist.get_rank()
  146. self.dist_num_replicas = dist.get_world_size()
  147. # Attributes that are updated in _lazy_init, including the tf.data pipeline itself
  148. self.global_num_workers = 1
  149. self.num_workers = 1
  150. self.worker_info = None
  151. self.worker_seed = 0 # seed unique to each work instance
  152. self.subsplit = None # set when data is distributed across workers using sub-splits
  153. self.ds = None # initialized lazily on each dataloader worker process
  154. self.init_count = 0 # number of ds TF data pipeline initializations
  155. self.epoch_count = SharedCount()
  156. # FIXME need to determine if reinit_each_iter is necessary. I'm don't completely trust behaviour
  157. # of `shuffle_reshuffle_each_iteration` when there are multiple workers / nodes across epochs
  158. self.reinit_each_iter = self.is_training
  159. def set_epoch(self, count):
  160. self.epoch_count.value = count
  161. def set_loader_cfg(
  162. self,
  163. num_workers: Optional[int] = None,
  164. ):
  165. if self.ds is not None:
  166. return
  167. if num_workers is not None:
  168. self.num_workers = num_workers
  169. self.global_num_workers = self.dist_num_replicas * self.num_workers
  170. def _lazy_init(self):
  171. """ Lazily initialize the dataset.
  172. This is necessary to init the Tensorflow dataset pipeline in the (dataloader) process that
  173. will be using the dataset instance. The __init__ method is called on the main process,
  174. this will be called in a dataloader worker process.
  175. NOTE: There will be problems if you try to re-use this dataset across different loader/worker
  176. instances once it has been initialized. Do not call any dataset methods that can call _lazy_init
  177. before it is passed to dataloader.
  178. """
  179. worker_info = torch.utils.data.get_worker_info()
  180. # setup input context to split dataset across distributed processes
  181. num_workers = 1
  182. global_worker_id = 0
  183. if worker_info is not None:
  184. self.worker_info = worker_info
  185. self.worker_seed = worker_info.seed
  186. self.num_workers = worker_info.num_workers
  187. self.global_num_workers = self.dist_num_replicas * self.num_workers
  188. global_worker_id = self.dist_rank * self.num_workers + worker_info.id
  189. """ Data sharding
  190. InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
  191. My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
  192. between the splits each iteration, but that understanding could be wrong.
  193. I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
  194. the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
  195. in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or
  196. for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding.
  197. """
  198. should_subsplit = self.global_num_workers > 1 and (
  199. self.split_info.num_shards < self.global_num_workers or not self.is_training)
  200. if should_subsplit:
  201. # split the dataset w/o using sharding for more even samples / worker, can result in less optimal
  202. # read patterns for distributed training (overlap across shards) so better to use InputContext there
  203. if has_buggy_even_splits:
  204. # my even_split workaround doesn't work on subsplits, upgrade tfds!
  205. if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
  206. subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
  207. self.subsplit = subsplits[global_worker_id]
  208. else:
  209. subsplits = tfds.even_splits(self.split, self.global_num_workers)
  210. self.subsplit = subsplits[global_worker_id]
  211. input_context = None
  212. if self.global_num_workers > 1 and self.subsplit is None:
  213. # set input context to divide shards among distributed replicas
  214. input_context = tf.distribute.InputContext(
  215. num_input_pipelines=self.global_num_workers,
  216. input_pipeline_id=global_worker_id,
  217. num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
  218. )
  219. read_config = tfds.ReadConfig(
  220. shuffle_seed=self.common_seed + self.epoch_count.value,
  221. shuffle_reshuffle_each_iteration=True,
  222. input_context=input_context,
  223. )
  224. ds = self.builder.as_dataset(
  225. split=self.subsplit or self.split,
  226. shuffle_files=self.is_training,
  227. decoders=dict(image=decode_example(channels=1 if self.input_img_mode == 'L' else 3)),
  228. read_config=read_config,
  229. )
  230. # avoid overloading threading w/ combo of TF ds threads + PyTorch workers
  231. options = tf.data.Options()
  232. thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
  233. getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // self.num_workers)
  234. getattr(options, thread_member).max_intra_op_parallelism = 1
  235. ds = ds.with_options(options)
  236. if self.is_training or self.repeats > 1:
  237. # to prevent excessive drop_last batch behaviour w/ IterableDatasets
  238. # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
  239. ds = ds.repeat() # allow wrap around and break iteration manually
  240. if self.is_training:
  241. ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
  242. ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size))
  243. self.ds = tfds.as_numpy(ds)
  244. self.init_count += 1
  245. def _num_samples_per_worker(self):
  246. num_worker_samples = \
  247. max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
  248. if self.is_training or self.dist_num_replicas > 1:
  249. num_worker_samples = math.ceil(num_worker_samples)
  250. if self.is_training:
  251. num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
  252. return int(num_worker_samples)
  253. def __iter__(self):
  254. if self.ds is None or self.reinit_each_iter:
  255. self._lazy_init()
  256. # Compute a rounded up sample count that is used to:
  257. # 1. make batches even cross workers & replicas in distributed validation.
  258. # This adds extra samples and will slightly alter validation results.
  259. # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
  260. # batches are produced (underlying tfds iter wraps around)
  261. target_sample_count = self._num_samples_per_worker()
  262. # Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
  263. sample_count = 0
  264. for sample in self.ds:
  265. input_data = sample[self.input_key]
  266. if self.input_img_mode:
  267. if self.input_img_mode == 'L' and input_data.ndim == 3:
  268. input_data = input_data[:, :, 0]
  269. input_data = Image.fromarray(input_data, mode=self.input_img_mode)
  270. target_data = sample[self.target_key]
  271. if self.target_img_mode:
  272. # dense pixel target
  273. target_data = Image.fromarray(target_data, mode=self.target_img_mode)
  274. elif self.remap_class:
  275. target_data = self.class_to_idx[target_data]
  276. yield input_data, target_data
  277. sample_count += 1
  278. if self.is_training and sample_count >= target_sample_count:
  279. # Need to break out of loop when repeat() is enabled for training w/ oversampling
  280. # this results in extra samples per epoch but seems more desirable than dropping
  281. # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
  282. break
  283. # Pad across distributed nodes (make counts equal by adding samples)
  284. if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
  285. 0 < sample_count < target_sample_count:
  286. # Validation batch padding only done for distributed training where results are reduced across nodes.
  287. # For single process case, it won't matter if workers return different batch sizes.
  288. # If using input_context or % based splits, sample count can vary significantly across workers and this
  289. # approach should not be used (hence disabled if self.subsplit isn't set).
  290. while sample_count < target_sample_count:
  291. yield input_data, target_data # yield prev sample again
  292. sample_count += 1
  293. def __len__(self):
  294. num_samples = self._num_samples_per_worker() * self.num_workers
  295. return num_samples
  296. def _filename(self, index, basename=False, absolute=False):
  297. assert False, "Not supported" # no random access to samples
  298. def filenames(self, basename=False, absolute=False):
  299. """ Return all filenames in dataset, overrides base"""
  300. if self.ds is None:
  301. self._lazy_init()
  302. names = []
  303. for sample in self.ds:
  304. if len(names) > self.num_samples:
  305. break # safety for ds.repeat() case
  306. if 'file_name' in sample:
  307. name = sample['file_name']
  308. elif 'filename' in sample:
  309. name = sample['filename']
  310. elif 'id' in sample:
  311. name = sample['id']
  312. else:
  313. assert False, "No supported name field present"
  314. names.append(name)
  315. return names