| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Any, Literal, Optional, Union
- from torch import Tensor, nn
- from torchmetrics.functional.image.lpips import _LPIPS
- from torchmetrics.functional.image.perceptual_path_length import (
- GeneratorType,
- _perceptual_path_length_validate_arguments,
- _validate_generator_model,
- perceptual_path_length,
- )
- from torchmetrics.metric import Metric
- from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE
- if not _TORCHVISION_AVAILABLE:
- __doctest_skip__ = ["PerceptualPathLength"]
- class PerceptualPathLength(Metric):
- r"""Computes the perceptual path length (`PPL`_) of a generator model.
- The perceptual path length can be used to measure the consistency of interpolation in latent-space models. It is
- defined as
- .. math::
- PPL = \mathbb{E}\left[\frac{1}{\epsilon^2} D(G(I(z_1, z_2, t)), G(I(z_1, z_2, t+\epsilon)))\right]
- where :math:`G` is the generator, :math:`I` is the interpolation function, :math:`D` is a similarity metric,
- :math:`z_1` and :math:`z_2` are two sets of latent points, and :math:`t` is a parameter between 0 and 1. The metric
- thus works by interpolating between two sets of latent points, and measuring the similarity between the generated
- images. The expectation is approximated by sampling :math:`z_1` and :math:`z_2` from the generator, and averaging
- the calculated distanced. The similarity metric :math:`D` is by default the `LPIPS`_ metric, but can be changed by
- setting the `sim_net` argument.
- The provided generator model must have a `sample` method with signature `sample(num_samples: int) -> Tensor` where
- the returned tensor has shape `(num_samples, z_size)`. If the generator is conditional, it must also have a
- `num_classes` attribute. The `forward` method of the generator must have signature `forward(z: Tensor) -> Tensor`
- if `conditional=False`, and `forward(z: Tensor, labels: Tensor) -> Tensor` if `conditional=True`. The returned
- tensor should have shape `(num_samples, C, H, W)` and be scaled to the range [0, 255].
- .. hint::
- Using this metric with the default feature extractor requires that ``torchvision`` is installed.
- Either install as ``pip install torchmetrics[image]`` or ``pip install torchvision``
- As input to ``forward`` and ``update`` the metric accepts the following input
- - ``generator`` (:class:`~torch.nn.Module`): Generator model, with specific requirements. See above.
- As output of `forward` and `compute` the metric returns the following output
- - ``ppl_mean`` (:class:`~torch.Tensor`): float scalar tensor with mean PPL value over distances
- - ``ppl_std`` (:class:`~torch.Tensor`): float scalar tensor with std PPL value over distances
- - ``ppl_raw`` (:class:`~torch.Tensor`): float scalar tensor with raw PPL distances
- Args:
- num_samples: Number of samples to use for the PPL computation.
- conditional: Whether the generator is conditional or not (i.e. whether it takes labels as input).
- batch_size: Batch size to use for the PPL computation.
- interpolation_method: Interpolation method to use. Choose from 'lerp', 'slerp_any', 'slerp_unit'.
- epsilon: Spacing between the points on the path between latent points.
- resize: Resize images to this size before computing the similarity between generated images.
- lower_discard: Lower quantile to discard from the distances, before computing the mean and standard deviation.
- upper_discard: Upper quantile to discard from the distances, before computing the mean and standard deviation.
- sim_net: Similarity network to use. Can be a `nn.Module` or one of 'alex', 'vgg', 'squeeze', where the three
- latter options correspond to the pretrained networks from the `LPIPS`_ paper.
- kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
- Raises:
- ModuleNotFoundError:
- If ``torch-fidelity`` is not installed.
- ValueError:
- If ``num_samples`` is not a positive integer.
- ValueError:
- If `conditional` is not a boolean.
- ValueError:
- If ``batch_size`` is not a positive integer.
- ValueError:
- If ``interpolation_method`` is not one of 'lerp', 'slerp_any', 'slerp_unit'.
- ValueError:
- If ``epsilon`` is not a positive float.
- ValueError:
- If ``resize`` is not a positive integer.
- ValueError:
- If ``lower_discard`` is not a float between 0 and 1 or None.
- ValueError:
- If ``upper_discard`` is not a float between 0 and 1 or None.
- Example::
- >>> import torch
- >>> class DummyGenerator(torch.nn.Module):
- ... def __init__(self, z_size) -> None:
- ... super().__init__()
- ... self.z_size = z_size
- ... self.model = torch.nn.Sequential(torch.nn.Linear(z_size, 3*128*128), torch.nn.Sigmoid())
- ... def forward(self, z):
- ... return 255 * (self.model(z).reshape(-1, 3, 128, 128) + 1)
- ... def sample(self, num_samples):
- ... return torch.randn(num_samples, self.z_size)
- >>> generator = DummyGenerator(2)
- >>> ppl = PerceptualPathLength(num_samples=10)
- >>> ppl(generator)
- (tensor(...), tensor(...), tensor([...]))
- """
- is_differentiable: bool = False
- higher_is_better: Optional[bool] = True
- full_state_update: bool = True
- net: nn.Module
- feature_network: str = "net"
- def __init__(
- self,
- num_samples: int = 10_000,
- conditional: bool = False,
- batch_size: int = 128,
- interpolation_method: Literal["lerp", "slerp_any", "slerp_unit"] = "lerp",
- epsilon: float = 1e-4,
- resize: Optional[int] = 64,
- lower_discard: Optional[float] = 0.01,
- upper_discard: Optional[float] = 0.99,
- sim_net: Union[nn.Module, Literal["alex", "vgg", "squeeze"]] = "vgg",
- **kwargs: Any,
- ) -> None:
- super().__init__(**kwargs)
- if not _TORCHVISION_AVAILABLE:
- raise ModuleNotFoundError(
- "Metric `PerceptualPathLength` requires torchvision which is not installed."
- "Install with `pip install torchvision` or `pip install torchmetrics[image]`"
- )
- _perceptual_path_length_validate_arguments(
- num_samples, conditional, batch_size, interpolation_method, epsilon, resize, lower_discard, upper_discard
- )
- self.num_samples = num_samples
- self.conditional = conditional
- self.batch_size = batch_size
- self.interpolation_method = interpolation_method
- self.epsilon = epsilon
- self.resize = resize
- self.lower_discard = lower_discard
- self.upper_discard = upper_discard
- if isinstance(sim_net, nn.Module):
- self.net = sim_net
- elif sim_net in ["alex", "vgg", "squeeze"]:
- self.net = _LPIPS(pretrained=True, net=sim_net, resize=resize)
- else:
- raise ValueError(f"sim_net must be a nn.Module or one of 'alex', 'vgg', 'squeeze', got {sim_net}")
- def update(self, generator: GeneratorType) -> None:
- """Update the generator model."""
- _validate_generator_model(generator, self.conditional)
- self.generator = generator
- def compute(self) -> tuple[Tensor, Tensor, Tensor]:
- """Compute the perceptual path length."""
- return perceptual_path_length(
- generator=self.generator,
- num_samples=self.num_samples,
- conditional=self.conditional,
- interpolation_method=self.interpolation_method,
- epsilon=self.epsilon,
- resize=self.resize,
- lower_discard=self.lower_discard,
- upper_discard=self.upper_discard,
- sim_net=self.net,
- device=self.device,
- )
|