| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Sequence
- from typing import Any, List, Optional, Union
- import torch
- from torch import Tensor
- from torch.nn import Module
- from torchmetrics.image.fid import NoTrainInceptionV3, _compute_fid
- from torchmetrics.metric import Metric
- from torchmetrics.utilities.data import dim_zero_cat
- from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE
- from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
- __doctest_requires__ = {
- ("MemorizationInformedFrechetInceptionDistance", "MemorizationInformedFrechetInceptionDistance.plot"): [
- "torch_fidelity"
- ]
- }
- if not _MATPLOTLIB_AVAILABLE:
- __doctest_skip__ = ["MemorizationInformedFrechetInceptionDistance.plot"]
- def _compute_cosine_distance(features1: Tensor, features2: Tensor, cosine_distance_eps: float = 0.1) -> Tensor:
- """Compute the cosine distance between two sets of features."""
- features1_nozero = features1[torch.sum(features1, dim=1) != 0]
- features2_nozero = features2[torch.sum(features2, dim=1) != 0]
- # normalize
- norm_f1 = features1_nozero / torch.norm(features1_nozero, dim=1, keepdim=True)
- norm_f2 = features2_nozero / torch.norm(features2_nozero, dim=1, keepdim=True)
- d = 1.0 - torch.abs(torch.matmul(norm_f1, norm_f2.t()))
- mean_min_d = torch.mean(d.min(dim=1).values)
- return mean_min_d if mean_min_d < cosine_distance_eps else torch.ones_like(mean_min_d)
- def _mifid_compute(
- mu1: Tensor,
- sigma1: Tensor,
- features1: Tensor,
- mu2: Tensor,
- sigma2: Tensor,
- features2: Tensor,
- cosine_distance_eps: float = 0.1,
- ) -> Tensor:
- """Compute MIFID score given two sets of features and their statistics."""
- fid_value = _compute_fid(mu1, sigma1, mu2, sigma2)
- distance = _compute_cosine_distance(features1, features2, cosine_distance_eps)
- # secure that very small fid values does not explode the mifid
- return fid_value / (distance + 10e-15) if fid_value > 1e-8 else torch.zeros_like(fid_value)
- class MemorizationInformedFrechetInceptionDistance(Metric):
- r"""Calculate Memorization-Informed Frechet Inception Distance (MIFID_).
- MIFID is a improved variation of the Frechet Inception Distance (FID_) that penalizes memorization of the training
- set by the generator. It is calculated as
- .. math::
- MIFID = \frac{FID(F_{real}, F_{fake})}{M(F_{real}, F_{fake})}
- where :math:`FID` is the normal FID score and :math:`M` is the memorization penalty. The memorization penalty
- essentially corresponds to the average minimum cosine distance between the features of the real and fake
- distribution.
- Using the default feature extraction (Inception v3 using the original weights from `fid ref2`_), the input is
- expected to be mini-batches of 3-channel RGB images of shape ``(3 x H x W)``. If argument ``normalize``
- is ``True`` images are expected to be dtype ``float`` and have values in the ``[0, 1]`` range, else if
- ``normalize`` is set to ``False`` images are expected to have dtype ``uint8`` and take values in the ``[0, 255]``
- range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian
- flag ``real`` determines if the images should update the statistics of the real distribution or the
- fake distribution.
- .. hint::
- Using this metrics requires you to have ``scipy`` install. Either install as ``pip install
- torchmetrics[image]`` or ``pip install scipy``
- .. hint::
- Using this metric with the default feature extractor requires that ``torch-fidelity``
- is installed. Either install as ``pip install torchmetrics[image]`` or
- ``pip install torch-fidelity``
- As input to ``forward`` and ``update`` the metric accepts the following input
- - ``imgs`` (:class:`~torch.Tensor`): tensor with images feed to the feature extractor with
- - ``real`` (:class:`~bool`): bool indicating if ``imgs`` belong to the real or the fake distribution
- As output of `forward` and `compute` the metric returns the following output
- - ``mifid`` (:class:`~torch.Tensor`): float scalar tensor with mean MIFID value over samples
- Args:
- feature:
- Either an integer or ``nn.Module``:
- - an integer will indicate the inceptionv3 feature layer to choose. Can be one of the following:
- 64, 192, 768, 2048
- - an ``nn.Module`` for using a custom feature extractor. Expects that its forward method returns
- an ``(N,d)`` matrix where ``N`` is the batch size and ``d`` is the feature size.
- reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not
- change, the features can be cached them to avoid recomputing them which is costly. Set this to ``False`` if
- your dataset does not change.
- normalize: Whether to normalize the input images. If ``True`` the input is expected to be in the range [0, 1]
- and converted to ``uint8``. If ``False`` the input is expected to already be in the range [0, 255] and of
- type ``uint8``. If a custom feature extractor is used, this argument is ignored.
- cosine_distance_eps: Epsilon value for the cosine distance. If the cosine distance is larger than this value
- it is set to 1 and thus ignored in the MIFID calculation.
- kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
- Raises:
- RuntimeError:
- If ``torch`` is version less than 1.10
- ValueError:
- If ``feature`` is set to an ``int`` and ``torch-fidelity`` is not installed
- ValueError:
- If ``feature`` is set to an ``int`` not in [64, 192, 768, 2048]
- TypeError:
- If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module``
- ValueError:
- If ``reset_real_features`` is not an ``bool``
- Example::
- >>> from torch import randint
- >>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
- >>> mifid = MemorizationInformedFrechetInceptionDistance(feature=64)
- >>> # generate two slightly overlapping image intensity distributions
- >>> imgs_dist1 = randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
- >>> imgs_dist2 = randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
- >>> mifid.update(imgs_dist1, real=True)
- >>> mifid.update(imgs_dist2, real=False)
- >>> mifid.compute()
- tensor(3003.3691)
- """
- higher_is_better: bool = False
- is_differentiable: bool = False
- full_state_update: bool = False
- real_features: List[Tensor]
- fake_features: List[Tensor]
- inception: Module
- feature_network: str = "inception"
- def __init__(
- self,
- feature: Union[int, Module] = 2048,
- reset_real_features: bool = True,
- normalize: bool = False,
- cosine_distance_eps: float = 0.1,
- **kwargs: Any,
- ) -> None:
- super().__init__(**kwargs)
- self.used_custom_model = False
- if isinstance(feature, int):
- if not _TORCH_FIDELITY_AVAILABLE:
- raise ModuleNotFoundError(
- "MemorizationInformedFrechetInceptionDistance metric requires that `Torch-fidelity` is installed."
- " Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
- )
- valid_int_input = [64, 192, 768, 2048]
- if feature not in valid_int_input:
- raise ValueError(
- f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
- )
- self.inception = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])
- elif isinstance(feature, Module):
- self.inception = feature
- self.used_custom_model = True
- else:
- raise TypeError("Got unknown input to argument `feature`")
- if not isinstance(reset_real_features, bool):
- raise ValueError("Argument `reset_real_features` expected to be a bool")
- self.reset_real_features = reset_real_features
- if not isinstance(normalize, bool):
- raise ValueError("Argument `normalize` expected to be a bool")
- self.normalize = normalize
- if not (isinstance(cosine_distance_eps, float) and 1 >= cosine_distance_eps > 0):
- raise ValueError("Argument `cosine_distance_eps` expected to be a float greater than 0 and less than 1")
- self.cosine_distance_eps = cosine_distance_eps
- # states for extracted features
- self.add_state("real_features", [], dist_reduce_fx=None)
- self.add_state("fake_features", [], dist_reduce_fx=None)
- def update(self, imgs: Tensor, real: bool) -> None:
- """Update the state with extracted features."""
- imgs = (imgs * 255).byte() if self.normalize and not self.used_custom_model else imgs
- features = self.inception(imgs)
- self.orig_dtype = features.dtype
- features = features.double()
- if real:
- self.real_features.append(features)
- else:
- self.fake_features.append(features)
- def compute(self) -> Tensor:
- """Calculate FID score based on accumulated extracted features from the two distributions."""
- real_features = dim_zero_cat(self.real_features)
- fake_features = dim_zero_cat(self.fake_features)
- mean_real, mean_fake = torch.mean(real_features, dim=0), torch.mean(fake_features, dim=0)
- cov_real, cov_fake = torch.cov(real_features.t()), torch.cov(fake_features.t())
- return _mifid_compute(
- mean_real,
- cov_real,
- real_features,
- mean_fake,
- cov_fake,
- fake_features,
- cosine_distance_eps=self.cosine_distance_eps,
- ).to(self.orig_dtype)
- def reset(self) -> None:
- """Reset metric states."""
- if not self.reset_real_features:
- # remove temporarily to avoid resetting
- value = self._defaults.pop("real_features")
- super().reset()
- self._defaults["real_features"] = value
- else:
- super().reset()
- def plot(
- self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
- ) -> _PLOT_OUT_TYPE:
- """Plot a single or multiple values from the metric.
- Args:
- val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
- If no value is provided, will automatically call `metric.compute` and plot that result.
- ax: An matplotlib axis object. If provided will add plot to that axis
- Returns:
- Figure and Axes object
- Raises:
- ModuleNotFoundError:
- If `matplotlib` is not installed
- .. plot::
- :scale: 75
- >>> # Example plotting a single value
- >>> import torch
- >>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
- >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
- >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
- >>> metric = MemorizationInformedFrechetInceptionDistance(feature=64)
- >>> metric.update(imgs_dist1, real=True)
- >>> metric.update(imgs_dist2, real=False)
- >>> fig_, ax_ = metric.plot()
- .. plot::
- :scale: 75
- >>> # Example plotting multiple values
- >>> import torch
- >>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
- >>> imgs_dist1 = lambda: torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
- >>> imgs_dist2 = lambda: torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
- >>> metric = MemorizationInformedFrechetInceptionDistance(feature=64)
- >>> values = [ ]
- >>> for _ in range(3):
- ... metric.update(imgs_dist1(), real=True)
- ... metric.update(imgs_dist2(), real=False)
- ... values.append(metric.compute())
- ... metric.reset()
- >>> fig_, ax_ = metric.plot(values)
- """
- return self._plot(val, ax)
|