# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Content inspired by the ARNIQA official repository: # https://github.com/miccunifi/ARNIQA # Copyright (c) 2024, Lorenzo Agnolucci, Leonardo Galteri, Marco Bertini, Alberto Del Bimbo # All rights reserved. # License under Apache-2.0 License import warnings from typing import Union import torch from torch import Tensor, nn from torch.nn.functional import normalize as normalize_fn from typing_extensions import Literal from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: from torchvision import transforms from torchvision.models import resnet50 _AVAILABLE_REGRESSOR_DATASETS = { "kadid10k": (1, 5), "koniq10k": (1, 100), } _TYPE_REGRESSOR_DATASET = Literal["kadid10k", "koniq10k"] _base_url = "https://github.com/miccunifi/ARNIQA/releases/download/weights" if not (_TORCH_GREATER_EQUAL_2_2 and _TORCHVISION_AVAILABLE): __doctest_skip__ = ["arniqa"] class _ARNIQA(nn.Module): """Initializes a No-Reference Image Quality Assessment ARNIQA torch.nn.Module. Args: regressor_dataset: dataset used for training the regressor, choose between [``koniq10k``, ``kadid10k``] """ def __init__(self, regressor_dataset: _TYPE_REGRESSOR_DATASET = "koniq10k") -> None: super().__init__() if not _TORCH_GREATER_EQUAL_2_2: # ToDo: RuntimeError: "slow_conv2d_cpu" not implemented for 'Half' raise RuntimeError("ARNIQA metric requires PyTorch >= 2.2.0") if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( "ARNIQA metric requires that torchvision is installed." " Either install as `pip install torchmetrics[image]` or `pip install torchvision`." ) valid_regressor_datasets = _AVAILABLE_REGRESSOR_DATASETS.keys() if regressor_dataset not in valid_regressor_datasets: raise ValueError( f"Argument `regressor_dataset` must be one of {valid_regressor_datasets}, but got {regressor_dataset}." ) self.regressor_dataset = regressor_dataset self.imagenet_norm_mean = [0.485, 0.456, 0.406] self.imagenet_norm_std = [0.229, 0.224, 0.225] encoder = resnet50() self.feat_dim = encoder.fc.in_features # get dimensions of the last layer of the encoder encoder = nn.Sequential(*list(encoder.children())[:-1]) # remove the fully connected layer self.encoder = encoder self.regressor = nn.Linear(self.feat_dim * 2, 1) self._load_weights() def _freeze(module: nn.Module) -> None: module.eval() for p in module.parameters(): p.requires_grad = False _freeze(self.encoder) _freeze(self.regressor) def _load_weights(self) -> None: """Loads the weights of the encoder and regressor.""" encoder_state_dict = torch.hub.load_state_dict_from_url( f"{_base_url}/ARNIQA.pth", progress=True, map_location="cpu" ) filtered_encoder_state_dict = { k.replace("model.", ""): v for k, v in encoder_state_dict.items() if "projector" not in k } self.encoder.load_state_dict(filtered_encoder_state_dict, strict=True) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning, module="torch.serialization") regressor_state_dict = torch.hub.load_state_dict_from_url( f"{_base_url}/regressor_{self.regressor_dataset}.pth", progress=True, map_location="cpu" ).state_dict() # Rename the keys to match the regressor's state_dict regressor_state_dict["weight"] = regressor_state_dict.pop("weights") regressor_state_dict["bias"] = regressor_state_dict.pop("biases").unsqueeze(0) self.regressor.load_state_dict(regressor_state_dict, strict=True) def _preprocess_input(self, img: Tensor, normalize: bool = False) -> tuple[Tensor, Tensor]: """Preprocesses the input to the model. Obtains the half-scale version of the input image and applies normalization if needed. """ h, w = img.shape[-2:] img_ds = transforms.Resize((h // 2, w // 2))(img) # get the half-scale version of the image if normalize: img = transforms.Normalize(mean=self.imagenet_norm_mean, std=self.imagenet_norm_std)(img) img_ds = transforms.Normalize(mean=self.imagenet_norm_mean, std=self.imagenet_norm_std)(img_ds) return img, img_ds def _scale_score(self, score: Tensor) -> Tensor: """Scales the quality score to be in the [0, 1] range, where higher is better.""" min_score, max_score = _AVAILABLE_REGRESSOR_DATASETS[self.regressor_dataset] return (score - min_score) / (max_score - min_score) def forward(self, img: Tensor, normalize: bool = False) -> Tensor: # Preprocessing img, img_ds = self._preprocess_input(img, normalize) # Extract features from full- and half-scale images img_f = self.encoder(img) img_f = img_f.view(-1, self.feat_dim) img_f = normalize_fn(img_f, dim=1) img_ds_f = self.encoder(img_ds) img_ds_f = img_ds_f.view(-1, self.feat_dim) img_ds_f = normalize_fn(img_ds_f, dim=1) f = torch.hstack((img_f, img_ds_f)) # Get the quality score score = self.regressor(f) return self._scale_score(score) class _NoTrainArniqa(_ARNIQA): """Wrapper to make sure ARNIQA never leaves evaluation mode.""" def train(self, mode: bool) -> "_NoTrainArniqa": # type: ignore[override] """Force network to always be in evaluation mode.""" return super().train(False) def _arniqa_update( img: Tensor, model: nn.Module, normalize: bool, autocast: bool = False ) -> tuple[Tensor, Union[int, Tensor]]: """Update step for ARNIQA metric. Args: img: the input image model: the pre-trained model normalize: boolean indicating whether the input image is normalized autocast: boolean indicating whether to use automatic mixed precision """ # Check that the input image is valid if not (img.ndim == 4 and img.shape[1] == 3): raise ValueError(f"Input image must have shape [N, 3, H, W]. Got input with shape {img.shape}.") if not (img.max() <= 1.0 and img.min() >= 0.0) and normalize: raise ValueError( f"Input image values must be in the [0, 1] range when normalize==True. Got input with values" f" in range {img.min()} and {img.max()}." ) if autocast: with torch.amp.autocast(device_type=img.device.type, dtype=img.dtype): loss = model(img, normalize=normalize) else: loss = model.to(dtype=img.dtype)(img, normalize=normalize) return loss.squeeze(), img.shape[0] def _arniqa_compute( scores: Tensor, num_scores: Union[Tensor, int], reduction: Literal["sum", "mean", "none"] = "mean" ) -> Tensor: """Compute step for ARNIQA metric.""" sum_scores = scores.sum() if reduction == "none": return scores if reduction == "mean": return sum_scores / num_scores return sum_scores def arniqa( img: Tensor, regressor_dataset: _TYPE_REGRESSOR_DATASET = "koniq10k", reduction: Literal["sum", "mean", "none"] = "mean", normalize: bool = True, autocast: bool = False, ) -> Tensor: """ARNIQA: leArning distoRtion maNifold for Image Quality Assessment metric. `ARNIQA`_ is a No-Reference Image Quality Assessment metric that predicts the technical quality of an image with a high correlation with human opinions. ARNIQA consists of an encoder and a regressor. The encoder is a ResNet-50 model trained in a self-supervised way to model the image distortion manifold to generate similar representation for images with similar distortions, regardless of the image content. The regressor is a linear model trained on IQA datasets using the ground-truth quality scores. ARNIQA extracts the features from the full- and half-scale versions of the input image and then outputs a quality score in the [0, 1] range, where higher is better. The input image is expected to have shape ``(N, 3, H, W)``. The image should be in the [0, 1] range if `normalize` is set to ``True``, otherwise it should be normalized with the ImageNet mean and standard deviation. .. note:: Using this metric requires you to have ``torchvision`` package installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torchvision``. Args: img: the input image regressor_dataset: dataset used for training the regressor. Choose between [``koniq10k``, ``kadid10k``]. ``koniq10k`` corresponds to the `KonIQ-10k`_ dataset, which consists of real-world images with authentic distortions. ``kadid10k`` corresponds to the `KADID-10k`_ dataset, which consists of images with synthetically generated distortions. reduction: indicates how to reduce over the batch dimension. Choose between [``sum``, ``mean``, ``none``]. normalize: by default this is ``True`` meaning that the input is expected to be in the [0, 1] range. If set to ``False`` will instead expect input to be already normalized with the ImageNet mean and standard deviation. autocast: boolean indicating whether to use automatic mixed precision Returns: A tensor in the [0, 1] range, where higher is better, representing the ARNIQA score of the input image. If `reduction` is set to ``none``, the output will have shape ``(N,)``, otherwise it will be a scalar tensor. Raises: ModuleNotFoundError: If ``torchvision`` package is not installed ValueError: If ``regressor_dataset`` is not in [``"kadid10k"``, ``"koniq10k"``] ValueError: If ``reduction`` is not in [``"sum"``, ``"mean"``, ``"none"``] ValueError: If ``normalize`` is not a bool ValueError: If the input image is not a valid image tensor with shape [N, 3, H, W]. ValueError: If the input image values are not in the [0, 1] range when ``normalize`` is set to ``True`` Examples: >>> from torch import rand >>> from torchmetrics.functional.image.arniqa import arniqa >>> img = rand(8, 3, 224, 224) >>> # Non-normalized input >>> arniqa(img, regressor_dataset='koniq10k', normalize=True) tensor(0.5308) >>> from torch import rand >>> from torchmetrics.functional.image.arniqa import arniqa >>> from torchvision.transforms import Normalize >>> img = rand(8, 3, 224, 224) >>> img = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) >>> # Normalized input >>> arniqa(img, regressor_dataset='koniq10k', normalize=False) tensor(0.5065) """ valid_reduction = ("mean", "sum", "none") if reduction not in valid_reduction: raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}") if not isinstance(normalize, bool): raise ValueError(f"Argument `normalize` should be a bool but got {normalize}") model = _NoTrainArniqa(regressor_dataset=regressor_dataset).to(device=img.device, dtype=img.dtype) loss, num_scores = _arniqa_update(img, model, normalize=normalize, autocast=autocast) return _arniqa_compute(loss, num_scores, reduction)