| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Sequence
- from typing import Any, ClassVar, Optional, Union
- from torch import Tensor
- from typing_extensions import Literal
- from torchmetrics.functional.image.lpips import _LPIPS, _lpips_compute, _lpips_update, _NoTrainLpips
- from torchmetrics.metric import Metric
- from torchmetrics.utilities import dim_zero_cat
- from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout
- from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
- from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
- if not _MATPLOTLIB_AVAILABLE:
- __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity.plot"]
- if _TORCHVISION_AVAILABLE:
- def _download_lpips() -> None:
- _LPIPS(pretrained=True, net="vgg")
- if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_lpips):
- __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LearnedPerceptualImagePatchSimilarity.plot"]
- else:
- __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LearnedPerceptualImagePatchSimilarity.plot"]
- class LearnedPerceptualImagePatchSimilarity(Metric):
- """The Learned Perceptual Image Patch Similarity (`LPIPS_`) calculates perceptual similarity between two images.
- LPIPS essentially computes the similarity between the activations of two image patches for some pre-defined network.
- This measure has been shown to match human perception well. A low LPIPS score means that image patches are
- perceptual similar.
- Both input image patches are expected to have shape ``(N, 3, H, W)``. The minimum size of `H, W` depends on the
- chosen backbone (see `net_type` arg).
- .. hint::
- Using this metrics requires you to have ``torchvision`` package installed. Either install as
- ``pip install torchmetrics[image]`` or ``pip install torchvision``.
- As input to ``forward`` and ``update`` the metric accepts the following input
- - ``img1`` (:class:`~torch.Tensor`): tensor with images of shape ``(N, 3, H, W)``
- - ``img2`` (:class:`~torch.Tensor`): tensor with images of shape ``(N, 3, H, W)``
- As output of `forward` and `compute` the metric returns the following output
- - ``lpips`` (:class:`~torch.Tensor`): returns float scalar tensor with average LPIPS value over samples
- Args:
- net_type: str indicating backbone network type to use. Choose between `'alex'`, `'vgg'` or `'squeeze'`
- reduction: str indicating how to reduce over the batch dimension. Choose between `'sum'`, `'mean'`,`'none'`
- or `None`.
- normalize: by default this is ``False`` meaning that the input is expected to be in the [-1,1] range. If set
- to ``True`` will instead expect input to be in the ``[0,1]`` range.
- kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
- Raises:
- ModuleNotFoundError:
- If ``torchvision`` package is not installed
- ValueError:
- If ``net_type`` is not one of ``"vgg"``, ``"alex"`` or ``"squeeze"``
- ValueError:
- If ``reduction`` is not one of ``"mean"`` or ``"sum"``
- Example:
- >>> from torch import rand
- >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
- >>> lpips = LearnedPerceptualImagePatchSimilarity(net_type='squeeze')
- >>> # LPIPS needs the images to be in the [-1, 1] range.
- >>> img1 = (rand(10, 3, 100, 100) * 2) - 1
- >>> img2 = (rand(10, 3, 100, 100) * 2) - 1
- >>> lpips(img1, img2)
- tensor(0.1024)
- >>> from torch import rand, Generator
- >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
- >>> gen = Generator().manual_seed(42)
- >>> lpips = LearnedPerceptualImagePatchSimilarity(net_type='squeeze', reduction='none')
- >>> # LPIPS needs the images to be in the [-1, 1] range.
- >>> img1 = (rand(2, 3, 100, 100, generator=gen) * 2) - 1
- >>> img2 = (rand(2, 3, 100, 100, generator=gen) * 2) - 1
- >>> lpips(img1, img2)
- tensor([0.1024, 0.0938])
- """
- is_differentiable: bool = True
- higher_is_better: bool = False
- full_state_update: bool = False
- plot_lower_bound: float = 0.0
- plot_upper_bound: float = 1.0
- all_scores: list[Tensor]
- feature_network: str = "net"
- # due to the use of named tuple in the backbone the net variable cannot be scripted
- __jit_ignored_attributes__: ClassVar[list[str]] = ["net"]
- def __init__(
- self,
- net_type: Literal["vgg", "alex", "squeeze"] = "alex",
- reduction: Optional[Literal["sum", "mean", "none"]] = "mean",
- normalize: bool = False,
- **kwargs: Any,
- ) -> None:
- super().__init__(**kwargs)
- if not _TORCHVISION_AVAILABLE:
- raise ModuleNotFoundError(
- "LPIPS metric requires that torchvision is installed."
- " Either install as `pip install torchmetrics[image]` or `pip install torchvision`."
- )
- valid_net_type = ("vgg", "alex", "squeeze")
- if net_type not in valid_net_type:
- raise ValueError(f"Argument `net_type` must be one of {valid_net_type}, but got {net_type}.")
- self.net = _NoTrainLpips(net=net_type)
- valid_reduction = ("mean", "sum", "none", None)
- if reduction not in valid_reduction:
- raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}")
- self.reduction = reduction
- if not isinstance(normalize, bool):
- raise ValueError(f"Argument `normalize` should be an bool but got {normalize}")
- self.normalize = normalize
- self.add_state("all_scores", default=[], dist_reduce_fx=None)
- def update(self, img1: Tensor, img2: Tensor) -> None:
- """Update internal states with lpips score."""
- loss = _lpips_update(img1, img2, net=self.net, normalize=self.normalize)
- self.all_scores.append(loss)
- def compute(self) -> Tensor:
- """Compute final perceptual similarity metric."""
- scores = dim_zero_cat(self.all_scores)
- return _lpips_compute(scores, reduction=self.reduction)
- def plot(
- self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
- ) -> _PLOT_OUT_TYPE:
- """Plot a single or multiple values from the metric.
- Args:
- val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
- If no value is provided, will automatically call `metric.compute` and plot that result.
- ax: An matplotlib axis object. If provided will add plot to that axis
- Returns:
- Figure and Axes object
- Raises:
- ModuleNotFoundError:
- If `matplotlib` is not installed
- .. plot::
- :scale: 75
- >>> # Example plotting a single value
- >>> import torch
- >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
- >>> metric = LearnedPerceptualImagePatchSimilarity(net_type='squeeze')
- >>> metric.update(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100))
- >>> fig_, ax_ = metric.plot()
- .. plot::
- :scale: 75
- >>> # Example plotting multiple values
- >>> import torch
- >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
- >>> metric = LearnedPerceptualImagePatchSimilarity(net_type='squeeze')
- >>> values = [ ]
- >>> for _ in range(3):
- ... values.append(metric(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100)))
- >>> fig_, ax_ = metric.plot(values)
- """
- return self._plot(val, ax)
|