sun397.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from pathlib import Path
  2. from typing import Any, Callable, Optional, Union
  3. from .folder import default_loader
  4. from .utils import download_and_extract_archive
  5. from .vision import VisionDataset
  6. class SUN397(VisionDataset):
  7. """`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_.
  8. The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
  9. 397 categories with 108'754 images.
  10. Args:
  11. root (str or ``pathlib.Path``): Root directory of the dataset.
  12. transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
  13. and returns a transformed version. E.g, ``transforms.RandomCrop``
  14. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
  15. download (bool, optional): If true, downloads the dataset from the internet and
  16. puts it in root directory. If dataset is already downloaded, it is not
  17. downloaded again.
  18. loader (callable, optional): A function to load an image given its path.
  19. By default, it uses PIL as its image loader, but users could also pass in
  20. ``torchvision.io.decode_image`` for decoding image data into tensors directly.
  21. """
  22. _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
  23. _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"
  24. def __init__(
  25. self,
  26. root: Union[str, Path],
  27. transform: Optional[Callable] = None,
  28. target_transform: Optional[Callable] = None,
  29. download: bool = False,
  30. loader: Callable[[Union[str, Path]], Any] = default_loader,
  31. ) -> None:
  32. super().__init__(root, transform=transform, target_transform=target_transform)
  33. self._data_dir = Path(self.root) / "SUN397"
  34. if download:
  35. self._download()
  36. if not self._check_exists():
  37. raise RuntimeError("Dataset not found. You can use download=True to download it")
  38. with open(self._data_dir / "ClassName.txt") as f:
  39. self.classes = [c[3:].strip() for c in f]
  40. self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
  41. self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
  42. self._labels = [
  43. self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
  44. ]
  45. self.loader = loader
  46. def __len__(self) -> int:
  47. return len(self._image_files)
  48. def __getitem__(self, idx: int) -> tuple[Any, Any]:
  49. image_file, label = self._image_files[idx], self._labels[idx]
  50. image = self.loader(image_file)
  51. if self.transform:
  52. image = self.transform(image)
  53. if self.target_transform:
  54. label = self.target_transform(label)
  55. return image, label
  56. def _check_exists(self) -> bool:
  57. return self._data_dir.is_dir()
  58. def _download(self) -> None:
  59. if self._check_exists():
  60. return
  61. download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)