geman_mcclure.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. import torch
  19. from torch import Tensor
  20. from kornia.core import Module
  21. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SAME_DEVICE, KORNIA_CHECK_SAME_SHAPE
  22. def geman_mcclure_loss(img1: Tensor, img2: Tensor, reduction: str = "none") -> Tensor:
  23. r"""Criterion that computes the Geman-McClure loss [2].
  24. According to [1], we compute the Geman-McClure loss as follows:
  25. .. math::
  26. \text{WL}(x, y) = \frac{2 (x - y)^{2}}{(x - y)^{2} + 4}
  27. Where:
  28. - :math:`x` is the prediction.
  29. - :math:`y` is the target to be regressed to.
  30. Reference:
  31. [1] https://arxiv.org/pdf/1701.03077.pdf
  32. [2] Bayesian image analysis: An application to single photon emission tomography, Geman and McClure, 1985
  33. Args:
  34. img1: the predicted tensor with shape :math:`(*)`.
  35. img2: the target tensor with the same shape as img1.
  36. reduction: Specifies the reduction to apply to the
  37. output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
  38. will be applied (default), ``'mean'``: the sum of the output will be divided
  39. by the number of elements in the output, ``'sum'``: the output will be
  40. summed.
  41. Return:
  42. a scalar with the computed loss.
  43. Example:
  44. >>> img1 = torch.randn(2, 3, 32, 32, requires_grad=True)
  45. >>> img2 = torch.randn(2, 3, 32, 32)
  46. >>> output = geman_mcclure_loss(img1, img2, reduction="mean")
  47. >>> output.backward()
  48. """
  49. KORNIA_CHECK_IS_TENSOR(img1)
  50. KORNIA_CHECK_IS_TENSOR(img2)
  51. KORNIA_CHECK_SAME_SHAPE(img1, img2)
  52. KORNIA_CHECK_SAME_DEVICE(img1, img2)
  53. KORNIA_CHECK(
  54. reduction in ("mean", "sum", "none", None), f"Given type of reduction is not supported. Got: {reduction}"
  55. )
  56. # compute loss
  57. diff = img1 - img2
  58. diff2 = torch.square(diff)
  59. loss = 2.0 * diff2 / (diff2 + 4.0)
  60. # perform reduction
  61. if reduction == "mean":
  62. loss = loss.mean()
  63. elif reduction == "sum":
  64. loss = loss.sum()
  65. elif reduction == "none" or reduction is None:
  66. pass
  67. else:
  68. raise NotImplementedError("Invalid reduction option.")
  69. return loss
  70. class GemanMcclureLoss(Module):
  71. r"""Criterion that computes the Geman-McClure loss [2].
  72. According to [1], we compute the Geman-McClure loss as follows:
  73. .. math::
  74. \text{WL}(x, y) = \frac{2 (x - y)^{2}}{(x - y)^{2} + 4}
  75. Where:
  76. - :math:`x` is the prediction.
  77. - :math:`y` is the target to be regressed to.
  78. Reference:
  79. [1] https://arxiv.org/pdf/1701.03077.pdf
  80. [2] Bayesian image analysis: An application to single photon emission tomography, Geman and McClure, 1985
  81. Args:
  82. reduction: Specifies the reduction to apply to the
  83. output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
  84. will be applied (default), ``'mean'``: the sum of the output will be divided
  85. by the number of elements in the output, ``'sum'``: the output will be
  86. summed.
  87. Shape:
  88. - img1: the predicted tensor with shape :math:`(*)`.
  89. - img2: the target tensor with the same shape as img1.
  90. Example:
  91. >>> criterion = GemanMcclureLoss(reduction="mean")
  92. >>> img1 = torch.randn(2, 3, 32, 2107, requires_grad=True)
  93. >>> img2 = torch.randn(2, 3, 32, 2107)
  94. >>> output = criterion(img1, img2)
  95. >>> output.backward()
  96. """
  97. def __init__(self, reduction: str = "none") -> None:
  98. super().__init__()
  99. self.reduction = reduction
  100. def forward(self, img1: Tensor, img2: Tensor) -> Tensor:
  101. return geman_mcclure_loss(img1=img1, img2=img2, reduction=self.reduction)