stl10.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import os.path
  2. from pathlib import Path
  3. from typing import Any, Callable, cast, Optional, Union
  4. import numpy as np
  5. from PIL import Image
  6. from .utils import check_integrity, download_and_extract_archive, verify_str_arg
  7. from .vision import VisionDataset
  8. class STL10(VisionDataset):
  9. """`STL10 <https://cs.stanford.edu/~acoates/stl10/>`_ Dataset.
  10. Args:
  11. root (str or ``pathlib.Path``): Root directory of dataset where directory
  12. ``stl10_binary`` exists.
  13. split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
  14. Accordingly, dataset is selected.
  15. folds (int, optional): One of {0-9} or None.
  16. For training, loads one of the 10 pre-defined folds of 1k samples for the
  17. standard evaluation procedure. If no value is passed, loads the 5k samples.
  18. transform (callable, optional): A function/transform that takes in a PIL image
  19. and returns a transformed version. E.g, ``transforms.RandomCrop``
  20. target_transform (callable, optional): A function/transform that takes in the
  21. target and transforms it.
  22. download (bool, optional): If true, downloads the dataset from the internet and
  23. puts it in root directory. If dataset is already downloaded, it is not
  24. downloaded again.
  25. """
  26. base_folder = "stl10_binary"
  27. url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
  28. filename = "stl10_binary.tar.gz"
  29. tgz_md5 = "91f7769df0f17e558f3565bffb0c7dfb"
  30. class_names_file = "class_names.txt"
  31. folds_list_file = "fold_indices.txt"
  32. train_list = [
  33. ["train_X.bin", "918c2871b30a85fa023e0c44e0bee87f"],
  34. ["train_y.bin", "5a34089d4802c674881badbb80307741"],
  35. ["unlabeled_X.bin", "5242ba1fed5e4be9e1e742405eb56ca4"],
  36. ]
  37. test_list = [["test_X.bin", "7f263ba9f9e0b06b93213547f721ac82"], ["test_y.bin", "36f9794fa4beb8a2c72628de14fa638e"]]
  38. splits = ("train", "train+unlabeled", "unlabeled", "test")
  39. def __init__(
  40. self,
  41. root: Union[str, Path],
  42. split: str = "train",
  43. folds: Optional[int] = None,
  44. transform: Optional[Callable] = None,
  45. target_transform: Optional[Callable] = None,
  46. download: bool = False,
  47. ) -> None:
  48. super().__init__(root, transform=transform, target_transform=target_transform)
  49. self.split = verify_str_arg(split, "split", self.splits)
  50. self.folds = self._verify_folds(folds)
  51. if download:
  52. self.download()
  53. elif not self._check_integrity():
  54. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  55. # now load the picked numpy arrays
  56. self.labels: Optional[np.ndarray]
  57. if self.split == "train":
  58. self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
  59. self.labels = cast(np.ndarray, self.labels)
  60. self.__load_folds(folds)
  61. elif self.split == "train+unlabeled":
  62. self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
  63. self.labels = cast(np.ndarray, self.labels)
  64. self.__load_folds(folds)
  65. unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
  66. self.data = np.concatenate((self.data, unlabeled_data))
  67. self.labels = np.concatenate((self.labels, np.asarray([-1] * unlabeled_data.shape[0])))
  68. elif self.split == "unlabeled":
  69. self.data, _ = self.__loadfile(self.train_list[2][0])
  70. self.labels = np.asarray([-1] * self.data.shape[0])
  71. else: # self.split == 'test':
  72. self.data, self.labels = self.__loadfile(self.test_list[0][0], self.test_list[1][0])
  73. class_file = os.path.join(self.root, self.base_folder, self.class_names_file)
  74. if os.path.isfile(class_file):
  75. with open(class_file) as f:
  76. self.classes = f.read().splitlines()
  77. def _verify_folds(self, folds: Optional[int]) -> Optional[int]:
  78. if folds is None:
  79. return folds
  80. elif isinstance(folds, int):
  81. if folds in range(10):
  82. return folds
  83. msg = "Value for argument folds should be in the range [0, 10), but got {}."
  84. raise ValueError(msg.format(folds))
  85. else:
  86. msg = "Expected type None or int for argument folds, but got type {}."
  87. raise ValueError(msg.format(type(folds)))
  88. def __getitem__(self, index: int) -> tuple[Any, Any]:
  89. """
  90. Args:
  91. index (int): Index
  92. Returns:
  93. tuple: (image, target) where target is index of the target class.
  94. """
  95. target: Optional[int]
  96. if self.labels is not None:
  97. img, target = self.data[index], int(self.labels[index])
  98. else:
  99. img, target = self.data[index], None
  100. # doing this so that it is consistent with all other datasets
  101. # to return a PIL Image
  102. img = Image.fromarray(np.transpose(img, (1, 2, 0)))
  103. if self.transform is not None:
  104. img = self.transform(img)
  105. if self.target_transform is not None:
  106. target = self.target_transform(target)
  107. return img, target
  108. def __len__(self) -> int:
  109. return self.data.shape[0]
  110. def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> tuple[np.ndarray, Optional[np.ndarray]]:
  111. labels = None
  112. if labels_file:
  113. path_to_labels = os.path.join(self.root, self.base_folder, labels_file)
  114. with open(path_to_labels, "rb") as f:
  115. labels = np.fromfile(f, dtype=np.uint8) - 1 # 0-based
  116. path_to_data = os.path.join(self.root, self.base_folder, data_file)
  117. with open(path_to_data, "rb") as f:
  118. # read whole file in uint8 chunks
  119. everything = np.fromfile(f, dtype=np.uint8)
  120. images = np.reshape(everything, (-1, 3, 96, 96))
  121. images = np.transpose(images, (0, 1, 3, 2))
  122. return images, labels
  123. def _check_integrity(self) -> bool:
  124. for filename, md5 in self.train_list + self.test_list:
  125. fpath = os.path.join(self.root, self.base_folder, filename)
  126. if not check_integrity(fpath, md5):
  127. return False
  128. return True
  129. def download(self) -> None:
  130. if self._check_integrity():
  131. return
  132. download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
  133. self._check_integrity()
  134. def extra_repr(self) -> str:
  135. return "Split: {split}".format(**self.__dict__)
  136. def __load_folds(self, folds: Optional[int]) -> None:
  137. # loads one of the folds if specified
  138. if folds is None:
  139. return
  140. path_to_folds = os.path.join(self.root, self.base_folder, self.folds_list_file)
  141. with open(path_to_folds) as f:
  142. str_idx = f.read().splitlines()[folds]
  143. list_idx = np.fromstring(str_idx, dtype=np.int64, sep=" ")
  144. self.data = self.data[list_idx, :, :, :]
  145. if self.labels is not None:
  146. self.labels = self.labels[list_idx]