inaturalist.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import os
  2. import os.path
  3. from pathlib import Path
  4. from typing import Any, Callable, Optional, Union
  5. from PIL import Image
  6. from .utils import download_and_extract_archive, verify_str_arg
  7. from .vision import VisionDataset
  8. CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]
  9. DATASET_URLS = {
  10. "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz",
  11. "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz",
  12. "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz",
  13. "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
  14. "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz",
  15. "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
  16. }
  17. DATASET_MD5 = {
  18. "2017": "7c784ea5e424efaec655bd392f87301f",
  19. "2018": "b1c6952ce38f31868cc50ea72d066cc3",
  20. "2019": "c60a6e2962c9b8ccbd458d12c8582644",
  21. "2021_train": "e0526d53c7f7b2e3167b2b43bb2690ed",
  22. "2021_train_mini": "db6ed8330e634445efc8fec83ae81442",
  23. "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc",
  24. }
  25. class INaturalist(VisionDataset):
  26. """`iNaturalist <https://github.com/visipedia/inat_comp>`_ Dataset.
  27. Args:
  28. root (str or ``pathlib.Path``): Root directory of dataset where the image files are stored.
  29. This class does not require/use annotation files.
  30. version (string, optional): Which version of the dataset to download/use. One of
  31. '2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'.
  32. Default: `2021_train`.
  33. target_type (string or list, optional): Type of target to use, for 2021 versions, one of:
  34. - ``full``: the full category (species)
  35. - ``kingdom``: e.g. "Animalia"
  36. - ``phylum``: e.g. "Arthropoda"
  37. - ``class``: e.g. "Insecta"
  38. - ``order``: e.g. "Coleoptera"
  39. - ``family``: e.g. "Cleridae"
  40. - ``genus``: e.g. "Trichodes"
  41. for 2017-2019 versions, one of:
  42. - ``full``: the full (numeric) category
  43. - ``super``: the super category, e.g. "Amphibians"
  44. Can also be a list to output a tuple with all specified target types.
  45. Defaults to ``full``.
  46. transform (callable, optional): A function/transform that takes in a PIL image
  47. and returns a transformed version. E.g, ``transforms.RandomCrop``
  48. target_transform (callable, optional): A function/transform that takes in the
  49. target and transforms it.
  50. download (bool, optional): If true, downloads the dataset from the internet and
  51. puts it in root directory. If dataset is already downloaded, it is not
  52. downloaded again.
  53. loader (callable, optional): A function to load an image given its path.
  54. By default, it uses PIL as its image loader, but users could also pass in
  55. ``torchvision.io.decode_image`` for decoding image data into tensors directly.
  56. """
  57. def __init__(
  58. self,
  59. root: Union[str, Path],
  60. version: str = "2021_train",
  61. target_type: Union[list[str], str] = "full",
  62. transform: Optional[Callable] = None,
  63. target_transform: Optional[Callable] = None,
  64. download: bool = False,
  65. loader: Optional[Callable[[Union[str, Path]], Any]] = None,
  66. ) -> None:
  67. self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
  68. super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform)
  69. os.makedirs(root, exist_ok=True)
  70. if download:
  71. self.download()
  72. if not self._check_exists():
  73. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  74. self.all_categories: list[str] = []
  75. # map: category type -> name of category -> index
  76. self.categories_index: dict[str, dict[str, int]] = {}
  77. # list indexed by category id, containing mapping from category type -> index
  78. self.categories_map: list[dict[str, int]] = []
  79. if not isinstance(target_type, list):
  80. target_type = [target_type]
  81. if self.version[:4] == "2021":
  82. self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type]
  83. self._init_2021()
  84. else:
  85. self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type]
  86. self._init_pre2021()
  87. # index of all files: (full category id, filename)
  88. self.index: list[tuple[int, str]] = []
  89. for dir_index, dir_name in enumerate(self.all_categories):
  90. files = os.listdir(os.path.join(self.root, dir_name))
  91. for fname in files:
  92. self.index.append((dir_index, fname))
  93. self.loader = loader
  94. def _init_2021(self) -> None:
  95. """Initialize based on 2021 layout"""
  96. self.all_categories = sorted(os.listdir(self.root))
  97. # map: category type -> name of category -> index
  98. self.categories_index = {k: {} for k in CATEGORIES_2021}
  99. for dir_index, dir_name in enumerate(self.all_categories):
  100. pieces = dir_name.split("_")
  101. if len(pieces) != 8:
  102. raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")
  103. if pieces[0] != f"{dir_index:05d}":
  104. raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")
  105. cat_map = {}
  106. for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
  107. if name in self.categories_index[cat]:
  108. cat_id = self.categories_index[cat][name]
  109. else:
  110. cat_id = len(self.categories_index[cat])
  111. self.categories_index[cat][name] = cat_id
  112. cat_map[cat] = cat_id
  113. self.categories_map.append(cat_map)
  114. def _init_pre2021(self) -> None:
  115. """Initialize based on 2017-2019 layout"""
  116. # map: category type -> name of category -> index
  117. self.categories_index = {"super": {}}
  118. cat_index = 0
  119. super_categories = sorted(os.listdir(self.root))
  120. for sindex, scat in enumerate(super_categories):
  121. self.categories_index["super"][scat] = sindex
  122. subcategories = sorted(os.listdir(os.path.join(self.root, scat)))
  123. for subcat in subcategories:
  124. if self.version == "2017":
  125. # this version does not use ids as directory names
  126. subcat_i = cat_index
  127. cat_index += 1
  128. else:
  129. try:
  130. subcat_i = int(subcat)
  131. except ValueError:
  132. raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}")
  133. if subcat_i >= len(self.categories_map):
  134. old_len = len(self.categories_map)
  135. self.categories_map.extend([{}] * (subcat_i - old_len + 1))
  136. self.all_categories.extend([""] * (subcat_i - old_len + 1))
  137. if self.categories_map[subcat_i]:
  138. raise RuntimeError(f"Duplicate category {subcat}")
  139. self.categories_map[subcat_i] = {"super": sindex}
  140. self.all_categories[subcat_i] = os.path.join(scat, subcat)
  141. # validate the dictionary
  142. for cindex, c in enumerate(self.categories_map):
  143. if not c:
  144. raise RuntimeError(f"Missing category {cindex}")
  145. def __getitem__(self, index: int) -> tuple[Any, Any]:
  146. """
  147. Args:
  148. index (int): Index
  149. Returns:
  150. tuple: (image, target) where the type of target specified by target_type.
  151. """
  152. cat_id, fname = self.index[index]
  153. image_path = os.path.join(self.root, self.all_categories[cat_id], fname)
  154. img = self.loader(image_path) if self.loader is not None else Image.open(image_path)
  155. target: Any = []
  156. for t in self.target_type:
  157. if t == "full":
  158. target.append(cat_id)
  159. else:
  160. target.append(self.categories_map[cat_id][t])
  161. target = tuple(target) if len(target) > 1 else target[0]
  162. if self.transform is not None:
  163. img = self.transform(img)
  164. if self.target_transform is not None:
  165. target = self.target_transform(target)
  166. return img, target
  167. def __len__(self) -> int:
  168. return len(self.index)
  169. def category_name(self, category_type: str, category_id: int) -> str:
  170. """
  171. Args:
  172. category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super"
  173. category_id(int): an index (class id) from this category
  174. Returns:
  175. the name of the category
  176. """
  177. if category_type == "full":
  178. return self.all_categories[category_id]
  179. else:
  180. if category_type not in self.categories_index:
  181. raise ValueError(f"Invalid category type '{category_type}'")
  182. else:
  183. for name, id in self.categories_index[category_type].items():
  184. if id == category_id:
  185. return name
  186. raise ValueError(f"Invalid category id {category_id} for {category_type}")
  187. def _check_exists(self) -> bool:
  188. return os.path.exists(self.root) and len(os.listdir(self.root)) > 0
  189. def download(self) -> None:
  190. if self._check_exists():
  191. return
  192. base_root = os.path.dirname(self.root)
  193. download_and_extract_archive(
  194. DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version]
  195. )
  196. orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
  197. if not os.path.exists(orig_dir_name):
  198. raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}")
  199. os.rename(orig_dir_name, self.root)