reader_hfids.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. """ Dataset reader for HF IterableDataset
  2. """
  3. import math
  4. import os
  5. from itertools import repeat, chain
  6. from typing import Optional
  7. import torch
  8. import torch.distributed as dist
  9. from PIL import Image
  10. try:
  11. import datasets
  12. from datasets.distributed import split_dataset_by_node
  13. from datasets.splits import SplitInfo
  14. except ImportError as e:
  15. print("Please install Hugging Face datasets package `pip install datasets`.")
  16. raise e
  17. from .class_map import load_class_map
  18. from .reader import Reader
  19. from .shared_count import SharedCount
  20. SHUFFLE_SIZE = int(os.environ.get('HFIDS_SHUFFLE_SIZE', 4096))
  21. class ReaderHfids(Reader):
  22. def __init__(
  23. self,
  24. name: str,
  25. root: Optional[str] = None,
  26. split: str = 'train',
  27. is_training: bool = False,
  28. batch_size: int = 1,
  29. download: bool = False,
  30. repeats: int = 0,
  31. seed: int = 42,
  32. class_map: Optional[dict] = None,
  33. input_key: str = 'image',
  34. input_img_mode: str = 'RGB',
  35. target_key: str = 'label',
  36. target_img_mode: str = '',
  37. shuffle_size: Optional[int] = None,
  38. num_samples: Optional[int] = None,
  39. trust_remote_code: bool = False
  40. ):
  41. super().__init__()
  42. self.root = root
  43. self.split = split
  44. self.is_training = is_training
  45. self.batch_size = batch_size
  46. self.download = download
  47. self.repeats = repeats
  48. self.common_seed = seed # a seed that's fixed across all worker / distributed instances
  49. self.shuffle_size = shuffle_size or SHUFFLE_SIZE
  50. self.input_key = input_key
  51. self.input_img_mode = input_img_mode
  52. self.target_key = target_key
  53. self.target_img_mode = target_img_mode
  54. self.builder = datasets.load_dataset_builder(
  55. name,
  56. cache_dir=root,
  57. trust_remote_code=trust_remote_code,
  58. )
  59. if download:
  60. self.builder.download_and_prepare()
  61. split_info: Optional[SplitInfo] = None
  62. if self.builder.info.splits and split in self.builder.info.splits:
  63. if isinstance(self.builder.info.splits[split], SplitInfo):
  64. split_info: Optional[SplitInfo] = self.builder.info.splits[split]
  65. if num_samples:
  66. self.num_samples = num_samples
  67. elif split_info and split_info.num_examples:
  68. self.num_samples = split_info.num_examples
  69. else:
  70. raise ValueError(
  71. "Dataset length is unknown, please pass `num_samples` explicitly. "
  72. "The number of steps needs to be known in advance for the learning rate scheduler."
  73. )
  74. self.remap_class = False
  75. if class_map:
  76. self.class_to_idx = load_class_map(class_map)
  77. self.remap_class = True
  78. else:
  79. self.class_to_idx = {}
  80. # Distributed world state
  81. self.dist_rank = 0
  82. self.dist_num_replicas = 1
  83. if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
  84. self.dist_rank = dist.get_rank()
  85. self.dist_num_replicas = dist.get_world_size()
  86. # Attributes that are updated in _lazy_init
  87. self.worker_info = None
  88. self.worker_id = 0
  89. self.num_workers = 1
  90. self.global_worker_id = 0
  91. self.global_num_workers = 1
  92. # Initialized lazily on each dataloader worker process
  93. self.ds: Optional[datasets.IterableDataset] = None
  94. self.epoch = SharedCount()
  95. def set_epoch(self, count):
  96. # to update the shuffling effective_seed = seed + epoch
  97. self.epoch.value = count
  98. def set_loader_cfg(
  99. self,
  100. num_workers: Optional[int] = None,
  101. ):
  102. if self.ds is not None:
  103. return
  104. if num_workers is not None:
  105. self.num_workers = num_workers
  106. self.global_num_workers = self.dist_num_replicas * self.num_workers
  107. def _lazy_init(self):
  108. """ Lazily initialize worker (in worker processes)
  109. """
  110. if self.worker_info is None:
  111. worker_info = torch.utils.data.get_worker_info()
  112. if worker_info is not None:
  113. self.worker_info = worker_info
  114. self.worker_id = worker_info.id
  115. self.num_workers = worker_info.num_workers
  116. self.global_num_workers = self.dist_num_replicas * self.num_workers
  117. self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
  118. if self.download:
  119. dataset = self.builder.as_dataset(split=self.split)
  120. # to distribute evenly to workers
  121. ds = dataset.to_iterable_dataset(num_shards=self.global_num_workers)
  122. else:
  123. # in this case the number of shard is determined by the number of remote files
  124. ds = self.builder.as_streaming_dataset(split=self.split)
  125. if self.is_training:
  126. # will shuffle the list of shards and use a shuffle buffer
  127. ds = ds.shuffle(seed=self.common_seed, buffer_size=self.shuffle_size)
  128. # Distributed:
  129. # The dataset has a number of shards that is a factor of `dist_num_replicas` (i.e. if `ds.n_shards % dist_num_replicas == 0`),
  130. # so the shards are evenly assigned across the nodes.
  131. # If it's not the case for dataset streaming, each node keeps 1 example out of `dist_num_replicas`, skipping the other examples.
  132. # Workers:
  133. # In a node, datasets.IterableDataset assigns the shards assigned to the node as evenly as possible to workers.
  134. self.ds = split_dataset_by_node(ds, rank=self.dist_rank, world_size=self.dist_num_replicas)
  135. def _num_samples_per_worker(self):
  136. num_worker_samples = \
  137. max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
  138. if self.is_training or self.dist_num_replicas > 1:
  139. num_worker_samples = math.ceil(num_worker_samples)
  140. if self.is_training and self.batch_size is not None:
  141. num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
  142. return int(num_worker_samples)
  143. def __iter__(self):
  144. if self.ds is None:
  145. self._lazy_init()
  146. self.ds.set_epoch(self.epoch.value)
  147. target_sample_count = self._num_samples_per_worker()
  148. sample_count = 0
  149. if self.is_training:
  150. ds_iter = chain.from_iterable(repeat(self.ds))
  151. else:
  152. ds_iter = iter(self.ds)
  153. for sample in ds_iter:
  154. input_data: Image.Image = sample[self.input_key]
  155. if self.input_img_mode and input_data.mode != self.input_img_mode:
  156. input_data = input_data.convert(self.input_img_mode)
  157. target_data = sample[self.target_key]
  158. if self.target_img_mode:
  159. assert isinstance(target_data, Image.Image), "target_img_mode is specified but target is not an image"
  160. if target_data.mode != self.target_img_mode:
  161. target_data = target_data.convert(self.target_img_mode)
  162. elif self.remap_class:
  163. target_data = self.class_to_idx[target_data]
  164. yield input_data, target_data
  165. sample_count += 1
  166. if self.is_training and sample_count >= target_sample_count:
  167. break
  168. def __len__(self):
  169. num_samples = self._num_samples_per_worker() * self.num_workers
  170. return num_samples
  171. def _filename(self, index, basename=False, absolute=False):
  172. assert False, "Not supported" # no random access to examples
  173. def filenames(self, basename=False, absolute=False):
  174. """ Return all filenames in dataset, overrides base"""
  175. if self.ds is None:
  176. self._lazy_init()
  177. names = []
  178. for sample in self.ds:
  179. if 'file_name' in sample:
  180. name = sample['file_name']
  181. elif 'filename' in sample:
  182. name = sample['filename']
  183. elif 'id' in sample:
  184. name = sample['id']
  185. elif 'image_id' in sample:
  186. name = sample['image_id']
  187. else:
  188. assert False, "No supported name field present"
  189. names.append(name)
  190. return names