vision.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import os
  2. from pathlib import Path
  3. from typing import Any, Callable, Optional, Union
  4. import torch.utils.data as data
  5. from ..utils import _log_api_usage_once
  6. class VisionDataset(data.Dataset):
  7. """
  8. Base Class For making datasets which are compatible with torchvision.
  9. It is necessary to override the ``__getitem__`` and ``__len__`` method.
  10. Args:
  11. root (string, optional): Root directory of dataset. Only used for `__repr__`.
  12. transforms (callable, optional): A function/transforms that takes in
  13. an image and a label and returns the transformed versions of both.
  14. transform (callable, optional): A function/transform that takes in a PIL image
  15. and returns a transformed version. E.g, ``transforms.RandomCrop``
  16. target_transform (callable, optional): A function/transform that takes in the
  17. target and transforms it.
  18. .. note::
  19. :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
  20. """
  21. _repr_indent = 4
  22. def __init__(
  23. self,
  24. root: Union[str, Path] = None, # type: ignore[assignment]
  25. transforms: Optional[Callable] = None,
  26. transform: Optional[Callable] = None,
  27. target_transform: Optional[Callable] = None,
  28. ) -> None:
  29. _log_api_usage_once(self)
  30. if isinstance(root, str):
  31. root = os.path.expanduser(root)
  32. self.root = root
  33. has_transforms = transforms is not None
  34. has_separate_transform = transform is not None or target_transform is not None
  35. if has_transforms and has_separate_transform:
  36. raise ValueError("Only transforms or transform/target_transform can be passed as argument")
  37. # for backwards-compatibility
  38. self.transform = transform
  39. self.target_transform = target_transform
  40. if has_separate_transform:
  41. transforms = StandardTransform(transform, target_transform)
  42. self.transforms = transforms
  43. def __getitem__(self, index: int) -> Any:
  44. """
  45. Args:
  46. index (int): Index
  47. Returns:
  48. (Any): Sample and meta data, optionally transformed by the respective transforms.
  49. """
  50. raise NotImplementedError
  51. def __len__(self) -> int:
  52. raise NotImplementedError
  53. def __repr__(self) -> str:
  54. head = "Dataset " + self.__class__.__name__
  55. body = [f"Number of datapoints: {self.__len__()}"]
  56. if self.root is not None:
  57. body.append(f"Root location: {self.root}")
  58. body += self.extra_repr().splitlines()
  59. if hasattr(self, "transforms") and self.transforms is not None:
  60. body += [repr(self.transforms)]
  61. lines = [head] + [" " * self._repr_indent + line for line in body]
  62. return "\n".join(lines)
  63. def _format_transform_repr(self, transform: Callable, head: str) -> list[str]:
  64. lines = transform.__repr__().splitlines()
  65. return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
  66. def extra_repr(self) -> str:
  67. return ""
  68. class StandardTransform:
  69. def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
  70. self.transform = transform
  71. self.target_transform = target_transform
  72. def __call__(self, input: Any, target: Any) -> tuple[Any, Any]:
  73. if self.transform is not None:
  74. input = self.transform(input)
  75. if self.target_transform is not None:
  76. target = self.target_transform(target)
  77. return input, target
  78. def _format_transform_repr(self, transform: Callable, head: str) -> list[str]:
  79. lines = transform.__repr__().splitlines()
  80. return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
  81. def __repr__(self) -> str:
  82. body = [self.__class__.__name__]
  83. if self.transform is not None:
  84. body += self._format_transform_repr(self.transform, "Transform: ")
  85. if self.target_transform is not None:
  86. body += self._format_transform_repr(self.target_transform, "Target transform: ")
  87. return "\n".join(body)