psnr.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. import torch
  19. from torch import nn
  20. from kornia import metrics
  21. def psnr_loss(image: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor:
  22. r"""Compute the PSNR loss.
  23. The loss is computed as follows:
  24. .. math::
  25. \text{loss} = -\text{psnr(x, y)}
  26. See :meth:`~kornia.losses.psnr` for details abut PSNR.
  27. Args:
  28. image: the input image with shape :math:`(*)`.
  29. target : the labels image with shape :math:`(*)`.
  30. max_val: The maximum value in the image tensor.
  31. Return:
  32. the computed loss as a scalar.
  33. Examples:
  34. >>> ones = torch.ones(1)
  35. >>> psnr_loss(ones, 1.2 * ones, 2.) # 10 * log(4/((1.2-1)**2)) / log(10)
  36. tensor(-20.0000)
  37. """
  38. return -1.0 * metrics.psnr(image, target, max_val)
  39. class PSNRLoss(nn.Module):
  40. r"""Create a criterion that calculates the PSNR loss.
  41. The loss is computed as follows:
  42. .. math::
  43. \text{loss} = -\text{psnr(x, y)}
  44. See :meth:`~kornia.losses.psnr` for details abut PSNR.
  45. Args:
  46. max_val: The maximum value in the image tensor.
  47. Shape:
  48. - Image: arbitrary dimensional tensor :math:`(*)`.
  49. - Target: arbitrary dimensional tensor :math:`(*)` same shape as image.
  50. - Output: a scalar.
  51. Examples:
  52. >>> ones = torch.ones(1)
  53. >>> criterion = PSNRLoss(2.)
  54. >>> criterion(ones, 1.2 * ones) # 10 * log(4/((1.2-1)**2)) / log(10)
  55. tensor(-20.0000)
  56. """
  57. def __init__(self, max_val: float) -> None:
  58. super().__init__()
  59. self.max_val: float = max_val
  60. def forward(self, image: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
  61. return psnr_loss(image, target, self.max_val)