tversky.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Optional
  19. import torch
  20. import torch.nn.functional as F
  21. from torch import nn
  22. from kornia.losses._utils import mask_ignore_pixels
  23. # based on:
  24. # https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py
  25. def tversky_loss(
  26. pred: torch.Tensor,
  27. target: torch.Tensor,
  28. alpha: float,
  29. beta: float,
  30. eps: float = 1e-8,
  31. ignore_index: Optional[int] = -100,
  32. ) -> torch.Tensor:
  33. r"""Criterion that computes Tversky Coefficient loss.
  34. According to :cite:`salehi2017tversky`, we compute the Tversky Coefficient as follows:
  35. .. math::
  36. \text{S}(P, G, \alpha; \beta) =
  37. \frac{|PG|}{|PG| + \alpha |P \setminus G| + \beta |G \setminus P|}
  38. Where:
  39. - :math:`P` and :math:`G` are the predicted and ground truth binary
  40. labels.
  41. - :math:`\alpha` and :math:`\beta` control the magnitude of the
  42. penalties for FPs and FNs, respectively.
  43. Note:
  44. - :math:`\alpha = \beta = 0.5` => dice coeff
  45. - :math:`\alpha = \beta = 1` => tanimoto coeff
  46. - :math:`\alpha + \beta = 1` => F beta coeff
  47. Args:
  48. pred: logits tensor with shape :math:`(N, C, H, W)` where C = number of classes.
  49. target: labels tensor with shape :math:`(N, H, W)` where each value
  50. is :math:`0 ≤ targets[i] ≤ C-1`.
  51. alpha: the first coefficient in the denominator.
  52. beta: the second coefficient in the denominator.
  53. eps: scalar for numerical stability.
  54. ignore_index: labels with this value are ignored in the loss computation.
  55. Return:
  56. the computed loss.
  57. Example:
  58. >>> N = 5 # num_classes
  59. >>> pred = torch.randn(1, N, 3, 5, requires_grad=True)
  60. >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
  61. >>> output = tversky_loss(pred, target, alpha=0.5, beta=0.5)
  62. >>> output.backward()
  63. """
  64. if not isinstance(pred, torch.Tensor):
  65. raise TypeError(f"pred type is not a torch.Tensor. Got {type(pred)}")
  66. if not len(pred.shape) == 4:
  67. raise ValueError(f"Invalid pred shape, we expect BxNxHxW. Got: {pred.shape}")
  68. if not pred.shape[-2:] == target.shape[-2:]:
  69. raise ValueError(f"pred and target shapes must be the same. Got: {pred.shape} and {target.shape}")
  70. if not pred.device == target.device:
  71. raise ValueError(f"pred and target must be in the same device. Got: {pred.device} and {target.device}")
  72. # compute softmax over the classes axis
  73. pred_soft = F.softmax(pred, dim=1)
  74. target, target_mask = mask_ignore_pixels(target, ignore_index)
  75. p_true = pred_soft.gather(1, target.unsqueeze(1)) # (B,1,H,W)
  76. if target_mask is not None:
  77. m = target_mask.unsqueeze(1).to(dtype=pred.dtype)
  78. p_true = p_true * m
  79. total = m.sum((1, 2, 3))
  80. else:
  81. B, _, H, W = pred.shape
  82. total = torch.full((B,), H * W, dtype=pred.dtype, device=pred.device)
  83. intersection = p_true.sum((1, 2, 3))
  84. # denominator = intersection + (alpha + beta) * (total - intersection) + eps
  85. # instead of multiple ops, do it in one fused step:
  86. denominator = torch.addcmul(
  87. intersection, # base
  88. total - intersection, # tensor1
  89. torch.full_like(total, alpha + beta), # tensor2 (scalar as tensor)
  90. value=1.0, # (intersection) + 1 * (tensor1*tensor2)
  91. ).add_(eps) # in-place add eps
  92. score = intersection.div(denominator)
  93. return 1.0 - score.mean()
  94. class TverskyLoss(nn.Module):
  95. r"""Criterion that computes Tversky Coefficient loss.
  96. According to :cite:`salehi2017tversky`, we compute the Tversky Coefficient as follows:
  97. .. math::
  98. \text{S}(P, G, \alpha; \beta) =
  99. \frac{|PG|}{|PG| + \alpha |P \setminus G| + \beta |G \setminus P|}
  100. Where:
  101. - :math:`P` and :math:`G` are the predicted and ground truth binary
  102. labels.
  103. - :math:`\alpha` and :math:`\beta` control the magnitude of the
  104. penalties for FPs and FNs, respectively.
  105. Note:
  106. - :math:`\alpha = \beta = 0.5` => dice coeff
  107. - :math:`\alpha = \beta = 1` => tanimoto coeff
  108. - :math:`\alpha + \beta = 1` => F beta coeff
  109. Args:
  110. alpha: the first coefficient in the denominator.
  111. beta: the second coefficient in the denominator.
  112. eps: scalar for numerical stability.
  113. ignore_index: labels with this value are ignored in the loss computation.
  114. Shape:
  115. - Pred: :math:`(N, C, H, W)` where C = number of classes.
  116. - Target: :math:`(N, H, W)` where each value is
  117. :math:`0 ≤ targets[i] ≤ C-1`.
  118. Examples:
  119. >>> N = 5 # num_classes
  120. >>> criterion = TverskyLoss(alpha=0.5, beta=0.5)
  121. >>> pred = torch.randn(1, N, 3, 5, requires_grad=True)
  122. >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
  123. >>> output = criterion(pred, target)
  124. >>> output.backward()
  125. """
  126. def __init__(self, alpha: float, beta: float, eps: float = 1e-8, ignore_index: Optional[int] = -100) -> None:
  127. super().__init__()
  128. self.alpha: float = alpha
  129. self.beta: float = beta
  130. self.eps: float = eps
  131. self.ignore_index: Optional[int] = ignore_index
  132. def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
  133. return tversky_loss(pred, target, self.alpha, self.beta, self.eps, self.ignore_index)