_mean_ap.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from collections.abc import Sequence
  16. from typing import Any, Callable, List, Literal, Optional, Union
  17. import numpy as np
  18. import torch
  19. import torch.distributed as dist
  20. from torch import IntTensor, Tensor
  21. from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator
  22. from torchmetrics.metric import Metric
  23. from torchmetrics.utilities.data import _cumsum
  24. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_AVAILABLE
  25. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  26. if not _MATPLOTLIB_AVAILABLE:
  27. __doctest_skip__ = ["MeanAveragePrecision.plot"]
  28. if not _TORCHVISION_AVAILABLE or not _PYCOCOTOOLS_AVAILABLE:
  29. __doctest_skip__ = ["MeanAveragePrecision.plot", "MeanAveragePrecision"]
  30. log = logging.getLogger(__name__)
  31. def compute_area(inputs: list[Any], iou_type: Literal["bbox", "segm"] = "bbox") -> Tensor:
  32. """Compute area of input depending on the specified iou_type.
  33. Default output for empty input is :class:`~torch.Tensor`
  34. """
  35. import pycocotools.mask as mask_utils
  36. from torchvision.ops import box_area
  37. if len(inputs) == 0:
  38. return Tensor([])
  39. if iou_type == "bbox":
  40. return box_area(torch.stack(inputs))
  41. if iou_type == "segm":
  42. inputs = [{"size": i[0], "counts": i[1]} for i in inputs]
  43. return torch.tensor(mask_utils.area(inputs).astype("float"))
  44. raise Exception(f"IOU type {iou_type} is not supported")
  45. def compute_iou(
  46. det: list[Any],
  47. gt: list[Any],
  48. iou_type: Literal["bbox", "segm"] = "bbox",
  49. ) -> Tensor:
  50. """Compute IOU between detections and ground-truth using the specified iou_type."""
  51. from torchvision.ops import box_iou
  52. if iou_type == "bbox":
  53. return box_iou(torch.stack(det), torch.stack(gt))
  54. if iou_type == "segm":
  55. return _segm_iou(det, gt)
  56. raise Exception(f"IOU type {iou_type} is not supported")
  57. class BaseMetricResults(dict):
  58. """Base metric class, that allows fields for pre-defined metrics."""
  59. def __getattr__(self, key: str) -> Tensor:
  60. """Get a specific metric attribute."""
  61. # Using this you get the correct error message, an AttributeError instead of a KeyError
  62. if key in self:
  63. return self[key]
  64. raise AttributeError(f"No such attribute: {key}")
  65. def __setattr__(self, key: str, value: Tensor) -> None:
  66. """Set a specific metric attribute."""
  67. self[key] = value
  68. def __delattr__(self, key: str) -> None:
  69. """Delete a specific metric attribute."""
  70. if key in self:
  71. del self[key]
  72. raise AttributeError(f"No such attribute: {key}")
  73. class MAPMetricResults(BaseMetricResults):
  74. """Class to wrap the final mAP results."""
  75. __slots__ = ("classes", "map", "map_50", "map_75", "map_large", "map_medium", "map_small")
  76. class MARMetricResults(BaseMetricResults):
  77. """Class to wrap the final mAR results."""
  78. __slots__ = ("mar_1", "mar_10", "mar_100", "mar_large", "mar_medium", "mar_small")
  79. class COCOMetricResults(BaseMetricResults):
  80. """Class to wrap the final COCO metric results including various mAP/mAR values."""
  81. __slots__ = (
  82. "map",
  83. "map_50",
  84. "map_75",
  85. "map_large",
  86. "map_medium",
  87. "map_per_class",
  88. "map_small",
  89. "mar_1",
  90. "mar_10",
  91. "mar_100",
  92. "mar_100_per_class",
  93. "mar_large",
  94. "mar_medium",
  95. "mar_small",
  96. )
  97. def _segm_iou(det: list[tuple[np.ndarray, np.ndarray]], gt: list[tuple[np.ndarray, np.ndarray]]) -> Tensor:
  98. """Compute IOU between detections and ground-truths using mask-IOU.
  99. Implementation is based on pycocotools toolkit for mask_utils.
  100. Args:
  101. det: A list of detection masks as ``[(RLE_SIZE, RLE_COUNTS)]``, where ``RLE_SIZE`` is (width, height) dimension
  102. of the input and RLE_COUNTS is its RLE representation;
  103. gt: A list of ground-truth masks as ``[(RLE_SIZE, RLE_COUNTS)]``, where ``RLE_SIZE`` is (width, height) dimension
  104. of the input and RLE_COUNTS is its RLE representation;
  105. """
  106. import pycocotools.mask as mask_utils
  107. det_coco_format = [{"size": i[0], "counts": i[1]} for i in det]
  108. gt_coco_format = [{"size": i[0], "counts": i[1]} for i in gt]
  109. return torch.tensor(mask_utils.iou(det_coco_format, gt_coco_format, [False for _ in gt]))
  110. class MeanAveragePrecision(Metric):
  111. r"""Compute the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)`_ for object detection predictions.
  112. .. math::
  113. \text{mAP} = \frac{1}{n} \sum_{i=1}^{n} AP_i
  114. where :math:`AP_i` is the average precision for class :math:`i` and :math:`n` is the number of classes. The average
  115. precision is defined as the area under the precision-recall curve. If argument `class_metrics` is set to ``True``,
  116. the metric will also return the mAP/mAR per class.
  117. As input to ``forward`` and ``update`` the metric accepts the following input:
  118. - ``preds`` (:class:`~List`): A list consisting of dictionaries each containing the key-values
  119. (each dictionary corresponds to a single image). Parameters that should be provided per dict
  120. - boxes: (:class:`~torch.FloatTensor`) of shape ``(num_boxes, 4)`` containing ``num_boxes`` detection
  121. boxes of the format specified in the constructor.
  122. By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
  123. - scores: :class:`~torch.FloatTensor` of shape ``(num_boxes)`` containing detection scores for the boxes.
  124. - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed detection classes for
  125. the boxes.
  126. - masks: :class:`~torch.bool` of shape ``(num_boxes, image_height, image_width)`` containing boolean masks.
  127. Only required when `iou_type="segm"`.
  128. - ``target`` (:class:`~List`) A list consisting of dictionaries each containing the key-values
  129. (each dictionary corresponds to a single image). Parameters that should be provided per dict:
  130. - boxes: :class:`~torch.FloatTensor` of shape ``(num_boxes, 4)`` containing ``num_boxes`` ground truth
  131. boxes of the format specified in the constructor.
  132. By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
  133. - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed ground truth
  134. classes for the boxes.
  135. - masks: :class:`~torch.bool` of shape ``(num_boxes, image_height, image_width)`` containing boolean masks.
  136. Only required when `iou_type="segm"`.
  137. As output of ``forward`` and ``compute`` the metric returns the following output:
  138. - ``map_dict``: A dictionary containing the following key-values:
  139. - map: (:class:`~torch.Tensor`)
  140. - map_small: (:class:`~torch.Tensor`)
  141. - map_medium:(:class:`~torch.Tensor`)
  142. - map_large: (:class:`~torch.Tensor`)
  143. - mar_1: (:class:`~torch.Tensor`)
  144. - mar_10: (:class:`~torch.Tensor`)
  145. - mar_100: (:class:`~torch.Tensor`)
  146. - mar_small: (:class:`~torch.Tensor`)
  147. - mar_medium: (:class:`~torch.Tensor`)
  148. - mar_large: (:class:`~torch.Tensor`)
  149. - map_50: (:class:`~torch.Tensor`) (-1 if 0.5 not in the list of iou thresholds)
  150. - map_75: (:class:`~torch.Tensor`) (-1 if 0.75 not in the list of iou thresholds)
  151. - map_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled)
  152. - mar_100_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled)
  153. - classes (:class:`~torch.Tensor`)
  154. For an example on how to use this metric check the `torchmetrics mAP example`_.
  155. .. attention::
  156. The ``map`` score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ]
  157. **Caution:** If the initialization parameters are changed, dictionary keys for mAR can change as well.
  158. The default properties are also accessible via fields and will raise an ``AttributeError`` if not available.
  159. .. important::
  160. This metric is following the mAP implementation of `pycocotools`_ a standard implementation for the mAP metric
  161. for object detection.
  162. .. hint::
  163. This metric requires you to have `torchvision` version 0.8.0 or newer installed
  164. (with corresponding version 1.7.0 of torch or newer). This metric requires `pycocotools`
  165. installed when iou_type is `segm`. Please install with ``pip install torchvision`` or
  166. ``pip install torchmetrics[detection]``.
  167. Args:
  168. box_format:
  169. Input format of given boxes. Supported formats are ``[`xyxy`, `xywh`, `cxcywh`]``.
  170. iou_type:
  171. Type of input (either masks or bounding-boxes) used for computing IOU.
  172. Supported IOU types are ``["bbox", "segm"]``.
  173. If using ``"segm"``, masks should be provided (see :meth:`update`).
  174. iou_thresholds:
  175. IoU thresholds for evaluation. If set to ``None`` it corresponds to the stepped range ``[0.5,...,0.95]``
  176. with step ``0.05``. Else provide a list of floats.
  177. rec_thresholds:
  178. Recall thresholds for evaluation. If set to ``None`` it corresponds to the stepped range ``[0,...,1]``
  179. with step ``0.01``. Else provide a list of floats.
  180. max_detection_thresholds:
  181. Thresholds on max detections per image. If set to `None` will use thresholds ``[1, 10, 100]``.
  182. Else, please provide a list of ints.
  183. class_metrics:
  184. Option to enable per-class metrics for mAP and mAR_100. Has a performance impact.
  185. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  186. Raises:
  187. ModuleNotFoundError:
  188. If ``torchvision`` is not installed or version installed is lower than 0.8.0
  189. ModuleNotFoundError:
  190. If ``iou_type`` is equal to ``segm`` and ``pycocotools`` is not installed
  191. ValueError:
  192. If ``class_metrics`` is not a boolean
  193. ValueError:
  194. If ``preds`` is not of type (:class:`~List[Dict[str, Tensor]]`)
  195. ValueError:
  196. If ``target`` is not of type ``List[Dict[str, Tensor]]``
  197. ValueError:
  198. If ``preds`` and ``target`` are not of the same length
  199. ValueError:
  200. If any of ``preds.boxes``, ``preds.scores`` and ``preds.labels`` are not of the same length
  201. ValueError:
  202. If any of ``target.boxes`` and ``target.labels`` are not of the same length
  203. ValueError:
  204. If any box is not type float and of length 4
  205. ValueError:
  206. If any class is not type int and of length 1
  207. ValueError:
  208. If any score is not type float and of length 1
  209. Example:
  210. >>> from torch import tensor
  211. >>> from torchmetrics.detection import MeanAveragePrecision
  212. >>> preds = [
  213. ... dict(
  214. ... boxes=tensor([[258.0, 41.0, 606.0, 285.0]]),
  215. ... scores=tensor([0.536]),
  216. ... labels=tensor([0]),
  217. ... )
  218. ... ]
  219. >>> target = [
  220. ... dict(
  221. ... boxes=tensor([[214.0, 41.0, 562.0, 285.0]]),
  222. ... labels=tensor([0]),
  223. ... )
  224. ... ]
  225. >>> metric = MeanAveragePrecision()
  226. >>> metric.update(preds, target)
  227. >>> from pprint import pprint
  228. >>> pprint(metric.compute())
  229. {'classes': tensor(0, dtype=torch.int32),
  230. 'map': tensor(0.6000),
  231. 'map_50': tensor(1.),
  232. 'map_75': tensor(1.),
  233. 'map_large': tensor(0.6000),
  234. 'map_medium': tensor(-1.),
  235. 'map_per_class': tensor(-1.),
  236. 'map_small': tensor(-1.),
  237. 'mar_1': tensor(0.6000),
  238. 'mar_10': tensor(0.6000),
  239. 'mar_100': tensor(0.6000),
  240. 'mar_100_per_class': tensor(-1.),
  241. 'mar_large': tensor(0.6000),
  242. 'mar_medium': tensor(-1.),
  243. 'mar_small': tensor(-1.)}
  244. """
  245. is_differentiable: bool = False
  246. higher_is_better: Optional[bool] = True
  247. full_state_update: bool = True
  248. plot_lower_bound: float = 0.0
  249. plot_upper_bound: float = 1.0
  250. detections: List[Tensor]
  251. detection_scores: List[Tensor]
  252. detection_labels: List[Tensor]
  253. groundtruths: List[Tensor]
  254. groundtruth_labels: List[Tensor]
  255. def __init__(
  256. self,
  257. box_format: str = "xyxy",
  258. iou_type: Literal["bbox", "segm"] = "bbox",
  259. iou_thresholds: Optional[list[float]] = None,
  260. rec_thresholds: Optional[list[float]] = None,
  261. max_detection_thresholds: Optional[list[int]] = None,
  262. class_metrics: bool = False,
  263. **kwargs: Any,
  264. ) -> None:
  265. super().__init__(**kwargs)
  266. if not _PYCOCOTOOLS_AVAILABLE:
  267. raise ModuleNotFoundError(
  268. "`MAP` metric requires that `pycocotools` installed."
  269. " Please install with `pip install pycocotools` or `pip install torchmetrics[detection]`"
  270. )
  271. if not _TORCHVISION_AVAILABLE:
  272. raise ModuleNotFoundError(
  273. "`MeanAveragePrecision` metric requires that `torchvision` is installed."
  274. " Please install with `pip install torchmetrics[detection]`."
  275. )
  276. allowed_box_formats = ("xyxy", "xywh", "cxcywh")
  277. allowed_iou_types = ("segm", "bbox")
  278. if box_format not in allowed_box_formats:
  279. raise ValueError(f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}")
  280. self.box_format = box_format
  281. self.iou_thresholds = iou_thresholds or torch.linspace(0.5, 0.95, round((0.95 - 0.5) / 0.05) + 1).tolist()
  282. self.rec_thresholds = rec_thresholds or torch.linspace(0.0, 1.00, round(1.00 / 0.01) + 1).tolist()
  283. max_det_threshold, _ = torch.sort(IntTensor(max_detection_thresholds or [1, 10, 100]))
  284. self.max_detection_thresholds = max_det_threshold.tolist()
  285. if iou_type not in allowed_iou_types:
  286. raise ValueError(f"Expected argument `iou_type` to be one of {allowed_iou_types} but got {iou_type}")
  287. if iou_type == "segm" and not _PYCOCOTOOLS_AVAILABLE:
  288. raise ModuleNotFoundError("When `iou_type` is set to 'segm', pycocotools need to be installed")
  289. self.iou_type = iou_type
  290. self.bbox_area_ranges = {
  291. "all": (float(0**2), float(1e5**2)),
  292. "small": (float(0**2), float(32**2)),
  293. "medium": (float(32**2), float(96**2)),
  294. "large": (float(96**2), float(1e5**2)),
  295. }
  296. if not isinstance(class_metrics, bool):
  297. raise ValueError("Expected argument `class_metrics` to be a boolean")
  298. self.class_metrics = class_metrics
  299. self.add_state("detections", default=[], dist_reduce_fx=None)
  300. self.add_state("detection_scores", default=[], dist_reduce_fx=None)
  301. self.add_state("detection_labels", default=[], dist_reduce_fx=None)
  302. self.add_state("groundtruths", default=[], dist_reduce_fx=None)
  303. self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None)
  304. def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]]) -> None:
  305. """Update state with predictions and targets."""
  306. _input_validator(preds, target, iou_type=self.iou_type)
  307. for item in preds:
  308. detections = self._get_safe_item_values(item)
  309. self.detections.append(detections) # type: ignore[arg-type]
  310. self.detection_labels.append(item["labels"])
  311. self.detection_scores.append(item["scores"])
  312. for item in target:
  313. groundtruths = self._get_safe_item_values(item)
  314. self.groundtruths.append(groundtruths) # type: ignore[arg-type]
  315. self.groundtruth_labels.append(item["labels"])
  316. def _move_list_states_to_cpu(self) -> None:
  317. """Move list states to cpu to save GPU memory."""
  318. for key in self._defaults:
  319. current_val = getattr(self, key)
  320. current_to_cpu = []
  321. if isinstance(current_val, Sequence):
  322. for cur_v in current_val:
  323. # Cannot handle RLE as Tensor
  324. if not isinstance(cur_v, tuple):
  325. cur_v = cur_v.to("cpu")
  326. current_to_cpu.append(cur_v)
  327. setattr(self, key, current_to_cpu)
  328. def _get_safe_item_values(self, item: dict[str, Any]) -> Union[Tensor, tuple]:
  329. import pycocotools.mask as mask_utils
  330. from torchvision.ops import box_convert
  331. if self.iou_type == "bbox":
  332. boxes = _fix_empty_tensors(item["boxes"])
  333. if boxes.numel() > 0:
  334. boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy")
  335. return boxes
  336. if self.iou_type == "segm":
  337. masks = []
  338. for i in item["masks"].cpu().numpy():
  339. rle = mask_utils.encode(np.asfortranarray(i))
  340. masks.append((tuple(rle["size"]), rle["counts"]))
  341. return tuple(masks)
  342. raise Exception(f"IOU type {self.iou_type} is not supported")
  343. def _get_classes(self) -> list:
  344. """Return a list of unique classes found in ground truth and detection data."""
  345. if len(self.detection_labels) > 0 or len(self.groundtruth_labels) > 0:
  346. return torch.cat(self.detection_labels + self.groundtruth_labels).unique().tolist()
  347. return []
  348. def _compute_iou(self, idx: int, class_id: int, max_det: int) -> Tensor:
  349. """Compute the Intersection over Union (IoU) between bounding boxes for the given image and class.
  350. Args:
  351. idx:
  352. Image Id, equivalent to the index of supplied samples
  353. class_id:
  354. Class Id of the supplied ground truth and detection labels
  355. max_det:
  356. Maximum number of evaluated detection bounding boxes
  357. """
  358. # if self.iou_type == "bbox":
  359. gt = self.groundtruths[idx]
  360. det = self.detections[idx]
  361. gt_label_mask = (self.groundtruth_labels[idx] == class_id).nonzero().squeeze(1)
  362. det_label_mask = (self.detection_labels[idx] == class_id).nonzero().squeeze(1)
  363. if len(gt_label_mask) == 0 or len(det_label_mask) == 0:
  364. return Tensor([])
  365. gt = [gt[i] for i in gt_label_mask]
  366. det = [det[i] for i in det_label_mask]
  367. if len(gt) == 0 or len(det) == 0:
  368. return Tensor([])
  369. # Sort by scores and use only max detections
  370. scores = self.detection_scores[idx]
  371. scores_filtered = scores[self.detection_labels[idx] == class_id]
  372. inds = torch.argsort(scores_filtered, descending=True)
  373. # TODO Fix (only for masks is necessary)
  374. det = [det[i] for i in inds]
  375. if len(det) > max_det:
  376. det = det[:max_det]
  377. return compute_iou(det, gt, self.iou_type).to(self.device)
  378. def __evaluate_image_gt_no_preds(
  379. self, gt: Tensor, gt_label_mask: Tensor, area_range: tuple[int, int], num_iou_thrs: int
  380. ) -> dict[str, Any]:
  381. """Evaluate images with a ground truth but no predictions."""
  382. # GTs
  383. gt = [gt[i] for i in gt_label_mask]
  384. num_gt = len(gt)
  385. areas = compute_area(gt, iou_type=self.iou_type).to(self.device)
  386. ignore_area = (areas < area_range[0]) | (areas > area_range[1])
  387. gt_ignore, _ = torch.sort(ignore_area.to(torch.uint8))
  388. gt_ignore = gt_ignore.to(torch.bool)
  389. # Detections
  390. num_det = 0
  391. det_ignore = torch.zeros((num_iou_thrs, num_det), dtype=torch.bool, device=self.device)
  392. return {
  393. "dtMatches": torch.zeros((num_iou_thrs, num_det), dtype=torch.bool, device=self.device),
  394. "gtMatches": torch.zeros((num_iou_thrs, num_gt), dtype=torch.bool, device=self.device),
  395. "dtScores": torch.zeros(num_det, dtype=torch.float32, device=self.device),
  396. "gtIgnore": gt_ignore,
  397. "dtIgnore": det_ignore,
  398. }
  399. def __evaluate_image_preds_no_gt(
  400. self,
  401. det: Tensor,
  402. idx: int,
  403. det_label_mask: Tensor,
  404. max_det: int,
  405. area_range: tuple[int, int],
  406. num_iou_thrs: int,
  407. ) -> dict[str, Any]:
  408. """Evaluate images with a prediction but no ground truth."""
  409. # GTs
  410. num_gt = 0
  411. gt_ignore = torch.zeros(num_gt, dtype=torch.bool, device=self.device)
  412. # Detections
  413. det = [det[i] for i in det_label_mask]
  414. scores = self.detection_scores[idx]
  415. scores_filtered = scores[det_label_mask]
  416. scores_sorted, dtind = torch.sort(scores_filtered, descending=True)
  417. det = [det[i] for i in dtind]
  418. if len(det) > max_det:
  419. det = det[:max_det]
  420. num_det = len(det)
  421. det_areas = compute_area(det, iou_type=self.iou_type).to(self.device)
  422. det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1])
  423. ar = det_ignore_area.reshape((1, num_det))
  424. det_ignore = torch.repeat_interleave(ar, num_iou_thrs, 0)
  425. return {
  426. "dtMatches": torch.zeros((num_iou_thrs, num_det), dtype=torch.bool, device=self.device),
  427. "gtMatches": torch.zeros((num_iou_thrs, num_gt), dtype=torch.bool, device=self.device),
  428. "dtScores": scores_sorted.to(self.device),
  429. "gtIgnore": gt_ignore.to(self.device),
  430. "dtIgnore": det_ignore.to(self.device),
  431. }
  432. def _evaluate_image(
  433. self, idx: int, class_id: int, area_range: tuple[int, int], max_det: int, ious: dict
  434. ) -> Optional[dict]:
  435. """Perform evaluation for single class and image.
  436. Args:
  437. idx:
  438. Image Id, equivalent to the index of supplied samples.
  439. class_id:
  440. Class Id of the supplied ground truth and detection labels.
  441. area_range:
  442. List of lower and upper bounding box area threshold.
  443. max_det:
  444. Maximum number of evaluated detection bounding boxes.
  445. ious:
  446. IoU results for image and class.
  447. """
  448. gt = self.groundtruths[idx]
  449. det = self.detections[idx]
  450. gt_label_mask = (self.groundtruth_labels[idx] == class_id).nonzero().squeeze(1)
  451. det_label_mask = (self.detection_labels[idx] == class_id).nonzero().squeeze(1)
  452. # No Gt and No predictions --> ignore image
  453. if len(gt_label_mask) == 0 and len(det_label_mask) == 0:
  454. return None
  455. num_iou_thrs = len(self.iou_thresholds)
  456. # Some GT but no predictions
  457. if len(gt_label_mask) > 0 and len(det_label_mask) == 0:
  458. return self.__evaluate_image_gt_no_preds(gt, gt_label_mask, area_range, num_iou_thrs)
  459. # Some predictions but no GT
  460. if len(gt_label_mask) == 0 and len(det_label_mask) > 0:
  461. return self.__evaluate_image_preds_no_gt(det, idx, det_label_mask, max_det, area_range, num_iou_thrs)
  462. gt = [gt[i] for i in gt_label_mask]
  463. det = [det[i] for i in det_label_mask]
  464. if len(gt) == 0 and len(det) == 0:
  465. return None
  466. if isinstance(det, dict):
  467. det = [det]
  468. if isinstance(gt, dict):
  469. gt = [gt]
  470. areas = compute_area(gt, iou_type=self.iou_type).to(self.device)
  471. ignore_area = torch.logical_or(areas < area_range[0], areas > area_range[1])
  472. # sort dt highest score first, sort gt ignore last
  473. ignore_area_sorted, gtind = torch.sort(ignore_area.to(torch.uint8))
  474. # Convert to uint8 temporarily and back to bool, because "Sort currently does not support bool dtype on CUDA"
  475. ignore_area_sorted = ignore_area_sorted.to(torch.bool).to(self.device)
  476. gt = [gt[i] for i in gtind]
  477. scores = self.detection_scores[idx]
  478. scores_filtered = scores[det_label_mask]
  479. scores_sorted, dtind = torch.sort(scores_filtered, descending=True)
  480. det = [det[i] for i in dtind]
  481. if len(det) > max_det:
  482. det = det[:max_det]
  483. # load computed ious
  484. ious = ious[idx, class_id][:, gtind] if len(ious[idx, class_id]) > 0 else ious[idx, class_id]
  485. num_iou_thrs = len(self.iou_thresholds)
  486. num_gt = len(gt)
  487. num_det = len(det)
  488. gt_matches = torch.zeros((num_iou_thrs, num_gt), dtype=torch.bool, device=self.device)
  489. det_matches = torch.zeros((num_iou_thrs, num_det), dtype=torch.bool, device=self.device)
  490. gt_ignore = ignore_area_sorted
  491. det_ignore = torch.zeros((num_iou_thrs, num_det), dtype=torch.bool, device=self.device)
  492. if torch.numel(ious) > 0:
  493. for idx_iou, t in enumerate(self.iou_thresholds):
  494. for idx_det, _ in enumerate(det):
  495. m = MeanAveragePrecision._find_best_gt_match(t, gt_matches, idx_iou, gt_ignore, ious, idx_det)
  496. if m == -1:
  497. continue
  498. det_ignore[idx_iou, idx_det] = gt_ignore[m]
  499. det_matches[idx_iou, idx_det] = 1
  500. gt_matches[idx_iou, m] = 1
  501. # set unmatched detections outside of area range to ignore
  502. det_areas = compute_area(det, iou_type=self.iou_type).to(self.device)
  503. det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1])
  504. ar = det_ignore_area.reshape((1, num_det))
  505. det_ignore = torch.logical_or(
  506. det_ignore, torch.logical_and(det_matches == 0, torch.repeat_interleave(ar, num_iou_thrs, 0))
  507. )
  508. return {
  509. "dtMatches": det_matches.to(self.device),
  510. "gtMatches": gt_matches.to(self.device),
  511. "dtScores": scores_sorted.to(self.device),
  512. "gtIgnore": gt_ignore.to(self.device),
  513. "dtIgnore": det_ignore.to(self.device),
  514. }
  515. @staticmethod
  516. def _find_best_gt_match(
  517. threshold: int, gt_matches: Tensor, idx_iou: float, gt_ignore: Tensor, ious: Tensor, idx_det: int
  518. ) -> int:
  519. """Return id of best ground truth match with current detection.
  520. Args:
  521. threshold:
  522. Current threshold value.
  523. gt_matches:
  524. Tensor showing if a ground truth matches for threshold ``t`` exists.
  525. idx_iou:
  526. Id of threshold ``t``.
  527. gt_ignore:
  528. Tensor showing if ground truth should be ignored.
  529. ious:
  530. IoUs for all combinations of detection and ground truth.
  531. idx_det:
  532. Id of current detection.
  533. """
  534. previously_matched = gt_matches[idx_iou] # type: ignore[index]
  535. # Remove previously matched or ignored gts
  536. remove_mask = previously_matched | gt_ignore
  537. gt_ious = ious[idx_det] * ~remove_mask
  538. match_idx = gt_ious.argmax().item()
  539. if gt_ious[match_idx] > threshold: # type: ignore[index]
  540. return match_idx # type: ignore[return-value]
  541. return -1
  542. def _summarize(
  543. self,
  544. results: dict,
  545. avg_prec: bool = True,
  546. iou_threshold: Optional[float] = None,
  547. area_range: str = "all",
  548. max_dets: int = 100,
  549. ) -> Tensor:
  550. """Perform evaluation for single class and image.
  551. Args:
  552. results:
  553. Dictionary including precision, recall and scores for all combinations.
  554. avg_prec:
  555. Calculate average precision. Else calculate average recall.
  556. iou_threshold:
  557. IoU threshold. If set to ``None`` it all values are used. Else results are filtered.
  558. area_range:
  559. Bounding box area range key.
  560. max_dets:
  561. Maximum detections.
  562. """
  563. area_inds = [i for i, k in enumerate(self.bbox_area_ranges.keys()) if k == area_range]
  564. mdet_inds = [i for i, k in enumerate(self.max_detection_thresholds) if k == max_dets]
  565. if avg_prec:
  566. # dimension of precision: [TxRxKxAxM]
  567. prec = results["precision"]
  568. # IoU
  569. if iou_threshold is not None:
  570. threshold = self.iou_thresholds.index(iou_threshold)
  571. prec = prec[threshold, :, :, area_inds, mdet_inds]
  572. else:
  573. prec = prec[:, :, :, area_inds, mdet_inds]
  574. else:
  575. # dimension of recall: [TxKxAxM]
  576. prec = results["recall"]
  577. if iou_threshold is not None:
  578. threshold = self.iou_thresholds.index(iou_threshold)
  579. prec = prec[threshold, :, :, area_inds, mdet_inds]
  580. else:
  581. prec = prec[:, :, area_inds, mdet_inds]
  582. return torch.tensor([-1.0]) if len(prec[prec > -1]) == 0 else torch.mean(prec[prec > -1])
  583. def _calculate(self, class_ids: list) -> tuple[MAPMetricResults, MARMetricResults]:
  584. """Calculate the precision and recall for all supplied classes to calculate mAP/mAR.
  585. Args:
  586. class_ids:
  587. List of label class Ids.
  588. """
  589. img_ids = range(len(self.groundtruths))
  590. max_detections = self.max_detection_thresholds[-1]
  591. area_ranges = self.bbox_area_ranges.values()
  592. ious = {
  593. (idx, class_id): self._compute_iou(idx, class_id, max_detections)
  594. for idx in img_ids
  595. for class_id in class_ids
  596. }
  597. eval_imgs = [
  598. self._evaluate_image(img_id, class_id, area, max_detections, ious) # type: ignore[arg-type]
  599. for class_id in class_ids
  600. for area in area_ranges
  601. for img_id in img_ids
  602. ]
  603. num_iou_thrs = len(self.iou_thresholds)
  604. num_rec_thrs = len(self.rec_thresholds)
  605. num_classes = len(class_ids)
  606. num_bbox_areas = len(self.bbox_area_ranges)
  607. num_max_det_thresholds = len(self.max_detection_thresholds)
  608. num_imgs = len(img_ids)
  609. precision = -torch.ones((num_iou_thrs, num_rec_thrs, num_classes, num_bbox_areas, num_max_det_thresholds))
  610. recall = -torch.ones((num_iou_thrs, num_classes, num_bbox_areas, num_max_det_thresholds))
  611. scores = -torch.ones((num_iou_thrs, num_rec_thrs, num_classes, num_bbox_areas, num_max_det_thresholds))
  612. # move tensors if necessary
  613. rec_thresholds_tensor = torch.tensor(self.rec_thresholds)
  614. # retrieve E at each category, area range, and max number of detections
  615. for idx_cls, _ in enumerate(class_ids):
  616. for idx_bbox_area, _ in enumerate(self.bbox_area_ranges):
  617. for idx_max_det_thresholds, max_det in enumerate(self.max_detection_thresholds):
  618. recall, precision, scores = MeanAveragePrecision.__calculate_recall_precision_scores(
  619. recall,
  620. precision,
  621. scores,
  622. idx_cls=idx_cls,
  623. idx_bbox_area=idx_bbox_area,
  624. idx_max_det_thresholds=idx_max_det_thresholds,
  625. eval_imgs=eval_imgs,
  626. rec_thresholds=rec_thresholds_tensor,
  627. max_det=max_det,
  628. num_imgs=num_imgs,
  629. num_bbox_areas=num_bbox_areas,
  630. )
  631. return precision, recall # type: ignore[return-value]
  632. def _summarize_results(self, precisions: Tensor, recalls: Tensor) -> tuple[MAPMetricResults, MARMetricResults]:
  633. """Summarizes the precision and recall values to calculate mAP/mAR.
  634. Args:
  635. precisions:
  636. Precision values for different thresholds
  637. recalls:
  638. Recall values for different thresholds
  639. """
  640. results = {"precision": precisions, "recall": recalls}
  641. map_metrics = MAPMetricResults()
  642. last_max_det_threshold = self.max_detection_thresholds[-1]
  643. map_metrics.map = self._summarize(results, True, max_dets=last_max_det_threshold)
  644. if 0.5 in self.iou_thresholds:
  645. map_metrics.map_50 = self._summarize(results, True, iou_threshold=0.5, max_dets=last_max_det_threshold)
  646. else:
  647. map_metrics.map_50 = torch.tensor([-1])
  648. if 0.75 in self.iou_thresholds:
  649. map_metrics.map_75 = self._summarize(results, True, iou_threshold=0.75, max_dets=last_max_det_threshold)
  650. else:
  651. map_metrics.map_75 = torch.tensor([-1])
  652. map_metrics.map_small = self._summarize(results, True, area_range="small", max_dets=last_max_det_threshold)
  653. map_metrics.map_medium = self._summarize(results, True, area_range="medium", max_dets=last_max_det_threshold)
  654. map_metrics.map_large = self._summarize(results, True, area_range="large", max_dets=last_max_det_threshold)
  655. mar_metrics = MARMetricResults()
  656. for max_det in self.max_detection_thresholds:
  657. mar_metrics[f"mar_{max_det}"] = self._summarize(results, False, max_dets=max_det)
  658. mar_metrics.mar_small = self._summarize(results, False, area_range="small", max_dets=last_max_det_threshold)
  659. mar_metrics.mar_medium = self._summarize(results, False, area_range="medium", max_dets=last_max_det_threshold)
  660. mar_metrics.mar_large = self._summarize(results, False, area_range="large", max_dets=last_max_det_threshold)
  661. return map_metrics, mar_metrics
  662. @staticmethod
  663. def __calculate_recall_precision_scores(
  664. recall: Tensor,
  665. precision: Tensor,
  666. scores: Tensor,
  667. idx_cls: int,
  668. idx_bbox_area: int,
  669. idx_max_det_thresholds: int,
  670. eval_imgs: list,
  671. rec_thresholds: Tensor,
  672. max_det: int,
  673. num_imgs: int,
  674. num_bbox_areas: int,
  675. ) -> tuple[Tensor, Tensor, Tensor]:
  676. num_rec_thrs = len(rec_thresholds)
  677. idx_cls_pointer = idx_cls * num_bbox_areas * num_imgs
  678. idx_bbox_area_pointer = idx_bbox_area * num_imgs
  679. # Load all image evals for current class_id and area_range
  680. img_eval_cls_bbox = [eval_imgs[idx_cls_pointer + idx_bbox_area_pointer + i] for i in range(num_imgs)]
  681. img_eval_cls_bbox = [e for e in img_eval_cls_bbox if e is not None]
  682. if not img_eval_cls_bbox:
  683. return recall, precision, scores
  684. det_scores = torch.cat([e["dtScores"][:max_det] for e in img_eval_cls_bbox])
  685. # different sorting method generates slightly different results.
  686. # mergesort is used to be consistent as Matlab implementation.
  687. # Sort in PyTorch does not support bool types on CUDA (yet, 1.11.0)
  688. dtype = torch.uint8 if det_scores.is_cuda and det_scores.dtype is torch.bool else det_scores.dtype
  689. # Explicitly cast to uint8 to avoid error for bool inputs on CUDA to argsort
  690. inds = torch.argsort(det_scores.to(dtype), descending=True)
  691. det_scores_sorted = det_scores[inds]
  692. det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] # type: ignore[call-overload]
  693. det_ignore = torch.cat([e["dtIgnore"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] # type: ignore[call-overload]
  694. gt_ignore = torch.cat([e["gtIgnore"] for e in img_eval_cls_bbox])
  695. npig = torch.count_nonzero(gt_ignore == False) # noqa: E712
  696. if npig == 0:
  697. return recall, precision, scores
  698. tps = torch.logical_and(det_matches, torch.logical_not(det_ignore))
  699. fps = torch.logical_and(torch.logical_not(det_matches), torch.logical_not(det_ignore))
  700. tp_sum = _cumsum(tps, dim=1, dtype=torch.float)
  701. fp_sum = _cumsum(fps, dim=1, dtype=torch.float)
  702. for idx, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
  703. tp_len = len(tp)
  704. rc = tp / npig
  705. pr = tp / (fp + tp + torch.finfo(torch.float64).eps)
  706. prec = torch.zeros((num_rec_thrs,))
  707. score = torch.zeros((num_rec_thrs,))
  708. recall[idx, idx_cls, idx_bbox_area, idx_max_det_thresholds] = rc[-1] if tp_len else 0
  709. # Remove zigzags for AUC
  710. diff_zero = torch.zeros((1,), device=pr.device)
  711. diff = torch.ones((1,), device=pr.device)
  712. while not torch.all(diff == 0):
  713. diff = torch.clamp(torch.cat(((pr[1:] - pr[:-1]), diff_zero), 0), min=0)
  714. pr += diff
  715. inds = torch.searchsorted(rc, rec_thresholds.to(rc.device), right=False)
  716. num_inds = inds.argmax() if inds.max() >= tp_len else num_rec_thrs
  717. inds = inds[:num_inds]
  718. prec[:num_inds] = pr[inds]
  719. score[:num_inds] = det_scores_sorted[inds]
  720. precision[idx, :, idx_cls, idx_bbox_area, idx_max_det_thresholds] = prec
  721. scores[idx, :, idx_cls, idx_bbox_area, idx_max_det_thresholds] = score
  722. return recall, precision, scores
  723. def compute(self) -> dict:
  724. """Compute metric."""
  725. classes = self._get_classes()
  726. precisions, recalls = self._calculate(classes)
  727. map_val, mar_val = self._summarize_results(precisions, recalls) # type: ignore[arg-type]
  728. # if class mode is enabled, evaluate metrics per class
  729. map_per_class_values: Tensor = torch.tensor([-1.0])
  730. mar_max_dets_per_class_values: Tensor = torch.tensor([-1.0])
  731. if self.class_metrics:
  732. map_per_class_list = []
  733. mar_max_dets_per_class_list = []
  734. for class_idx, _ in enumerate(classes):
  735. cls_precisions = precisions[:, :, class_idx].unsqueeze(dim=2)
  736. cls_recalls = recalls[:, class_idx].unsqueeze(dim=1)
  737. cls_map, cls_mar = self._summarize_results(cls_precisions, cls_recalls)
  738. map_per_class_list.append(cls_map.map)
  739. mar_max_dets_per_class_list.append(cls_mar[f"mar_{self.max_detection_thresholds[-1]}"])
  740. map_per_class_values = torch.tensor(map_per_class_list, dtype=torch.float)
  741. mar_max_dets_per_class_values = torch.tensor(mar_max_dets_per_class_list, dtype=torch.float)
  742. metrics = COCOMetricResults()
  743. metrics.update(map_val)
  744. metrics.update(mar_val)
  745. metrics.map_per_class = map_per_class_values
  746. metrics[f"mar_{self.max_detection_thresholds[-1]}_per_class"] = mar_max_dets_per_class_values
  747. metrics.classes = torch.tensor(classes, dtype=torch.int)
  748. return metrics
  749. def _apply(self, fn: Callable) -> torch.nn.Module: # type: ignore[override]
  750. """Custom apply function.
  751. Excludes the detections and groundtruths from the casting when the iou_type is set to `segm` as the state is
  752. no longer a tensor but a tuple.
  753. """
  754. if self.iou_type == "segm":
  755. this = super()._apply(fn, exclude_state=("detections", "groundtruths"))
  756. else:
  757. this = super()._apply(fn)
  758. return this
  759. def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None) -> None:
  760. """Custom sync function.
  761. For the iou_type `segm` the detections and groundtruths are no longer tensors but tuples. Therefore, we need
  762. to gather the list of tuples and then convert it back to a list of tuples.
  763. """
  764. super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) # type: ignore[arg-type]
  765. if self.iou_type == "segm":
  766. self.detections = self._gather_tuple_list(self.detections, process_group) # type: ignore[arg-type]
  767. self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group) # type: ignore[arg-type]
  768. @staticmethod
  769. def _gather_tuple_list(
  770. list_to_gather: list[Union[tuple, Tensor]], process_group: Optional[Any] = None
  771. ) -> list[Any]:
  772. """Gather a list of tuples over multiple devices."""
  773. world_size = dist.get_world_size(group=process_group)
  774. dist.barrier(group=process_group)
  775. list_gathered = [None for _ in range(world_size)]
  776. dist.all_gather_object(list_gathered, list_to_gather, group=process_group)
  777. return [list_gathered[rank][idx] for idx in range(len(list_gathered[0])) for rank in range(world_size)] # type: ignore[arg-type,index]
  778. def plot(
  779. self, val: Optional[Union[dict[str, Tensor], Sequence[dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None
  780. ) -> _PLOT_OUT_TYPE:
  781. """Plot a single or multiple values from the metric.
  782. Args:
  783. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  784. If no value is provided, will automatically call `metric.compute` and plot that result.
  785. ax: An matplotlib axis object. If provided will add plot to that axis
  786. Returns:
  787. Figure object and Axes object
  788. Raises:
  789. ModuleNotFoundError:
  790. If `matplotlib` is not installed
  791. .. plot::
  792. :scale: 75
  793. >>> from torch import tensor
  794. >>> from torchmetrics.detection.mean_ap import MeanAveragePrecision
  795. >>> preds = [dict(
  796. ... boxes=tensor([[258.0, 41.0, 606.0, 285.0]]),
  797. ... scores=tensor([0.536]),
  798. ... labels=tensor([0]),
  799. ... )]
  800. >>> target = [dict(
  801. ... boxes=tensor([[214.0, 41.0, 562.0, 285.0]]),
  802. ... labels=tensor([0]),
  803. ... )]
  804. >>> metric = MeanAveragePrecision()
  805. >>> metric.update(preds, target)
  806. >>> fig_, ax_ = metric.plot()
  807. .. plot::
  808. :scale: 75
  809. >>> # Example plotting multiple values
  810. >>> import torch
  811. >>> from torchmetrics.detection.mean_ap import MeanAveragePrecision
  812. >>> preds = lambda: [dict(
  813. ... boxes=torch.tensor([[258.0, 41.0, 606.0, 285.0]]) + torch.randint(10, (1,4)),
  814. ... scores=torch.tensor([0.536]) + 0.1*torch.rand(1),
  815. ... labels=torch.tensor([0]),
  816. ... )]
  817. >>> target = [dict(
  818. ... boxes=torch.tensor([[214.0, 41.0, 562.0, 285.0]]),
  819. ... labels=torch.tensor([0]),
  820. ... )]
  821. >>> metric = MeanAveragePrecision()
  822. >>> vals = []
  823. >>> for _ in range(20):
  824. ... vals.append(metric(preds(), target))
  825. >>> fig_, ax_ = metric.plot(vals)
  826. """
  827. return self._plot(val, ax)