lfw.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import os
  2. from pathlib import Path
  3. from typing import Any, Callable, Optional, Union
  4. from .folder import default_loader
  5. from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
  6. from .vision import VisionDataset
  7. class _LFW(VisionDataset):
  8. base_folder = "lfw-py"
  9. download_url_prefix = "http://vis-www.cs.umass.edu/lfw/"
  10. file_dict = {
  11. "original": ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"),
  12. "funneled": ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"),
  13. "deepfunneled": ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201"),
  14. }
  15. checksums = {
  16. "pairs.txt": "9f1ba174e4e1c508ff7cdf10ac338a7d",
  17. "pairsDevTest.txt": "5132f7440eb68cf58910c8a45a2ac10b",
  18. "pairsDevTrain.txt": "4f27cbf15b2da4a85c1907eb4181ad21",
  19. "people.txt": "450f0863dd89e85e73936a6d71a3474b",
  20. "peopleDevTest.txt": "e4bf5be0a43b5dcd9dc5ccfcb8fb19c5",
  21. "peopleDevTrain.txt": "54eaac34beb6d042ed3a7d883e247a21",
  22. "lfw-names.txt": "a6d0a479bd074669f656265a6e693f6d",
  23. }
  24. annot_file = {"10fold": "", "train": "DevTrain", "test": "DevTest"}
  25. names = "lfw-names.txt"
  26. def __init__(
  27. self,
  28. root: Union[str, Path],
  29. split: str,
  30. image_set: str,
  31. view: str,
  32. transform: Optional[Callable] = None,
  33. target_transform: Optional[Callable] = None,
  34. download: bool = False,
  35. loader: Callable[[str], Any] = default_loader,
  36. ) -> None:
  37. super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform)
  38. self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys())
  39. images_dir, self.filename, self.md5 = self.file_dict[self.image_set]
  40. self.view = verify_str_arg(view.lower(), "view", ["people", "pairs"])
  41. self.split = verify_str_arg(split.lower(), "split", ["10fold", "train", "test"])
  42. self.labels_file = f"{self.view}{self.annot_file[self.split]}.txt"
  43. self.data: list[Any] = []
  44. if download:
  45. raise ValueError(
  46. "LFW dataset is no longer available for download."
  47. "Please download the dataset manually and place it in the specified directory"
  48. )
  49. self.download()
  50. if not self._check_integrity():
  51. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  52. self.images_dir = os.path.join(self.root, images_dir)
  53. self._loader = loader
  54. def _check_integrity(self) -> bool:
  55. st1 = check_integrity(os.path.join(self.root, self.filename), self.md5)
  56. st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file])
  57. if not st1 or not st2:
  58. return False
  59. if self.view == "people":
  60. return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names])
  61. return True
  62. def download(self) -> None:
  63. if self._check_integrity():
  64. return
  65. url = f"{self.download_url_prefix}{self.filename}"
  66. download_and_extract_archive(url, self.root, filename=self.filename, md5=self.md5)
  67. download_url(f"{self.download_url_prefix}{self.labels_file}", self.root)
  68. if self.view == "people":
  69. download_url(f"{self.download_url_prefix}{self.names}", self.root)
  70. def _get_path(self, identity: str, no: Union[int, str]) -> str:
  71. return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg")
  72. def extra_repr(self) -> str:
  73. return f"Alignment: {self.image_set}\nSplit: {self.split}"
  74. def __len__(self) -> int:
  75. return len(self.data)
  76. class LFWPeople(_LFW):
  77. """`LFW <http://vis-www.cs.umass.edu/lfw/>`_ Dataset.
  78. .. warning:
  79. The LFW dataset is no longer available for automatic download. Please
  80. download it manually and place it in the specified directory.
  81. Args:
  82. root (str or ``pathlib.Path``): Root directory of dataset where directory
  83. ``lfw-py`` exists or will be saved to if download is set to True.
  84. split (string, optional): The image split to use. Can be one of ``train``, ``test``,
  85. ``10fold`` (default).
  86. image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
  87. ``deepfunneled``. Defaults to ``funneled``.
  88. transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
  89. and returns a transformed version. E.g, ``transforms.RandomCrop``
  90. target_transform (callable, optional): A function/transform that takes in the
  91. target and transforms it.
  92. download (bool, optional): NOT SUPPORTED ANYMORE, leave to False.
  93. loader (callable, optional): A function to load an image given its path.
  94. By default, it uses PIL as its image loader, but users could also pass in
  95. ``torchvision.io.decode_image`` for decoding image data into tensors directly.
  96. """
  97. def __init__(
  98. self,
  99. root: str,
  100. split: str = "10fold",
  101. image_set: str = "funneled",
  102. transform: Optional[Callable] = None,
  103. target_transform: Optional[Callable] = None,
  104. download: bool = False,
  105. loader: Callable[[str], Any] = default_loader,
  106. ) -> None:
  107. super().__init__(root, split, image_set, "people", transform, target_transform, download, loader=loader)
  108. self.class_to_idx = self._get_classes()
  109. self.data, self.targets = self._get_people()
  110. def _get_people(self) -> tuple[list[str], list[int]]:
  111. data, targets = [], []
  112. with open(os.path.join(self.root, self.labels_file)) as f:
  113. lines = f.readlines()
  114. n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0)
  115. for fold in range(n_folds):
  116. n_lines = int(lines[s])
  117. people = [line.strip().split("\t") for line in lines[s + 1 : s + n_lines + 1]]
  118. s += n_lines + 1
  119. for i, (identity, num_imgs) in enumerate(people):
  120. for num in range(1, int(num_imgs) + 1):
  121. img = self._get_path(identity, num)
  122. data.append(img)
  123. targets.append(self.class_to_idx[identity])
  124. return data, targets
  125. def _get_classes(self) -> dict[str, int]:
  126. with open(os.path.join(self.root, self.names)) as f:
  127. lines = f.readlines()
  128. names = [line.strip().split()[0] for line in lines]
  129. class_to_idx = {name: i for i, name in enumerate(names)}
  130. return class_to_idx
  131. def __getitem__(self, index: int) -> tuple[Any, Any]:
  132. """
  133. Args:
  134. index (int): Index
  135. Returns:
  136. tuple: Tuple (image, target) where target is the identity of the person.
  137. """
  138. img = self._loader(self.data[index])
  139. target = self.targets[index]
  140. if self.transform is not None:
  141. img = self.transform(img)
  142. if self.target_transform is not None:
  143. target = self.target_transform(target)
  144. return img, target
  145. def extra_repr(self) -> str:
  146. return super().extra_repr() + f"\nClasses (identities): {len(self.class_to_idx)}"
  147. class LFWPairs(_LFW):
  148. """`LFW <http://vis-www.cs.umass.edu/lfw/>`_ Dataset.
  149. .. warning:
  150. The LFW dataset is no longer available for automatic download. Please
  151. download it manually and place it in the specified directory.
  152. Args:
  153. root (str or ``pathlib.Path``): Root directory of dataset where directory
  154. ``lfw-py`` exists or will be saved to if download is set to True.
  155. split (string, optional): The image split to use. Can be one of ``train``, ``test``,
  156. ``10fold``. Defaults to ``10fold``.
  157. image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
  158. ``deepfunneled``. Defaults to ``funneled``.
  159. transform (callable, optional): A function/transform that takes in a PIL image
  160. and returns a transformed version. E.g, ``transforms.RandomRotation``
  161. target_transform (callable, optional): A function/transform that takes in the
  162. target and transforms it.
  163. download (bool, optional): NOT SUPPORTED ANYMORE, leave to False.
  164. loader (callable, optional): A function to load an image given its path.
  165. By default, it uses PIL as its image loader, but users could also pass in
  166. ``torchvision.io.decode_image`` for decoding image data into tensors directly.
  167. """
  168. def __init__(
  169. self,
  170. root: str,
  171. split: str = "10fold",
  172. image_set: str = "funneled",
  173. transform: Optional[Callable] = None,
  174. target_transform: Optional[Callable] = None,
  175. download: bool = False,
  176. loader: Callable[[str], Any] = default_loader,
  177. ) -> None:
  178. super().__init__(root, split, image_set, "pairs", transform, target_transform, download, loader=loader)
  179. self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
  180. def _get_pairs(self, images_dir: str) -> tuple[list[tuple[str, str]], list[tuple[str, str]], list[int]]:
  181. pair_names, data, targets = [], [], []
  182. with open(os.path.join(self.root, self.labels_file)) as f:
  183. lines = f.readlines()
  184. if self.split == "10fold":
  185. n_folds, n_pairs = lines[0].split("\t")
  186. n_folds, n_pairs = int(n_folds), int(n_pairs)
  187. else:
  188. n_folds, n_pairs = 1, int(lines[0])
  189. s = 1
  190. for fold in range(n_folds):
  191. matched_pairs = [line.strip().split("\t") for line in lines[s : s + n_pairs]]
  192. unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs : s + (2 * n_pairs)]]
  193. s += 2 * n_pairs
  194. for pair in matched_pairs:
  195. img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[0], pair[2]), 1
  196. pair_names.append((pair[0], pair[0]))
  197. data.append((img1, img2))
  198. targets.append(same)
  199. for pair in unmatched_pairs:
  200. img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[2], pair[3]), 0
  201. pair_names.append((pair[0], pair[2]))
  202. data.append((img1, img2))
  203. targets.append(same)
  204. return pair_names, data, targets
  205. def __getitem__(self, index: int) -> tuple[Any, Any, int]:
  206. """
  207. Args:
  208. index (int): Index
  209. Returns:
  210. tuple: (image1, image2, target) where target is `0` for different indentities and `1` for same identities.
  211. """
  212. img1, img2 = self.data[index]
  213. img1, img2 = self._loader(img1), self._loader(img2)
  214. target = self.targets[index]
  215. if self.transform is not None:
  216. img1, img2 = self.transform(img1), self.transform(img2)
  217. if self.target_transform is not None:
  218. target = self.target_transform(target)
  219. return img1, img2, target