| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Sequence
- from copy import deepcopy
- from typing import Any, Optional, Union
- import torch
- from torch import Tensor
- from torch.nn import Module
- from torch.nn.functional import adaptive_avg_pool2d
- from torchmetrics.metric import Metric
- from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE
- from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
- if not _MATPLOTLIB_AVAILABLE:
- __doctest_skip__ = ["FrechetInceptionDistance.plot"]
- if _TORCH_FIDELITY_AVAILABLE:
- from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 as _FeatureExtractorInceptionV3
- from torch_fidelity.helpers import vassert
- from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x
- else:
- class _FeatureExtractorInceptionV3(Module): # type: ignore[no-redef]
- pass
- vassert = None
- interpolate_bilinear_2d_like_tensorflow1x = None
- __doctest_skip__ = ["FrechetInceptionDistance", "FrechetInceptionDistance.plot"]
- class NoTrainInceptionV3(_FeatureExtractorInceptionV3):
- """Module that never leaves evaluation mode."""
- INPUT_IMAGE_SIZE: int
- def __init__(
- self,
- name: str,
- features_list: list[str],
- feature_extractor_weights_path: Optional[str] = None,
- antialias: bool = True,
- ) -> None:
- if not _TORCH_FIDELITY_AVAILABLE:
- raise ModuleNotFoundError(
- "NoTrainInceptionV3 module requires that `Torch-fidelity` is installed."
- " Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
- )
- super().__init__(name, features_list, feature_extractor_weights_path)
- self.use_antialias = antialias
- # put into evaluation mode
- self.eval()
- def train(self, mode: bool) -> "NoTrainInceptionV3":
- """Force network to always be in evaluation mode."""
- return super().train(False)
- def _torch_fidelity_forward(self, x: Tensor) -> tuple[Tensor, ...]:
- """Forward method of inception net.
- Copy of the forward method from this file:
- https://github.com/toshas/torch-fidelity/blob/master/torch_fidelity/feature_extractor_inceptionv3.py
- with a single line change regarding the casting of `x` in the beginning.
- Corresponding license file (Apache License, Version 2.0):
- https://github.com/toshas/torch-fidelity/blob/master/LICENSE.md
- """
- vassert(torch.is_tensor(x) and x.dtype == torch.uint8, "Expecting image as torch.Tensor with dtype=torch.uint8")
- features = {}
- remaining_features = self.features_list.copy()
- x = x.to(self._dtype) if hasattr(self, "_dtype") else x.to(torch.float)
- if self.use_antialias:
- x = torch.nn.functional.interpolate(
- x,
- size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
- mode="bilinear",
- align_corners=False,
- antialias=True,
- )
- else:
- x = interpolate_bilinear_2d_like_tensorflow1x(
- x,
- size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
- align_corners=False,
- )
- x = (x - 128) / 128
- x = self.Conv2d_1a_3x3(x)
- x = self.Conv2d_2a_3x3(x)
- x = self.Conv2d_2b_3x3(x)
- x = self.MaxPool_1(x)
- if "64" in remaining_features:
- features["64"] = adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
- remaining_features.remove("64")
- if len(remaining_features) == 0:
- return tuple(features[a] for a in self.features_list)
- x = self.Conv2d_3b_1x1(x)
- x = self.Conv2d_4a_3x3(x)
- x = self.MaxPool_2(x)
- if "192" in remaining_features:
- features["192"] = adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
- remaining_features.remove("192")
- if len(remaining_features) == 0:
- return tuple(features[a] for a in self.features_list)
- x = self.Mixed_5b(x)
- x = self.Mixed_5c(x)
- x = self.Mixed_5d(x)
- x = self.Mixed_6a(x)
- x = self.Mixed_6b(x)
- x = self.Mixed_6c(x)
- x = self.Mixed_6d(x)
- x = self.Mixed_6e(x)
- if "768" in remaining_features:
- features["768"] = adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
- remaining_features.remove("768")
- if len(remaining_features) == 0:
- return tuple(features[a] for a in self.features_list)
- x = self.Mixed_7a(x)
- x = self.Mixed_7b(x)
- x = self.Mixed_7c(x)
- x = self.AvgPool(x)
- x = torch.flatten(x, 1)
- if "2048" in remaining_features:
- features["2048"] = x
- remaining_features.remove("2048")
- if len(remaining_features) == 0:
- return tuple(features[a] for a in self.features_list)
- if "logits_unbiased" in remaining_features:
- x = x.mm(self.fc.weight.T)
- # N x 1008 (num_classes)
- features["logits_unbiased"] = x
- remaining_features.remove("logits_unbiased")
- if len(remaining_features) == 0:
- return tuple(features[a] for a in self.features_list)
- x = x + self.fc.bias.unsqueeze(0)
- else:
- x = self.fc(x)
- features["logits"] = x
- return tuple(features[a] for a in self.features_list)
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass of neural network with reshaping of output."""
- out = self._torch_fidelity_forward(x)
- return out[0].reshape(x.shape[0], -1)
- def _compute_fid(mu1: Tensor, sigma1: Tensor, mu2: Tensor, sigma2: Tensor) -> Tensor:
- r"""Compute adjusted version of `Fid Score`_.
- The Frechet Inception Distance between two multivariate Gaussians X_x ~ N(mu_1, sigm_1)
- and X_y ~ N(mu_2, sigm_2) is d^2 = ||mu_1 - mu_2||^2 + Tr(sigm_1 + sigm_2 - 2*sqrt(sigm_1*sigm_2)).
- Args:
- mu1: mean of activations calculated on predicted (x) samples
- sigma1: covariance matrix over activations calculated on predicted (x) samples
- mu2: mean of activations calculated on target (y) samples
- sigma2: covariance matrix over activations calculated on target (y) samples
- Returns:
- Scalar value of the distance between sets.
- """
- a = (mu1 - mu2).square().sum(dim=-1)
- b = sigma1.trace() + sigma2.trace()
- c = torch.linalg.eigvals(sigma1 @ sigma2).sqrt().real.sum(dim=-1)
- return a + b - 2 * c
- class FrechetInceptionDistance(Metric):
- r"""Calculate Fréchet inception distance (FID_) which is used to assess the quality of generated images.
- .. math::
- FID = \|\mu - \mu_w\|^2 + tr(\Sigma + \Sigma_w - 2(\Sigma \Sigma_w)^{\frac{1}{2}})
- where :math:`\mathcal{N}(\mu, \Sigma)` is the multivariate normal distribution estimated from Inception v3
- (`fid ref1`_) features calculated on real life images and :math:`\mathcal{N}(\mu_w, \Sigma_w)` is the
- multivariate normal distribution estimated from Inception v3 features calculated on generated (fake) images.
- The metric was originally proposed in `fid ref1`_.
- Using the default feature extraction (Inception v3 using the original weights from `fid ref2`_), the input is
- expected to be mini-batches of 3-channel RGB images of shape ``(3xHxW)``. If argument ``normalize``
- is ``True`` images are expected to be dtype ``float`` and have values in the ``[0,1]`` range, else if
- ``normalize`` is set to ``False`` images are expected to have dtype ``uint8`` and take values in the ``[0, 255]``
- range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian
- flag ``real`` determines if the images should update the statistics of the real distribution or the
- fake distribution.
- Using custom feature extractor is also possible. One can give a torch.nn.Module as `feature` argument. This
- custom feature extractor is expected to have output shape of ``(1, num_features)``. This would change the
- used feature extractor from default (Inception v3) to the given network. In case network doesn't have
- ``num_features`` attribute, a random tensor will be given to the network to infer feature dimensionality.
- Size of this tensor can be controlled by ``input_img_size`` argument and type of the tensor can be controlled
- with ``normalize`` argument (``True`` uses float32 tensors and ``False`` uses int8 tensors). In this case, update
- method expects to have the tensor given to `imgs` argument to be in the correct shape and type that is compatible
- to the custom feature extractor.
- This metric is known to be unstable in its calculatations, and we recommend for the best results using this metric
- that you calculate using `torch.float64` (default is `torch.float32`) which can be set using the `.set_dtype`
- method of the metric.
- .. hint::
- Using this metric with the default feature extractor requires that ``torch-fidelity``
- is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity``
- As input to ``forward`` and ``update`` the metric accepts the following input
- - ``imgs`` (:class:`~torch.Tensor`): tensor with images feed to the feature extractor with
- - ``real`` (:class:`~bool`): bool indicating if ``imgs`` belong to the real or the fake distribution
- As output of `forward` and `compute` the metric returns the following output
- - ``fid`` (:class:`~torch.Tensor`): float scalar tensor with mean FID value over samples
- Args:
- feature:
- Either an integer or ``nn.Module``:
- - an integer will indicate the inceptionv3 feature layer to choose. Can be one of the following:
- 64, 192, 768, 2048
- - an ``nn.Module`` for using a custom feature extractor. Expects that its forward method returns
- an ``(N,d)`` matrix where ``N`` is the batch size and ``d`` is the feature size.
- reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not
- change, the features can be cached them to avoid recomputing them which is costly. Set this to ``False`` if
- your dataset does not change.
- normalize:
- Argument for controlling the input image dtype normalization:
- - If default feature extractor is used, controls whether input imgs have values in range [0, 1] or not:
- - True: if input imgs have values ranged in [0, 1]. They are cast to int8/byte tensors.
- - False: if input imgs have values ranged in [0, 255]. No casting is done.
- - If custom feature extractor module is used, controls type of the input img tensors:
- - True: if input imgs are expected to be in the data type of torch.float32.
- - False: if input imgs are expected to be in the data type of torch.int8.
- input_img_size: tuple of integers. Indicates input img size to the custom feature extractor network if provided.
- use_antialias: boolian flag to indicate whether to use antialiasing when resizing images. This will change the
- resize function to use bilinear interpolation with antialiasing, which is different from the original
- Inception v3 implementation. Does not apply to custom feature extractor networks.
- kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
- Raises:
- ValueError:
- If torch version is lower than 1.9
- ModuleNotFoundError:
- If ``feature`` is set to an ``int`` (default settings) and ``torch-fidelity`` is not installed
- ValueError:
- If ``feature`` is set to an ``int`` not in [64, 192, 768, 2048]
- TypeError:
- If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module``
- ValueError:
- If ``reset_real_features`` is not an ``bool``
- Example:
- >>> from torch import rand
- >>> from torchmetrics.image.fid import FrechetInceptionDistance
- >>> fid = FrechetInceptionDistance(feature=64)
- >>> # generate two slightly overlapping image intensity distributions
- >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
- >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
- >>> fid.update(imgs_dist1, real=True)
- >>> fid.update(imgs_dist2, real=False)
- >>> fid.compute()
- tensor(12.6388)
- """
- higher_is_better: bool = False
- is_differentiable: bool = False
- full_state_update: bool = False
- plot_lower_bound: float = 0.0
- real_features_sum: Tensor
- real_features_cov_sum: Tensor
- real_features_num_samples: Tensor
- fake_features_sum: Tensor
- fake_features_cov_sum: Tensor
- fake_features_num_samples: Tensor
- inception: Module
- feature_network: str = "inception"
- def __init__(
- self,
- feature: Union[int, Module] = 2048,
- reset_real_features: bool = True,
- normalize: bool = False,
- input_img_size: tuple[int, int, int] = (3, 299, 299),
- feature_extractor_weights_path: Optional[str] = None,
- antialias: bool = True,
- **kwargs: Any,
- ) -> None:
- super().__init__(**kwargs)
- if not isinstance(normalize, bool):
- raise ValueError("Argument `normalize` expected to be a bool")
- self.normalize = normalize
- self.used_custom_model = False
- antialias = antialias
- if isinstance(feature, int):
- num_features = feature
- if not _TORCH_FIDELITY_AVAILABLE:
- raise ModuleNotFoundError(
- "FrechetInceptionDistance metric requires that `Torch-fidelity` is installed."
- " Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
- )
- valid_int_input = (64, 192, 768, 2048)
- if feature not in valid_int_input:
- raise ValueError(
- f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
- )
- self.inception = NoTrainInceptionV3(
- name="inception-v3-compat",
- features_list=[str(feature)],
- feature_extractor_weights_path=feature_extractor_weights_path,
- antialias=antialias,
- )
- elif isinstance(feature, Module):
- self.inception = feature
- self.used_custom_model = True
- if hasattr(self.inception, "num_features"):
- if isinstance(self.inception.num_features, int):
- num_features = self.inception.num_features
- elif isinstance(self.inception.num_features, Tensor):
- num_features = int(self.inception.num_features.item())
- else:
- raise TypeError("Expected `self.inception.num_features` to be of type int or Tensor.")
- else:
- if self.normalize:
- dummy_image = torch.rand(1, *input_img_size, dtype=torch.float32)
- else:
- dummy_image = torch.randint(0, 255, (1, *input_img_size), dtype=torch.uint8)
- num_features = self.inception(dummy_image).shape[-1]
- else:
- raise TypeError("Got unknown input to argument `feature`")
- if not isinstance(reset_real_features, bool):
- raise ValueError("Argument `reset_real_features` expected to be a bool")
- self.reset_real_features = reset_real_features
- mx_num_feats = (num_features, num_features)
- self.add_state("real_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum")
- self.add_state("real_features_cov_sum", torch.zeros(mx_num_feats).double(), dist_reduce_fx="sum")
- self.add_state("real_features_num_samples", torch.tensor(0).long(), dist_reduce_fx="sum")
- self.add_state("fake_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum")
- self.add_state("fake_features_cov_sum", torch.zeros(mx_num_feats).double(), dist_reduce_fx="sum")
- self.add_state("fake_features_num_samples", torch.tensor(0).long(), dist_reduce_fx="sum")
- def update(self, imgs: Tensor, real: bool) -> None:
- """Update the state with extracted features.
- Args:
- imgs: Input img tensors to evaluate. If used custom feature extractor please
- make sure dtype and size is correct for the model.
- real: Whether given image is real or fake.
- """
- imgs = (imgs * 255).byte() if self.normalize and (not self.used_custom_model) else imgs
- features = self.inception(imgs)
- self.orig_dtype = features.dtype
- features = features.double()
- if features.dim() == 1:
- features = features.unsqueeze(0)
- if real:
- self.real_features_sum += features.sum(dim=0)
- self.real_features_cov_sum += features.t().mm(features)
- self.real_features_num_samples += imgs.shape[0]
- else:
- self.fake_features_sum += features.sum(dim=0)
- self.fake_features_cov_sum += features.t().mm(features)
- self.fake_features_num_samples += imgs.shape[0]
- def compute(self) -> Tensor:
- """Calculate FID score based on accumulated extracted features from the two distributions."""
- if self.real_features_num_samples < 2 or self.fake_features_num_samples < 2:
- raise RuntimeError("More than one sample is required for both the real and fake distributed to compute FID")
- mean_real = (self.real_features_sum / self.real_features_num_samples).unsqueeze(0)
- mean_fake = (self.fake_features_sum / self.fake_features_num_samples).unsqueeze(0)
- cov_real_num = self.real_features_cov_sum - self.real_features_num_samples * mean_real.t().mm(mean_real)
- cov_real = cov_real_num / (self.real_features_num_samples - 1)
- cov_fake_num = self.fake_features_cov_sum - self.fake_features_num_samples * mean_fake.t().mm(mean_fake)
- cov_fake = cov_fake_num / (self.fake_features_num_samples - 1)
- return _compute_fid(mean_real.squeeze(0), cov_real, mean_fake.squeeze(0), cov_fake).to(self.orig_dtype)
- def reset(self) -> None:
- """Reset metric states."""
- if not self.reset_real_features:
- real_features_sum = deepcopy(self.real_features_sum)
- real_features_cov_sum = deepcopy(self.real_features_cov_sum)
- real_features_num_samples = deepcopy(self.real_features_num_samples)
- super().reset()
- self.real_features_sum = real_features_sum
- self.real_features_cov_sum = real_features_cov_sum
- self.real_features_num_samples = real_features_num_samples
- else:
- super().reset()
- def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric":
- """Transfer all metric state to specific dtype. Special version of standard `type` method.
- Arguments:
- dst_type: the desired type as ``torch.dtype`` or string
- """
- out = super().set_dtype(dst_type)
- if isinstance(out.inception, NoTrainInceptionV3):
- out.inception._dtype = dst_type
- return out
- def plot(
- self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
- ) -> _PLOT_OUT_TYPE:
- """Plot a single or multiple values from the metric.
- Args:
- val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
- If no value is provided, will automatically call `metric.compute` and plot that result.
- ax: An matplotlib axis object. If provided will add plot to that axis
- Returns:
- Figure and Axes object
- Raises:
- ModuleNotFoundError:
- If `matplotlib` is not installed
- .. plot::
- :scale: 75
- >>> # Example plotting a single value
- >>> import torch
- >>> from torchmetrics.image.fid import FrechetInceptionDistance
- >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
- >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
- >>> metric = FrechetInceptionDistance(feature=64)
- >>> metric.update(imgs_dist1, real=True)
- >>> metric.update(imgs_dist2, real=False)
- >>> fig_, ax_ = metric.plot()
- .. plot::
- :scale: 75
- >>> # Example plotting multiple values
- >>> import torch
- >>> from torchmetrics.image.fid import FrechetInceptionDistance
- >>> imgs_dist1 = lambda: torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
- >>> imgs_dist2 = lambda: torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
- >>> metric = FrechetInceptionDistance(feature=64)
- >>> values = [ ]
- >>> for _ in range(3):
- ... metric.update(imgs_dist1(), real=True)
- ... metric.update(imgs_dist2(), real=False)
- ... values.append(metric.compute())
- ... metric.reset()
- >>> fig_, ax_ = metric.plot(values)
- """
- return self._plot(val, ax)
|