dice.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Optional
  19. import torch
  20. from torch import nn
  21. from kornia.core import Tensor
  22. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR
  23. from kornia.losses._utils import mask_ignore_pixels
  24. from kornia.utils.one_hot import one_hot
  25. # based on:
  26. # https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py
  27. # https://github.com/Lightning-AI/metrics/blob/v0.11.3/src/torchmetrics/functional/classification/dice.py#L66-L207
  28. def dice_loss(
  29. pred: Tensor,
  30. target: Tensor,
  31. average: str = "micro",
  32. eps: float = 1e-8,
  33. weight: Optional[Tensor] = None,
  34. ignore_index: Optional[int] = -100,
  35. ) -> Tensor:
  36. r"""Criterion that computes Sørensen-Dice Coefficient loss.
  37. According to [1], we compute the Sørensen-Dice Coefficient as follows:
  38. .. math::
  39. \text{Dice}(x, class) = \frac{2 |X \cap Y|}{|X| + |Y|}
  40. Where:
  41. - :math:`X` expects to be the scores of each class.
  42. - :math:`Y` expects to be the one-hot tensor with the class labels.
  43. the loss, is finally computed as:
  44. .. math::
  45. \text{loss}(x, class) = 1 - \text{Dice}(x, class)
  46. Reference:
  47. [1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
  48. Args:
  49. pred: logits tensor with shape :math:`(N, C, H, W)` where C = number of classes.
  50. target: labels tensor with shape :math:`(N, H, W)` where each value
  51. is :math:`0 ≤ targets[i] ≤ C-1`.
  52. average:
  53. Reduction applied in multi-class scenario:
  54. - ``'micro'`` [default]: Calculate the loss across all classes.
  55. - ``'macro'``: Calculate the loss for each class separately and average the metrics across classes.
  56. eps: Scalar to enforce numerical stabiliy.
  57. weight: weights for classes with shape :math:`(num\_of\_classes,)`.
  58. ignore_index: labels with this value are ignored in the loss computation.
  59. Return:
  60. One-element tensor of the computed loss.
  61. Example:
  62. >>> N = 5 # num_classes
  63. >>> pred = torch.randn(1, N, 3, 5, requires_grad=True)
  64. >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
  65. >>> output = dice_loss(pred, target)
  66. >>> output.backward()
  67. """
  68. KORNIA_CHECK_IS_TENSOR(pred)
  69. if not len(pred.shape) == 4:
  70. raise ValueError(f"Invalid pred shape, we expect BxNxHxW. Got: {pred.shape}")
  71. if not pred.shape[-2:] == target.shape[-2:]:
  72. raise ValueError(f"pred and target shapes must be the same. Got: {pred.shape} and {target.shape}")
  73. if not pred.device == target.device:
  74. raise ValueError(f"pred and target must be in the same device. Got: {pred.device} and {target.device}")
  75. num_of_classes = pred.shape[1]
  76. possible_average = {"micro", "macro"}
  77. KORNIA_CHECK(average in possible_average, f"The `average` has to be one of {possible_average}. Got: {average}")
  78. # compute softmax over the classes axis
  79. pred_soft: Tensor = pred.softmax(dim=1)
  80. target, target_mask = mask_ignore_pixels(target, ignore_index)
  81. # create the labels one hot tensor
  82. target_one_hot: Tensor = one_hot(target, num_classes=pred.shape[1], device=pred.device, dtype=pred.dtype)
  83. # mask ignore pixels
  84. if target_mask is not None:
  85. target_mask.unsqueeze_(1)
  86. target_one_hot = target_one_hot * target_mask
  87. pred_soft = pred_soft * target_mask
  88. # compute the actual dice score
  89. if weight is not None:
  90. KORNIA_CHECK_IS_TENSOR(weight, "weight must be Tensor or None.")
  91. KORNIA_CHECK(
  92. (weight.shape[0] == num_of_classes and weight.numel() == num_of_classes),
  93. f"weight shape must be (num_of_classes,): ({num_of_classes},), got {weight.shape}",
  94. )
  95. KORNIA_CHECK(
  96. weight.device == pred.device,
  97. f"weight and pred must be in the same device. Got: {weight.device} and {pred.device}",
  98. )
  99. else:
  100. weight = pred.new_ones(pred.shape[1])
  101. # set dimensions for the appropriate averaging
  102. dims: tuple[int, ...] = (2, 3)
  103. if average == "micro":
  104. dims = (1, *dims)
  105. weight = weight.view(-1, 1, 1)
  106. pred_soft = pred_soft * weight
  107. target_one_hot = target_one_hot * weight
  108. intersection = torch.sum(pred_soft * target_one_hot, dims)
  109. cardinality = torch.sum(pred_soft + target_one_hot, dims)
  110. dice_score = 2.0 * intersection / (cardinality + eps)
  111. dice_loss = -dice_score + 1.0
  112. # reduce the loss across samples (and classes in case of `macro` averaging)
  113. if average == "macro":
  114. dice_loss = (dice_loss * weight).sum(-1) / weight.sum()
  115. dice_loss = torch.mean(dice_loss)
  116. return dice_loss
  117. class DiceLoss(nn.Module):
  118. r"""Criterion that computes Sørensen-Dice Coefficient loss.
  119. According to [1], we compute the Sørensen-Dice Coefficient as follows:
  120. .. math::
  121. \text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|}
  122. Where:
  123. - :math:`X` expects to be the scores of each class.
  124. - :math:`Y` expects to be the one-hot tensor with the class labels.
  125. the loss, is finally computed as:
  126. .. math::
  127. \text{loss}(x, class) = 1 - \text{Dice}(x, class)
  128. Reference:
  129. [1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
  130. Args:
  131. average:
  132. Reduction applied in multi-class scenario:
  133. - ``'micro'`` [default]: Calculate the loss across all classes.
  134. - ``'macro'``: Calculate the loss for each class separately and average the metrics across classes.
  135. eps: Scalar to enforce numerical stabiliy.
  136. weight: weights for classes with shape :math:`(num\_of\_classes,)`.
  137. ignore_index: labels with this value are ignored in the loss computation.
  138. Shape:
  139. - Pred: :math:`(N, C, H, W)` where C = number of classes.
  140. - Target: :math:`(N, H, W)` where each value is
  141. :math:`0 ≤ targets[i] ≤ C-1`.
  142. Example:
  143. >>> N = 5 # num_classes
  144. >>> criterion = DiceLoss()
  145. >>> pred = torch.randn(1, N, 3, 5, requires_grad=True)
  146. >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
  147. >>> output = criterion(pred, target)
  148. >>> output.backward()
  149. """
  150. def __init__(
  151. self,
  152. average: str = "micro",
  153. eps: float = 1e-8,
  154. weight: Optional[Tensor] = None,
  155. ignore_index: Optional[int] = -100,
  156. ) -> None:
  157. super().__init__()
  158. self.average = average
  159. self.eps = eps
  160. self.weight = weight
  161. self.ignore_index = ignore_index
  162. def forward(self, pred: Tensor, target: Tensor) -> Tensor:
  163. return dice_loss(pred, target, self.average, self.eps, self.weight, self.ignore_index)