mean_iou.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright The PyTorch Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional, Tuple, Union
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.segmentation.utils import _segmentation_inputs_format
  19. from torchmetrics.utilities.compute import _safe_divide
  20. def _mean_iou_reshape_args(
  21. preds: Tensor,
  22. targets: Tensor,
  23. input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
  24. ) -> Tuple[Tensor, Tensor]:
  25. """Reshape tensors to 3D if needed."""
  26. if input_format == "one-hot":
  27. return preds, targets
  28. if preds.dim() == 1:
  29. preds = preds.unsqueeze(0).unsqueeze(0)
  30. elif preds.dim() == 2:
  31. preds = preds.unsqueeze(0)
  32. if targets.dim() == 1:
  33. targets = targets.unsqueeze(0).unsqueeze(0)
  34. elif targets.dim() == 2:
  35. targets = targets.unsqueeze(0)
  36. return preds, targets
  37. def _mean_iou_validate_args(
  38. num_classes: Optional[int],
  39. include_background: bool,
  40. per_class: bool,
  41. input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
  42. ) -> None:
  43. """Validate the arguments of the metric."""
  44. if input_format in ["index"] and num_classes is None:
  45. raise ValueError("Argument `num_classes` must be provided when `input_format` is 'index'.")
  46. if num_classes is not None and num_classes <= 0:
  47. raise ValueError(
  48. f"Expected argument `num_classes` must be `None` or a positive integer, but got {num_classes}."
  49. )
  50. if not isinstance(include_background, bool):
  51. raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.")
  52. if not isinstance(per_class, bool):
  53. raise ValueError(f"Expected argument `per_class` must be a boolean, but got {per_class}.")
  54. if input_format not in ["one-hot", "index", "mixed"]:
  55. raise ValueError(
  56. f"Expected argument `input_format` to be one of 'one-hot', 'index', 'mixed', but got {input_format}."
  57. )
  58. def _mean_iou_update(
  59. preds: Tensor,
  60. target: Tensor,
  61. num_classes: Optional[int] = None,
  62. include_background: bool = False,
  63. input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
  64. ) -> tuple[Tensor, Tensor]:
  65. """Update the intersection and union counts for the mean IoU computation."""
  66. preds, target = _mean_iou_reshape_args(preds, target, input_format)
  67. preds, target = _segmentation_inputs_format(preds, target, include_background, num_classes, input_format)
  68. reduce_axis = list(range(2, preds.ndim))
  69. intersection = torch.sum(preds & target, dim=reduce_axis)
  70. target_sum = torch.sum(target, dim=reduce_axis)
  71. pred_sum = torch.sum(preds, dim=reduce_axis)
  72. union = target_sum + pred_sum - intersection
  73. return intersection, union
  74. def _mean_iou_compute(
  75. intersection: Tensor,
  76. union: Tensor,
  77. zero_division: Union[float, Literal["warn", "nan"]],
  78. ) -> Tensor:
  79. """Compute the mean IoU metric."""
  80. return _safe_divide(intersection, union, zero_division=zero_division)
  81. def mean_iou(
  82. preds: Tensor,
  83. target: Tensor,
  84. num_classes: Optional[int] = None,
  85. include_background: bool = True,
  86. per_class: bool = False,
  87. input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
  88. ) -> Tensor:
  89. """Calculates the mean Intersection over Union (mIoU) for semantic segmentation.
  90. Returns -1 if class is completely absent both from predictions and ground truth labels.
  91. Args:
  92. preds: Predictions from model
  93. target: Ground truth values
  94. num_classes: Number of classes
  95. (required when input_format="index", optional when input_format="one-hot" or "mixed")
  96. include_background: Whether to include the background class in the computation
  97. per_class: Whether to compute the IoU for each class separately, else average over all classes
  98. input_format: What kind of input the function receives.
  99. Choose between ``"one-hot"`` for one-hot encoded tensors, ``"index"`` for index tensors
  100. or ``"mixed"`` for one one-hot encoded and one index tensor
  101. Returns:
  102. The mean IoU score
  103. Example:
  104. >>> import torch
  105. >>> from torch import randint
  106. >>> from torchmetrics.functional.segmentation import mean_iou
  107. >>> # 4 samples, 5 classes, 16x16 prediction
  108. >>> preds = randint(0, 2, (4, 5, 16, 16), generator=torch.Generator().manual_seed(42))
  109. >>> # 4 samples, 5 classes, 16x16 target
  110. >>> target = randint(0, 2, (4, 5, 16, 16), generator=torch.Generator().manual_seed(43))
  111. >>> mean_iou(preds, target)
  112. tensor([0.3323, 0.3336, 0.3397, 0.3435])
  113. >>> mean_iou(preds, target, include_background=False, num_classes=5)
  114. tensor([0.3250, 0.3258, 0.3307, 0.3398])
  115. >>> mean_iou(preds, target, include_background=True, num_classes=5, per_class=True)
  116. tensor([[0.3617, 0.3128, 0.3366, 0.3242, 0.3263],
  117. [0.3646, 0.2893, 0.3297, 0.3073, 0.3770],
  118. [0.3756, 0.3168, 0.3505, 0.3400, 0.3155],
  119. [0.3579, 0.3317, 0.3797, 0.3523, 0.2957]])
  120. >>> # re-initialize tensors for ``input_format="index"``
  121. >>> preds = randint(0, 2, (4, 16, 16), generator=torch.Generator().manual_seed(42))
  122. >>> target = randint(0, 2, (4, 16, 16), generator=torch.Generator().manual_seed(43))
  123. >>> mean_iou(preds, target, num_classes=5, input_format = "index")
  124. tensor([0.3617, 0.3128, 0.3047, 0.3499])
  125. >>> mean_iou(preds, target, num_classes=5, per_class=True, input_format="index")
  126. tensor([[ 0.3617, 0.3617, -1.0000, -1.0000, -1.0000],
  127. [ 0.3128, 0.3128, -1.0000, -1.0000, -1.0000],
  128. [ 0.2727, 0.3366, -1.0000, -1.0000, -1.0000],
  129. [ 0.3756, 0.3242, -1.0000, -1.0000, -1.0000]])
  130. """
  131. _mean_iou_validate_args(num_classes, include_background, per_class, input_format)
  132. intersection, union = _mean_iou_update(preds, target, num_classes, include_background, input_format)
  133. scores = _mean_iou_compute(intersection, union, zero_division="nan")
  134. valid_classes = union > 0
  135. return scores.nan_to_num(-1.0) if per_class else scores.nansum(dim=-1) / valid_classes.sum(dim=-1)