coco.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import os.path
  2. from pathlib import Path
  3. from typing import Any, Callable, Optional, Union
  4. from PIL import Image
  5. from .vision import VisionDataset
  6. class CocoDetection(VisionDataset):
  7. """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
  8. It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed,
  9. which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``.
  10. Args:
  11. root (str or ``pathlib.Path``): Root directory where images are downloaded to.
  12. annFile (string): Path to json annotation file.
  13. transform (callable, optional): A function/transform that takes in a PIL image
  14. and returns a transformed version. E.g, ``transforms.PILToTensor``
  15. target_transform (callable, optional): A function/transform that takes in the
  16. target and transforms it.
  17. transforms (callable, optional): A function/transform that takes input sample and its target as entry
  18. and returns a transformed version.
  19. """
  20. def __init__(
  21. self,
  22. root: Union[str, Path],
  23. annFile: str,
  24. transform: Optional[Callable] = None,
  25. target_transform: Optional[Callable] = None,
  26. transforms: Optional[Callable] = None,
  27. ) -> None:
  28. super().__init__(root, transforms, transform, target_transform)
  29. from pycocotools.coco import COCO
  30. self.coco = COCO(annFile)
  31. self.ids = list(sorted(self.coco.imgs.keys()))
  32. def _load_image(self, id: int) -> Image.Image:
  33. path = self.coco.loadImgs(id)[0]["file_name"]
  34. return Image.open(os.path.join(self.root, path)).convert("RGB")
  35. def _load_target(self, id: int) -> list[Any]:
  36. return self.coco.loadAnns(self.coco.getAnnIds(id))
  37. def __getitem__(self, index: int) -> tuple[Any, Any]:
  38. if not isinstance(index, int):
  39. raise ValueError(f"Index must be of type integer, got {type(index)} instead.")
  40. id = self.ids[index]
  41. image = self._load_image(id)
  42. target = self._load_target(id)
  43. if self.transforms is not None:
  44. image, target = self.transforms(image, target)
  45. return image, target
  46. def __len__(self) -> int:
  47. return len(self.ids)
  48. class CocoCaptions(CocoDetection):
  49. """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
  50. It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed,
  51. which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``.
  52. Args:
  53. root (str or ``pathlib.Path``): Root directory where images are downloaded to.
  54. annFile (string): Path to json annotation file.
  55. transform (callable, optional): A function/transform that takes in a PIL image
  56. and returns a transformed version. E.g, ``transforms.PILToTensor``
  57. target_transform (callable, optional): A function/transform that takes in the
  58. target and transforms it.
  59. transforms (callable, optional): A function/transform that takes input sample and its target as entry
  60. and returns a transformed version.
  61. Example:
  62. .. code:: python
  63. import torchvision.datasets as dset
  64. import torchvision.transforms as transforms
  65. cap = dset.CocoCaptions(root = 'dir where images are',
  66. annFile = 'json annotation file',
  67. transform=transforms.PILToTensor())
  68. print('Number of samples: ', len(cap))
  69. img, target = cap[3] # load 4th sample
  70. print("Image Size: ", img.size())
  71. print(target)
  72. Output: ::
  73. Number of samples: 82783
  74. Image Size: (3L, 427L, 640L)
  75. [u'A plane emitting smoke stream flying over a mountain.',
  76. u'A plane darts across a bright blue sky behind a mountain covered in snow',
  77. u'A plane leaves a contrail above the snowy mountain top.',
  78. u'A mountain that has a plane flying overheard in the distance.',
  79. u'A mountain view with a plume of smoke in the background']
  80. """
  81. def _load_target(self, id: int) -> list[str]:
  82. return [ann["caption"] for ann in super()._load_target(id)]