mean_average_precision.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Dict, List, Tuple
  18. import torch
  19. from kornia.core import Tensor, concatenate, tensor, zeros
  20. from .mean_iou import mean_iou_bbox
  21. def mean_average_precision(
  22. pred_boxes: List[Tensor],
  23. pred_labels: List[Tensor],
  24. pred_scores: List[Tensor],
  25. gt_boxes: List[Tensor],
  26. gt_labels: List[Tensor],
  27. n_classes: int,
  28. threshold: float = 0.5,
  29. ) -> Tuple[Tensor, Dict[int, float]]:
  30. """Calculate the Mean Average Precision (mAP) of detected objects.
  31. Code altered from https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection/blob/master/utils.py#L271.
  32. Background class (0 index) is excluded.
  33. Args:
  34. pred_boxes: a tensor list of predicted bounding boxes.
  35. pred_labels: a tensor list of predicted labels.
  36. pred_scores: a tensor list of predicted labels' scores.
  37. gt_boxes: a tensor list of ground truth bounding boxes.
  38. gt_labels: a tensor list of ground truth labels.
  39. n_classes: the number of classes.
  40. threshold: count as a positive if the overlap is greater than the threshold.
  41. Returns:
  42. mean average precision (mAP), list of average precisions for each class.
  43. Examples:
  44. >>> boxes, labels, scores = torch.tensor([[100, 50, 150, 100.]]), torch.tensor([1]), torch.tensor([.7])
  45. >>> gt_boxes, gt_labels = torch.tensor([[100, 50, 150, 100.]]), torch.tensor([1])
  46. >>> mean_average_precision([boxes], [labels], [scores], [gt_boxes], [gt_labels], 2)
  47. (tensor(1.), {1: 1.0})
  48. """
  49. # these are all lists of tensors of the same length, i.e. number of images
  50. if not len(pred_boxes) == len(pred_labels) == len(pred_scores) == len(gt_boxes) == len(gt_labels):
  51. raise AssertionError
  52. # Store all (true) objects in a single continuous tensor while keeping track of the image it is from
  53. gt_images = []
  54. for i, labels in enumerate(gt_labels):
  55. gt_images.extend([i] * labels.size(0))
  56. # (n_objects), n_objects is the total no. of objects across all images
  57. _gt_boxes = concatenate(gt_boxes, 0) # (n_objects, 4)
  58. _gt_labels = concatenate(gt_labels, 0) # (n_objects)
  59. _gt_images = tensor(gt_images, device=_gt_boxes.device, dtype=torch.long)
  60. if not _gt_images.size(0) == _gt_boxes.size(0) == _gt_labels.size(0):
  61. raise AssertionError
  62. # Store all detections in a single continuous tensor while keeping track of the image it is from
  63. pred_images = []
  64. for i, labels in enumerate(pred_labels):
  65. pred_images.extend([i] * labels.size(0))
  66. _pred_boxes = concatenate(pred_boxes, 0) # (n_detections, 4)
  67. _pred_labels = concatenate(pred_labels, 0) # (n_detections)
  68. _pred_scores = concatenate(pred_scores, 0) # (n_detections)
  69. _pred_images = tensor(pred_images, device=_pred_boxes.device, dtype=torch.long) # (n_detections)
  70. if not _pred_images.size(0) == _pred_boxes.size(0) == _pred_labels.size(0) == _pred_scores.size(0):
  71. raise AssertionError
  72. # Calculate APs for each class (except background)
  73. average_precisions = zeros((n_classes - 1), device=_pred_boxes.device, dtype=_pred_boxes.dtype) # (n_classes - 1)
  74. for c in range(1, n_classes):
  75. # Extract only objects with this class
  76. gt_class_images = _gt_images[_gt_labels == c] # (n_class_objects)
  77. gt_class_boxes = _gt_boxes[_gt_labels == c] # (n_class_objects, 4)
  78. # Keep track of which true objects with this class have already been 'detected'
  79. # (n_class_objects)
  80. gt_class_boxes_detected = zeros((gt_class_images.size(0)), dtype=torch.uint8, device=gt_class_images.device)
  81. # Extract only detections with this class
  82. pred_class_images = _pred_images[_pred_labels == c] # (n_class_detections)
  83. pred_class_boxes = _pred_boxes[_pred_labels == c] # (n_class_detections, 4)
  84. pred_class_scores = _pred_scores[_pred_labels == c] # (n_class_detections)
  85. n_class_detections = pred_class_boxes.size(0)
  86. if n_class_detections == 0:
  87. continue
  88. # Sort detections in decreasing order of confidence/scores
  89. pred_class_scores, sort_ind = torch.sort(pred_class_scores, dim=0, descending=True) # (n_class_detections)
  90. pred_class_images = pred_class_images[sort_ind] # (n_class_detections)
  91. pred_class_boxes = pred_class_boxes[sort_ind] # (n_class_detections, 4)
  92. # In the order of decreasing scores, check if true or false positive
  93. gt_positives = zeros(
  94. (n_class_detections,), dtype=pred_class_boxes.dtype, device=pred_class_boxes.device
  95. ) # (n_class_detections)
  96. false_positives = zeros(
  97. (n_class_detections,), dtype=pred_class_boxes.dtype, device=pred_class_boxes.device
  98. ) # (n_class_detections)
  99. for d in range(n_class_detections):
  100. this_detection_box = pred_class_boxes[d].unsqueeze(0) # (1, 4)
  101. this_image = pred_class_images[d] # (), scalar
  102. # Find objects in the image with this class, their difficulties, and whether they have been detected before
  103. object_boxes = gt_class_boxes[gt_class_images == this_image] # (n_class_objects_in_img)
  104. # If no such object in this image, then the detection is a false positive
  105. if object_boxes.size(0) == 0:
  106. false_positives[d] = 1
  107. continue
  108. # Find maximum overlap of this detection with objects in this image of this class
  109. overlaps = mean_iou_bbox(this_detection_box, object_boxes) # (1, n_class_objects_in_img)
  110. max_overlap, ind = torch.max(overlaps.squeeze(0), dim=0) # (), () - scalars
  111. # 'ind' is the index of the object in these image-level tensors 'object_boxes', 'object_difficulties'
  112. # In the original class-level tensors 'gt_class_boxes', etc., 'ind' corresponds to object with index...
  113. original_ind = tensor(
  114. range(gt_class_boxes.size(0)), device=gt_class_boxes_detected.device, dtype=torch.long
  115. )[gt_class_images == this_image][ind]
  116. # We need 'original_ind' to update 'gt_class_boxes_detected'
  117. # If the maximum overlap is greater than the threshold of 0.5, it's a match
  118. if max_overlap.item() > threshold:
  119. # If this object has already not been detected, it's a true positive
  120. if gt_class_boxes_detected[original_ind] == 0:
  121. gt_positives[d] = 1
  122. gt_class_boxes_detected[original_ind] = 1 # this object has now been detected/accounted for
  123. # Otherwise, it's a false positive (since this object is already accounted for)
  124. else:
  125. false_positives[d] = 1
  126. # Otherwise, the detection occurs in a different location than the actual object, and is a false positive
  127. else:
  128. false_positives[d] = 1
  129. # Compute cumulative precision and recall at each detection in the order of decreasing scores
  130. cumul_gt_positives = torch.cumsum(gt_positives, dim=0) # (n_class_detections)
  131. cumul_false_positives = torch.cumsum(false_positives, dim=0) # (n_class_detections)
  132. cumul_precision = cumul_gt_positives / (
  133. cumul_gt_positives + cumul_false_positives + 1e-10
  134. ) # (n_class_detections)
  135. cumul_recall = cumul_gt_positives / _gt_boxes.size(0) # (n_class_detections)
  136. # Find the mean of the maximum of the precisions corresponding to recalls above the threshold 't'
  137. recall_thresholds = torch.arange(start=0, end=1.1, step=0.1).tolist() # (11)
  138. precisions = zeros((len(recall_thresholds)), device=_gt_boxes.device, dtype=_gt_boxes.dtype) # (11)
  139. for i, t in enumerate(recall_thresholds):
  140. recalls_above_t = cumul_recall >= t
  141. if recalls_above_t.any():
  142. precisions[i] = cumul_precision[recalls_above_t].max()
  143. else:
  144. precisions[i] = 0.0
  145. average_precisions[c - 1] = precisions.mean() # c is in [1, n_classes - 1]
  146. # Calculate Mean Average Precision (mAP)
  147. mean_ap = average_precisions.mean()
  148. # Keep class-wise average precisions in a dictionary
  149. ap_dict = {c + 1: float(v) for c, v in enumerate(average_precisions.tolist())}
  150. return mean_ap, ap_dict