lovasz_softmax.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Optional
  19. import torch
  20. from torch import Tensor, nn
  21. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
  22. # based on:
  23. # https://github.com/bermanmaxim/LovaszSoftmax
  24. def lovasz_softmax_loss(pred: Tensor, target: Tensor, weight: Optional[Tensor] = None) -> Tensor:
  25. r"""Criterion that computes a surrogate multi-class intersection-over-union (IoU) loss.
  26. According to [1], we compute the IoU as follows:
  27. .. math::
  28. \text{IoU}(x, class) = \frac{|X \cap Y|}{|X \cup Y|}
  29. [1] approximates this fomular with a surrogate, which is fully differentable.
  30. Where:
  31. - :math:`X` expects to be the scores of each class.
  32. - :math:`Y` expects to be the long tensor with the class labels.
  33. the loss, is finally computed as:
  34. .. math::
  35. \text{loss}(x, class) = 1 - \text{IoU}(x, class)
  36. Reference:
  37. [1] https://arxiv.org/pdf/1705.08790.pdf
  38. .. note::
  39. This loss function only supports multi-class (C > 1) labels. For binary
  40. labels please use the Lovasz-Hinge loss.
  41. Args:
  42. pred: logits tensor with shape :math:`(N, C, H, W)` where C = number of classes > 1.
  43. target: labels tensor with shape :math:`(N, H, W)` where each value
  44. is :math:`0 ≤ targets[i] ≤ C-1`.
  45. weight: weights for classes with shape :math:`(num\_of\_classes,)`.
  46. Return:
  47. a scalar with the computed loss.
  48. Example:
  49. >>> N = 5 # num_classes
  50. >>> pred = torch.randn(1, N, 3, 5, requires_grad=True)
  51. >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
  52. >>> output = lovasz_softmax_loss(pred, target)
  53. >>> output.backward()
  54. """
  55. KORNIA_CHECK_SHAPE(pred, ["B", "N", "H", "W"])
  56. KORNIA_CHECK_SHAPE(target, ["B", "H", "W"])
  57. if not pred.shape[1] > 1:
  58. raise ValueError(f"Invalid pred shape, we expect BxNxHxW, with N > 1. Got: {pred.shape}")
  59. if not pred.shape[-2:] == target.shape[-2:]:
  60. raise ValueError(f"pred and target shapes must be the same. Got: {pred.shape} and {target.shape}")
  61. if not pred.device == target.device:
  62. raise ValueError(f"pred and target must be in the same device. Got: {pred.device} and {target.device}")
  63. num_of_classes = pred.shape[1]
  64. # compute the actual dice score
  65. if weight is not None:
  66. KORNIA_CHECK_IS_TENSOR(weight, "weight must be Tensor or None.")
  67. KORNIA_CHECK(
  68. (weight.shape[0] == num_of_classes and weight.numel() == num_of_classes),
  69. f"weight shape must be (num_of_classes,): ({num_of_classes},), got {weight.shape}",
  70. )
  71. KORNIA_CHECK(
  72. weight.device == pred.device,
  73. f"weight and pred must be in the same device. Got: {weight.device} and {pred.device}",
  74. )
  75. # flatten pred [B, C, -1] and target [B, -1] and to float
  76. pred_flatten: Tensor = pred.reshape(pred.shape[0], pred.shape[1], -1)
  77. target_flatten: Tensor = target.reshape(target.shape[0], -1)
  78. # get shapes
  79. B, C, N = pred_flatten.shape
  80. # compute softmax over the classes axis
  81. pred_soft: Tensor = pred_flatten.softmax(1)
  82. # compute actual loss
  83. foreground: Tensor = (
  84. torch.nn.functional.one_hot(target_flatten.to(torch.int64), num_classes=C).permute(0, 2, 1).to(pred.dtype)
  85. )
  86. errors: Tensor = (pred_soft - foreground).abs()
  87. errors_sorted, permutations = torch.sort(errors, dim=2, descending=True)
  88. batch_index = torch.arange(B, device=pred.device).unsqueeze(1).unsqueeze(2).expand(B, C, N)
  89. target_sorted = target_flatten[batch_index, permutations]
  90. target_sorted_sum = target_sorted.sum(2, keepdim=True)
  91. intersection = target_sorted_sum - target_sorted.cumsum(2)
  92. union = target_sorted_sum + (1.0 - target_sorted).cumsum(2)
  93. gradient = 1.0 - intersection / union
  94. if N > 1:
  95. gradient[..., 1:] = gradient[..., 1:] - gradient[..., :-1]
  96. weighted_errors = errors_sorted * gradient
  97. loss_per_class = weighted_errors.sum(2).mean(0)
  98. if weight is not None:
  99. loss_per_class *= weight
  100. final_loss: Tensor = loss_per_class.mean()
  101. return final_loss
  102. class LovaszSoftmaxLoss(nn.Module):
  103. r"""Criterion that computes a surrogate multi-class intersection-over-union (IoU) loss.
  104. According to [1], we compute the IoU as follows:
  105. .. math::
  106. \text{IoU}(x, class) = \frac{|X \cap Y|}{|X \cup Y|}
  107. [1] approximates this fomular with a surrogate, which is fully differentable.
  108. Where:
  109. - :math:`X` expects to be the scores of each class.
  110. - :math:`Y` expects to be the binary tensor with the class labels.
  111. the loss, is finally computed as:
  112. .. math::
  113. \text{loss}(x, class) = 1 - \text{IoU}(x, class)
  114. Reference:
  115. [1] https://arxiv.org/pdf/1705.08790.pdf
  116. .. note::
  117. This loss function only supports multi-class (C > 1) labels. For binary
  118. labels please use the Lovasz-Hinge loss.
  119. Args:
  120. pred: logits tensor with shape :math:`(N, C, H, W)` where C = number of classes > 1.
  121. labels: labels tensor with shape :math:`(N, H, W)` where each value
  122. is :math:`0 ≤ targets[i] ≤ C-1`.
  123. weight: weights for classes with shape :math:`(num\_of\_classes,)`.
  124. Return:
  125. a scalar with the computed loss.
  126. Example:
  127. >>> N = 5 # num_classes
  128. >>> criterion = LovaszSoftmaxLoss()
  129. >>> pred = torch.randn(1, N, 3, 5, requires_grad=True)
  130. >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
  131. >>> output = criterion(pred, target)
  132. >>> output.backward()
  133. """
  134. def __init__(self, weight: Optional[Tensor] = None) -> None:
  135. super().__init__()
  136. self.weight = weight
  137. def forward(self, pred: Tensor, target: Tensor) -> Tensor:
  138. return lovasz_softmax_loss(pred=pred, target=target, weight=self.weight)