dice.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright The PyTorch Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.segmentation.utils import _segmentation_inputs_format
  19. from torchmetrics.utilities import rank_zero_warn
  20. from torchmetrics.utilities.compute import _safe_divide
  21. def _dice_score_validate_args(
  22. num_classes: int,
  23. include_background: bool,
  24. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
  25. input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
  26. aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise",
  27. ) -> None:
  28. """Validate the arguments of the metric."""
  29. if not isinstance(num_classes, int) or num_classes <= 0:
  30. raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.")
  31. if not isinstance(include_background, bool):
  32. raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.")
  33. allowed_average = ["micro", "macro", "weighted", "none"]
  34. if average is not None and average not in allowed_average:
  35. raise ValueError(f"Expected argument `average` to be one of {allowed_average} or None, but got {average}.")
  36. if input_format not in ["one-hot", "index", "mixed"]:
  37. raise ValueError(
  38. f"Expected argument `input_format` to be one of 'one-hot', 'index', 'mixed', but got {input_format}."
  39. )
  40. if aggregation_level not in ("samplewise", "global"):
  41. raise ValueError(
  42. f"Expected argument `aggregation_level` to be one of `samplewise`, `global`, but got {aggregation_level}"
  43. )
  44. def _dice_score_update(
  45. preds: Tensor,
  46. target: Tensor,
  47. num_classes: int,
  48. include_background: bool,
  49. input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
  50. ) -> tuple[Tensor, Tensor, Tensor]:
  51. """Update the state with the current prediction and target."""
  52. preds, target = _segmentation_inputs_format(preds, target, include_background, num_classes, input_format)
  53. reduce_axis = list(range(2, target.ndim))
  54. intersection = torch.sum(preds * target, dim=reduce_axis)
  55. target_sum = torch.sum(target, dim=reduce_axis)
  56. pred_sum = torch.sum(preds, dim=reduce_axis)
  57. numerator = 2 * intersection
  58. denominator = pred_sum + target_sum
  59. support = target_sum
  60. return numerator, denominator, support
  61. def _dice_score_compute(
  62. numerator: Tensor,
  63. denominator: Tensor,
  64. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
  65. aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise",
  66. support: Optional[Tensor] = None,
  67. ) -> Tensor:
  68. """Compute the Dice score from the numerator and denominator."""
  69. if aggregation_level == "global":
  70. numerator = torch.sum(numerator, dim=0).unsqueeze(0)
  71. denominator = torch.sum(denominator, dim=0).unsqueeze(0)
  72. support = torch.sum(support, dim=0) if support is not None else None
  73. if average == "micro":
  74. numerator = torch.sum(numerator, dim=-1)
  75. denominator = torch.sum(denominator, dim=-1)
  76. return _safe_divide(numerator, denominator, zero_division="nan")
  77. dice = _safe_divide(numerator, denominator, zero_division="nan")
  78. if average == "macro":
  79. return torch.nanmean(dice, dim=-1)
  80. if average == "weighted":
  81. if not isinstance(support, torch.Tensor):
  82. raise ValueError(f"Expected argument `support` to be a tensor, got: {type(support)}.")
  83. weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division="nan")
  84. nan_mask = dice.isnan().all(dim=-1)
  85. dice = torch.nansum(dice * weights, dim=-1)
  86. dice[nan_mask] = torch.nan
  87. return dice
  88. if average in ("none", None):
  89. return dice
  90. raise ValueError(f"Invalid value for `average`: {average}.")
  91. def dice_score(
  92. preds: Tensor,
  93. target: Tensor,
  94. num_classes: int,
  95. include_background: bool = True,
  96. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  97. input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
  98. aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise",
  99. ) -> Tensor:
  100. """Compute the Dice score for semantic segmentation.
  101. Args:
  102. preds: Predictions from model
  103. target: Ground truth values
  104. num_classes: Number of classes
  105. include_background: Whether to include the background class in the computation
  106. average: The method to average the dice score. Options are ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``
  107. or ``None``. This determines how to average the dice score across different classes.
  108. input_format: What kind of input the function receives.
  109. Choose between ``"one-hot"`` for one-hot encoded tensors, ``"index"`` for index tensors
  110. or ``"mixed"`` for one one-hot encoded and one index tensor
  111. aggregation_level: The level at which to aggregate the dice score. Options are ``"samplewise"`` or ``"global"``.
  112. For ``"samplewise"`` the dice score is computed for each sample and then averaged. For ``"global"`` the dice
  113. score is computed globally over all samples.
  114. Returns:
  115. The Dice score.
  116. Example (with one-hot encoded tensors):
  117. >>> from torch import randint
  118. >>> from torchmetrics.functional.segmentation import dice_score
  119. >>> _ = torch.manual_seed(42)
  120. >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction
  121. >>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target
  122. >>> # dice score micro averaged over all classes
  123. >>> dice_score(preds, target, num_classes=5, average="micro")
  124. tensor([0.4842, 0.4968, 0.5053, 0.4902])
  125. >>> # dice score per sample and class
  126. >>> dice_score(preds, target, num_classes=5, average="none")
  127. tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500],
  128. [0.4571, 0.4980, 0.5191, 0.4380, 0.5649],
  129. [0.5428, 0.4904, 0.5358, 0.4830, 0.4724],
  130. [0.4715, 0.4925, 0.4797, 0.5267, 0.4788]])
  131. >>> # global dice score over all samples with macro averaging
  132. >>> dice_score(preds, target, num_classes=5, average="macro", aggregation_level="global")
  133. tensor([0.4942])
  134. Example (with index tensors):
  135. >>> from torch import randint
  136. >>> from torchmetrics.functional.segmentation import dice_score
  137. >>> _ = torch.manual_seed(42)
  138. >>> preds = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 prediction
  139. >>> target = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 target
  140. >>> # dice score micro averaged over all classes
  141. >>> dice_score(preds, target, num_classes=5, average="micro", input_format="index")
  142. tensor([0.2031, 0.1914, 0.2266, 0.1641])
  143. >>> # dice score per sample and class
  144. >>> dice_score(preds, target, num_classes=5, average="none", input_format="index")
  145. tensor([[0.1731, 0.1667, 0.2400, 0.2424, 0.1947],
  146. [0.2245, 0.2247, 0.2321, 0.1132, 0.1682],
  147. [0.2500, 0.2476, 0.1887, 0.1818, 0.2718],
  148. [0.1308, 0.1800, 0.1980, 0.1607, 0.1522]])
  149. >>> # global dice score over all samples with macro averaging
  150. >>> dice_score(preds, target, num_classes=5, average="macro", aggregation_level="global", input_format="index")
  151. tensor([0.1965])
  152. """
  153. if average == "micro":
  154. rank_zero_warn(
  155. "dice_score metric currently defaults to `average=micro`, but will change to"
  156. "`average=macro` in the v1.9 release."
  157. " If you've explicitly set this parameter, you can ignore this warning.",
  158. UserWarning,
  159. )
  160. _dice_score_validate_args(num_classes, include_background, average, input_format, aggregation_level)
  161. numerator, denominator, support = _dice_score_update(preds, target, num_classes, include_background, input_format)
  162. return _dice_score_compute(numerator, denominator, average, aggregation_level=aggregation_level, support=support)