reader_wds.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. """ Dataset reader for webdataset
  2. Hacked together by / Copyright 2022 Ross Wightman
  3. """
  4. import io
  5. import json
  6. import logging
  7. import math
  8. import os
  9. import random
  10. import sys
  11. from dataclasses import dataclass
  12. from functools import partial
  13. from itertools import islice
  14. from typing import Any, Callable, Dict, List, Optional, Tuple
  15. import torch
  16. import torch.distributed as dist
  17. import yaml
  18. from PIL import Image
  19. from torch.utils.data import Dataset, IterableDataset, get_worker_info
  20. try:
  21. import webdataset as wds
  22. from webdataset.filters import _shuffle, getfirst
  23. from webdataset.shardlists import expand_urls
  24. from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
  25. except ImportError:
  26. wds = None
  27. expand_urls = None
  28. from .class_map import load_class_map
  29. from .reader import Reader
  30. from .shared_count import SharedCount
  31. _logger = logging.getLogger(__name__)
  32. SAMPLE_SHUFFLE_SIZE = int(os.environ.get('WDS_SHUFFLE_SIZE', 8192))
  33. SAMPLE_INITIAL_SIZE = int(os.environ.get('WDS_INITIAL_SIZE', 2048))
  34. def _load_info(root, names=('_info.json', 'info.json')):
  35. if isinstance(names, str):
  36. names = (names,)
  37. tried = []
  38. err_str = ''
  39. for n in names:
  40. full_path = os.path.join(root, n)
  41. try:
  42. tried.append(full_path)
  43. with wds.gopen(full_path) as f:
  44. if n.endswith('.json'):
  45. info_dict = json.load(f)
  46. else:
  47. info_dict = yaml.safe_load(f)
  48. return info_dict
  49. except Exception as e:
  50. err_str = str(e)
  51. _logger.warning(
  52. f'Dataset info file not found at {tried}. Error: {err_str}. '
  53. 'Falling back to provided split and size arg.')
  54. return {}
  55. @dataclass
  56. class SplitInfo:
  57. num_samples: int
  58. filenames: Tuple[str]
  59. shard_lengths: Tuple[int] = ()
  60. alt_label: str = ''
  61. name: str = ''
  62. def _parse_split_info(split: str, info: Dict):
  63. def _info_convert(dict_info):
  64. return SplitInfo(
  65. num_samples=dict_info['num_samples'],
  66. filenames=tuple(dict_info['filenames']),
  67. shard_lengths=tuple(dict_info['shard_lengths']),
  68. alt_label=dict_info.get('alt_label', ''),
  69. name=dict_info['name'],
  70. )
  71. if 'tar' in split or '..' in split:
  72. # split in WDS string braceexpand format, sample count can be included with a | separator
  73. # ex: `dataset-split-{0000..9999}.tar|100000` for 9999 shards, covering 100,000 samples
  74. split = split.split('|')
  75. num_samples = 0
  76. split_name = ''
  77. if len(split) > 1:
  78. num_samples = int(split[1])
  79. split = split[0]
  80. if '::' not in split:
  81. split_parts = split.split('-', 3)
  82. split_idx = len(split_parts) - 1
  83. if split_idx and 'splits' in info and split_parts[split_idx] in info['splits']:
  84. split_name = split_parts[split_idx]
  85. split_filenames = expand_urls(split)
  86. if split_name:
  87. split_info = info['splits'][split_name]
  88. if not num_samples:
  89. _fc = {f: c for f, c in zip(split_info['filenames'], split_info['shard_lengths'])}
  90. num_samples = sum(_fc[f] for f in split_filenames)
  91. split_info['filenames'] = tuple(_fc.keys())
  92. split_info['shard_lengths'] = tuple(_fc.values())
  93. split_info['num_samples'] = num_samples
  94. split_info = _info_convert(split_info)
  95. else:
  96. split_info = SplitInfo(
  97. name=split_name,
  98. num_samples=num_samples,
  99. filenames=split_filenames,
  100. )
  101. else:
  102. if 'splits' not in info or split not in info['splits']:
  103. raise RuntimeError(f"split {split} not found in info ({info.get('splits', {}).keys()})")
  104. split = split
  105. split_info = info['splits'][split]
  106. split_info = _info_convert(split_info)
  107. return split_info
  108. def log_and_continue(exn):
  109. """Call in an exception handler to ignore exceptions, issue a warning, and continue."""
  110. _logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
  111. # NOTE: try force an exit on errors that are clearly code / config and not transient
  112. if isinstance(exn, TypeError):
  113. raise exn
  114. return True
  115. def _decode(
  116. sample,
  117. image_key='jpg',
  118. image_mode='RGB',
  119. target_key='cls',
  120. alt_label=''
  121. ):
  122. """ Custom sample decode
  123. * decode and convert PIL Image
  124. * cls byte string label to int
  125. * pass through JSON byte string (if it exists) without parse
  126. """
  127. # decode class label, skip if alternate label not valid
  128. if alt_label:
  129. # alternative labels are encoded in json metadata
  130. meta = json.loads(sample['json'])
  131. class_label = int(meta[alt_label])
  132. if class_label < 0:
  133. # skipped labels currently encoded as -1, may change to a null/None value
  134. return None
  135. else:
  136. class_label = int(sample[target_key])
  137. # decode image
  138. img = getfirst(sample, image_key)
  139. with io.BytesIO(img) as b:
  140. img = Image.open(b)
  141. img.load()
  142. if image_mode:
  143. img = img.convert(image_mode)
  144. # json passed through in undecoded state
  145. decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None))
  146. return decoded
  147. def pytorch_worker_seed():
  148. """get dataloader worker seed from pytorch"""
  149. worker_info = get_worker_info()
  150. if worker_info is not None:
  151. # favour the seed already created for pytorch dataloader workers if it exists
  152. return worker_info.seed
  153. # fallback to wds rank based seed
  154. return wds.utils.pytorch_worker_seed()
  155. if wds is not None:
  156. # conditional to avoid mandatory wds import (via inheritance of wds.PipelineStage)
  157. class detshuffle2(wds.PipelineStage):
  158. def __init__(
  159. self,
  160. bufsize=1000,
  161. initial=100,
  162. seed=0,
  163. epoch=-1,
  164. ):
  165. self.bufsize = bufsize
  166. self.initial = initial
  167. self.seed = seed
  168. self.epoch = epoch
  169. def run(self, src):
  170. if isinstance(self.epoch, SharedCount):
  171. epoch = self.epoch.value
  172. else:
  173. # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
  174. # situation as different workers may wrap at different times (or not at all).
  175. self.epoch += 1
  176. epoch = self.epoch
  177. if self.seed < 0:
  178. seed = pytorch_worker_seed() + epoch
  179. else:
  180. seed = self.seed + epoch
  181. # _logger.info(f'shuffle seed: {self.seed}, {seed}, epoch: {epoch}') # FIXME temporary
  182. rng = random.Random(seed)
  183. return _shuffle(src, self.bufsize, self.initial, rng)
  184. else:
  185. detshuffle2 = None
  186. class ResampledShards2(IterableDataset):
  187. """An iterable dataset yielding a list of urls."""
  188. def __init__(
  189. self,
  190. urls,
  191. nshards=sys.maxsize,
  192. worker_seed=None,
  193. deterministic=True,
  194. epoch=-1,
  195. ):
  196. """Sample shards from the shard list with replacement.
  197. :param urls: a list of URLs as a Python list or brace notation string
  198. """
  199. super().__init__()
  200. urls = wds.shardlists.expand_urls(urls)
  201. self.urls = urls
  202. assert isinstance(self.urls[0], str)
  203. self.nshards = nshards
  204. self.rng = random.Random()
  205. self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
  206. self.deterministic = deterministic
  207. self.epoch = epoch
  208. def __iter__(self):
  209. """Return an iterator over the shards."""
  210. if isinstance(self.epoch, SharedCount):
  211. epoch = self.epoch.value
  212. else:
  213. # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
  214. # situation as different workers may wrap at different times (or not at all).
  215. self.epoch += 1
  216. epoch = self.epoch
  217. if self.deterministic:
  218. # reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed
  219. self.rng = random.Random(self.worker_seed() + epoch)
  220. for _ in range(self.nshards):
  221. index = self.rng.randint(0, len(self.urls) - 1)
  222. yield dict(url=self.urls[index])
  223. class ReaderWds(Reader):
  224. def __init__(
  225. self,
  226. root: str,
  227. name: Optional[str] = None,
  228. split: str = 'train',
  229. is_training: bool = False,
  230. num_samples: Optional[int] = None,
  231. batch_size: int = 1,
  232. repeats: int = 0,
  233. seed: int = 42,
  234. class_map: Optional[dict] = None,
  235. input_key: str = 'jpg;png;webp',
  236. input_img_mode: str = 'RGB',
  237. target_key: str = 'cls',
  238. target_img_mode: str = '',
  239. filename_key: str = 'filename',
  240. sample_shuffle_size: Optional[int] = None,
  241. sample_initial_size: Optional[int] = None,
  242. ):
  243. super().__init__()
  244. if wds is None:
  245. raise RuntimeError(
  246. 'Please install webdataset 0.2.x package `pip install git+https://github.com/webdataset/webdataset`.')
  247. self.root = root
  248. self.is_training = is_training
  249. self.batch_size = batch_size
  250. self.repeats = repeats
  251. self.common_seed = seed # a seed that's fixed across all worker / distributed instances
  252. self.shard_shuffle_size = 500
  253. self.sample_shuffle_size = sample_shuffle_size or SAMPLE_SHUFFLE_SIZE
  254. self.sample_initial_size = sample_initial_size or SAMPLE_INITIAL_SIZE
  255. self.input_key = input_key
  256. self.input_img_mode = input_img_mode
  257. self.target_key = target_key
  258. self.filename_key = filename_key
  259. self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet)
  260. self.info = _load_info(self.root)
  261. self.split_info = _parse_split_info(split, self.info)
  262. if num_samples is not None:
  263. self.num_samples = num_samples
  264. else:
  265. self.num_samples = self.split_info.num_samples
  266. if is_training and not self.num_samples:
  267. raise RuntimeError(f'Invalid split definition, num_samples not specified in train mode.')
  268. self.remap_class = False
  269. if class_map:
  270. self.class_to_idx = load_class_map(class_map)
  271. self.remap_class = True
  272. else:
  273. self.class_to_idx = {}
  274. # Distributed world state
  275. self.dist_rank = 0
  276. self.dist_num_replicas = 1
  277. if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
  278. self.dist_rank = dist.get_rank()
  279. self.dist_num_replicas = dist.get_world_size()
  280. # Attributes that are updated in _lazy_init
  281. self.worker_info = None
  282. self.worker_id = 0
  283. self.worker_seed = seed # seed unique to each worker instance
  284. self.num_workers = 1
  285. self.global_worker_id = 0
  286. self.global_num_workers = 1
  287. self.init_count = 0
  288. self.epoch_count = SharedCount()
  289. # DataPipeline is lazy init, the majority of WDS DataPipeline could be init here, BUT, shuffle seed
  290. # is not handled in manner where it can be deterministic for each worker AND initialized up front
  291. self.ds = None
  292. def set_epoch(self, count):
  293. self.epoch_count.value = count
  294. def set_loader_cfg(
  295. self,
  296. num_workers: Optional[int] = None,
  297. ):
  298. if self.ds is not None:
  299. return
  300. if num_workers is not None:
  301. self.num_workers = num_workers
  302. self.global_num_workers = self.dist_num_replicas * self.num_workers
  303. def _lazy_init(self):
  304. """ Lazily initialize worker (in worker processes)
  305. """
  306. if self.worker_info is None:
  307. worker_info = torch.utils.data.get_worker_info()
  308. if worker_info is not None:
  309. self.worker_info = worker_info
  310. self.worker_id = worker_info.id
  311. self.worker_seed = worker_info.seed
  312. self.num_workers = worker_info.num_workers
  313. self.global_num_workers = self.dist_num_replicas * self.num_workers
  314. self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
  315. # init data pipeline
  316. abs_shard_filenames = [os.path.join(self.root, f) for f in self.split_info.filenames]
  317. pipeline = [wds.SimpleShardList(abs_shard_filenames)]
  318. # at this point we have an iterator over all the shards
  319. if self.is_training:
  320. pipeline.extend([
  321. detshuffle2(
  322. self.shard_shuffle_size,
  323. seed=self.common_seed,
  324. epoch=self.epoch_count,
  325. ),
  326. self._split_by_node_and_worker,
  327. # at this point, we have an iterator over the shards assigned to each worker
  328. wds.tarfile_to_samples(handler=log_and_continue),
  329. wds.shuffle(
  330. bufsize=self.sample_shuffle_size,
  331. initial=self.sample_initial_size,
  332. rng=random.Random(self.worker_seed) # this is why we lazy-init whole DataPipeline
  333. ),
  334. ])
  335. else:
  336. pipeline.extend([
  337. self._split_by_node_and_worker,
  338. # at this point, we have an iterator over the shards assigned to each worker
  339. wds.tarfile_to_samples(handler=log_and_continue),
  340. ])
  341. pipeline.extend([
  342. wds.map(
  343. partial(
  344. _decode,
  345. image_key=self.input_key,
  346. image_mode=self.input_img_mode,
  347. alt_label=self.split_info.alt_label,
  348. ),
  349. handler=log_and_continue,
  350. ),
  351. wds.rename(image=self.input_key, target=self.target_key)
  352. ])
  353. self.ds = wds.DataPipeline(*pipeline)
  354. def _split_by_node_and_worker(self, src):
  355. if self.global_num_workers > 1:
  356. for s in islice(src, self.global_worker_id, None, self.global_num_workers):
  357. yield s
  358. else:
  359. for s in src:
  360. yield s
  361. def _num_samples_per_worker(self):
  362. num_worker_samples = self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
  363. if self.is_training or self.dist_num_replicas > 1:
  364. num_worker_samples = math.ceil(num_worker_samples)
  365. if self.is_training:
  366. num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
  367. return int(num_worker_samples)
  368. def __iter__(self):
  369. if self.ds is None:
  370. self._lazy_init()
  371. num_worker_samples = self._num_samples_per_worker()
  372. if self.is_training or self.dist_num_replicas > 1:
  373. # NOTE: doing distributed validation w/ WDS is messy, hard to meet constraints that
  374. # same # of batches needed across all replicas w/ seeing each sample once.
  375. # with_epoch() is simple but could miss a shard's worth of samples in some workers,
  376. # and duplicate in others. Best to keep num DL workers low and a divisor of #val shards.
  377. ds = self.ds.with_epoch(num_worker_samples)
  378. else:
  379. ds = self.ds
  380. i = 0
  381. # _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
  382. for sample in ds:
  383. target = sample['target']
  384. if self.remap_class:
  385. target = self.class_to_idx[target]
  386. yield sample['image'], target
  387. i += 1
  388. # _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug
  389. def __len__(self):
  390. num_samples = self._num_samples_per_worker() * self.num_workers
  391. return num_samples
  392. def _filename(self, index, basename=False, absolute=False):
  393. assert False, "Not supported" # no random access to examples
  394. def filenames(self, basename=False, absolute=False):
  395. """ Return all filenames in dataset, overrides base"""
  396. if self.ds is None:
  397. self._lazy_init()
  398. names = []
  399. for sample in self.ds:
  400. if self.filename_key in sample:
  401. name = sample[self.filename_key]
  402. elif '__key__' in sample:
  403. name = sample['__key__'] + self.key_ext
  404. else:
  405. assert False, "No supported name field present"
  406. names.append(name)
  407. if len(names) >= self.num_samples:
  408. break # safety for ds.repeat() case
  409. return names