| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import Dict, List, Tuple
- import torch
- from kornia.core import Tensor, concatenate, tensor, zeros
- from .mean_iou import mean_iou_bbox
- def mean_average_precision(
- pred_boxes: List[Tensor],
- pred_labels: List[Tensor],
- pred_scores: List[Tensor],
- gt_boxes: List[Tensor],
- gt_labels: List[Tensor],
- n_classes: int,
- threshold: float = 0.5,
- ) -> Tuple[Tensor, Dict[int, float]]:
- """Calculate the Mean Average Precision (mAP) of detected objects.
- Code altered from https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection/blob/master/utils.py#L271.
- Background class (0 index) is excluded.
- Args:
- pred_boxes: a tensor list of predicted bounding boxes.
- pred_labels: a tensor list of predicted labels.
- pred_scores: a tensor list of predicted labels' scores.
- gt_boxes: a tensor list of ground truth bounding boxes.
- gt_labels: a tensor list of ground truth labels.
- n_classes: the number of classes.
- threshold: count as a positive if the overlap is greater than the threshold.
- Returns:
- mean average precision (mAP), list of average precisions for each class.
- Examples:
- >>> boxes, labels, scores = torch.tensor([[100, 50, 150, 100.]]), torch.tensor([1]), torch.tensor([.7])
- >>> gt_boxes, gt_labels = torch.tensor([[100, 50, 150, 100.]]), torch.tensor([1])
- >>> mean_average_precision([boxes], [labels], [scores], [gt_boxes], [gt_labels], 2)
- (tensor(1.), {1: 1.0})
- """
- # these are all lists of tensors of the same length, i.e. number of images
- if not len(pred_boxes) == len(pred_labels) == len(pred_scores) == len(gt_boxes) == len(gt_labels):
- raise AssertionError
- # Store all (true) objects in a single continuous tensor while keeping track of the image it is from
- gt_images = []
- for i, labels in enumerate(gt_labels):
- gt_images.extend([i] * labels.size(0))
- # (n_objects), n_objects is the total no. of objects across all images
- _gt_boxes = concatenate(gt_boxes, 0) # (n_objects, 4)
- _gt_labels = concatenate(gt_labels, 0) # (n_objects)
- _gt_images = tensor(gt_images, device=_gt_boxes.device, dtype=torch.long)
- if not _gt_images.size(0) == _gt_boxes.size(0) == _gt_labels.size(0):
- raise AssertionError
- # Store all detections in a single continuous tensor while keeping track of the image it is from
- pred_images = []
- for i, labels in enumerate(pred_labels):
- pred_images.extend([i] * labels.size(0))
- _pred_boxes = concatenate(pred_boxes, 0) # (n_detections, 4)
- _pred_labels = concatenate(pred_labels, 0) # (n_detections)
- _pred_scores = concatenate(pred_scores, 0) # (n_detections)
- _pred_images = tensor(pred_images, device=_pred_boxes.device, dtype=torch.long) # (n_detections)
- if not _pred_images.size(0) == _pred_boxes.size(0) == _pred_labels.size(0) == _pred_scores.size(0):
- raise AssertionError
- # Calculate APs for each class (except background)
- average_precisions = zeros((n_classes - 1), device=_pred_boxes.device, dtype=_pred_boxes.dtype) # (n_classes - 1)
- for c in range(1, n_classes):
- # Extract only objects with this class
- gt_class_images = _gt_images[_gt_labels == c] # (n_class_objects)
- gt_class_boxes = _gt_boxes[_gt_labels == c] # (n_class_objects, 4)
- # Keep track of which true objects with this class have already been 'detected'
- # (n_class_objects)
- gt_class_boxes_detected = zeros((gt_class_images.size(0)), dtype=torch.uint8, device=gt_class_images.device)
- # Extract only detections with this class
- pred_class_images = _pred_images[_pred_labels == c] # (n_class_detections)
- pred_class_boxes = _pred_boxes[_pred_labels == c] # (n_class_detections, 4)
- pred_class_scores = _pred_scores[_pred_labels == c] # (n_class_detections)
- n_class_detections = pred_class_boxes.size(0)
- if n_class_detections == 0:
- continue
- # Sort detections in decreasing order of confidence/scores
- pred_class_scores, sort_ind = torch.sort(pred_class_scores, dim=0, descending=True) # (n_class_detections)
- pred_class_images = pred_class_images[sort_ind] # (n_class_detections)
- pred_class_boxes = pred_class_boxes[sort_ind] # (n_class_detections, 4)
- # In the order of decreasing scores, check if true or false positive
- gt_positives = zeros(
- (n_class_detections,), dtype=pred_class_boxes.dtype, device=pred_class_boxes.device
- ) # (n_class_detections)
- false_positives = zeros(
- (n_class_detections,), dtype=pred_class_boxes.dtype, device=pred_class_boxes.device
- ) # (n_class_detections)
- for d in range(n_class_detections):
- this_detection_box = pred_class_boxes[d].unsqueeze(0) # (1, 4)
- this_image = pred_class_images[d] # (), scalar
- # Find objects in the image with this class, their difficulties, and whether they have been detected before
- object_boxes = gt_class_boxes[gt_class_images == this_image] # (n_class_objects_in_img)
- # If no such object in this image, then the detection is a false positive
- if object_boxes.size(0) == 0:
- false_positives[d] = 1
- continue
- # Find maximum overlap of this detection with objects in this image of this class
- overlaps = mean_iou_bbox(this_detection_box, object_boxes) # (1, n_class_objects_in_img)
- max_overlap, ind = torch.max(overlaps.squeeze(0), dim=0) # (), () - scalars
- # 'ind' is the index of the object in these image-level tensors 'object_boxes', 'object_difficulties'
- # In the original class-level tensors 'gt_class_boxes', etc., 'ind' corresponds to object with index...
- original_ind = tensor(
- range(gt_class_boxes.size(0)), device=gt_class_boxes_detected.device, dtype=torch.long
- )[gt_class_images == this_image][ind]
- # We need 'original_ind' to update 'gt_class_boxes_detected'
- # If the maximum overlap is greater than the threshold of 0.5, it's a match
- if max_overlap.item() > threshold:
- # If this object has already not been detected, it's a true positive
- if gt_class_boxes_detected[original_ind] == 0:
- gt_positives[d] = 1
- gt_class_boxes_detected[original_ind] = 1 # this object has now been detected/accounted for
- # Otherwise, it's a false positive (since this object is already accounted for)
- else:
- false_positives[d] = 1
- # Otherwise, the detection occurs in a different location than the actual object, and is a false positive
- else:
- false_positives[d] = 1
- # Compute cumulative precision and recall at each detection in the order of decreasing scores
- cumul_gt_positives = torch.cumsum(gt_positives, dim=0) # (n_class_detections)
- cumul_false_positives = torch.cumsum(false_positives, dim=0) # (n_class_detections)
- cumul_precision = cumul_gt_positives / (
- cumul_gt_positives + cumul_false_positives + 1e-10
- ) # (n_class_detections)
- cumul_recall = cumul_gt_positives / _gt_boxes.size(0) # (n_class_detections)
- # Find the mean of the maximum of the precisions corresponding to recalls above the threshold 't'
- recall_thresholds = torch.arange(start=0, end=1.1, step=0.1).tolist() # (11)
- precisions = zeros((len(recall_thresholds)), device=_gt_boxes.device, dtype=_gt_boxes.dtype) # (11)
- for i, t in enumerate(recall_thresholds):
- recalls_above_t = cumul_recall >= t
- if recalls_above_t.any():
- precisions[i] = cumul_precision[recalls_above_t].max()
- else:
- precisions[i] = 0.0
- average_precisions[c - 1] = precisions.mean() # c is in [1, n_classes - 1]
- # Calculate Mean Average Precision (mAP)
- mean_ap = average_precisions.mean()
- # Keep class-wise average precisions in a dictionary
- ap_dict = {c + 1: float(v) for c, v in enumerate(average_precisions.tolist())}
- return mean_ap, ap_dict
|