arniqa.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. import torch
  17. from torch import Tensor
  18. from typing_extensions import Literal
  19. from torchmetrics.functional.image.arniqa import (
  20. _ARNIQA,
  21. _TYPE_REGRESSOR_DATASET,
  22. _arniqa_compute,
  23. _arniqa_update,
  24. _NoTrainArniqa,
  25. )
  26. from torchmetrics.metric import Metric
  27. from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout
  28. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_GREATER_EQUAL_2_2, _TORCHVISION_AVAILABLE
  29. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  30. if not _MATPLOTLIB_AVAILABLE:
  31. __doctest_skip__ = ["ARNIQA.plot"]
  32. if _TORCH_GREATER_EQUAL_2_2 and _TORCHVISION_AVAILABLE:
  33. def _download_arniqa() -> None:
  34. _ARNIQA(regressor_dataset="koniq10k")
  35. if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_arniqa):
  36. __doctest_skip__ = ["ARNIQA", "ARNIQA.plot"]
  37. else:
  38. __doctest_skip__ = ["ARNIQA", "ARNIQA.plot"]
  39. class ARNIQA(Metric):
  40. """ARNIQA: leArning distoRtion maNifold for Image Quality Assessment metric.
  41. `ARNIQA`_ is a No-Reference Image Quality Assessment metric that predicts the technical quality of an image with
  42. a high correlation with human opinions. ARNIQA consists of an encoder and a regressor. The encoder is a ResNet-50
  43. model trained in a self-supervised way to model the image distortion manifold to generate similar representation for
  44. images with similar distortions, regardless of the image content. The regressor is a linear model trained on IQA
  45. datasets using the ground-truth quality scores. ARNIQA extracts the features from the full- and half-scale versions
  46. of the input image and then outputs a quality score in the [0, 1] range, where higher is better.
  47. The input image is expected to have shape ``(N, 3, H, W)``. The image should be in the [0, 1] range if `normalize`
  48. is set to ``True``, otherwise it should be normalized with the ImageNet mean and standard deviation.
  49. .. note::
  50. Using this metric requires you to have ``torchvision`` package installed. Either install as
  51. ``pip install torchmetrics[image]`` or ``pip install torchvision``.
  52. As input to ``forward`` and ``update`` the metric accepts the following input
  53. - ``img`` (:class:`~torch.Tensor`): tensor with images of shape ``(N, 3, H, W)``
  54. As output of `forward` and `compute` the metric returns the following output
  55. - ``arniqa`` (:class:`~torch.Tensor`): tensor with ARNIQA score. If `reduction` is set to ``none``, the output will
  56. have shape ``(N,)``, otherwise it will be a scalar tensor. Tensor values are in the [0, 1] range, where higher
  57. is better.
  58. Args:
  59. img: the input image
  60. regressor_dataset: dataset used for training the regressor. Choose between [``koniq10k``, ``kadid10k``].
  61. ``koniq10k`` corresponds to the `KonIQ-10k`_ dataset, which consists of real-world images with authentic
  62. distortions. ``kadid10k`` corresponds to the `KADID-10k`_ dataset, which consists of images with
  63. synthetically generated distortions.
  64. reduction: indicates how to reduce over the batch dimension. Choose between [``sum``, ``mean``, ``none``].
  65. normalize: by default this is ``True`` meaning that the input is expected to be in the [0, 1] range. If set
  66. to ``False`` will instead expect input to be already normalized with the ImageNet mean and standard
  67. deviation.
  68. autocast: if ``True``, metric will convert model to mixed precision before running forward pass.
  69. kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.
  70. Raises:
  71. ModuleNotFoundError:
  72. If ``torchvision`` package is not installed
  73. ValueError:
  74. If ``regressor_dataset`` is not in [``"kadid10k"``, ``"koniq10k"``]
  75. ValueError:
  76. If ``reduction`` is not in [``"sum"``, ``"mean"``, ``"none"``]
  77. ValueError:
  78. If ``normalize`` is not a bool
  79. ValueError:
  80. If the input image is not a valid image tensor with shape [N, 3, H, W].
  81. ValueError:
  82. If the input image values are not in the [0, 1] range when ``normalize`` is set to ``True``
  83. Examples:
  84. >>> from torch import rand
  85. >>> from torchmetrics.image.arniqa import ARNIQA
  86. >>> img = rand(8, 3, 224, 224)
  87. >>> # Non-normalized input
  88. >>> metric = ARNIQA(regressor_dataset='koniq10k', normalize=True)
  89. >>> metric(img)
  90. tensor(0.5308)
  91. >>> from torch import rand
  92. >>> from torchmetrics.image.arniqa import ARNIQA
  93. >>> from torchvision.transforms import Normalize
  94. >>> img = rand(8, 3, 224, 224)
  95. >>> img = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
  96. >>> # Normalized input
  97. >>> metric = ARNIQA(regressor_dataset='koniq10k', normalize=False)
  98. >>> metric(img)
  99. tensor(0.5065)
  100. """
  101. is_differentiable: bool = True
  102. higher_is_better: bool = True
  103. full_state_update: bool = False
  104. plot_lower_bound: float = 0.0
  105. plot_upper_bound: float = 1.0
  106. sum_scores: Tensor
  107. num_scores: Tensor
  108. feature_network: str = "model"
  109. def __init__(
  110. self,
  111. regressor_dataset: _TYPE_REGRESSOR_DATASET = "koniq10k",
  112. reduction: Literal["sum", "mean", "none"] = "mean",
  113. normalize: bool = True,
  114. autocast: bool = False,
  115. **kwargs: Any,
  116. ) -> None:
  117. super().__init__(**kwargs)
  118. if not _TORCH_GREATER_EQUAL_2_2: # ToDo: RuntimeError: "slow_conv2d_cpu" not implemented for 'Half'
  119. raise RuntimeError("ARNIQA metric requires PyTorch >= 2.2.0")
  120. if not _TORCHVISION_AVAILABLE:
  121. raise ModuleNotFoundError(
  122. "ARNIQA metric requires that torchvision is installed."
  123. " Either install as `pip install torchmetrics[image]` or `pip install torchvision`."
  124. )
  125. self.model = _NoTrainArniqa(regressor_dataset=regressor_dataset)
  126. valid_reduction = ("mean", "sum", "none")
  127. if reduction not in valid_reduction:
  128. raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}")
  129. self.reduction = reduction
  130. if not isinstance(normalize, bool):
  131. raise ValueError(f"Argument `normalize` should be a bool but got {normalize}")
  132. self.normalize = normalize
  133. self.autocast = autocast
  134. self.add_state("sum_scores", torch.tensor(0.0), dist_reduce_fx="sum")
  135. self.add_state("num_scores", torch.tensor(0.0), dist_reduce_fx="sum")
  136. def update(self, img: Tensor) -> None:
  137. """Update internal states with arniqa score."""
  138. loss, num_scores = _arniqa_update(img, model=self.model, normalize=self.normalize, autocast=self.autocast)
  139. self.sum_scores += loss.sum()
  140. self.num_scores += num_scores
  141. def compute(self) -> Tensor:
  142. """Compute final arniqa metric."""
  143. return _arniqa_compute(self.sum_scores, self.num_scores, self.reduction)
  144. def plot(
  145. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  146. ) -> _PLOT_OUT_TYPE:
  147. """Plot a single or multiple values from the metric.
  148. Args:
  149. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  150. If no value is provided, will automatically call `metric.compute` and plot that result.
  151. ax: An matplotlib axis object. If provided will add plot to that axis
  152. Returns:
  153. Figure and Axes object
  154. Raises:
  155. ModuleNotFoundError:
  156. If `matplotlib` is not installed
  157. .. plot::
  158. :scale: 75
  159. >>> # Example plotting a single value
  160. >>> import torch
  161. >>> from torchmetrics.image.arniqa import ARNIQA
  162. >>> metric = ARNIQA(regressor_dataset='koniq10k')
  163. >>> metric.update(torch.rand(8, 3, 224, 224))
  164. >>> fig_, ax_ = metric.plot()
  165. .. plot::
  166. :scale: 75
  167. >>> # Example plotting multiple values
  168. >>> import torch
  169. >>> from torchmetrics.image.arniqa import ARNIQA
  170. >>> metric = ARNIQA(regressor_dataset='koniq10k')
  171. >>> values = [ ]
  172. >>> for _ in range(3):
  173. ... values.append(metric(torch.rand(8, 3, 224, 224)))
  174. >>> fig_, ax_ = metric.plot(values)
  175. """
  176. return self._plot(val, ax)