panoptic_qualities.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # Copyright The PyTorch Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Collection, Sequence
  15. from typing import Any, Optional, Union
  16. import torch
  17. from torch import Tensor
  18. from torchmetrics.functional.detection._panoptic_quality_common import (
  19. _get_category_id_to_continuous_id,
  20. _get_void_color,
  21. _panoptic_quality_compute,
  22. _panoptic_quality_update,
  23. _parse_categories,
  24. _prepocess_inputs,
  25. _validate_inputs,
  26. )
  27. from torchmetrics.metric import Metric
  28. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  29. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  30. if not _MATPLOTLIB_AVAILABLE:
  31. __doctest_skip__ = ["PanopticQuality.plot", "ModifiedPanopticQuality.plot"]
  32. class PanopticQuality(Metric):
  33. r"""Compute the `Panoptic Quality`_ for panoptic segmentations.
  34. .. math::
  35. PQ = \frac{IOU}{TP + 0.5 FP + 0.5 FN}
  36. where IOU, TP, FP and FN are respectively the sum of the intersection over union for true positives,
  37. the number of true positives, false positives and false negatives. This metric is inspired by the PQ
  38. implementation of panopticapi, a standard implementation for the PQ metric for panoptic segmentation.
  39. .. note:
  40. Points in the target tensor that do not map to a known category ID are automatically ignored in the metric
  41. computation.
  42. As input to ``forward`` and ``update`` the metric accepts the following input:
  43. - ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, *spatial_dims, 2)`` containing
  44. the pair ``(category_id, instance_id)`` for each point, where there needs to
  45. be at least one spatial dimension.
  46. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, *spatial_dims, 2)`` containing
  47. the pair ``(category_id, instance_id)`` for each point, where there needs to
  48. be at least one spatial dimension.
  49. As output to ``forward`` and ``compute`` the metric returns the following output:
  50. - ``quality`` (:class:`~torch.Tensor`): If ``return_sq_and_rq=False`` and ``return_per_class=False`` then a
  51. single scalar tensor is returned with average panoptic quality over all classes. If ``return_sq_and_rq=True``
  52. and ``return_per_class=False`` a tensor of length 3 is returned with panoptic, segmentation and recognition
  53. quality (in that order). If If ``return_sq_and_rq=False`` and ``return_per_class=True`` a tensor of length
  54. equal to the number of classes are returned, with panoptic quality for each class. The order of classes is
  55. ``things`` first and then ``stuffs``, and numerically sorted within each.
  56. (ex. with ``things=[4, 1], stuffs=[3, 2]``, the output classes are ordered by ``[1, 4, 2, 3]``)
  57. Finally, if both arguments are ``True`` a tensor of shape ``(3, C)`` is returned with individual panoptic,
  58. segmentation and recognition quality for each class.
  59. Args:
  60. things:
  61. Set of ``category_id`` for countable things.
  62. stuffs:
  63. Set of ``category_id`` for uncountable stuffs.
  64. allow_unknown_preds_category:
  65. Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric
  66. computation or raise an exception when found.
  67. return_sq_and_rq:
  68. Boolean flag to specify if Segmentation Quality and Recognition Quality should be also returned.
  69. return_per_class:
  70. Boolean flag to specify if the per-class values should be returned or the class average.
  71. Raises:
  72. ValueError:
  73. If ``things``, ``stuffs`` have at least one common ``category_id``.
  74. TypeError:
  75. If ``things``, ``stuffs`` contain non-integer ``category_id``.
  76. Example:
  77. >>> from torch import tensor
  78. >>> from torchmetrics.detection import PanopticQuality
  79. >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
  80. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  81. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  82. ... [[0, 0], [7, 0], [6, 0], [1, 0]],
  83. ... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
  84. >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
  85. ... [[0, 1], [0, 1], [6, 0], [0, 1]],
  86. ... [[0, 1], [0, 1], [6, 0], [1, 0]],
  87. ... [[0, 1], [7, 0], [1, 0], [1, 0]],
  88. ... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
  89. >>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7})
  90. >>> panoptic_quality(preds, target)
  91. tensor(0.5463, dtype=torch.float64)
  92. You can also return the segmentation and recognition quality alognside the PQ
  93. >>> from torch import tensor
  94. >>> from torchmetrics.detection import PanopticQuality
  95. >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
  96. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  97. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  98. ... [[0, 0], [7, 0], [6, 0], [1, 0]],
  99. ... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
  100. >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
  101. ... [[0, 1], [0, 1], [6, 0], [0, 1]],
  102. ... [[0, 1], [0, 1], [6, 0], [1, 0]],
  103. ... [[0, 1], [7, 0], [1, 0], [1, 0]],
  104. ... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
  105. >>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7}, return_sq_and_rq=True)
  106. >>> panoptic_quality(preds, target)
  107. tensor([0.5463, 0.6111, 0.6667], dtype=torch.float64)
  108. You can also specify to return the per-class metrics
  109. >>> from torch import tensor
  110. >>> from torchmetrics.detection import PanopticQuality
  111. >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
  112. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  113. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  114. ... [[0, 0], [7, 0], [6, 0], [1, 0]],
  115. ... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
  116. >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
  117. ... [[0, 1], [0, 1], [6, 0], [0, 1]],
  118. ... [[0, 1], [0, 1], [6, 0], [1, 0]],
  119. ... [[0, 1], [7, 0], [1, 0], [1, 0]],
  120. ... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
  121. >>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7}, return_per_class=True)
  122. >>> panoptic_quality(preds, target)
  123. tensor([[0.5185, 0.0000, 0.6667, 1.0000]], dtype=torch.float64)
  124. """
  125. is_differentiable: bool = False
  126. higher_is_better: bool = True
  127. full_state_update: bool = False
  128. plot_lower_bound: float = 0.0
  129. plot_upper_bound: float = 1.0
  130. iou_sum: Tensor
  131. true_positives: Tensor
  132. false_positives: Tensor
  133. false_negatives: Tensor
  134. def __init__(
  135. self,
  136. things: Collection[int],
  137. stuffs: Collection[int],
  138. allow_unknown_preds_category: bool = False,
  139. return_sq_and_rq: bool = False,
  140. return_per_class: bool = False,
  141. **kwargs: Any,
  142. ) -> None:
  143. super().__init__(**kwargs)
  144. things, stuffs = _parse_categories(things, stuffs)
  145. self.things = things
  146. self.stuffs = stuffs
  147. self.void_color = _get_void_color(things, stuffs)
  148. self.cat_id_to_continuous_id = _get_category_id_to_continuous_id(things, stuffs)
  149. self.allow_unknown_preds_category = allow_unknown_preds_category
  150. self.return_sq_and_rq = return_sq_and_rq
  151. self.return_per_class = return_per_class
  152. # per category intermediate metrics
  153. num_categories = len(things) + len(stuffs)
  154. self.add_state("iou_sum", default=torch.zeros(num_categories, dtype=torch.double), dist_reduce_fx="sum")
  155. self.add_state("true_positives", default=torch.zeros(num_categories, dtype=torch.int), dist_reduce_fx="sum")
  156. self.add_state("false_positives", default=torch.zeros(num_categories, dtype=torch.int), dist_reduce_fx="sum")
  157. self.add_state("false_negatives", default=torch.zeros(num_categories, dtype=torch.int), dist_reduce_fx="sum")
  158. def update(self, preds: Tensor, target: Tensor) -> None:
  159. r"""Update state with predictions and targets.
  160. Args:
  161. preds: panoptic detection of shape ``[batch, *spatial_dims, 2]`` containing
  162. the pair ``(category_id, instance_id)`` for each point.
  163. If the ``category_id`` refer to a stuff, the instance_id is ignored.
  164. target: ground truth of shape ``[batch, *spatial_dims, 2]`` containing
  165. the pair ``(category_id, instance_id)`` for each pixel of the image.
  166. If the ``category_id`` refer to a stuff, the instance_id is ignored.
  167. Raises:
  168. TypeError:
  169. If ``preds`` or ``target`` is not an ``torch.Tensor``.
  170. ValueError:
  171. If ``preds`` and ``target`` have different shape.
  172. ValueError:
  173. If ``preds`` has less than 3 dimensions.
  174. ValueError:
  175. If the final dimension of ``preds`` has size != 2.
  176. """
  177. _validate_inputs(preds, target)
  178. flatten_preds = _prepocess_inputs(
  179. self.things, self.stuffs, preds, self.void_color, self.allow_unknown_preds_category
  180. )
  181. flatten_target = _prepocess_inputs(self.things, self.stuffs, target, self.void_color, True)
  182. iou_sum, true_positives, false_positives, false_negatives = _panoptic_quality_update(
  183. flatten_preds, flatten_target, self.cat_id_to_continuous_id, self.void_color
  184. )
  185. self.iou_sum += iou_sum
  186. self.true_positives += true_positives
  187. self.false_positives += false_positives
  188. self.false_negatives += false_negatives
  189. def compute(self) -> Tensor:
  190. """Compute panoptic quality based on inputs passed in to ``update`` previously."""
  191. pq, sq, rq, pq_avg, sq_avg, rq_avg = _panoptic_quality_compute(
  192. self.iou_sum, self.true_positives, self.false_positives, self.false_negatives
  193. )
  194. if self.return_per_class:
  195. if self.return_sq_and_rq:
  196. return torch.stack((pq, sq, rq), dim=-1)
  197. return pq.view(1, -1)
  198. if self.return_sq_and_rq:
  199. return torch.stack((pq_avg, sq_avg, rq_avg), dim=0)
  200. return pq_avg
  201. def plot(
  202. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  203. ) -> _PLOT_OUT_TYPE:
  204. """Plot a single or multiple values from the metric.
  205. Args:
  206. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  207. If no value is provided, will automatically call `metric.compute` and plot that result.
  208. ax: An matplotlib axis object. If provided will add plot to that axis
  209. Returns:
  210. Figure object and Axes object
  211. Raises:
  212. ModuleNotFoundError:
  213. If `matplotlib` is not installed
  214. .. plot::
  215. :scale: 75
  216. >>> from torch import tensor
  217. >>> from torchmetrics.detection import PanopticQuality
  218. >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
  219. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  220. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  221. ... [[0, 0], [7, 0], [6, 0], [1, 0]],
  222. ... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
  223. >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
  224. ... [[0, 1], [0, 1], [6, 0], [0, 1]],
  225. ... [[0, 1], [0, 1], [6, 0], [1, 0]],
  226. ... [[0, 1], [7, 0], [1, 0], [1, 0]],
  227. ... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
  228. >>> metric = PanopticQuality(things = {0, 1}, stuffs = {6, 7})
  229. >>> metric.update(preds, target)
  230. >>> fig_, ax_ = metric.plot()
  231. .. plot::
  232. :scale: 75
  233. >>> # Example plotting multiple values
  234. >>> from torch import tensor
  235. >>> from torchmetrics.detection import PanopticQuality
  236. >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
  237. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  238. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  239. ... [[0, 0], [7, 0], [6, 0], [1, 0]],
  240. ... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
  241. >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
  242. ... [[0, 1], [0, 1], [6, 0], [0, 1]],
  243. ... [[0, 1], [0, 1], [6, 0], [1, 0]],
  244. ... [[0, 1], [7, 0], [1, 0], [1, 0]],
  245. ... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
  246. >>> metric = PanopticQuality(things = {0, 1}, stuffs = {6, 7})
  247. >>> vals = []
  248. >>> for _ in range(20):
  249. ... vals.append(metric(preds, target))
  250. >>> fig_, ax_ = metric.plot(vals)
  251. """
  252. return self._plot(val, ax)
  253. class ModifiedPanopticQuality(Metric):
  254. r"""Compute `Modified Panoptic Quality`_ for panoptic segmentations.
  255. The metric was introduced in `Seamless Scene Segmentation paper`_, and is an adaptation of the original
  256. `Panoptic Quality`_ where the metric for a stuff class is computed as
  257. .. math::
  258. PQ^{\dagger}_c = \frac{IOU_c}{|S_c|}
  259. where :math:`IOU_c` is the sum of the intersection over union of all matching segments for a given class, and
  260. :math:`|S_c|` is the overall number of segments in the ground truth for that class.
  261. .. note:
  262. Points in the target tensor that do not map to a known category ID are automatically ignored in the metric
  263. computation.
  264. Args:
  265. things:
  266. Set of ``category_id`` for countable things.
  267. stuffs:
  268. Set of ``category_id`` for uncountable stuffs.
  269. allow_unknown_preds_category:
  270. Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric
  271. computation or raise an exception when found.
  272. Raises:
  273. ValueError:
  274. If ``things``, ``stuffs`` have at least one common ``category_id``.
  275. TypeError:
  276. If ``things``, ``stuffs`` contain non-integer ``category_id``.
  277. Example:
  278. >>> from torch import tensor
  279. >>> from torchmetrics.detection import ModifiedPanopticQuality
  280. >>> preds = tensor([[[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]])
  281. >>> target = tensor([[[0, 1], [0, 0], [6, 0], [7, 0], [6, 0], [255, 0]]])
  282. >>> pq_modified = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7})
  283. >>> pq_modified(preds, target)
  284. tensor(0.7667, dtype=torch.float64)
  285. """
  286. is_differentiable: bool = False
  287. higher_is_better: bool = True
  288. full_state_update: bool = False
  289. plot_lower_bound: float = 0.0
  290. plot_upper_bound: float = 1.0
  291. iou_sum: Tensor
  292. true_positives: Tensor
  293. false_positives: Tensor
  294. false_negatives: Tensor
  295. def __init__(
  296. self,
  297. things: Collection[int],
  298. stuffs: Collection[int],
  299. allow_unknown_preds_category: bool = False,
  300. **kwargs: Any,
  301. ) -> None:
  302. super().__init__(**kwargs)
  303. things, stuffs = _parse_categories(things, stuffs)
  304. self.things = things
  305. self.stuffs = stuffs
  306. self.void_color = _get_void_color(things, stuffs)
  307. self.cat_id_to_continuous_id = _get_category_id_to_continuous_id(things, stuffs)
  308. self.allow_unknown_preds_category = allow_unknown_preds_category
  309. # per category intermediate metrics
  310. num_categories = len(things) + len(stuffs)
  311. self.add_state("iou_sum", default=torch.zeros(num_categories, dtype=torch.double), dist_reduce_fx="sum")
  312. self.add_state("true_positives", default=torch.zeros(num_categories, dtype=torch.int), dist_reduce_fx="sum")
  313. self.add_state("false_positives", default=torch.zeros(num_categories, dtype=torch.int), dist_reduce_fx="sum")
  314. self.add_state("false_negatives", default=torch.zeros(num_categories, dtype=torch.int), dist_reduce_fx="sum")
  315. def update(self, preds: Tensor, target: Tensor) -> None:
  316. r"""Update state with predictions and targets.
  317. Args:
  318. preds: panoptic detection of shape ``[batch, *spatial_dims, 2]`` containing
  319. the pair ``(category_id, instance_id)`` for each point.
  320. If the ``category_id`` refer to a stuff, the instance_id is ignored.
  321. target: ground truth of shape ``[batch, *spatial_dims, 2]`` containing
  322. the pair ``(category_id, instance_id)`` for each pixel of the image.
  323. If the ``category_id`` refer to a stuff, the instance_id is ignored.
  324. Raises:
  325. TypeError:
  326. If ``preds`` or ``target`` is not an ``torch.Tensor``.
  327. ValueError:
  328. If ``preds`` and ``target`` have different shape.
  329. ValueError:
  330. If ``preds`` has less than 3 dimensions.
  331. ValueError:
  332. If the final dimension of ``preds`` has size != 2.
  333. """
  334. _validate_inputs(preds, target)
  335. flatten_preds = _prepocess_inputs(
  336. self.things, self.stuffs, preds, self.void_color, self.allow_unknown_preds_category
  337. )
  338. flatten_target = _prepocess_inputs(self.things, self.stuffs, target, self.void_color, True)
  339. iou_sum, true_positives, false_positives, false_negatives = _panoptic_quality_update(
  340. flatten_preds,
  341. flatten_target,
  342. self.cat_id_to_continuous_id,
  343. self.void_color,
  344. modified_metric_stuffs=self.stuffs,
  345. )
  346. self.iou_sum += iou_sum
  347. self.true_positives += true_positives
  348. self.false_positives += false_positives
  349. self.false_negatives += false_negatives
  350. def compute(self) -> Tensor:
  351. """Compute panoptic quality based on inputs passed in to ``update`` previously."""
  352. _, _, _, pq_avg, _, _ = _panoptic_quality_compute(
  353. self.iou_sum, self.true_positives, self.false_positives, self.false_negatives
  354. )
  355. return pq_avg
  356. def plot(
  357. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  358. ) -> _PLOT_OUT_TYPE:
  359. """Plot a single or multiple values from the metric.
  360. Args:
  361. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  362. If no value is provided, will automatically call `metric.compute` and plot that result.
  363. ax: An matplotlib axis object. If provided will add plot to that axis
  364. Returns:
  365. Figure object and Axes object
  366. Raises:
  367. ModuleNotFoundError:
  368. If `matplotlib` is not installed
  369. .. plot::
  370. :scale: 75
  371. >>> from torch import tensor
  372. >>> from torchmetrics.detection import ModifiedPanopticQuality
  373. >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
  374. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  375. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  376. ... [[0, 0], [7, 0], [6, 0], [1, 0]],
  377. ... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
  378. >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
  379. ... [[0, 1], [0, 1], [6, 0], [0, 1]],
  380. ... [[0, 1], [0, 1], [6, 0], [1, 0]],
  381. ... [[0, 1], [7, 0], [1, 0], [1, 0]],
  382. ... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
  383. >>> metric = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7})
  384. >>> metric.update(preds, target)
  385. >>> fig_, ax_ = metric.plot()
  386. .. plot::
  387. :scale: 75
  388. >>> # Example plotting multiple values
  389. >>> from torch import tensor
  390. >>> from torchmetrics.detection import ModifiedPanopticQuality
  391. >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
  392. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  393. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  394. ... [[0, 0], [7, 0], [6, 0], [1, 0]],
  395. ... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
  396. >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
  397. ... [[0, 1], [0, 1], [6, 0], [0, 1]],
  398. ... [[0, 1], [0, 1], [6, 0], [1, 0]],
  399. ... [[0, 1], [7, 0], [1, 0], [1, 0]],
  400. ... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
  401. >>> metric = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7})
  402. >>> vals = []
  403. >>> for _ in range(20):
  404. ... vals.append(metric(preds, target))
  405. >>> fig_, ax_ = metric.plot(vals)
  406. """
  407. return self._plot(val, ax)