lpip.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, ClassVar, Optional, Union
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.image.lpips import _LPIPS, _lpips_compute, _lpips_update, _NoTrainLpips
  19. from torchmetrics.metric import Metric
  20. from torchmetrics.utilities import dim_zero_cat
  21. from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout
  22. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
  23. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  24. if not _MATPLOTLIB_AVAILABLE:
  25. __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity.plot"]
  26. if _TORCHVISION_AVAILABLE:
  27. def _download_lpips() -> None:
  28. _LPIPS(pretrained=True, net="vgg")
  29. if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_lpips):
  30. __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LearnedPerceptualImagePatchSimilarity.plot"]
  31. else:
  32. __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LearnedPerceptualImagePatchSimilarity.plot"]
  33. class LearnedPerceptualImagePatchSimilarity(Metric):
  34. """The Learned Perceptual Image Patch Similarity (`LPIPS_`) calculates perceptual similarity between two images.
  35. LPIPS essentially computes the similarity between the activations of two image patches for some pre-defined network.
  36. This measure has been shown to match human perception well. A low LPIPS score means that image patches are
  37. perceptual similar.
  38. Both input image patches are expected to have shape ``(N, 3, H, W)``. The minimum size of `H, W` depends on the
  39. chosen backbone (see `net_type` arg).
  40. .. hint::
  41. Using this metrics requires you to have ``torchvision`` package installed. Either install as
  42. ``pip install torchmetrics[image]`` or ``pip install torchvision``.
  43. As input to ``forward`` and ``update`` the metric accepts the following input
  44. - ``img1`` (:class:`~torch.Tensor`): tensor with images of shape ``(N, 3, H, W)``
  45. - ``img2`` (:class:`~torch.Tensor`): tensor with images of shape ``(N, 3, H, W)``
  46. As output of `forward` and `compute` the metric returns the following output
  47. - ``lpips`` (:class:`~torch.Tensor`): returns float scalar tensor with average LPIPS value over samples
  48. Args:
  49. net_type: str indicating backbone network type to use. Choose between `'alex'`, `'vgg'` or `'squeeze'`
  50. reduction: str indicating how to reduce over the batch dimension. Choose between `'sum'`, `'mean'`,`'none'`
  51. or `None`.
  52. normalize: by default this is ``False`` meaning that the input is expected to be in the [-1,1] range. If set
  53. to ``True`` will instead expect input to be in the ``[0,1]`` range.
  54. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  55. Raises:
  56. ModuleNotFoundError:
  57. If ``torchvision`` package is not installed
  58. ValueError:
  59. If ``net_type`` is not one of ``"vgg"``, ``"alex"`` or ``"squeeze"``
  60. ValueError:
  61. If ``reduction`` is not one of ``"mean"`` or ``"sum"``
  62. Example:
  63. >>> from torch import rand
  64. >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
  65. >>> lpips = LearnedPerceptualImagePatchSimilarity(net_type='squeeze')
  66. >>> # LPIPS needs the images to be in the [-1, 1] range.
  67. >>> img1 = (rand(10, 3, 100, 100) * 2) - 1
  68. >>> img2 = (rand(10, 3, 100, 100) * 2) - 1
  69. >>> lpips(img1, img2)
  70. tensor(0.1024)
  71. >>> from torch import rand, Generator
  72. >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
  73. >>> gen = Generator().manual_seed(42)
  74. >>> lpips = LearnedPerceptualImagePatchSimilarity(net_type='squeeze', reduction='none')
  75. >>> # LPIPS needs the images to be in the [-1, 1] range.
  76. >>> img1 = (rand(2, 3, 100, 100, generator=gen) * 2) - 1
  77. >>> img2 = (rand(2, 3, 100, 100, generator=gen) * 2) - 1
  78. >>> lpips(img1, img2)
  79. tensor([0.1024, 0.0938])
  80. """
  81. is_differentiable: bool = True
  82. higher_is_better: bool = False
  83. full_state_update: bool = False
  84. plot_lower_bound: float = 0.0
  85. plot_upper_bound: float = 1.0
  86. all_scores: list[Tensor]
  87. feature_network: str = "net"
  88. # due to the use of named tuple in the backbone the net variable cannot be scripted
  89. __jit_ignored_attributes__: ClassVar[list[str]] = ["net"]
  90. def __init__(
  91. self,
  92. net_type: Literal["vgg", "alex", "squeeze"] = "alex",
  93. reduction: Optional[Literal["sum", "mean", "none"]] = "mean",
  94. normalize: bool = False,
  95. **kwargs: Any,
  96. ) -> None:
  97. super().__init__(**kwargs)
  98. if not _TORCHVISION_AVAILABLE:
  99. raise ModuleNotFoundError(
  100. "LPIPS metric requires that torchvision is installed."
  101. " Either install as `pip install torchmetrics[image]` or `pip install torchvision`."
  102. )
  103. valid_net_type = ("vgg", "alex", "squeeze")
  104. if net_type not in valid_net_type:
  105. raise ValueError(f"Argument `net_type` must be one of {valid_net_type}, but got {net_type}.")
  106. self.net = _NoTrainLpips(net=net_type)
  107. valid_reduction = ("mean", "sum", "none", None)
  108. if reduction not in valid_reduction:
  109. raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}")
  110. self.reduction = reduction
  111. if not isinstance(normalize, bool):
  112. raise ValueError(f"Argument `normalize` should be an bool but got {normalize}")
  113. self.normalize = normalize
  114. self.add_state("all_scores", default=[], dist_reduce_fx=None)
  115. def update(self, img1: Tensor, img2: Tensor) -> None:
  116. """Update internal states with lpips score."""
  117. loss = _lpips_update(img1, img2, net=self.net, normalize=self.normalize)
  118. self.all_scores.append(loss)
  119. def compute(self) -> Tensor:
  120. """Compute final perceptual similarity metric."""
  121. scores = dim_zero_cat(self.all_scores)
  122. return _lpips_compute(scores, reduction=self.reduction)
  123. def plot(
  124. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  125. ) -> _PLOT_OUT_TYPE:
  126. """Plot a single or multiple values from the metric.
  127. Args:
  128. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  129. If no value is provided, will automatically call `metric.compute` and plot that result.
  130. ax: An matplotlib axis object. If provided will add plot to that axis
  131. Returns:
  132. Figure and Axes object
  133. Raises:
  134. ModuleNotFoundError:
  135. If `matplotlib` is not installed
  136. .. plot::
  137. :scale: 75
  138. >>> # Example plotting a single value
  139. >>> import torch
  140. >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
  141. >>> metric = LearnedPerceptualImagePatchSimilarity(net_type='squeeze')
  142. >>> metric.update(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100))
  143. >>> fig_, ax_ = metric.plot()
  144. .. plot::
  145. :scale: 75
  146. >>> # Example plotting multiple values
  147. >>> import torch
  148. >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
  149. >>> metric = LearnedPerceptualImagePatchSimilarity(net_type='squeeze')
  150. >>> values = [ ]
  151. >>> for _ in range(3):
  152. ... values.append(metric(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100)))
  153. >>> fig_, ax_ = metric.plot(values)
  154. """
  155. return self._plot(val, ax)