perceptual_path_length.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Literal, Optional, Union
  15. from torch import Tensor, nn
  16. from torchmetrics.functional.image.lpips import _LPIPS
  17. from torchmetrics.functional.image.perceptual_path_length import (
  18. GeneratorType,
  19. _perceptual_path_length_validate_arguments,
  20. _validate_generator_model,
  21. perceptual_path_length,
  22. )
  23. from torchmetrics.metric import Metric
  24. from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE
  25. if not _TORCHVISION_AVAILABLE:
  26. __doctest_skip__ = ["PerceptualPathLength"]
  27. class PerceptualPathLength(Metric):
  28. r"""Computes the perceptual path length (`PPL`_) of a generator model.
  29. The perceptual path length can be used to measure the consistency of interpolation in latent-space models. It is
  30. defined as
  31. .. math::
  32. PPL = \mathbb{E}\left[\frac{1}{\epsilon^2} D(G(I(z_1, z_2, t)), G(I(z_1, z_2, t+\epsilon)))\right]
  33. where :math:`G` is the generator, :math:`I` is the interpolation function, :math:`D` is a similarity metric,
  34. :math:`z_1` and :math:`z_2` are two sets of latent points, and :math:`t` is a parameter between 0 and 1. The metric
  35. thus works by interpolating between two sets of latent points, and measuring the similarity between the generated
  36. images. The expectation is approximated by sampling :math:`z_1` and :math:`z_2` from the generator, and averaging
  37. the calculated distanced. The similarity metric :math:`D` is by default the `LPIPS`_ metric, but can be changed by
  38. setting the `sim_net` argument.
  39. The provided generator model must have a `sample` method with signature `sample(num_samples: int) -> Tensor` where
  40. the returned tensor has shape `(num_samples, z_size)`. If the generator is conditional, it must also have a
  41. `num_classes` attribute. The `forward` method of the generator must have signature `forward(z: Tensor) -> Tensor`
  42. if `conditional=False`, and `forward(z: Tensor, labels: Tensor) -> Tensor` if `conditional=True`. The returned
  43. tensor should have shape `(num_samples, C, H, W)` and be scaled to the range [0, 255].
  44. .. hint::
  45. Using this metric with the default feature extractor requires that ``torchvision`` is installed.
  46. Either install as ``pip install torchmetrics[image]`` or ``pip install torchvision``
  47. As input to ``forward`` and ``update`` the metric accepts the following input
  48. - ``generator`` (:class:`~torch.nn.Module`): Generator model, with specific requirements. See above.
  49. As output of `forward` and `compute` the metric returns the following output
  50. - ``ppl_mean`` (:class:`~torch.Tensor`): float scalar tensor with mean PPL value over distances
  51. - ``ppl_std`` (:class:`~torch.Tensor`): float scalar tensor with std PPL value over distances
  52. - ``ppl_raw`` (:class:`~torch.Tensor`): float scalar tensor with raw PPL distances
  53. Args:
  54. num_samples: Number of samples to use for the PPL computation.
  55. conditional: Whether the generator is conditional or not (i.e. whether it takes labels as input).
  56. batch_size: Batch size to use for the PPL computation.
  57. interpolation_method: Interpolation method to use. Choose from 'lerp', 'slerp_any', 'slerp_unit'.
  58. epsilon: Spacing between the points on the path between latent points.
  59. resize: Resize images to this size before computing the similarity between generated images.
  60. lower_discard: Lower quantile to discard from the distances, before computing the mean and standard deviation.
  61. upper_discard: Upper quantile to discard from the distances, before computing the mean and standard deviation.
  62. sim_net: Similarity network to use. Can be a `nn.Module` or one of 'alex', 'vgg', 'squeeze', where the three
  63. latter options correspond to the pretrained networks from the `LPIPS`_ paper.
  64. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  65. Raises:
  66. ModuleNotFoundError:
  67. If ``torch-fidelity`` is not installed.
  68. ValueError:
  69. If ``num_samples`` is not a positive integer.
  70. ValueError:
  71. If `conditional` is not a boolean.
  72. ValueError:
  73. If ``batch_size`` is not a positive integer.
  74. ValueError:
  75. If ``interpolation_method`` is not one of 'lerp', 'slerp_any', 'slerp_unit'.
  76. ValueError:
  77. If ``epsilon`` is not a positive float.
  78. ValueError:
  79. If ``resize`` is not a positive integer.
  80. ValueError:
  81. If ``lower_discard`` is not a float between 0 and 1 or None.
  82. ValueError:
  83. If ``upper_discard`` is not a float between 0 and 1 or None.
  84. Example::
  85. >>> import torch
  86. >>> class DummyGenerator(torch.nn.Module):
  87. ... def __init__(self, z_size) -> None:
  88. ... super().__init__()
  89. ... self.z_size = z_size
  90. ... self.model = torch.nn.Sequential(torch.nn.Linear(z_size, 3*128*128), torch.nn.Sigmoid())
  91. ... def forward(self, z):
  92. ... return 255 * (self.model(z).reshape(-1, 3, 128, 128) + 1)
  93. ... def sample(self, num_samples):
  94. ... return torch.randn(num_samples, self.z_size)
  95. >>> generator = DummyGenerator(2)
  96. >>> ppl = PerceptualPathLength(num_samples=10)
  97. >>> ppl(generator)
  98. (tensor(...), tensor(...), tensor([...]))
  99. """
  100. is_differentiable: bool = False
  101. higher_is_better: Optional[bool] = True
  102. full_state_update: bool = True
  103. net: nn.Module
  104. feature_network: str = "net"
  105. def __init__(
  106. self,
  107. num_samples: int = 10_000,
  108. conditional: bool = False,
  109. batch_size: int = 128,
  110. interpolation_method: Literal["lerp", "slerp_any", "slerp_unit"] = "lerp",
  111. epsilon: float = 1e-4,
  112. resize: Optional[int] = 64,
  113. lower_discard: Optional[float] = 0.01,
  114. upper_discard: Optional[float] = 0.99,
  115. sim_net: Union[nn.Module, Literal["alex", "vgg", "squeeze"]] = "vgg",
  116. **kwargs: Any,
  117. ) -> None:
  118. super().__init__(**kwargs)
  119. if not _TORCHVISION_AVAILABLE:
  120. raise ModuleNotFoundError(
  121. "Metric `PerceptualPathLength` requires torchvision which is not installed."
  122. "Install with `pip install torchvision` or `pip install torchmetrics[image]`"
  123. )
  124. _perceptual_path_length_validate_arguments(
  125. num_samples, conditional, batch_size, interpolation_method, epsilon, resize, lower_discard, upper_discard
  126. )
  127. self.num_samples = num_samples
  128. self.conditional = conditional
  129. self.batch_size = batch_size
  130. self.interpolation_method = interpolation_method
  131. self.epsilon = epsilon
  132. self.resize = resize
  133. self.lower_discard = lower_discard
  134. self.upper_discard = upper_discard
  135. if isinstance(sim_net, nn.Module):
  136. self.net = sim_net
  137. elif sim_net in ["alex", "vgg", "squeeze"]:
  138. self.net = _LPIPS(pretrained=True, net=sim_net, resize=resize)
  139. else:
  140. raise ValueError(f"sim_net must be a nn.Module or one of 'alex', 'vgg', 'squeeze', got {sim_net}")
  141. def update(self, generator: GeneratorType) -> None:
  142. """Update the generator model."""
  143. _validate_generator_model(generator, self.conditional)
  144. self.generator = generator
  145. def compute(self) -> tuple[Tensor, Tensor, Tensor]:
  146. """Compute the perceptual path length."""
  147. return perceptual_path_length(
  148. generator=self.generator,
  149. num_samples=self.num_samples,
  150. conditional=self.conditional,
  151. interpolation_method=self.interpolation_method,
  152. epsilon=self.epsilon,
  153. resize=self.resize,
  154. lower_discard=self.lower_discard,
  155. upper_discard=self.upper_discard,
  156. sim_net=self.net,
  157. device=self.device,
  158. )