mean_iou.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import torch
  18. from .confusion_matrix import confusion_matrix
  19. def mean_iou(pred: torch.Tensor, target: torch.Tensor, num_classes: int, eps: float = 1e-6) -> torch.Tensor:
  20. r"""Calculate mean Intersection-Over-Union (mIOU).
  21. The function internally computes the confusion matrix.
  22. Args:
  23. pred : tensor with estimated targets returned by a
  24. classifier. The shape can be :math:`(B, *)` and must contain integer
  25. values between 0 and K-1.
  26. target: tensor with ground truth (correct) target
  27. values. The shape can be :math:`(B, *)` and must contain integer
  28. values between 0 and K-1, where targets are assumed to be provided as
  29. one-hot vectors.
  30. num_classes: total possible number of classes in target.
  31. eps: epsilon for numerical stability.
  32. Returns:
  33. a tensor representing the mean intersection-over union
  34. with shape :math:`(B, K)` where K is the number of classes.
  35. Example:
  36. >>> logits = torch.tensor([[0, 1, 0]])
  37. >>> target = torch.tensor([[0, 1, 0]])
  38. >>> mean_iou(logits, target, num_classes=3)
  39. tensor([[1., 1., 1.]])
  40. """
  41. if not torch.is_tensor(pred) and pred.dtype is not torch.int64:
  42. raise TypeError(f"Input pred type is not a torch.Tensor with torch.int64 dtype. Got {type(pred)}")
  43. if not torch.is_tensor(target) and target.dtype is not torch.int64:
  44. raise TypeError(f"Input target type is not a torch.Tensor with torch.int64 dtype. Got {type(target)}")
  45. if not pred.shape == target.shape:
  46. raise ValueError(f"Inputs pred and target must have the same shape. Got: {pred.shape} and {target.shape}")
  47. if not pred.device == target.device:
  48. raise ValueError(f"Inputs must be in the same device. Got: {pred.device} - {target.device}")
  49. if not isinstance(num_classes, int) or num_classes < 2:
  50. raise ValueError(f"The number of classes must be an integer bigger than two. Got: {num_classes}")
  51. # we first compute the confusion matrix
  52. conf_mat: torch.Tensor = confusion_matrix(pred, target, num_classes)
  53. # compute the actual intersection over union
  54. sum_over_row = torch.sum(conf_mat, dim=1)
  55. sum_over_col = torch.sum(conf_mat, dim=2)
  56. conf_mat_diag = torch.diagonal(conf_mat, dim1=-2, dim2=-1)
  57. denominator = sum_over_row + sum_over_col - conf_mat_diag
  58. # NOTE: we add epsilon so that samples that are neither in the
  59. # prediction or ground truth are taken into account.
  60. ious = (conf_mat_diag + eps) / (denominator + eps)
  61. return ious
  62. def mean_iou_bbox(boxes_1: torch.Tensor, boxes_2: torch.Tensor) -> torch.Tensor:
  63. """Compute the IoU of the cartesian product of two sets of boxes.
  64. Each box in each set shall be (x1, y1, x2, y2).
  65. Args:
  66. boxes_1: a tensor of bounding boxes in :math:`(B1, 4)`.
  67. boxes_2: a tensor of bounding boxes in :math:`(B2, 4)`.
  68. Returns:
  69. a tensor in dimensions :math:`(B1, B2)`, representing the
  70. intersection of each of the boxes in set 1 with respect to each of the boxes in set 2.
  71. Example:
  72. >>> boxes_1 = torch.tensor([[40, 40, 60, 60], [30, 40, 50, 60]])
  73. >>> boxes_2 = torch.tensor([[40, 50, 60, 70], [30, 40, 40, 50]])
  74. >>> mean_iou_bbox(boxes_1, boxes_2)
  75. tensor([[0.3333, 0.0000],
  76. [0.1429, 0.2500]])
  77. """
  78. # TODO: support more box types. e.g. xywh,
  79. if not (((boxes_1[:, 2] - boxes_1[:, 0]) > 0).all() or ((boxes_1[:, 3] - boxes_1[:, 1]) > 0).all()):
  80. raise AssertionError("Boxes_1 does not follow (x1, y1, x2, y2) format.")
  81. if not (((boxes_2[:, 2] - boxes_2[:, 0]) > 0).all() or ((boxes_2[:, 3] - boxes_2[:, 1]) > 0).all()):
  82. raise AssertionError("Boxes_2 does not follow (x1, y1, x2, y2) format.")
  83. # find intersection
  84. lower_bounds = torch.max(boxes_1[:, :2].unsqueeze(1), boxes_2[:, :2].unsqueeze(0)) # (n1, n2, 2)
  85. upper_bounds = torch.min(boxes_1[:, 2:].unsqueeze(1), boxes_2[:, 2:].unsqueeze(0)) # (n1, n2, 2)
  86. intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0) # (n1, n2, 2)
  87. intersection = intersection_dims[:, :, 0] * intersection_dims[:, :, 1] # (n1, n2)
  88. # Find areas of each box in both sets
  89. areas_set_1 = (boxes_1[:, 2] - boxes_1[:, 0]) * (boxes_1[:, 3] - boxes_1[:, 1]) # (n1)
  90. areas_set_2 = (boxes_2[:, 2] - boxes_2[:, 0]) * (boxes_2[:, 3] - boxes_2[:, 1]) # (n2)
  91. # Find the union
  92. # PyTorch auto-broadcasts singleton dimensions
  93. union = areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection # (n1, n2)
  94. return intersection / union # (n1, n2)