reader_factory.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import os
  2. from typing import Optional
  3. from .reader_image_folder import ReaderImageFolder
  4. from .reader_image_in_tar import ReaderImageInTar
  5. def create_reader(
  6. name: str,
  7. root: Optional[str] = None,
  8. split: str = 'train',
  9. **kwargs,
  10. ):
  11. kwargs = {k: v for k, v in kwargs.items() if v is not None}
  12. name = name.lower()
  13. name = name.split('/', 1)
  14. prefix = ''
  15. if len(name) > 1:
  16. prefix = name[0]
  17. name = name[-1]
  18. # FIXME the additional features are only supported by ReaderHfds for now.
  19. additional_features = kwargs.pop("additional_features", None)
  20. # FIXME improve the selection right now just tfds prefix or fallback path, will need options to
  21. # explicitly select other options shortly
  22. if prefix == 'hfds':
  23. from .reader_hfds import ReaderHfds # defer Hf datasets import
  24. reader = ReaderHfds(name=name, root=root, split=split, additional_features=additional_features, **kwargs)
  25. elif prefix == 'hfids':
  26. from .reader_hfids import ReaderHfids # defer HF datasets import
  27. reader = ReaderHfids(name=name, root=root, split=split, **kwargs)
  28. elif prefix == 'tfds':
  29. from .reader_tfds import ReaderTfds # defer tensorflow import
  30. reader = ReaderTfds(name=name, root=root, split=split, **kwargs)
  31. elif prefix == 'wds':
  32. from .reader_wds import ReaderWds
  33. kwargs.pop('download', False)
  34. reader = ReaderWds(root=root, name=name, split=split, **kwargs)
  35. else:
  36. assert os.path.exists(root)
  37. # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
  38. # FIXME support split here or in reader?
  39. if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
  40. reader = ReaderImageInTar(root, **kwargs)
  41. else:
  42. reader = ReaderImageFolder(root, **kwargs)
  43. return reader