sbu.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import os
  2. from pathlib import Path
  3. from typing import Any, Callable, Optional, Union
  4. from .folder import default_loader
  5. from .utils import check_integrity, download_and_extract_archive, download_url
  6. from .vision import VisionDataset
  7. class SBU(VisionDataset):
  8. """`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
  9. Args:
  10. root (str or ``pathlib.Path``): Root directory of dataset where tarball
  11. ``SBUCaptionedPhotoDataset.tar.gz`` exists.
  12. transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
  13. and returns a transformed version. E.g, ``transforms.RandomCrop``
  14. target_transform (callable, optional): A function/transform that takes in the
  15. target and transforms it.
  16. download (bool, optional): If True, downloads the dataset from the internet and
  17. puts it in root directory. If dataset is already downloaded, it is not
  18. downloaded again.
  19. loader (callable, optional): A function to load an image given its path.
  20. By default, it uses PIL as its image loader, but users could also pass in
  21. ``torchvision.io.decode_image`` for decoding image data into tensors directly.
  22. """
  23. url = "https://www.cs.rice.edu/~vo9/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
  24. filename = "SBUCaptionedPhotoDataset.tar.gz"
  25. md5_checksum = "9aec147b3488753cf758b4d493422285"
  26. def __init__(
  27. self,
  28. root: Union[str, Path],
  29. transform: Optional[Callable] = None,
  30. target_transform: Optional[Callable] = None,
  31. download: bool = True,
  32. loader: Callable[[str], Any] = default_loader,
  33. ) -> None:
  34. super().__init__(root, transform=transform, target_transform=target_transform)
  35. self.loader = loader
  36. if download:
  37. self.download()
  38. if not self._check_integrity():
  39. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  40. # Read the caption for each photo
  41. self.photos = []
  42. self.captions = []
  43. file1 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")
  44. file2 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_captions.txt")
  45. for line1, line2 in zip(open(file1), open(file2)):
  46. url = line1.rstrip()
  47. photo = os.path.basename(url)
  48. filename = os.path.join(self.root, "dataset", photo)
  49. if os.path.exists(filename):
  50. caption = line2.rstrip()
  51. self.photos.append(photo)
  52. self.captions.append(caption)
  53. def __getitem__(self, index: int) -> tuple[Any, Any]:
  54. """
  55. Args:
  56. index (int): Index
  57. Returns:
  58. tuple: (image, target) where target is a caption for the photo.
  59. """
  60. filename = os.path.join(self.root, "dataset", self.photos[index])
  61. img = self.loader(filename)
  62. if self.transform is not None:
  63. img = self.transform(img)
  64. target = self.captions[index]
  65. if self.target_transform is not None:
  66. target = self.target_transform(target)
  67. return img, target
  68. def __len__(self) -> int:
  69. """The number of photos in the dataset."""
  70. return len(self.photos)
  71. def _check_integrity(self) -> bool:
  72. """Check the md5 checksum of the downloaded tarball."""
  73. root = self.root
  74. fpath = os.path.join(root, self.filename)
  75. if not check_integrity(fpath, self.md5_checksum):
  76. return False
  77. return True
  78. def download(self) -> None:
  79. """Download and extract the tarball, and download each individual photo."""
  80. if self._check_integrity():
  81. return
  82. download_and_extract_archive(self.url, self.root, self.root, self.filename, self.md5_checksum)
  83. # Download individual photos
  84. with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh:
  85. for line in fh:
  86. url = line.rstrip()
  87. try:
  88. download_url(url, os.path.join(self.root, "dataset"))
  89. except OSError:
  90. # The images point to public images on Flickr.
  91. # Note: Images might be removed by users at anytime.
  92. pass