dists.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # Below is a derivative work based on the original work:
  16. # https://github.com/dingkeyan93/DISTS
  17. # with the following license:
  18. #
  19. # MIT License
  20. # Copyright (c) 2020 Keyan Ding
  21. # Permission is hereby granted, free of charge, to any person obtaining a copy
  22. # of this software and associated documentation files (the "Software"), to deal
  23. # in the Software without restriction, including without limitation the rights
  24. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  25. # copies of the Software, and to permit persons to whom the Software is
  26. # furnished to do so, subject to the following conditions:
  27. # The above copyright notice and this permission notice shall be included in all
  28. # copies or substantial portions of the Software.
  29. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  30. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  31. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  32. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  33. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  34. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  35. # SOFTWARE.
  36. from pathlib import Path
  37. from typing import List, Optional
  38. import numpy as np
  39. import torch
  40. import torch.nn as nn
  41. from torch import Tensor
  42. from torch.nn.functional import conv2d
  43. from typing_extensions import Literal
  44. from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE
  45. if not _TORCHVISION_AVAILABLE:
  46. __doctest_skip__ = ["deep_image_structure_and_texture_similarity"]
  47. else:
  48. from torchvision.models import VGG16_Weights, vgg16
  49. _PATH_WEIGHT_DISTS = Path(__file__).resolve().parent / "dists_models" / "weights.pt"
  50. class L2pooling(nn.Module):
  51. """L2 pooling layer."""
  52. filter: Tensor
  53. def __init__(self, filter_size: int = 5, stride: int = 2, channels: int = 3) -> None:
  54. super().__init__()
  55. self.padding = (filter_size - 2) // 2
  56. self.stride = stride
  57. self.channels = channels
  58. a = np.hanning(filter_size)[1:-1]
  59. g = torch.Tensor(a[:, None] * a[None, :])
  60. g = g / torch.sum(g)
  61. self.register_buffer("filter", g[None, None, :, :].repeat(self.channels, 1, 1, 1))
  62. def forward(self, tensor: Tensor) -> Tensor:
  63. """Forward pass of the layer."""
  64. tensor = tensor**2
  65. out = conv2d(tensor, self.filter, stride=self.stride, padding=self.padding, groups=tensor.shape[1])
  66. return (out + 1e-12).sqrt()
  67. class DISTSNetwork(torch.nn.Module):
  68. """DISTS network."""
  69. alpha: Tensor
  70. beta: Tensor
  71. mean: Tensor
  72. std: Tensor
  73. def __init__(self, load_weights: bool = True) -> None:
  74. super().__init__()
  75. if not _TORCHVISION_AVAILABLE:
  76. raise ModuleNotFoundError(
  77. "DISTS requires torchvision to be installed. Please install it with `pip install torchvision`."
  78. )
  79. vgg_pretrained_features = vgg16(weights=VGG16_Weights.DEFAULT).features
  80. self.stage1 = torch.nn.Sequential()
  81. self.stage2 = torch.nn.Sequential()
  82. self.stage3 = torch.nn.Sequential()
  83. self.stage4 = torch.nn.Sequential()
  84. self.stage5 = torch.nn.Sequential()
  85. for x in range(4):
  86. self.stage1.add_module(str(x), vgg_pretrained_features[x])
  87. self.stage2.add_module(str(4), L2pooling(channels=64))
  88. for x in range(5, 9):
  89. self.stage2.add_module(str(x), vgg_pretrained_features[x])
  90. self.stage3.add_module(str(9), L2pooling(channels=128))
  91. for x in range(10, 16):
  92. self.stage3.add_module(str(x), vgg_pretrained_features[x])
  93. self.stage4.add_module(str(16), L2pooling(channels=256))
  94. for x in range(17, 23):
  95. self.stage4.add_module(str(x), vgg_pretrained_features[x])
  96. self.stage5.add_module(str(23), L2pooling(channels=512))
  97. for x in range(24, 30):
  98. self.stage5.add_module(str(x), vgg_pretrained_features[x])
  99. for param in self.parameters():
  100. param.requires_grad = False
  101. self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1))
  102. self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1))
  103. self.chns = [3, 64, 128, 256, 512, 512]
  104. self.register_parameter("alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
  105. self.register_parameter("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
  106. self.alpha.data.normal_(0.1, 0.01)
  107. self.beta.data.normal_(0.1, 0.01)
  108. if load_weights:
  109. if not _PATH_WEIGHT_DISTS.exists():
  110. raise FileNotFoundError(f"The weights file is not found in {_PATH_WEIGHT_DISTS}")
  111. weights = torch.load(str(_PATH_WEIGHT_DISTS))
  112. self.alpha.data = weights["alpha"]
  113. self.beta.data = weights["beta"]
  114. def forward_once(self, x: Tensor) -> List[Tensor]:
  115. """Forward pass of the network."""
  116. h = (x - self.mean) / self.std
  117. h = self.stage1(h)
  118. h_relu1_2 = h
  119. h = self.stage2(h)
  120. h_relu2_2 = h
  121. h = self.stage3(h)
  122. h_relu3_3 = h
  123. h = self.stage4(h)
  124. h_relu4_3 = h
  125. h = self.stage5(h)
  126. h_relu5_3 = h
  127. return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
  128. def forward(self, x: Tensor, y: Tensor, require_grad: bool = False) -> Tensor:
  129. """Computes DISTS score between two images."""
  130. if require_grad:
  131. feats0 = self.forward_once(x)
  132. feats1 = self.forward_once(y)
  133. else:
  134. with torch.inference_mode():
  135. feats0 = self.forward_once(x)
  136. feats1 = self.forward_once(y)
  137. dist1: Tensor = torch.tensor(0.0, device=x.device)
  138. dist2: Tensor = torch.tensor(0.0, device=x.device)
  139. c1, c2 = 1e-6, 1e-6
  140. w_sum = self.alpha.sum() + self.beta.sum()
  141. alpha = torch.split(self.alpha / w_sum, self.chns, dim=1)
  142. beta = torch.split(self.beta / w_sum, self.chns, dim=1)
  143. for k in range(len(self.chns)):
  144. x_mean = feats0[k].mean([2, 3], keepdim=True)
  145. y_mean = feats1[k].mean([2, 3], keepdim=True)
  146. s1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1)
  147. dist1 = dist1 + (alpha[k] * s1).sum(1, keepdim=True)
  148. x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True)
  149. y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True)
  150. xy_cov = (feats0[k] * feats1[k]).mean([2, 3], keepdim=True) - x_mean * y_mean
  151. s2 = (2 * xy_cov + c2) / (x_var + y_var + c2)
  152. dist2 = dist2 + (beta[k] * s2).sum(1, keepdim=True)
  153. return 1 - (dist1 + dist2).squeeze()
  154. def _dists_update(preds: Tensor, target: Tensor) -> Tensor:
  155. dists = DISTSNetwork().to(preds.device)
  156. return dists(preds, target, require_grad=preds.requires_grad)
  157. def _dists_compute(scores: Tensor, reduction: Optional[Literal["sum", "mean", "none"]]) -> Tensor:
  158. if reduction == "sum":
  159. return scores.sum()
  160. if reduction == "mean":
  161. return scores.mean()
  162. if reduction is None or reduction == "none":
  163. return scores
  164. raise ValueError(f"Argument {reduction} is not valid. Choose 'sum', 'mean' or 'none'., but got {reduction}")
  165. def deep_image_structure_and_texture_similarity(
  166. preds: Tensor, target: Tensor, reduction: Optional[Literal["sum", "mean", "none"]] = None
  167. ) -> Tensor:
  168. """Calculates `Deep Image Structure and Texture Similarity`_ (DISTS) score.
  169. Args:
  170. preds: Predicted image tensor.
  171. target: Target image tensor.
  172. reduction: Reduction method for the output.
  173. Returns:
  174. DISTS Similarity score between the two images.
  175. Example:
  176. >>> from torch import rand
  177. >>> preds = rand(5, 3, 256, 256)
  178. >>> target = rand(5, 3, 256, 256)
  179. >>> deep_image_structure_and_texture_similarity(preds, target)
  180. tensor([0.1285, 0.1344, 0.1356, 0.1277, 0.1276], grad_fn=<RsubBackward1>)
  181. >>> deep_image_structure_and_texture_similarity(preds, target, reduction='mean')
  182. tensor(0.1308, grad_fn=<MeanBackward0>)
  183. """
  184. scores = _dists_update(preds, target)
  185. return _dists_compute(scores, reduction)