voc.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import collections
  2. import os
  3. from pathlib import Path
  4. from typing import Any, Callable, Optional, Union
  5. from xml.etree.ElementTree import Element as ET_Element
  6. try:
  7. from defusedxml.ElementTree import parse as ET_parse
  8. except ImportError:
  9. from xml.etree.ElementTree import parse as ET_parse
  10. from PIL import Image
  11. from .utils import download_and_extract_archive, verify_str_arg
  12. from .vision import VisionDataset
  13. DATASET_YEAR_DICT = {
  14. "2012": {
  15. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
  16. "filename": "VOCtrainval_11-May-2012.tar",
  17. "md5": "6cd6e144f989b92b3379bac3b3de84fd",
  18. "base_dir": os.path.join("VOCdevkit", "VOC2012"),
  19. },
  20. "2011": {
  21. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
  22. "filename": "VOCtrainval_25-May-2011.tar",
  23. "md5": "6c3384ef61512963050cb5d687e5bf1e",
  24. "base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"),
  25. },
  26. "2010": {
  27. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
  28. "filename": "VOCtrainval_03-May-2010.tar",
  29. "md5": "da459979d0c395079b5c75ee67908abb",
  30. "base_dir": os.path.join("VOCdevkit", "VOC2010"),
  31. },
  32. "2009": {
  33. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
  34. "filename": "VOCtrainval_11-May-2009.tar",
  35. "md5": "a3e00b113cfcfebf17e343f59da3caa1",
  36. "base_dir": os.path.join("VOCdevkit", "VOC2009"),
  37. },
  38. "2008": {
  39. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
  40. "filename": "VOCtrainval_11-May-2012.tar",
  41. "md5": "2629fa636546599198acfcfbfcf1904a",
  42. "base_dir": os.path.join("VOCdevkit", "VOC2008"),
  43. },
  44. "2007": {
  45. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
  46. "filename": "VOCtrainval_06-Nov-2007.tar",
  47. "md5": "c52e279531787c972589f7e41ab4ae64",
  48. "base_dir": os.path.join("VOCdevkit", "VOC2007"),
  49. },
  50. "2007-test": {
  51. "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar",
  52. "filename": "VOCtest_06-Nov-2007.tar",
  53. "md5": "b6e924de25625d8de591ea690078ad9f",
  54. "base_dir": os.path.join("VOCdevkit", "VOC2007"),
  55. },
  56. }
  57. class _VOCBase(VisionDataset):
  58. _SPLITS_DIR: str
  59. _TARGET_DIR: str
  60. _TARGET_FILE_EXT: str
  61. def __init__(
  62. self,
  63. root: Union[str, Path],
  64. year: str = "2012",
  65. image_set: str = "train",
  66. download: bool = False,
  67. transform: Optional[Callable] = None,
  68. target_transform: Optional[Callable] = None,
  69. transforms: Optional[Callable] = None,
  70. ):
  71. super().__init__(root, transforms, transform, target_transform)
  72. self.year = verify_str_arg(year, "year", valid_values=[str(yr) for yr in range(2007, 2013)])
  73. valid_image_sets = ["train", "trainval", "val"]
  74. if year == "2007":
  75. valid_image_sets.append("test")
  76. self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)
  77. key = "2007-test" if year == "2007" and image_set == "test" else year
  78. dataset_year_dict = DATASET_YEAR_DICT[key]
  79. self.url = dataset_year_dict["url"]
  80. self.filename = dataset_year_dict["filename"]
  81. self.md5 = dataset_year_dict["md5"]
  82. base_dir = dataset_year_dict["base_dir"]
  83. voc_root = os.path.join(self.root, base_dir)
  84. if download:
  85. download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
  86. if not os.path.isdir(voc_root):
  87. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  88. splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
  89. split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
  90. with open(os.path.join(split_f)) as f:
  91. file_names = [x.strip() for x in f.readlines()]
  92. image_dir = os.path.join(voc_root, "JPEGImages")
  93. self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
  94. target_dir = os.path.join(voc_root, self._TARGET_DIR)
  95. self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names]
  96. assert len(self.images) == len(self.targets)
  97. def __len__(self) -> int:
  98. return len(self.images)
  99. class VOCSegmentation(_VOCBase):
  100. """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
  101. Args:
  102. root (str or ``pathlib.Path``): Root directory of the VOC Dataset.
  103. year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
  104. image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
  105. ``year=="2007"``, can also be ``"test"``.
  106. download (bool, optional): If true, downloads the dataset from the internet and
  107. puts it in root directory. If dataset is already downloaded, it is not
  108. downloaded again.
  109. transform (callable, optional): A function/transform that takes in a PIL image
  110. and returns a transformed version. E.g, ``transforms.RandomCrop``
  111. target_transform (callable, optional): A function/transform that takes in the
  112. target and transforms it.
  113. transforms (callable, optional): A function/transform that takes input sample and its target as entry
  114. and returns a transformed version.
  115. """
  116. _SPLITS_DIR = "Segmentation"
  117. _TARGET_DIR = "SegmentationClass"
  118. _TARGET_FILE_EXT = ".png"
  119. @property
  120. def masks(self) -> list[str]:
  121. return self.targets
  122. def __getitem__(self, index: int) -> tuple[Any, Any]:
  123. """
  124. Args:
  125. index (int): Index
  126. Returns:
  127. tuple: (image, target) where target is the image segmentation.
  128. """
  129. img = Image.open(self.images[index]).convert("RGB")
  130. target = Image.open(self.masks[index])
  131. if self.transforms is not None:
  132. img, target = self.transforms(img, target)
  133. return img, target
  134. class VOCDetection(_VOCBase):
  135. """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
  136. Args:
  137. root (str or ``pathlib.Path``): Root directory of the VOC Dataset.
  138. year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
  139. image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
  140. ``year=="2007"``, can also be ``"test"``.
  141. download (bool, optional): If true, downloads the dataset from the internet and
  142. puts it in root directory. If dataset is already downloaded, it is not
  143. downloaded again.
  144. (default: alphabetic indexing of VOC's 20 classes).
  145. transform (callable, optional): A function/transform that takes in a PIL image
  146. and returns a transformed version. E.g, ``transforms.RandomCrop``
  147. target_transform (callable, required): A function/transform that takes in the
  148. target and transforms it.
  149. transforms (callable, optional): A function/transform that takes input sample and its target as entry
  150. and returns a transformed version.
  151. """
  152. _SPLITS_DIR = "Main"
  153. _TARGET_DIR = "Annotations"
  154. _TARGET_FILE_EXT = ".xml"
  155. @property
  156. def annotations(self) -> list[str]:
  157. return self.targets
  158. def __getitem__(self, index: int) -> tuple[Any, Any]:
  159. """
  160. Args:
  161. index (int): Index
  162. Returns:
  163. tuple: (image, target) where target is a dictionary of the XML tree.
  164. """
  165. img = Image.open(self.images[index]).convert("RGB")
  166. target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
  167. if self.transforms is not None:
  168. img, target = self.transforms(img, target)
  169. return img, target
  170. @staticmethod
  171. def parse_voc_xml(node: ET_Element) -> dict[str, Any]:
  172. voc_dict: dict[str, Any] = {}
  173. children = list(node)
  174. if children:
  175. def_dic: dict[str, Any] = collections.defaultdict(list)
  176. for dc in map(VOCDetection.parse_voc_xml, children):
  177. for ind, v in dc.items():
  178. def_dic[ind].append(v)
  179. if node.tag == "annotation":
  180. def_dic["object"] = [def_dic["object"]]
  181. voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}}
  182. if node.text:
  183. text = node.text.strip()
  184. if not children:
  185. voc_dict[node.tag] = text
  186. return voc_dict