| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- # Below is a derivative work based on the original work:
- # https://github.com/dingkeyan93/DISTS
- # with the following license:
- #
- # MIT License
- # Copyright (c) 2020 Keyan Ding
- # Permission is hereby granted, free of charge, to any person obtaining a copy
- # of this software and associated documentation files (the "Software"), to deal
- # in the Software without restriction, including without limitation the rights
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- # copies of the Software, and to permit persons to whom the Software is
- # furnished to do so, subject to the following conditions:
- # The above copyright notice and this permission notice shall be included in all
- # copies or substantial portions of the Software.
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- # SOFTWARE.
- from pathlib import Path
- from typing import List, Optional
- import numpy as np
- import torch
- import torch.nn as nn
- from torch import Tensor
- from torch.nn.functional import conv2d
- from typing_extensions import Literal
- from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE
- if not _TORCHVISION_AVAILABLE:
- __doctest_skip__ = ["deep_image_structure_and_texture_similarity"]
- else:
- from torchvision.models import VGG16_Weights, vgg16
- _PATH_WEIGHT_DISTS = Path(__file__).resolve().parent / "dists_models" / "weights.pt"
- class L2pooling(nn.Module):
- """L2 pooling layer."""
- filter: Tensor
- def __init__(self, filter_size: int = 5, stride: int = 2, channels: int = 3) -> None:
- super().__init__()
- self.padding = (filter_size - 2) // 2
- self.stride = stride
- self.channels = channels
- a = np.hanning(filter_size)[1:-1]
- g = torch.Tensor(a[:, None] * a[None, :])
- g = g / torch.sum(g)
- self.register_buffer("filter", g[None, None, :, :].repeat(self.channels, 1, 1, 1))
- def forward(self, tensor: Tensor) -> Tensor:
- """Forward pass of the layer."""
- tensor = tensor**2
- out = conv2d(tensor, self.filter, stride=self.stride, padding=self.padding, groups=tensor.shape[1])
- return (out + 1e-12).sqrt()
- class DISTSNetwork(torch.nn.Module):
- """DISTS network."""
- alpha: Tensor
- beta: Tensor
- mean: Tensor
- std: Tensor
- def __init__(self, load_weights: bool = True) -> None:
- super().__init__()
- if not _TORCHVISION_AVAILABLE:
- raise ModuleNotFoundError(
- "DISTS requires torchvision to be installed. Please install it with `pip install torchvision`."
- )
- vgg_pretrained_features = vgg16(weights=VGG16_Weights.DEFAULT).features
- self.stage1 = torch.nn.Sequential()
- self.stage2 = torch.nn.Sequential()
- self.stage3 = torch.nn.Sequential()
- self.stage4 = torch.nn.Sequential()
- self.stage5 = torch.nn.Sequential()
- for x in range(4):
- self.stage1.add_module(str(x), vgg_pretrained_features[x])
- self.stage2.add_module(str(4), L2pooling(channels=64))
- for x in range(5, 9):
- self.stage2.add_module(str(x), vgg_pretrained_features[x])
- self.stage3.add_module(str(9), L2pooling(channels=128))
- for x in range(10, 16):
- self.stage3.add_module(str(x), vgg_pretrained_features[x])
- self.stage4.add_module(str(16), L2pooling(channels=256))
- for x in range(17, 23):
- self.stage4.add_module(str(x), vgg_pretrained_features[x])
- self.stage5.add_module(str(23), L2pooling(channels=512))
- for x in range(24, 30):
- self.stage5.add_module(str(x), vgg_pretrained_features[x])
- for param in self.parameters():
- param.requires_grad = False
- self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1))
- self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1))
- self.chns = [3, 64, 128, 256, 512, 512]
- self.register_parameter("alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
- self.register_parameter("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
- self.alpha.data.normal_(0.1, 0.01)
- self.beta.data.normal_(0.1, 0.01)
- if load_weights:
- if not _PATH_WEIGHT_DISTS.exists():
- raise FileNotFoundError(f"The weights file is not found in {_PATH_WEIGHT_DISTS}")
- weights = torch.load(str(_PATH_WEIGHT_DISTS))
- self.alpha.data = weights["alpha"]
- self.beta.data = weights["beta"]
- def forward_once(self, x: Tensor) -> List[Tensor]:
- """Forward pass of the network."""
- h = (x - self.mean) / self.std
- h = self.stage1(h)
- h_relu1_2 = h
- h = self.stage2(h)
- h_relu2_2 = h
- h = self.stage3(h)
- h_relu3_3 = h
- h = self.stage4(h)
- h_relu4_3 = h
- h = self.stage5(h)
- h_relu5_3 = h
- return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
- def forward(self, x: Tensor, y: Tensor, require_grad: bool = False) -> Tensor:
- """Computes DISTS score between two images."""
- if require_grad:
- feats0 = self.forward_once(x)
- feats1 = self.forward_once(y)
- else:
- with torch.inference_mode():
- feats0 = self.forward_once(x)
- feats1 = self.forward_once(y)
- dist1: Tensor = torch.tensor(0.0, device=x.device)
- dist2: Tensor = torch.tensor(0.0, device=x.device)
- c1, c2 = 1e-6, 1e-6
- w_sum = self.alpha.sum() + self.beta.sum()
- alpha = torch.split(self.alpha / w_sum, self.chns, dim=1)
- beta = torch.split(self.beta / w_sum, self.chns, dim=1)
- for k in range(len(self.chns)):
- x_mean = feats0[k].mean([2, 3], keepdim=True)
- y_mean = feats1[k].mean([2, 3], keepdim=True)
- s1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1)
- dist1 = dist1 + (alpha[k] * s1).sum(1, keepdim=True)
- x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True)
- y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True)
- xy_cov = (feats0[k] * feats1[k]).mean([2, 3], keepdim=True) - x_mean * y_mean
- s2 = (2 * xy_cov + c2) / (x_var + y_var + c2)
- dist2 = dist2 + (beta[k] * s2).sum(1, keepdim=True)
- return 1 - (dist1 + dist2).squeeze()
- def _dists_update(preds: Tensor, target: Tensor) -> Tensor:
- dists = DISTSNetwork().to(preds.device)
- return dists(preds, target, require_grad=preds.requires_grad)
- def _dists_compute(scores: Tensor, reduction: Optional[Literal["sum", "mean", "none"]]) -> Tensor:
- if reduction == "sum":
- return scores.sum()
- if reduction == "mean":
- return scores.mean()
- if reduction is None or reduction == "none":
- return scores
- raise ValueError(f"Argument {reduction} is not valid. Choose 'sum', 'mean' or 'none'., but got {reduction}")
- def deep_image_structure_and_texture_similarity(
- preds: Tensor, target: Tensor, reduction: Optional[Literal["sum", "mean", "none"]] = None
- ) -> Tensor:
- """Calculates `Deep Image Structure and Texture Similarity`_ (DISTS) score.
- Args:
- preds: Predicted image tensor.
- target: Target image tensor.
- reduction: Reduction method for the output.
- Returns:
- DISTS Similarity score between the two images.
- Example:
- >>> from torch import rand
- >>> preds = rand(5, 3, 256, 256)
- >>> target = rand(5, 3, 256, 256)
- >>> deep_image_structure_and_texture_similarity(preds, target)
- tensor([0.1285, 0.1344, 0.1356, 0.1277, 0.1276], grad_fn=<RsubBackward1>)
- >>> deep_image_structure_and_texture_similarity(preds, target, reduction='mean')
- tensor(0.1308, grad_fn=<MeanBackward0>)
- """
- scores = _dists_update(preds, target)
- return _dists_compute(scores, reduction)
|