ciou.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright The PyTorch Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. from torch import Tensor
  17. from torchmetrics.detection.iou import IntersectionOverUnion
  18. from torchmetrics.functional.detection.ciou import _ciou_compute, _ciou_update
  19. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
  20. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  21. if not _TORCHVISION_AVAILABLE:
  22. __doctest_skip__ = ["CompleteIntersectionOverUnion", "CompleteIntersectionOverUnion.plot"]
  23. elif not _MATPLOTLIB_AVAILABLE:
  24. __doctest_skip__ = ["CompleteIntersectionOverUnion.plot"]
  25. class CompleteIntersectionOverUnion(IntersectionOverUnion):
  26. r"""Computes Complete Intersection Over Union (`CIoU`_).
  27. As input to ``forward`` and ``update`` the metric accepts the following input:
  28. - ``preds`` (:class:`~List`): A list consisting of dictionaries each containing the key-values
  29. (each dictionary corresponds to a single image). Parameters that should be provided per dict:
  30. - ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes``
  31. detection boxes of the format specified in the constructor.
  32. By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
  33. - ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
  34. classes for the boxes.
  35. - ``target`` (:class:`~List`): A list consisting of dictionaries each containing the key-values
  36. (each dictionary corresponds to a single image). Parameters that should be provided per dict:
  37. - ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes`` ground
  38. truth boxes of the format specified in the constructor.
  39. By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
  40. - ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
  41. classes for the boxes.
  42. As output of ``forward`` and ``compute`` the metric returns the following output:
  43. - ``ciou_dict``: A dictionary containing the following key-values:
  44. - ciou: (:class:`~torch.Tensor`) with overall ciou value over all classes and samples.
  45. - ciou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class_metrics=True``
  46. Args:
  47. box_format:
  48. Input format of given boxes. Supported formats are ``[`xyxy`, `xywh`, `cxcywh`]``.
  49. iou_thresholds:
  50. Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored.
  51. class_metrics:
  52. Option to enable per-class metrics for IoU. Has a performance impact.
  53. respect_labels:
  54. Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou
  55. between all pairs of boxes.
  56. kwargs:
  57. Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  58. Example:
  59. >>> import torch
  60. >>> from torchmetrics.detection import CompleteIntersectionOverUnion
  61. >>> preds = [
  62. ... {
  63. ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
  64. ... "scores": torch.tensor([0.236, 0.56]),
  65. ... "labels": torch.tensor([4, 5]),
  66. ... }
  67. ... ]
  68. >>> target = [
  69. ... {
  70. ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]),
  71. ... "labels": torch.tensor([5]),
  72. ... }
  73. ... ]
  74. >>> metric = CompleteIntersectionOverUnion()
  75. >>> metric(preds, target)
  76. {'ciou': tensor(0.8611)}
  77. Raises:
  78. ModuleNotFoundError:
  79. If torchvision is not installed with version 0.13.0 or newer.
  80. """
  81. is_differentiable: bool = False
  82. higher_is_better: Optional[bool] = True
  83. full_state_update: bool = True
  84. _iou_type: str = "ciou"
  85. _invalid_val: float = -2.0 # unsure, min val could be just -1.5 as well
  86. def __init__(
  87. self,
  88. box_format: str = "xyxy",
  89. iou_threshold: Optional[float] = None,
  90. class_metrics: bool = False,
  91. respect_labels: bool = True,
  92. **kwargs: Any,
  93. ) -> None:
  94. if not _TORCHVISION_AVAILABLE:
  95. raise ModuleNotFoundError(
  96. f"Metric `{self._iou_type.upper()}` requires that `torchvision` is installed."
  97. " Please install with `pip install torchmetrics[detection]`."
  98. )
  99. super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)
  100. @staticmethod
  101. def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor:
  102. return _ciou_update(*args, **kwargs)
  103. @staticmethod
  104. def _iou_compute_fn(*args: Any, **kwargs: Any) -> Tensor:
  105. return _ciou_compute(*args, **kwargs)
  106. def plot(
  107. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  108. ) -> _PLOT_OUT_TYPE:
  109. """Plot a single or multiple values from the metric.
  110. Args:
  111. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  112. If no value is provided, will automatically call `metric.compute` and plot that result.
  113. ax: An matplotlib axis object. If provided will add plot to that axis
  114. Returns:
  115. Figure object and Axes object
  116. Raises:
  117. ModuleNotFoundError:
  118. If `matplotlib` is not installed
  119. .. plot::
  120. :scale: 75
  121. >>> # Example plotting single value
  122. >>> import torch
  123. >>> from torchmetrics.detection import CompleteIntersectionOverUnion
  124. >>> preds = [
  125. ... {
  126. ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
  127. ... "scores": torch.tensor([0.236, 0.56]),
  128. ... "labels": torch.tensor([4, 5]),
  129. ... }
  130. ... ]
  131. >>> target = [
  132. ... {
  133. ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]),
  134. ... "labels": torch.tensor([5]),
  135. ... }
  136. ... ]
  137. >>> metric = CompleteIntersectionOverUnion()
  138. >>> metric.update(preds, target)
  139. >>> fig_, ax_ = metric.plot()
  140. .. plot::
  141. :scale: 75
  142. >>> # Example plotting multiple values
  143. >>> import torch
  144. >>> from torchmetrics.detection import CompleteIntersectionOverUnion
  145. >>> preds = [
  146. ... {
  147. ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
  148. ... "scores": torch.tensor([0.236, 0.56]),
  149. ... "labels": torch.tensor([4, 5]),
  150. ... }
  151. ... ]
  152. >>> target = lambda : [
  153. ... {
  154. ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)),
  155. ... "labels": torch.tensor([5]),
  156. ... }
  157. ... ]
  158. >>> metric = CompleteIntersectionOverUnion()
  159. >>> vals = []
  160. >>> for _ in range(20):
  161. ... vals.append(metric(preds, target()))
  162. >>> fig_, ax_ = metric.plot(vals)
  163. """
  164. return self._plot(val, ax)