arniqa.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Content inspired by the ARNIQA official repository:
  15. # https://github.com/miccunifi/ARNIQA
  16. # Copyright (c) 2024, Lorenzo Agnolucci, Leonardo Galteri, Marco Bertini, Alberto Del Bimbo
  17. # All rights reserved.
  18. # License under Apache-2.0 License
  19. import warnings
  20. from typing import Union
  21. import torch
  22. from torch import Tensor, nn
  23. from torch.nn.functional import normalize as normalize_fn
  24. from typing_extensions import Literal
  25. from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCHVISION_AVAILABLE
  26. if _TORCHVISION_AVAILABLE:
  27. from torchvision import transforms
  28. from torchvision.models import resnet50
  29. _AVAILABLE_REGRESSOR_DATASETS = {
  30. "kadid10k": (1, 5),
  31. "koniq10k": (1, 100),
  32. }
  33. _TYPE_REGRESSOR_DATASET = Literal["kadid10k", "koniq10k"]
  34. _base_url = "https://github.com/miccunifi/ARNIQA/releases/download/weights"
  35. if not (_TORCH_GREATER_EQUAL_2_2 and _TORCHVISION_AVAILABLE):
  36. __doctest_skip__ = ["arniqa"]
  37. class _ARNIQA(nn.Module):
  38. """Initializes a No-Reference Image Quality Assessment ARNIQA torch.nn.Module.
  39. Args:
  40. regressor_dataset: dataset used for training the regressor, choose between [``koniq10k``, ``kadid10k``]
  41. """
  42. def __init__(self, regressor_dataset: _TYPE_REGRESSOR_DATASET = "koniq10k") -> None:
  43. super().__init__()
  44. if not _TORCH_GREATER_EQUAL_2_2: # ToDo: RuntimeError: "slow_conv2d_cpu" not implemented for 'Half'
  45. raise RuntimeError("ARNIQA metric requires PyTorch >= 2.2.0")
  46. if not _TORCHVISION_AVAILABLE:
  47. raise ModuleNotFoundError(
  48. "ARNIQA metric requires that torchvision is installed."
  49. " Either install as `pip install torchmetrics[image]` or `pip install torchvision`."
  50. )
  51. valid_regressor_datasets = _AVAILABLE_REGRESSOR_DATASETS.keys()
  52. if regressor_dataset not in valid_regressor_datasets:
  53. raise ValueError(
  54. f"Argument `regressor_dataset` must be one of {valid_regressor_datasets}, but got {regressor_dataset}."
  55. )
  56. self.regressor_dataset = regressor_dataset
  57. self.imagenet_norm_mean = [0.485, 0.456, 0.406]
  58. self.imagenet_norm_std = [0.229, 0.224, 0.225]
  59. encoder = resnet50()
  60. self.feat_dim = encoder.fc.in_features # get dimensions of the last layer of the encoder
  61. encoder = nn.Sequential(*list(encoder.children())[:-1]) # remove the fully connected layer
  62. self.encoder = encoder
  63. self.regressor = nn.Linear(self.feat_dim * 2, 1)
  64. self._load_weights()
  65. def _freeze(module: nn.Module) -> None:
  66. module.eval()
  67. for p in module.parameters():
  68. p.requires_grad = False
  69. _freeze(self.encoder)
  70. _freeze(self.regressor)
  71. def _load_weights(self) -> None:
  72. """Loads the weights of the encoder and regressor."""
  73. encoder_state_dict = torch.hub.load_state_dict_from_url(
  74. f"{_base_url}/ARNIQA.pth", progress=True, map_location="cpu"
  75. )
  76. filtered_encoder_state_dict = {
  77. k.replace("model.", ""): v for k, v in encoder_state_dict.items() if "projector" not in k
  78. }
  79. self.encoder.load_state_dict(filtered_encoder_state_dict, strict=True)
  80. with warnings.catch_warnings():
  81. warnings.filterwarnings("ignore", category=UserWarning, module="torch.serialization")
  82. regressor_state_dict = torch.hub.load_state_dict_from_url(
  83. f"{_base_url}/regressor_{self.regressor_dataset}.pth", progress=True, map_location="cpu"
  84. ).state_dict()
  85. # Rename the keys to match the regressor's state_dict
  86. regressor_state_dict["weight"] = regressor_state_dict.pop("weights")
  87. regressor_state_dict["bias"] = regressor_state_dict.pop("biases").unsqueeze(0)
  88. self.regressor.load_state_dict(regressor_state_dict, strict=True)
  89. def _preprocess_input(self, img: Tensor, normalize: bool = False) -> tuple[Tensor, Tensor]:
  90. """Preprocesses the input to the model.
  91. Obtains the half-scale version of the input image and applies normalization if needed.
  92. """
  93. h, w = img.shape[-2:]
  94. img_ds = transforms.Resize((h // 2, w // 2))(img) # get the half-scale version of the image
  95. if normalize:
  96. img = transforms.Normalize(mean=self.imagenet_norm_mean, std=self.imagenet_norm_std)(img)
  97. img_ds = transforms.Normalize(mean=self.imagenet_norm_mean, std=self.imagenet_norm_std)(img_ds)
  98. return img, img_ds
  99. def _scale_score(self, score: Tensor) -> Tensor:
  100. """Scales the quality score to be in the [0, 1] range, where higher is better."""
  101. min_score, max_score = _AVAILABLE_REGRESSOR_DATASETS[self.regressor_dataset]
  102. return (score - min_score) / (max_score - min_score)
  103. def forward(self, img: Tensor, normalize: bool = False) -> Tensor:
  104. # Preprocessing
  105. img, img_ds = self._preprocess_input(img, normalize)
  106. # Extract features from full- and half-scale images
  107. img_f = self.encoder(img)
  108. img_f = img_f.view(-1, self.feat_dim)
  109. img_f = normalize_fn(img_f, dim=1)
  110. img_ds_f = self.encoder(img_ds)
  111. img_ds_f = img_ds_f.view(-1, self.feat_dim)
  112. img_ds_f = normalize_fn(img_ds_f, dim=1)
  113. f = torch.hstack((img_f, img_ds_f))
  114. # Get the quality score
  115. score = self.regressor(f)
  116. return self._scale_score(score)
  117. class _NoTrainArniqa(_ARNIQA):
  118. """Wrapper to make sure ARNIQA never leaves evaluation mode."""
  119. def train(self, mode: bool) -> "_NoTrainArniqa": # type: ignore[override]
  120. """Force network to always be in evaluation mode."""
  121. return super().train(False)
  122. def _arniqa_update(
  123. img: Tensor, model: nn.Module, normalize: bool, autocast: bool = False
  124. ) -> tuple[Tensor, Union[int, Tensor]]:
  125. """Update step for ARNIQA metric.
  126. Args:
  127. img: the input image
  128. model: the pre-trained model
  129. normalize: boolean indicating whether the input image is normalized
  130. autocast: boolean indicating whether to use automatic mixed precision
  131. """
  132. # Check that the input image is valid
  133. if not (img.ndim == 4 and img.shape[1] == 3):
  134. raise ValueError(f"Input image must have shape [N, 3, H, W]. Got input with shape {img.shape}.")
  135. if not (img.max() <= 1.0 and img.min() >= 0.0) and normalize:
  136. raise ValueError(
  137. f"Input image values must be in the [0, 1] range when normalize==True. Got input with values"
  138. f" in range {img.min()} and {img.max()}."
  139. )
  140. if autocast:
  141. with torch.amp.autocast(device_type=img.device.type, dtype=img.dtype):
  142. loss = model(img, normalize=normalize)
  143. else:
  144. loss = model.to(dtype=img.dtype)(img, normalize=normalize)
  145. return loss.squeeze(), img.shape[0]
  146. def _arniqa_compute(
  147. scores: Tensor, num_scores: Union[Tensor, int], reduction: Literal["sum", "mean", "none"] = "mean"
  148. ) -> Tensor:
  149. """Compute step for ARNIQA metric."""
  150. sum_scores = scores.sum()
  151. if reduction == "none":
  152. return scores
  153. if reduction == "mean":
  154. return sum_scores / num_scores
  155. return sum_scores
  156. def arniqa(
  157. img: Tensor,
  158. regressor_dataset: _TYPE_REGRESSOR_DATASET = "koniq10k",
  159. reduction: Literal["sum", "mean", "none"] = "mean",
  160. normalize: bool = True,
  161. autocast: bool = False,
  162. ) -> Tensor:
  163. """ARNIQA: leArning distoRtion maNifold for Image Quality Assessment metric.
  164. `ARNIQA`_ is a No-Reference Image Quality Assessment metric that predicts the technical quality of an image with
  165. a high correlation with human opinions. ARNIQA consists of an encoder and a regressor. The encoder is a ResNet-50
  166. model trained in a self-supervised way to model the image distortion manifold to generate similar representation for
  167. images with similar distortions, regardless of the image content. The regressor is a linear model trained on IQA
  168. datasets using the ground-truth quality scores. ARNIQA extracts the features from the full- and half-scale versions
  169. of the input image and then outputs a quality score in the [0, 1] range, where higher is better.
  170. The input image is expected to have shape ``(N, 3, H, W)``. The image should be in the [0, 1] range if `normalize`
  171. is set to ``True``, otherwise it should be normalized with the ImageNet mean and standard deviation.
  172. .. note::
  173. Using this metric requires you to have ``torchvision`` package installed. Either install as
  174. ``pip install torchmetrics[image]`` or ``pip install torchvision``.
  175. Args:
  176. img: the input image
  177. regressor_dataset: dataset used for training the regressor. Choose between [``koniq10k``, ``kadid10k``].
  178. ``koniq10k`` corresponds to the `KonIQ-10k`_ dataset, which consists of real-world images with authentic
  179. distortions. ``kadid10k`` corresponds to the `KADID-10k`_ dataset, which consists of images with
  180. synthetically generated distortions.
  181. reduction: indicates how to reduce over the batch dimension. Choose between [``sum``, ``mean``, ``none``].
  182. normalize: by default this is ``True`` meaning that the input is expected to be in the [0, 1] range. If set
  183. to ``False`` will instead expect input to be already normalized with the ImageNet mean and standard
  184. deviation.
  185. autocast: boolean indicating whether to use automatic mixed precision
  186. Returns:
  187. A tensor in the [0, 1] range, where higher is better, representing the ARNIQA score of the input image. If
  188. `reduction` is set to ``none``, the output will have shape ``(N,)``, otherwise it will be a scalar tensor.
  189. Raises:
  190. ModuleNotFoundError:
  191. If ``torchvision`` package is not installed
  192. ValueError:
  193. If ``regressor_dataset`` is not in [``"kadid10k"``, ``"koniq10k"``]
  194. ValueError:
  195. If ``reduction`` is not in [``"sum"``, ``"mean"``, ``"none"``]
  196. ValueError:
  197. If ``normalize`` is not a bool
  198. ValueError:
  199. If the input image is not a valid image tensor with shape [N, 3, H, W].
  200. ValueError:
  201. If the input image values are not in the [0, 1] range when ``normalize`` is set to ``True``
  202. Examples:
  203. >>> from torch import rand
  204. >>> from torchmetrics.functional.image.arniqa import arniqa
  205. >>> img = rand(8, 3, 224, 224)
  206. >>> # Non-normalized input
  207. >>> arniqa(img, regressor_dataset='koniq10k', normalize=True)
  208. tensor(0.5308)
  209. >>> from torch import rand
  210. >>> from torchmetrics.functional.image.arniqa import arniqa
  211. >>> from torchvision.transforms import Normalize
  212. >>> img = rand(8, 3, 224, 224)
  213. >>> img = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
  214. >>> # Normalized input
  215. >>> arniqa(img, regressor_dataset='koniq10k', normalize=False)
  216. tensor(0.5065)
  217. """
  218. valid_reduction = ("mean", "sum", "none")
  219. if reduction not in valid_reduction:
  220. raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}")
  221. if not isinstance(normalize, bool):
  222. raise ValueError(f"Argument `normalize` should be a bool but got {normalize}")
  223. model = _NoTrainArniqa(regressor_dataset=regressor_dataset).to(device=img.device, dtype=img.dtype)
  224. loss, num_scores = _arniqa_update(img, model, normalize=normalize, autocast=autocast)
  225. return _arniqa_compute(loss, num_scores, reduction)