caltech.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import os
  2. import os.path
  3. import shutil
  4. from pathlib import Path
  5. from typing import Any, Callable, Optional, Union
  6. from PIL import Image
  7. from .utils import download_and_extract_archive, extract_archive, verify_str_arg
  8. from .vision import VisionDataset
  9. class Caltech101(VisionDataset):
  10. """`Caltech 101 <https://data.caltech.edu/records/20086>`_ Dataset.
  11. .. warning::
  12. This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
  13. Args:
  14. root (str or ``pathlib.Path``): Root directory of dataset where directory
  15. ``caltech101`` exists or will be saved to if download is set to True.
  16. target_type (string or list, optional): Type of target to use, ``category`` or
  17. ``annotation``. Can also be a list to output a tuple with all specified
  18. target types. ``category`` represents the target class, and
  19. ``annotation`` is a list of points from a hand-generated outline.
  20. Defaults to ``category``.
  21. transform (callable, optional): A function/transform that takes in a PIL image
  22. and returns a transformed version. E.g, ``transforms.RandomCrop``
  23. target_transform (callable, optional): A function/transform that takes in the
  24. target and transforms it.
  25. download (bool, optional): If true, downloads the dataset from the internet and
  26. puts it in root directory. If dataset is already downloaded, it is not
  27. downloaded again.
  28. .. warning::
  29. To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
  30. """
  31. def __init__(
  32. self,
  33. root: Union[str, Path],
  34. target_type: Union[list[str], str] = "category",
  35. transform: Optional[Callable] = None,
  36. target_transform: Optional[Callable] = None,
  37. download: bool = False,
  38. ) -> None:
  39. super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
  40. os.makedirs(self.root, exist_ok=True)
  41. if isinstance(target_type, str):
  42. target_type = [target_type]
  43. self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
  44. if download:
  45. self.download()
  46. if not self._check_integrity():
  47. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  48. self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
  49. self.categories.remove("BACKGROUND_Google") # this is not a real class
  50. # For some reason, the category names in "101_ObjectCategories" and
  51. # "Annotations" do not always match. This is a manual map between the
  52. # two. Defaults to using same name, since most names are fine.
  53. name_map = {
  54. "Faces": "Faces_2",
  55. "Faces_easy": "Faces_3",
  56. "Motorbikes": "Motorbikes_16",
  57. "airplanes": "Airplanes_Side_2",
  58. }
  59. self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
  60. self.index: list[int] = []
  61. self.y = []
  62. for i, c in enumerate(self.categories):
  63. n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
  64. self.index.extend(range(1, n + 1))
  65. self.y.extend(n * [i])
  66. def __getitem__(self, index: int) -> tuple[Any, Any]:
  67. """
  68. Args:
  69. index (int): Index
  70. Returns:
  71. tuple: (image, target) where the type of target specified by target_type.
  72. """
  73. import scipy.io
  74. img = Image.open(
  75. os.path.join(
  76. self.root,
  77. "101_ObjectCategories",
  78. self.categories[self.y[index]],
  79. f"image_{self.index[index]:04d}.jpg",
  80. )
  81. )
  82. target: Any = []
  83. for t in self.target_type:
  84. if t == "category":
  85. target.append(self.y[index])
  86. elif t == "annotation":
  87. data = scipy.io.loadmat(
  88. os.path.join(
  89. self.root,
  90. "Annotations",
  91. self.annotation_categories[self.y[index]],
  92. f"annotation_{self.index[index]:04d}.mat",
  93. )
  94. )
  95. target.append(data["obj_contour"])
  96. target = tuple(target) if len(target) > 1 else target[0]
  97. if self.transform is not None:
  98. img = self.transform(img)
  99. if self.target_transform is not None:
  100. target = self.target_transform(target)
  101. return img, target
  102. def _check_integrity(self) -> bool:
  103. # can be more robust and check hash of files
  104. return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
  105. def __len__(self) -> int:
  106. return len(self.index)
  107. def download(self) -> None:
  108. if self._check_integrity():
  109. return
  110. download_and_extract_archive(
  111. "https://data.caltech.edu/records/mzrjq-6wc02/files/caltech-101.zip",
  112. download_root=self.root,
  113. filename="caltech-101.zip",
  114. md5="3138e1922a9193bfa496528edbbc45d0",
  115. )
  116. gzip_folder = os.path.join(self.root, "caltech-101")
  117. for gzip_file in os.listdir(gzip_folder):
  118. if gzip_file.endswith(".gz"):
  119. extract_archive(os.path.join(gzip_folder, gzip_file), self.root)
  120. shutil.rmtree(gzip_folder)
  121. os.remove(os.path.join(self.root, "caltech-101.zip"))
  122. def extra_repr(self) -> str:
  123. return "Target type: {target_type}".format(**self.__dict__)
  124. class Caltech256(VisionDataset):
  125. """`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
  126. Args:
  127. root (str or ``pathlib.Path``): Root directory of dataset where directory
  128. ``caltech256`` exists or will be saved to if download is set to True.
  129. transform (callable, optional): A function/transform that takes in a PIL image
  130. and returns a transformed version. E.g, ``transforms.RandomCrop``
  131. target_transform (callable, optional): A function/transform that takes in the
  132. target and transforms it.
  133. download (bool, optional): If true, downloads the dataset from the internet and
  134. puts it in root directory. If dataset is already downloaded, it is not
  135. downloaded again.
  136. """
  137. def __init__(
  138. self,
  139. root: str,
  140. transform: Optional[Callable] = None,
  141. target_transform: Optional[Callable] = None,
  142. download: bool = False,
  143. ) -> None:
  144. super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
  145. os.makedirs(self.root, exist_ok=True)
  146. if download:
  147. self.download()
  148. if not self._check_integrity():
  149. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  150. self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
  151. self.index: list[int] = []
  152. self.y = []
  153. for i, c in enumerate(self.categories):
  154. n = len(
  155. [
  156. item
  157. for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
  158. if item.endswith(".jpg")
  159. ]
  160. )
  161. self.index.extend(range(1, n + 1))
  162. self.y.extend(n * [i])
  163. def __getitem__(self, index: int) -> tuple[Any, Any]:
  164. """
  165. Args:
  166. index (int): Index
  167. Returns:
  168. tuple: (image, target) where target is index of the target class.
  169. """
  170. img = Image.open(
  171. os.path.join(
  172. self.root,
  173. "256_ObjectCategories",
  174. self.categories[self.y[index]],
  175. f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
  176. )
  177. )
  178. target = self.y[index]
  179. if self.transform is not None:
  180. img = self.transform(img)
  181. if self.target_transform is not None:
  182. target = self.target_transform(target)
  183. return img, target
  184. def _check_integrity(self) -> bool:
  185. # can be more robust and check hash of files
  186. return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
  187. def __len__(self) -> int:
  188. return len(self.index)
  189. def download(self) -> None:
  190. if self._check_integrity():
  191. return
  192. download_and_extract_archive(
  193. "https://data.caltech.edu/records/nyy15-4j048/files/256_ObjectCategories.tar",
  194. self.root,
  195. filename="256_ObjectCategories.tar",
  196. md5="67b4f42ca05d46448c6bb8ecd2220f6d",
  197. )