reader_hfds.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. """ Dataset reader that wraps Hugging Face datasets
  2. Hacked together by / Copyright 2022 Ross Wightman
  3. """
  4. import io
  5. import math
  6. from typing import Optional
  7. import torch
  8. import torch.distributed as dist
  9. from PIL import Image
  10. try:
  11. import datasets
  12. except ImportError as e:
  13. print("Please install Hugging Face datasets package `pip install datasets`.")
  14. raise e
  15. from .class_map import load_class_map
  16. from .reader import Reader
  17. def get_class_labels(info, label_key='label'):
  18. if 'label' not in info.features:
  19. return {}
  20. class_label = info.features[label_key]
  21. class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
  22. return class_to_idx
  23. class ReaderHfds(Reader):
  24. def __init__(
  25. self,
  26. name: str,
  27. root: Optional[str] = None,
  28. split: str = 'train',
  29. class_map: dict = None,
  30. input_key: str = 'image',
  31. target_key: str = 'label',
  32. additional_features: Optional[list[str]] = None,
  33. download: bool = False,
  34. trust_remote_code: bool = False
  35. ):
  36. """
  37. """
  38. super().__init__()
  39. self.root = root
  40. self.split = split
  41. self.dataset = datasets.load_dataset(
  42. name, # 'name' maps to path arg in hf datasets
  43. split=split,
  44. cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path if root set
  45. trust_remote_code=trust_remote_code
  46. )
  47. # leave decode for caller, plus we want easy access to original path names...
  48. self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False))
  49. self.image_key = input_key
  50. self.label_key = target_key
  51. self.remap_class = False
  52. if class_map:
  53. self.class_to_idx = load_class_map(class_map)
  54. self.remap_class = True
  55. else:
  56. self.class_to_idx = get_class_labels(self.dataset.info, self.label_key)
  57. self.split_info = self.dataset.info.splits[split]
  58. self.num_samples = self.split_info.num_examples
  59. if additional_features is not None:
  60. if isinstance(additional_features, list):
  61. self.additional_features = additional_features
  62. else:
  63. self.additional_features = [additional_features]
  64. else:
  65. self.additional_features = None
  66. def __getitem__(self, index):
  67. item = self.dataset[index]
  68. image = item[self.image_key]
  69. if 'bytes' in image and image['bytes']:
  70. image = io.BytesIO(image['bytes'])
  71. else:
  72. assert 'path' in image and image['path']
  73. image = open(image['path'], 'rb')
  74. label = item[self.label_key]
  75. if self.remap_class:
  76. label = self.class_to_idx[label]
  77. if self.additional_features is not None:
  78. features = [item[feat] for feat in self.additional_features]
  79. return image, label, *features
  80. else:
  81. return image, label
  82. def __len__(self):
  83. return len(self.dataset)
  84. def _filename(self, index, basename=False, absolute=False):
  85. item = self.dataset[index]
  86. return item[self.image_key]['path']