lpips.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Content copied from
  15. # https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py
  16. # and
  17. # https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/pretrained_networks.py
  18. # and with adjustments from
  19. # https://github.com/richzhang/PerceptualSimilarity/pull/114/files
  20. # due to package no longer being maintained
  21. # Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
  22. # All rights reserved.
  23. # License under BSD 2-clause
  24. import inspect
  25. import os
  26. from typing import List, NamedTuple, Optional, Union
  27. import torch
  28. from torch import Tensor, nn
  29. from typing_extensions import Literal
  30. from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE
  31. _weight_map = {
  32. "squeezenet1_1": "SqueezeNet1_1_Weights",
  33. "alexnet": "AlexNet_Weights",
  34. "vgg16": "VGG16_Weights",
  35. }
  36. if not _TORCHVISION_AVAILABLE:
  37. __doctest_skip__ = ["learned_perceptual_image_patch_similarity", "_get_tv_model_features"]
  38. def _get_tv_model_features(net: str, pretrained: bool = False) -> nn.modules.container.Sequential:
  39. """Get torchvision network.
  40. Args:
  41. net: Name of network
  42. pretrained: If pretrained weights should be used
  43. >>> _ = _get_tv_model_features("alexnet", pretrained=True)
  44. >>> _ = _get_tv_model_features("squeezenet1_1", pretrained=True)
  45. >>> _ = _get_tv_model_features("vgg16", pretrained=True)
  46. """
  47. if not _TORCHVISION_AVAILABLE:
  48. raise ModuleNotFoundError("Torchvision is not installed. Please install torchvision to use this functionality.")
  49. import torchvision
  50. if pretrained:
  51. model_weights = getattr(torchvision.models, _weight_map[net])
  52. model = getattr(torchvision.models, net)(weights=model_weights.DEFAULT)
  53. else:
  54. model = getattr(torchvision.models, net)(weights=None)
  55. return model.features
  56. class SqueezeNet(torch.nn.Module):
  57. """SqueezeNet implementation."""
  58. def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None:
  59. super().__init__()
  60. pretrained_features = _get_tv_model_features("squeezenet1_1", pretrained)
  61. self.N_slices = 7
  62. slices = []
  63. feature_ranges = [range(2), range(2, 5), range(5, 8), range(8, 10), range(10, 11), range(11, 12), range(12, 13)]
  64. for feature_range in feature_ranges:
  65. seq = torch.nn.Sequential()
  66. for i in feature_range:
  67. seq.add_module(str(i), pretrained_features[i])
  68. slices.append(seq)
  69. self.slices = nn.ModuleList(slices)
  70. if not requires_grad:
  71. for param in self.parameters():
  72. param.requires_grad = False
  73. def forward(self, x: Tensor) -> NamedTuple:
  74. """Process input."""
  75. class _SqueezeOutput(NamedTuple):
  76. relu1: Tensor
  77. relu2: Tensor
  78. relu3: Tensor
  79. relu4: Tensor
  80. relu5: Tensor
  81. relu6: Tensor
  82. relu7: Tensor
  83. relus = []
  84. for slice_ in self.slices:
  85. x = slice_(x)
  86. relus.append(x)
  87. return _SqueezeOutput(*relus)
  88. class Alexnet(torch.nn.Module):
  89. """Alexnet implementation."""
  90. def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None:
  91. super().__init__()
  92. alexnet_pretrained_features = _get_tv_model_features("alexnet", pretrained)
  93. self.slice1 = torch.nn.Sequential()
  94. self.slice2 = torch.nn.Sequential()
  95. self.slice3 = torch.nn.Sequential()
  96. self.slice4 = torch.nn.Sequential()
  97. self.slice5 = torch.nn.Sequential()
  98. self.N_slices = 5
  99. for x in range(2):
  100. self.slice1.add_module(str(x), alexnet_pretrained_features[x])
  101. for x in range(2, 5):
  102. self.slice2.add_module(str(x), alexnet_pretrained_features[x])
  103. for x in range(5, 8):
  104. self.slice3.add_module(str(x), alexnet_pretrained_features[x])
  105. for x in range(8, 10):
  106. self.slice4.add_module(str(x), alexnet_pretrained_features[x])
  107. for x in range(10, 12):
  108. self.slice5.add_module(str(x), alexnet_pretrained_features[x])
  109. if not requires_grad:
  110. for param in self.parameters():
  111. param.requires_grad = False
  112. def forward(self, x: Tensor) -> NamedTuple:
  113. """Process input."""
  114. h = self.slice1(x)
  115. h_relu1 = h
  116. h = self.slice2(h)
  117. h_relu2 = h
  118. h = self.slice3(h)
  119. h_relu3 = h
  120. h = self.slice4(h)
  121. h_relu4 = h
  122. h = self.slice5(h)
  123. h_relu5 = h
  124. class _AlexnetOutputs(NamedTuple):
  125. relu1: Tensor
  126. relu2: Tensor
  127. relu3: Tensor
  128. relu4: Tensor
  129. relu5: Tensor
  130. return _AlexnetOutputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
  131. class Vgg16(torch.nn.Module):
  132. """Vgg16 implementation."""
  133. def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None:
  134. super().__init__()
  135. vgg_pretrained_features = _get_tv_model_features("vgg16", pretrained)
  136. self.slice1 = torch.nn.Sequential()
  137. self.slice2 = torch.nn.Sequential()
  138. self.slice3 = torch.nn.Sequential()
  139. self.slice4 = torch.nn.Sequential()
  140. self.slice5 = torch.nn.Sequential()
  141. self.N_slices = 5
  142. for x in range(4):
  143. self.slice1.add_module(str(x), vgg_pretrained_features[x])
  144. for x in range(4, 9):
  145. self.slice2.add_module(str(x), vgg_pretrained_features[x])
  146. for x in range(9, 16):
  147. self.slice3.add_module(str(x), vgg_pretrained_features[x])
  148. for x in range(16, 23):
  149. self.slice4.add_module(str(x), vgg_pretrained_features[x])
  150. for x in range(23, 30):
  151. self.slice5.add_module(str(x), vgg_pretrained_features[x])
  152. if not requires_grad:
  153. for param in self.parameters():
  154. param.requires_grad = False
  155. def forward(self, x: Tensor) -> NamedTuple:
  156. """Process input."""
  157. h = self.slice1(x)
  158. h_relu1_2 = h
  159. h = self.slice2(h)
  160. h_relu2_2 = h
  161. h = self.slice3(h)
  162. h_relu3_3 = h
  163. h = self.slice4(h)
  164. h_relu4_3 = h
  165. h = self.slice5(h)
  166. h_relu5_3 = h
  167. class _VGGOutputs(NamedTuple):
  168. relu1_2: Tensor
  169. relu2_2: Tensor
  170. relu3_3: Tensor
  171. relu4_3: Tensor
  172. relu5_3: Tensor
  173. return _VGGOutputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
  174. def _spatial_average(in_tens: Tensor, keep_dim: bool = True) -> Tensor:
  175. """Spatial averaging over height and width of images."""
  176. return in_tens.mean([2, 3], keepdim=keep_dim)
  177. def _upsample(in_tens: Tensor, out_hw: tuple[int, ...] = (64, 64)) -> Tensor:
  178. """Upsample input with bilinear interpolation."""
  179. return nn.Upsample(size=out_hw, mode="bilinear", align_corners=False)(in_tens)
  180. def _normalize_tensor(in_feat: Tensor, eps: float = 1e-8) -> Tensor:
  181. """Normalize input tensor."""
  182. norm_factor = torch.sqrt(eps + torch.sum(in_feat**2, dim=1, keepdim=True))
  183. return in_feat / norm_factor
  184. def _resize_tensor(x: Tensor, size: int = 64) -> Tensor:
  185. """https://github.com/toshas/torch-fidelity/blob/master/torch_fidelity/sample_similarity_lpips.py#L127C22-L132."""
  186. if x.shape[-1] > size and x.shape[-2] > size:
  187. return torch.nn.functional.interpolate(x, (size, size), mode="area")
  188. return torch.nn.functional.interpolate(x, (size, size), mode="bilinear", align_corners=False)
  189. class ScalingLayer(nn.Module):
  190. """Scaling layer."""
  191. shift: Tensor
  192. scale: Tensor
  193. def __init__(self) -> None:
  194. super().__init__()
  195. self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None], persistent=False)
  196. self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None], persistent=False)
  197. def forward(self, inp: Tensor) -> Tensor:
  198. """Process input."""
  199. return (inp - self.shift) / self.scale
  200. class NetLinLayer(nn.Module):
  201. """A single linear layer which does a 1x1 conv."""
  202. def __init__(self, chn_in: int, chn_out: int = 1, use_dropout: bool = False) -> None:
  203. super().__init__()
  204. layers = [nn.Dropout()] if use_dropout else []
  205. layers += [
  206. nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), # type: ignore[list-item]
  207. ]
  208. self.model = nn.Sequential(*layers)
  209. def forward(self, x: Tensor) -> Tensor:
  210. """Process input."""
  211. return self.model(x)
  212. class _LPIPS(nn.Module):
  213. def __init__(
  214. self,
  215. pretrained: bool = True,
  216. net: Literal["alex", "vgg", "squeeze"] = "alex",
  217. spatial: bool = False,
  218. pnet_rand: bool = False,
  219. pnet_tune: bool = False,
  220. use_dropout: bool = True,
  221. model_path: Optional[str] = None,
  222. eval_mode: bool = True,
  223. resize: Optional[int] = None,
  224. ) -> None:
  225. """Initializes a perceptual loss torch.nn.Module.
  226. Args:
  227. pretrained: This flag controls the linear layers should be pretrained version or random
  228. net: Indicate backbone to use, choose between ['alex','vgg','squeeze']
  229. spatial: If input should be spatial averaged
  230. pnet_rand: If backbone should be random or use imagenet pre-trained weights
  231. pnet_tune: If backprop should be enabled for both backbone and linear layers
  232. use_dropout: If dropout layers should be added
  233. model_path: Model path to load pretained models from
  234. eval_mode: If network should be in evaluation mode
  235. resize: If input should be resized to this size
  236. """
  237. super().__init__()
  238. self.pnet_type = net
  239. self.pnet_tune = pnet_tune
  240. self.pnet_rand = pnet_rand
  241. self.spatial = spatial
  242. self.resize = resize
  243. self.scaling_layer = ScalingLayer()
  244. if self.pnet_type in ["vgg", "vgg16"]:
  245. net_type = Vgg16
  246. self.chns = [64, 128, 256, 512, 512]
  247. elif self.pnet_type == "alex":
  248. net_type = Alexnet # type: ignore[assignment]
  249. self.chns = [64, 192, 384, 256, 256]
  250. elif self.pnet_type == "squeeze":
  251. net_type = SqueezeNet # type: ignore[assignment]
  252. self.chns = [64, 128, 256, 384, 384, 512, 512]
  253. self.L = len(self.chns)
  254. self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
  255. self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
  256. self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
  257. self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
  258. self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
  259. self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
  260. self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
  261. if self.pnet_type == "squeeze": # 7 layers for squeezenet
  262. self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
  263. self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
  264. self.lins += [self.lin5, self.lin6]
  265. self.lins = nn.ModuleList(self.lins) # type: ignore[assignment]
  266. if pretrained:
  267. if model_path is None:
  268. model_path = os.path.abspath(
  269. os.path.join(inspect.getfile(self.__init__), "..", f"lpips_models/{net}.pth") # type: ignore[misc]
  270. )
  271. self.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False)
  272. if eval_mode:
  273. self.eval()
  274. if not self.pnet_tune:
  275. for param in self.parameters():
  276. param.requires_grad = False
  277. def forward(
  278. self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False
  279. ) -> Union[Tensor, tuple[Tensor, List[Tensor]]]:
  280. if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
  281. in0 = 2 * in0 - 1
  282. in1 = 2 * in1 - 1
  283. # normalize input
  284. in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
  285. # resize input if needed
  286. if self.resize is not None:
  287. in0_input = _resize_tensor(in0_input, size=self.resize)
  288. in1_input = _resize_tensor(in1_input, size=self.resize)
  289. outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
  290. feats0, feats1, diffs = {}, {}, {}
  291. for kk in range(self.L):
  292. feats0[kk], feats1[kk] = _normalize_tensor(outs0[kk]), _normalize_tensor(outs1[kk])
  293. diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
  294. res = []
  295. for kk in range(self.L):
  296. if self.spatial:
  297. res.append(_upsample(self.lins[kk](diffs[kk]), out_hw=tuple(in0.shape[2:])))
  298. else:
  299. res.append(_spatial_average(self.lins[kk](diffs[kk]), keep_dim=True))
  300. val: Tensor = sum(res) # type: ignore[assignment]
  301. if retperlayer:
  302. return (val, res)
  303. return val
  304. class _NoTrainLpips(_LPIPS):
  305. """Wrapper to make sure LPIPS never leaves evaluation mode."""
  306. def train(self, mode: bool) -> "_NoTrainLpips": # type: ignore[override]
  307. """Force network to always be in evaluation mode."""
  308. return super().train(False)
  309. def _valid_img(img: Tensor, normalize: bool) -> bool:
  310. """Check that input is a valid image to the network."""
  311. value_check = img.max() <= 1.0 and img.min() >= 0.0 if normalize else img.min() >= -1
  312. return img.ndim == 4 and img.shape[1] == 3 and value_check # type: ignore[return-value]
  313. def _lpips_update(img1: Tensor, img2: Tensor, net: nn.Module, normalize: bool) -> Tensor:
  314. if not (_valid_img(img1, normalize) and _valid_img(img2, normalize)):
  315. raise ValueError(
  316. "Expected both input arguments to be normalized tensors with shape [N, 3, H, W]."
  317. f" Got input with shape {img1.shape} and {img2.shape} and values in range"
  318. f" {[img1.min(), img1.max()]} and {[img2.min(), img2.max()]} when all values are"
  319. f" expected to be in the {[0, 1] if normalize else [-1, 1]} range."
  320. )
  321. return net(img1, img2, normalize=normalize).squeeze()
  322. def _lpips_compute(scores: Tensor, reduction: Optional[Literal["sum", "mean", "none"]] = "mean") -> Tensor:
  323. if reduction == "mean":
  324. return scores.mean()
  325. if reduction == "sum":
  326. return scores.sum()
  327. if reduction == "none" or reduction is None:
  328. return scores
  329. raise ValueError(f"Invalid reduction type: {reduction}")
  330. def learned_perceptual_image_patch_similarity(
  331. img1: Tensor,
  332. img2: Tensor,
  333. net_type: Literal["alex", "vgg", "squeeze"] = "alex",
  334. reduction: Optional[Literal["sum", "mean", "none"]] = "mean",
  335. normalize: bool = False,
  336. ) -> Tensor:
  337. """The Learned Perceptual Image Patch Similarity (`LPIPS_`) calculates perceptual similarity between two images.
  338. LPIPS essentially computes the similarity between the activations of two image patches for some pre-defined network.
  339. This measure has been shown to match human perception well. A low LPIPS score means that image patches are
  340. perceptual similar.
  341. Both input image patches are expected to have shape ``(N, 3, H, W)``. The minimum size of `H, W` depends on the
  342. chosen backbone (see `net_type` arg).
  343. Args:
  344. img1: first set of images
  345. img2: second set of images
  346. net_type: str indicating backbone network type to use. Choose between `'alex'`, `'vgg'` or `'squeeze'`
  347. reduction: str indicating how to reduce over the batch dimension. Choose between `'sum'`, `'mean'`, `'none'`
  348. or `None`.
  349. normalize: by default this is ``False`` meaning that the input is expected to be in the [-1,1] range. If set
  350. to ``True`` will instead expect input to be in the ``[0,1]`` range.
  351. Example:
  352. >>> from torch import rand
  353. >>> from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
  354. >>> img1 = (rand(10, 3, 100, 100) * 2) - 1
  355. >>> img2 = (rand(10, 3, 100, 100) * 2) - 1
  356. >>> learned_perceptual_image_patch_similarity(img1, img2, net_type='squeeze')
  357. tensor(0.1005)
  358. >>> from torch import rand, Generator
  359. >>> from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
  360. >>> gen = Generator().manual_seed(42)
  361. >>> img1 = (rand(2, 3, 100, 100, generator=gen) * 2) - 1
  362. >>> img2 = (rand(2, 3, 100, 100, generator=gen) * 2) - 1
  363. >>> learned_perceptual_image_patch_similarity(img1, img2, net_type='squeeze', reduction='none')
  364. tensor([0.1024, 0.0938])
  365. """
  366. net = _NoTrainLpips(net=net_type).to(device=img1.device, dtype=img1.dtype)
  367. loss = _lpips_update(img1, img2, net, normalize)
  368. return _lpips_compute(loss, reduction)