vif.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright The PyTorch Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from torch import Tensor
  16. from torch.nn.functional import conv2d
  17. from typing_extensions import Literal
  18. from torchmetrics.utilities.data import dim_zero_cat
  19. def _filter(win_size: float, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
  20. # This code is inspired by
  21. # https://github.com/andrewekhalel/sewar/blob/ac76e7bc75732fde40bb0d3908f4b6863400cc27/sewar/utils.py#L45
  22. # https://github.com/photosynthesis-team/piq/blob/01e16b7d8c76bc8765fb6a69560d806148b8046a/piq/functional/filters.py#L38
  23. # Both links do the same, but the second one is cleaner
  24. coords = torch.arange(win_size, dtype=dtype, device=device) - (win_size - 1) / 2
  25. g = coords**2
  26. g = torch.exp(-(g.unsqueeze(0) + g.unsqueeze(1)) / (2.0 * sigma**2))
  27. g /= torch.sum(g)
  28. return g
  29. def _vif_per_channel(preds: Tensor, target: Tensor, sigma_n_sq: float) -> Tensor:
  30. dtype = preds.dtype
  31. device = preds.device
  32. preds = preds.unsqueeze(1) # Add channel dimension
  33. target = target.unsqueeze(1)
  34. # Constant for numerical stability
  35. eps = torch.tensor(1e-10, dtype=dtype, device=device)
  36. sigma_n_sq = torch.tensor(sigma_n_sq, dtype=dtype, device=device)
  37. preds_vif = torch.zeros(preds.size(0), dtype=dtype, device=device)
  38. target_vif = torch.zeros(preds.size(0), dtype=dtype, device=device)
  39. for scale in range(4):
  40. n = 2.0 ** (4 - scale) + 1
  41. kernel = _filter(n, n / 5, dtype=dtype, device=device)[None, None, :]
  42. if scale > 0:
  43. target = conv2d(target, kernel)[:, :, ::2, ::2]
  44. preds = conv2d(preds, kernel)[:, :, ::2, ::2]
  45. mu_target = conv2d(target, kernel)
  46. mu_preds = conv2d(preds, kernel)
  47. mu_target_sq = mu_target**2
  48. mu_preds_sq = mu_preds**2
  49. mu_target_preds = mu_target * mu_preds
  50. sigma_target_sq = torch.clamp(conv2d(target**2, kernel) - mu_target_sq, min=0.0)
  51. sigma_preds_sq = torch.clamp(conv2d(preds**2, kernel) - mu_preds_sq, min=0.0)
  52. sigma_target_preds = conv2d(target * preds, kernel) - mu_target_preds
  53. g = sigma_target_preds / (sigma_target_sq + eps)
  54. sigma_v_sq = sigma_preds_sq - g * sigma_target_preds
  55. mask = sigma_target_sq < eps
  56. g[mask] = 0
  57. sigma_v_sq[mask] = sigma_preds_sq[mask]
  58. sigma_target_sq[mask] = 0
  59. mask = sigma_preds_sq < eps
  60. g[mask] = 0
  61. sigma_v_sq[mask] = 0
  62. mask = g < 0
  63. sigma_v_sq[mask] = sigma_preds_sq[mask]
  64. g[mask] = 0
  65. sigma_v_sq = torch.clamp(sigma_v_sq, min=eps)
  66. preds_vif += torch.sum(torch.log10(1.0 + (g**2.0) * sigma_target_sq / (sigma_v_sq + sigma_n_sq)), dim=[1, 2, 3])
  67. target_vif += torch.sum(torch.log10(1.0 + sigma_target_sq / sigma_n_sq), dim=[1, 2, 3])
  68. return preds_vif / target_vif
  69. def visual_information_fidelity(
  70. preds: Tensor,
  71. target: Tensor,
  72. sigma_n_sq: float = 2.0,
  73. reduction: Literal["mean", "none"] = "mean",
  74. ) -> Tensor:
  75. """Compute Pixel-Based Visual Information Fidelity (VIF-P).
  76. VIF is a full-reference metric that measures the amount of visual information
  77. preserved in a distorted image compared to the reference image.
  78. Args:
  79. preds: Predicted images of shape (N, C, H, W). Height and width must be at least 41.
  80. target: Ground truth images of shape (N, C, H, W). Must match preds in shape.
  81. sigma_n_sq: Variance of the visual noise. Default: 2.0.
  82. reduction: Method for reducing the metric across the batch.
  83. - "mean": Return a tensor average over the batch.
  84. - "none": Return a VIF score for each sample as a 1D tensor of shape (N,).
  85. Returns:
  86. torch.Tensor: VIF score(s). The shape depends on the `reduction` argument:
  87. - If ``reduction="mean"``, returns a scalar tensor.
  88. - If ``reduction="none"``, returns a tensor of shape ``(N,)``.
  89. Raises:
  90. ValueError: If input dimensions are smaller than ``41x41``.
  91. ValueError: If ``preds`` and ``target`` shapes don't match.
  92. ValueError: If ``reduction`` is not ``"mean"`` or ``"none"``.
  93. Example:
  94. >>> from torchmetrics.functional.image import visual_information_fidelity
  95. >>> preds = torch.randn(4, 3, 41, 41, generator=torch.Generator().manual_seed(42))
  96. >>> target = torch.randn(4, 3, 41, 41, generator=torch.Generator().manual_seed(43))
  97. >>> visual_information_fidelity(preds, target, reduction="none")
  98. tensor([0.0040, 0.0049, 0.0017, 0.0039])
  99. """
  100. # This code is inspired by
  101. # https://github.com/photosynthesis-team/piq/blob/01e16b7d8c76bc8765fb6a69560d806148b8046a/piq/vif.py and
  102. # https://github.com/andrewekhalel/sewar/blob/ac76e7bc75732fde40bb0d3908f4b6863400cc27/sewar/full_ref.py#L357.
  103. if preds.size(-1) < 41 or preds.size(-2) < 41:
  104. raise ValueError(f"Invalid size of preds. Expected at least 41x41, but got {preds.size(-1)}x{preds.size(-2)}!")
  105. if target.size(-1) < 41 or target.size(-2) < 41:
  106. raise ValueError(
  107. f"Invalid size of target. Expected at least 41x41, but got {target.size(-1)}x{target.size(-2)}!"
  108. )
  109. if preds.shape != target.shape:
  110. raise ValueError(f"`preds` and `target` must have the same shape, but got {preds.shape} vs {target.shape}.")
  111. if reduction not in ("mean", "none"):
  112. raise ValueError(f"Argument `reduction` must be 'mean' or 'none', but got {reduction}")
  113. per_channel_scores = [
  114. _vif_per_channel(preds[:, i, :, :], target[:, i, :, :], sigma_n_sq) for i in range(preds.size(1))
  115. ]
  116. vif_per_sample = dim_zero_cat(
  117. torch.stack(per_channel_scores, dim=0).mean(0) if preds.size(1) > 1 else per_channel_scores[0]
  118. )
  119. if reduction == "mean":
  120. return vif_per_sample.mean()
  121. return vif_per_sample