lovasz_hinge.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. import torch
  19. from torch import Tensor, nn
  20. from kornia.core.check import KORNIA_CHECK_SHAPE
  21. # based on:
  22. # https://github.com/bermanmaxim/LovaszSoftmax
  23. def lovasz_hinge_loss(pred: Tensor, target: Tensor) -> Tensor:
  24. r"""Criterion that computes a surrogate binary intersection-over-union (IoU) loss.
  25. According to [2], we compute the IoU as follows:
  26. .. math::
  27. \text{IoU}(x, class) = \frac{|X \cap Y|}{|X \cup Y|}
  28. [1] approximates this fomular with a surrogate, which is fully differentable.
  29. Where:
  30. - :math:`X` expects to be the scores of each class.
  31. - :math:`Y` expects to be the binary tensor with the class labels.
  32. the loss, is finally computed as:
  33. .. math::
  34. \text{loss}(x, class) = 1 - \text{IoU}(x, class)
  35. Reference:
  36. [1] http://proceedings.mlr.press/v37/yub15.pdf
  37. [2] https://arxiv.org/pdf/1705.08790.pdf
  38. .. note::
  39. This loss function only supports binary labels. For multi-class labels please
  40. use the Lovasz-Softmax loss.
  41. Args:
  42. pred: logits tensor with shape :math:`(N, 1, H, W)`.
  43. target: labels tensor with shape :math:`(N, H, W)` with binary values.
  44. Return:
  45. a scalar with the computed loss.
  46. Example:
  47. >>> N = 1 # num_classes
  48. >>> pred = torch.randn(1, N, 3, 5, requires_grad=True)
  49. >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
  50. >>> output = lovasz_hinge_loss(pred, target)
  51. >>> output.backward()
  52. """
  53. KORNIA_CHECK_SHAPE(pred, ["B", "1", "H", "W"])
  54. KORNIA_CHECK_SHAPE(target, ["B", "H", "W"])
  55. if not pred.shape[-2:] == target.shape[-2:]:
  56. raise ValueError(f"pred and target shapes must be the same. Got: {pred.shape} and {target.shape}")
  57. if not pred.device == target.device:
  58. raise ValueError(f"pred and target must be in the same device. Got: {pred.device} and {target.device}")
  59. # flatten pred and target [B, -1] and to float
  60. pred_flatten: Tensor = pred.reshape(pred.shape[0], -1)
  61. target_flatten: Tensor = target.reshape(target.shape[0], -1)
  62. # get shapes
  63. B, N = pred_flatten.shape
  64. # compute actual loss
  65. signs = 2.0 * target_flatten - 1.0
  66. errors = 1.0 - pred_flatten * signs
  67. errors_sorted, permutation = errors.sort(dim=1, descending=True)
  68. batch_index: Tensor = torch.arange(B, device=pred.device).reshape(-1, 1).repeat(1, N).reshape(-1)
  69. target_sorted: Tensor = target_flatten[batch_index, permutation.view(-1)]
  70. target_sorted = target_sorted.view(B, N)
  71. target_sorted_sum: Tensor = target_sorted.sum(1, keepdim=True)
  72. intersection: Tensor = target_sorted_sum - target_sorted.cumsum(1)
  73. union: Tensor = target_sorted_sum + (1.0 - target_sorted).cumsum(1)
  74. gradient: Tensor = 1.0 - intersection / union
  75. if N > 1:
  76. gradient[..., 1:] = gradient[..., 1:] - gradient[..., :-1]
  77. loss: Tensor = (errors_sorted.relu() * gradient).sum(1).mean()
  78. return loss
  79. class LovaszHingeLoss(nn.Module):
  80. r"""Criterion that computes a surrogate binary intersection-over-union (IoU) loss.
  81. According to [2], we compute the IoU as follows:
  82. .. math::
  83. \text{IoU}(x, class) = \frac{|X \cap Y|}{|X \cup Y|}
  84. [1] approximates this fomular with a surrogate, which is fully differentable.
  85. Where:
  86. - :math:`X` expects to be the scores of each class.
  87. - :math:`Y` expects to be the binary tensor with the class labels.
  88. the loss, is finally computed as:
  89. .. math::
  90. \text{loss}(x, class) = 1 - \text{IoU}(x, class)
  91. Reference:
  92. [1] http://proceedings.mlr.press/v37/yub15.pdf
  93. [2] https://arxiv.org/pdf/1705.08790.pdf
  94. .. note::
  95. This loss function only supports binary labels. For multi-class labels please
  96. use the Lovasz-Softmax loss.
  97. Args:
  98. pred: logits tensor with shape :math:`(N, 1, H, W)`.
  99. labels: labels tensor with shape :math:`(N, H, W)` with binary values.
  100. Return:
  101. a scalar with the computed loss.
  102. Example:
  103. >>> N = 1 # num_classes
  104. >>> criterion = LovaszHingeLoss()
  105. >>> pred = torch.randn(1, N, 3, 5, requires_grad=True)
  106. >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
  107. >>> output = criterion(pred, target)
  108. >>> output.backward()
  109. """
  110. def __init__(self) -> None:
  111. super().__init__()
  112. def forward(self, pred: Tensor, target: Tensor) -> Tensor:
  113. return lovasz_hinge_loss(pred=pred, target=target)