welsch.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from torch import Tensor
  19. from kornia.core import Module
  20. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SAME_DEVICE, KORNIA_CHECK_SAME_SHAPE
  21. def welsch_loss(img1: Tensor, img2: Tensor, reduction: str = "none") -> Tensor:
  22. r"""Criterion that computes the Welsch [2] (aka. Leclerc [3]) loss.
  23. According to [1], we compute the Welsch loss as follows:
  24. .. math::
  25. \text{WL}(x, y) = 1 - exp(-\frac{1}{2} (x - y)^{2})
  26. Where:
  27. - :math:`x` is the prediction.
  28. - :math:`y` is the target to be regressed to.
  29. Reference:
  30. [1] https://arxiv.org/pdf/1701.03077.pdf
  31. [2] https://www.tandfonline.com/doi/abs/10.1080/03610917808812083
  32. [3] https://link.springer.com/article/10.1007/BF00054839
  33. Args:
  34. img1: the predicted tensor with shape :math:`(*)`.
  35. img2: the target tensor with the same shape as img1.
  36. reduction: Specifies the reduction to apply to the
  37. output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
  38. will be applied (default), ``'mean'``: the sum of the output will be divided
  39. by the number of elements in the output, ``'sum'``: the output will be
  40. summed.
  41. Return:
  42. a scalar with the computed loss.
  43. Example:
  44. >>> img1 = torch.randn(2, 3, 32, 32, requires_grad=True)
  45. >>> img2 = torch.randn(2, 3, 32, 32)
  46. >>> output = welsch_loss(img1, img2, reduction="mean")
  47. >>> output.backward()
  48. """
  49. KORNIA_CHECK_IS_TENSOR(img1)
  50. KORNIA_CHECK_IS_TENSOR(img2)
  51. KORNIA_CHECK_SAME_SHAPE(img1, img2)
  52. KORNIA_CHECK_SAME_DEVICE(img1, img2)
  53. KORNIA_CHECK(reduction in ("mean", "sum", "none"), f"Given type of reduction is not supported. Got: {reduction}")
  54. # compute loss
  55. loss = 1.0 - (-0.5 * (img1 - img2) ** 2).exp()
  56. # perform reduction
  57. if reduction == "mean":
  58. loss = loss.mean()
  59. elif reduction == "sum":
  60. loss = loss.sum()
  61. elif reduction == "none":
  62. pass
  63. else:
  64. raise NotImplementedError("Invalid reduction option.")
  65. return loss
  66. class WelschLoss(Module):
  67. r"""Criterion that computes the Welsch [2] (aka. Leclerc [3]) loss.
  68. According to [1], we compute the Welsch loss as follows:
  69. .. math::
  70. \text{WL}(x, y) = 1 - exp(-\frac{1}{2} (x - y)^{2})
  71. Where:
  72. - :math:`x` is the prediction.
  73. - :math:`y` is the target to be regressed to.
  74. Reference:
  75. [1] https://arxiv.org/pdf/1701.03077.pdf
  76. [2] https://www.tandfonline.com/doi/abs/10.1080/03610917808812083
  77. [3] https://link.springer.com/article/10.1007/BF00054839
  78. Args:
  79. reduction: Specifies the reduction to apply to the
  80. output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
  81. will be applied (default), ``'mean'``: the sum of the output will be divided
  82. by the number of elements in the output, ``'sum'``: the output will be
  83. summed.
  84. Shape:
  85. - img1: the predicted tensor with shape :math:`(*)`.
  86. - img2: the target tensor with the same shape as img1.
  87. Example:
  88. >>> criterion = WelschLoss(reduction="mean")
  89. >>> img1 = torch.randn(2, 3, 32, 1904, requires_grad=True)
  90. >>> img2 = torch.randn(2, 3, 32, 1904)
  91. >>> output = criterion(img1, img2)
  92. >>> output.backward()
  93. """
  94. def __init__(self, reduction: str = "none") -> None:
  95. super().__init__()
  96. self.reduction = reduction
  97. def forward(self, img1: Tensor, img2: Tensor) -> Tensor:
  98. return welsch_loss(img1=img1, img2=img2, reduction=self.reduction)