jaccard.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.classification.base import _ClassificationTaskWrapper
  19. from torchmetrics.classification.confusion_matrix import (
  20. BinaryConfusionMatrix,
  21. MulticlassConfusionMatrix,
  22. MultilabelConfusionMatrix,
  23. )
  24. from torchmetrics.functional.classification.jaccard import (
  25. _jaccard_index_reduce,
  26. _multiclass_jaccard_index_arg_validation,
  27. _multilabel_jaccard_index_arg_validation,
  28. )
  29. from torchmetrics.metric import Metric
  30. from torchmetrics.utilities.enums import ClassificationTask
  31. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  32. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  33. if not _MATPLOTLIB_AVAILABLE:
  34. __doctest_skip__ = ["BinaryJaccardIndex.plot", "MulticlassJaccardIndex.plot", "MultilabelJaccardIndex.plot"]
  35. class BinaryJaccardIndex(BinaryConfusionMatrix):
  36. r"""Calculate the Jaccard index for binary tasks.
  37. The `Jaccard index`_ (also known as the intersection over union or jaccard similarity coefficient) is an statistic
  38. that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the
  39. intersection divided by the union of the sample sets:
  40. .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|}
  41. As input to ``forward`` and ``update`` the metric accepts the following input:
  42. - ``preds`` (:class:`~torch.Tensor`): A int or float tensor of shape ``(N, ...)``. If preds is a floating point
  43. tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.
  44. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
  45. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``.
  46. .. tip::
  47. Additional dimension ``...`` will be flattened into the batch dimension.
  48. As output to ``forward`` and ``compute`` the metric returns the following output:
  49. - ``bji`` (:class:`~torch.Tensor`): A tensor containing the Binary Jaccard Index.
  50. Args:
  51. threshold: Threshold for transforming probability to binary (0,1) predictions
  52. ignore_index:
  53. Specifies a target value that is ignored and does not contribute to the metric calculation
  54. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  55. Set to ``False`` for faster computations.
  56. zero_division:
  57. Value to replace when there is a division by zero. Should be `0` or `1`.
  58. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  59. Example (preds is int tensor):
  60. >>> from torch import tensor
  61. >>> from torchmetrics.classification import BinaryJaccardIndex
  62. >>> target = tensor([1, 1, 0, 0])
  63. >>> preds = tensor([0, 1, 0, 0])
  64. >>> metric = BinaryJaccardIndex()
  65. >>> metric(preds, target)
  66. tensor(0.5000)
  67. Example (preds is float tensor):
  68. >>> from torchmetrics.classification import BinaryJaccardIndex
  69. >>> target = tensor([1, 1, 0, 0])
  70. >>> preds = tensor([0.35, 0.85, 0.48, 0.01])
  71. >>> metric = BinaryJaccardIndex()
  72. >>> metric(preds, target)
  73. tensor(0.5000)
  74. """
  75. is_differentiable: bool = False
  76. higher_is_better: bool = True
  77. full_state_update: bool = False
  78. plot_lower_bound: float = 0.0
  79. plot_upper_bound: float = 1.0
  80. def __init__(
  81. self,
  82. threshold: float = 0.5,
  83. ignore_index: Optional[int] = None,
  84. validate_args: bool = True,
  85. zero_division: float = 0,
  86. **kwargs: Any,
  87. ) -> None:
  88. super().__init__(
  89. threshold=threshold, ignore_index=ignore_index, normalize=None, validate_args=validate_args, **kwargs
  90. )
  91. self.zero_division = zero_division
  92. def compute(self) -> Tensor:
  93. """Compute metric."""
  94. return _jaccard_index_reduce(self.confmat, average="binary", zero_division=self.zero_division)
  95. def plot( # type: ignore[override]
  96. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  97. ) -> _PLOT_OUT_TYPE:
  98. """Plot a single or multiple values from the metric.
  99. Args:
  100. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  101. If no value is provided, will automatically call `metric.compute` and plot that result.
  102. ax: An matplotlib axis object. If provided will add plot to that axis
  103. Returns:
  104. Figure object and Axes object
  105. Raises:
  106. ModuleNotFoundError:
  107. If `matplotlib` is not installed
  108. .. plot::
  109. :scale: 75
  110. >>> # Example plotting a single value
  111. >>> from torch import rand, randint
  112. >>> from torchmetrics.classification import BinaryJaccardIndex
  113. >>> metric = BinaryJaccardIndex()
  114. >>> metric.update(rand(10), randint(2,(10,)))
  115. >>> fig_, ax_ = metric.plot()
  116. .. plot::
  117. :scale: 75
  118. >>> # Example plotting multiple values
  119. >>> from torch import rand, randint
  120. >>> from torchmetrics.classification import BinaryJaccardIndex
  121. >>> metric = BinaryJaccardIndex()
  122. >>> values = [ ]
  123. >>> for _ in range(10):
  124. ... values.append(metric(rand(10), randint(2,(10,))))
  125. >>> fig_, ax_ = metric.plot(values)
  126. """
  127. return self._plot(val, ax)
  128. class MulticlassJaccardIndex(MulticlassConfusionMatrix):
  129. r"""Calculate the Jaccard index for multiclass tasks.
  130. The `Jaccard index`_ (also known as the intersection over union or jaccard similarity coefficient) is an statistic
  131. that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the
  132. intersection divided by the union of the sample sets:
  133. .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|}
  134. As input to ``forward`` and ``update`` the metric accepts the following input:
  135. - ``preds`` (:class:`~torch.Tensor`): A int tensor of shape ``(N, ...)`` or float tensor of shape ``(N, C, ..)``.
  136. If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert
  137. probabilities/logits into an int tensor.
  138. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``.
  139. .. tip::
  140. Additional dimension ``...`` will be flattened into the batch dimension.
  141. As output to ``forward`` and ``compute`` the metric returns the following output:
  142. - ``mcji`` (:class:`~torch.Tensor`): A tensor containing the Multi-class Jaccard Index.
  143. Args:
  144. num_classes: Integer specifying the number of classes
  145. ignore_index:
  146. Specifies a target value that is ignored and does not contribute to the metric calculation
  147. average:
  148. Defines the reduction that is applied over labels. Should be one of the following:
  149. - ``micro``: Sum statistics over all labels
  150. - ``macro``: Calculate statistics for each label and average them
  151. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  152. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  153. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  154. Set to ``False`` for faster computations.
  155. zero_division:
  156. Value to replace when there is a division by zero. Should be `0` or `1`.
  157. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  158. Example (pred is integer tensor):
  159. >>> from torch import tensor
  160. >>> from torchmetrics.classification import MulticlassJaccardIndex
  161. >>> target = tensor([2, 1, 0, 0])
  162. >>> preds = tensor([2, 1, 0, 1])
  163. >>> metric = MulticlassJaccardIndex(num_classes=3)
  164. >>> metric(preds, target)
  165. tensor(0.6667)
  166. Example (pred is float tensor):
  167. >>> from torchmetrics.classification import MulticlassJaccardIndex
  168. >>> target = tensor([2, 1, 0, 0])
  169. >>> preds = tensor([[0.16, 0.26, 0.58],
  170. ... [0.22, 0.61, 0.17],
  171. ... [0.71, 0.09, 0.20],
  172. ... [0.05, 0.82, 0.13]])
  173. >>> metric = MulticlassJaccardIndex(num_classes=3)
  174. >>> metric(preds, target)
  175. tensor(0.6667)
  176. """
  177. is_differentiable: bool = False
  178. higher_is_better: bool = True
  179. full_state_update: bool = False
  180. plot_lower_bound: float = 0.0
  181. plot_upper_bound: float = 1.0
  182. plot_legend_name: str = "Class"
  183. def __init__(
  184. self,
  185. num_classes: int,
  186. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  187. ignore_index: Optional[int] = None,
  188. validate_args: bool = True,
  189. zero_division: float = 0,
  190. **kwargs: Any,
  191. ) -> None:
  192. super().__init__(
  193. num_classes=num_classes, ignore_index=ignore_index, normalize=None, validate_args=False, **kwargs
  194. )
  195. if validate_args:
  196. _multiclass_jaccard_index_arg_validation(num_classes, ignore_index, average)
  197. self.validate_args = validate_args
  198. self.average = average
  199. self.zero_division = zero_division
  200. def compute(self) -> Tensor:
  201. """Compute metric."""
  202. return _jaccard_index_reduce(
  203. self.confmat, average=self.average, ignore_index=self.ignore_index, zero_division=self.zero_division
  204. )
  205. def plot( # type: ignore[override]
  206. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  207. ) -> _PLOT_OUT_TYPE:
  208. """Plot a single or multiple values from the metric.
  209. Args:
  210. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  211. If no value is provided, will automatically call `metric.compute` and plot that result.
  212. ax: An matplotlib axis object. If provided will add plot to that axis
  213. Returns:
  214. Figure object and Axes object
  215. Raises:
  216. ModuleNotFoundError:
  217. If `matplotlib` is not installed
  218. .. plot::
  219. :scale: 75
  220. >>> # Example plotting a single value per class
  221. >>> from torch import randint
  222. >>> from torchmetrics.classification import MulticlassJaccardIndex
  223. >>> metric = MulticlassJaccardIndex(num_classes=3, average=None)
  224. >>> metric.update(randint(3, (20,)), randint(3, (20,)))
  225. >>> fig_, ax_ = metric.plot()
  226. .. plot::
  227. :scale: 75
  228. >>> # Example plotting a multiple values per class
  229. >>> from torch import randint
  230. >>> from torchmetrics.classification import MulticlassJaccardIndex
  231. >>> metric = MulticlassJaccardIndex(num_classes=3, average=None)
  232. >>> values = []
  233. >>> for _ in range(20):
  234. ... values.append(metric(randint(3, (20,)), randint(3, (20,))))
  235. >>> fig_, ax_ = metric.plot(values)
  236. """
  237. return self._plot(val, ax)
  238. class MultilabelJaccardIndex(MultilabelConfusionMatrix):
  239. r"""Calculate the Jaccard index for multilabel tasks.
  240. The `Jaccard index`_ (also known as the intersection over union or jaccard similarity coefficient) is an statistic
  241. that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the
  242. intersection divided by the union of the sample sets:
  243. .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|}
  244. As input to ``forward`` and ``update`` the metric accepts the following input:
  245. - ``preds`` (:class:`~torch.Tensor`): A int tensor or float tensor of shape ``(N, C, ...)``. If preds is a
  246. floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply
  247. sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
  248. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``
  249. .. tip::
  250. Additional dimension ``...`` will be flattened into the batch dimension.
  251. As output to ``forward`` and ``compute`` the metric returns the following output:
  252. - ``mlji`` (:class:`~torch.Tensor`): A tensor containing the Multi-label Jaccard Index loss.
  253. Args:
  254. num_classes: Integer specifying the number of labels
  255. threshold: Threshold for transforming probability to binary (0,1) predictions
  256. ignore_index:
  257. Specifies a target value that is ignored and does not contribute to the metric calculation
  258. average:
  259. Defines the reduction that is applied over labels. Should be one of the following:
  260. - ``micro``: Sum statistics over all labels
  261. - ``macro``: Calculate statistics for each label and average them
  262. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  263. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  264. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  265. Set to ``False`` for faster computations.
  266. zero_division:
  267. Value to replace when there is a division by zero. Should be `0` or `1`.
  268. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  269. Example (preds is int tensor):
  270. >>> from torch import tensor
  271. >>> from torchmetrics.classification import MultilabelJaccardIndex
  272. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  273. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  274. >>> metric = MultilabelJaccardIndex(num_labels=3)
  275. >>> metric(preds, target)
  276. tensor(0.5000)
  277. Example (preds is float tensor):
  278. >>> from torchmetrics.classification import MultilabelJaccardIndex
  279. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  280. >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
  281. >>> metric = MultilabelJaccardIndex(num_labels=3)
  282. >>> metric(preds, target)
  283. tensor(0.5000)
  284. """
  285. is_differentiable: bool = False
  286. higher_is_better: bool = True
  287. full_state_update: bool = False
  288. plot_lower_bound: float = 0.0
  289. plot_upper_bound: float = 1.0
  290. plot_legend_name: str = "Label"
  291. def __init__(
  292. self,
  293. num_labels: int,
  294. threshold: float = 0.5,
  295. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  296. ignore_index: Optional[int] = None,
  297. validate_args: bool = True,
  298. zero_division: float = 0,
  299. **kwargs: Any,
  300. ) -> None:
  301. super().__init__(
  302. num_labels=num_labels,
  303. threshold=threshold,
  304. ignore_index=ignore_index,
  305. normalize=None,
  306. validate_args=False,
  307. **kwargs,
  308. )
  309. if validate_args:
  310. _multilabel_jaccard_index_arg_validation(num_labels, threshold, ignore_index, average)
  311. self.validate_args = validate_args
  312. self.average = average
  313. self.zero_division = zero_division
  314. def compute(self) -> Tensor:
  315. """Compute metric."""
  316. return _jaccard_index_reduce(self.confmat, average=self.average, zero_division=self.zero_division)
  317. def plot( # type: ignore[override]
  318. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  319. ) -> _PLOT_OUT_TYPE:
  320. """Plot a single or multiple values from the metric.
  321. Args:
  322. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  323. If no value is provided, will automatically call `metric.compute` and plot that result.
  324. ax: An matplotlib axis object. If provided will add plot to that axis
  325. Returns:
  326. Figure and Axes object
  327. Raises:
  328. ModuleNotFoundError:
  329. If `matplotlib` is not installed
  330. .. plot::
  331. :scale: 75
  332. >>> # Example plotting a single value
  333. >>> from torch import rand, randint
  334. >>> from torchmetrics.classification import MultilabelJaccardIndex
  335. >>> metric = MultilabelJaccardIndex(num_labels=3)
  336. >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
  337. >>> fig_, ax_ = metric.plot()
  338. .. plot::
  339. :scale: 75
  340. >>> # Example plotting multiple values
  341. >>> from torch import rand, randint
  342. >>> from torchmetrics.classification import MultilabelJaccardIndex
  343. >>> metric = MultilabelJaccardIndex(num_labels=3)
  344. >>> values = [ ]
  345. >>> for _ in range(10):
  346. ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
  347. >>> fig_, ax_ = metric.plot(values)
  348. """
  349. return self._plot(val, ax)
  350. class JaccardIndex(_ClassificationTaskWrapper):
  351. r"""Calculate the Jaccard index for multilabel tasks.
  352. The `Jaccard index`_ (also known as the intersection over union or jaccard similarity coefficient) is an statistic
  353. that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the
  354. intersection divided by the union of the sample sets:
  355. .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|}
  356. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  357. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  358. :class:`~torchmetrics.classification.BinaryJaccardIndex`,
  359. :class:`~torchmetrics.classification.MulticlassJaccardIndex` and
  360. :class:`~torchmetrics.classification.MultilabelJaccardIndex` for the specific details of each argument influence
  361. and examples.
  362. Legacy Example:
  363. >>> from torch import randint, tensor
  364. >>> target = randint(0, 2, (10, 25, 25))
  365. >>> pred = tensor(target)
  366. >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
  367. >>> jaccard = JaccardIndex(task="multiclass", num_classes=2)
  368. >>> jaccard(pred, target)
  369. tensor(0.9660)
  370. """
  371. def __new__( # type: ignore[misc]
  372. cls: type["JaccardIndex"],
  373. task: Literal["binary", "multiclass", "multilabel"],
  374. threshold: float = 0.5,
  375. num_classes: Optional[int] = None,
  376. num_labels: Optional[int] = None,
  377. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  378. ignore_index: Optional[int] = None,
  379. validate_args: bool = True,
  380. **kwargs: Any,
  381. ) -> Metric:
  382. """Initialize task metric."""
  383. task = ClassificationTask.from_str(task)
  384. kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
  385. if task == ClassificationTask.BINARY:
  386. return BinaryJaccardIndex(threshold, **kwargs)
  387. if task == ClassificationTask.MULTICLASS:
  388. if not isinstance(num_classes, int):
  389. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  390. return MulticlassJaccardIndex(num_classes, average, **kwargs)
  391. if task == ClassificationTask.MULTILABEL:
  392. if not isinstance(num_labels, int):
  393. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  394. return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs)
  395. raise ValueError(f"Task {task} not supported!")