confusion_matrix.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.utilities.checks import _check_same_shape
  19. from torchmetrics.utilities.compute import normalize_logits_if_needed
  20. from torchmetrics.utilities.data import _bincount
  21. from torchmetrics.utilities.enums import ClassificationTask
  22. from torchmetrics.utilities.prints import rank_zero_warn
  23. def _confusion_matrix_reduce(
  24. confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None
  25. ) -> Tensor:
  26. """Reduce an un-normalized confusion matrix.
  27. Args:
  28. confmat: un-normalized confusion matrix
  29. normalize: normalization method.
  30. - `"true"` will divide by the sum of the column dimension.
  31. - `"pred"` will divide by the sum of the row dimension.
  32. - `"all"` will divide by the sum of the full matrix
  33. - `"none"` or `None` will apply no reduction.
  34. Returns:
  35. Normalized confusion matrix
  36. """
  37. allowed_normalize = ("true", "pred", "all", "none", None)
  38. if normalize not in allowed_normalize:
  39. raise ValueError(f"Argument `normalize` needs to one of the following: {allowed_normalize}")
  40. if normalize is not None and normalize != "none":
  41. confmat = confmat.float() if not confmat.is_floating_point() else confmat
  42. if normalize == "true":
  43. confmat = confmat / confmat.sum(dim=-1, keepdim=True)
  44. elif normalize == "pred":
  45. confmat = confmat / confmat.sum(dim=-2, keepdim=True)
  46. elif normalize == "all":
  47. confmat = confmat / confmat.sum(dim=[-2, -1], keepdim=True)
  48. nan_elements = confmat[torch.isnan(confmat)].nelement()
  49. if nan_elements:
  50. confmat[torch.isnan(confmat)] = 0
  51. rank_zero_warn(f"{nan_elements} NaN values found in confusion matrix have been replaced with zeros.")
  52. return confmat
  53. def _binary_confusion_matrix_arg_validation(
  54. threshold: float = 0.5,
  55. ignore_index: Optional[int] = None,
  56. normalize: Optional[Literal["true", "pred", "all", "none"]] = None,
  57. ) -> None:
  58. """Validate non tensor input.
  59. - ``threshold`` has to be a float in the [0,1] range
  60. - ``ignore_index`` has to be None or int
  61. - ``normalize`` has to be "true" | "pred" | "all" | "none" | None
  62. """
  63. if not (isinstance(threshold, float) and (0 <= threshold <= 1)):
  64. raise ValueError(f"Expected argument `threshold` to be a float in the [0,1] range, but got {threshold}.")
  65. if ignore_index is not None and not isinstance(ignore_index, int):
  66. raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}")
  67. allowed_normalize = ("true", "pred", "all", "none", None)
  68. if normalize not in allowed_normalize:
  69. raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.")
  70. def _binary_confusion_matrix_tensor_validation(
  71. preds: Tensor, target: Tensor, ignore_index: Optional[int] = None
  72. ) -> None:
  73. """Validate tensor input.
  74. - tensors have to be of same shape
  75. - all values in target tensor that are not ignored have to be in {0, 1}
  76. - if pred tensor is not floating point, then all values also have to be in {0, 1}
  77. """
  78. # Check that they have same shape
  79. _check_same_shape(preds, target)
  80. # Check that target only contains {0,1} values or value in ignore_index
  81. unique_values = torch.unique(target, dim=None)
  82. if ignore_index is None:
  83. check = torch.any((unique_values != 0) & (unique_values != 1))
  84. else:
  85. check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index))
  86. if check:
  87. raise RuntimeError(
  88. f"Detected the following values in `target`: {unique_values} but expected only"
  89. f" the following values {[0, 1] if ignore_index is None else [ignore_index]}."
  90. )
  91. # If preds is label tensor, also check that it only contains {0,1} values
  92. if not preds.is_floating_point():
  93. unique_values = torch.unique(preds, dim=None)
  94. if torch.any((unique_values != 0) & (unique_values != 1)):
  95. raise RuntimeError(
  96. f"Detected the following values in `preds`: {unique_values} but expected only"
  97. " the following values [0,1] since preds is a label tensor."
  98. )
  99. def _binary_confusion_matrix_format(
  100. preds: Tensor,
  101. target: Tensor,
  102. threshold: float = 0.5,
  103. ignore_index: Optional[int] = None,
  104. convert_to_labels: bool = True,
  105. ) -> tuple[Tensor, Tensor]:
  106. """Convert all input to label format.
  107. - Remove all datapoints that should be ignored
  108. - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range
  109. - If preds tensor is floating point, thresholds afterwards
  110. """
  111. preds = preds.flatten()
  112. target = target.flatten()
  113. if ignore_index is not None:
  114. idx = target != ignore_index
  115. preds = preds[idx]
  116. target = target[idx]
  117. if preds.is_floating_point():
  118. preds = normalize_logits_if_needed(preds, "sigmoid")
  119. if convert_to_labels:
  120. preds = preds > threshold
  121. return preds, target
  122. def _binary_confusion_matrix_update(preds: Tensor, target: Tensor) -> Tensor:
  123. """Compute the bins to update the confusion matrix with."""
  124. unique_mapping = (target * 2 + preds).to(torch.long)
  125. bins = _bincount(unique_mapping, minlength=4)
  126. return bins.reshape(2, 2)
  127. def _binary_confusion_matrix_compute(
  128. confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None
  129. ) -> Tensor:
  130. """Reduces the confusion matrix to it's final form.
  131. Normalization technique can be chosen by ``normalize``.
  132. """
  133. return _confusion_matrix_reduce(confmat, normalize)
  134. def binary_confusion_matrix(
  135. preds: Tensor,
  136. target: Tensor,
  137. threshold: float = 0.5,
  138. normalize: Optional[Literal["true", "pred", "all", "none"]] = None,
  139. ignore_index: Optional[int] = None,
  140. validate_args: bool = True,
  141. ) -> Tensor:
  142. r"""Compute the `confusion matrix`_ for binary tasks.
  143. Accepts the following input tensors:
  144. - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside
  145. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  146. we convert to int tensor with thresholding using the value in ``threshold``.
  147. - ``target`` (int tensor): ``(N, ...)``
  148. Additional dimension ``...`` will be flattened into the batch dimension.
  149. Args:
  150. preds: Tensor with predictions
  151. target: Tensor with true labels
  152. threshold: Threshold for transforming probability to binary (0,1) predictions
  153. normalize: Normalization mode for confusion matrix. Choose from:
  154. - ``None`` or ``'none'``: no normalization (default)
  155. - ``'true'``: normalization over the targets (most commonly used)
  156. - ``'pred'``: normalization over the predictions
  157. - ``'all'``: normalization over the whole matrix
  158. ignore_index:
  159. Specifies a target value that is ignored and does not contribute to the metric calculation
  160. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  161. Set to ``False`` for faster computations.
  162. Returns:
  163. A ``[2, 2]`` tensor
  164. Example (preds is int tensor):
  165. >>> from torch import tensor
  166. >>> from torchmetrics.functional.classification import binary_confusion_matrix
  167. >>> target = tensor([1, 1, 0, 0])
  168. >>> preds = tensor([0, 1, 0, 0])
  169. >>> binary_confusion_matrix(preds, target)
  170. tensor([[2, 0],
  171. [1, 1]])
  172. Example (preds is float tensor):
  173. >>> from torchmetrics.functional.classification import binary_confusion_matrix
  174. >>> target = tensor([1, 1, 0, 0])
  175. >>> preds = tensor([0.35, 0.85, 0.48, 0.01])
  176. >>> binary_confusion_matrix(preds, target)
  177. tensor([[2, 0],
  178. [1, 1]])
  179. """
  180. if validate_args:
  181. _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize)
  182. _binary_confusion_matrix_tensor_validation(preds, target, ignore_index)
  183. preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index)
  184. confmat = _binary_confusion_matrix_update(preds, target)
  185. return _binary_confusion_matrix_compute(confmat, normalize)
  186. def _multiclass_confusion_matrix_arg_validation(
  187. num_classes: int,
  188. ignore_index: Optional[int] = None,
  189. normalize: Optional[Literal["true", "pred", "all", "none"]] = None,
  190. ) -> None:
  191. """Validate non tensor input.
  192. - ``num_classes`` has to be a int larger than 1
  193. - ``ignore_index`` has to be None or int
  194. - ``normalize`` has to be "true" | "pred" | "all" | "none" | None
  195. """
  196. if not isinstance(num_classes, int) or num_classes < 2:
  197. raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}")
  198. if ignore_index is not None and not isinstance(ignore_index, int):
  199. raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}")
  200. allowed_normalize = ("true", "pred", "all", "none", None)
  201. if normalize not in allowed_normalize:
  202. raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.")
  203. def _multiclass_confusion_matrix_tensor_validation(
  204. preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None
  205. ) -> None:
  206. """Validate tensor input.
  207. - if target has one more dimension than preds, then all dimensions except for preds.shape[1] should match
  208. exactly. preds.shape[1] should have size equal to number of classes
  209. - if preds and target have same number of dims, then all dimensions should match
  210. - all values in target tensor that are not ignored have to be {0, ..., num_classes - 1}
  211. - if pred tensor is not floating point, then all values also have to be in {0, ..., num_classes - 1}
  212. """
  213. if preds.ndim == target.ndim + 1:
  214. if not preds.is_floating_point():
  215. raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.")
  216. if preds.shape[1] != num_classes:
  217. raise ValueError(
  218. "If `preds` have one dimension more than `target`, `preds.shape[1]` should be"
  219. " equal to number of classes."
  220. )
  221. if preds.shape[2:] != target.shape[1:]:
  222. raise ValueError(
  223. "If `preds` have one dimension more than `target`, the shape of `preds` should be"
  224. " (N, C, ...), and the shape of `target` should be (N, ...)."
  225. )
  226. elif preds.ndim == target.ndim:
  227. if preds.shape != target.shape:
  228. raise ValueError(
  229. "The `preds` and `target` should have the same shape,",
  230. f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.",
  231. )
  232. else:
  233. raise ValueError(
  234. "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)"
  235. " and `preds` should be (N, C, ...)."
  236. )
  237. check_value = num_classes if ignore_index is None else num_classes + 1
  238. for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
  239. num_unique_values = len(torch.unique(t, dim=None))
  240. if num_unique_values > check_value:
  241. raise RuntimeError(
  242. f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
  243. f" {num_unique_values} in `target`."
  244. )
  245. def _multiclass_confusion_matrix_format(
  246. preds: Tensor,
  247. target: Tensor,
  248. ignore_index: Optional[int] = None,
  249. convert_to_labels: bool = True,
  250. ) -> tuple[Tensor, Tensor]:
  251. """Convert all input to label format.
  252. - Applies argmax if preds have one more dimension than target
  253. - Remove all datapoints that should be ignored
  254. """
  255. # Apply argmax if we have one more dimension
  256. if preds.ndim == target.ndim + 1 and convert_to_labels:
  257. preds = preds.argmax(dim=1)
  258. preds = preds.flatten() if convert_to_labels else torch.movedim(preds, 1, -1).reshape(-1, preds.shape[1])
  259. target = target.flatten()
  260. if ignore_index is not None:
  261. idx = target != ignore_index
  262. preds = preds[idx]
  263. target = target[idx]
  264. return preds, target
  265. def _multiclass_confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int) -> Tensor:
  266. """Compute the bins to update the confusion matrix with."""
  267. unique_mapping = target.to(torch.long) * num_classes + preds.to(torch.long)
  268. bins = _bincount(unique_mapping, minlength=num_classes**2)
  269. return bins.reshape(num_classes, num_classes)
  270. def _multiclass_confusion_matrix_compute(
  271. confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None
  272. ) -> Tensor:
  273. """Reduces the confusion matrix to it's final form.
  274. Normalization technique can be chosen by ``normalize``.
  275. """
  276. return _confusion_matrix_reduce(confmat, normalize)
  277. def multiclass_confusion_matrix(
  278. preds: Tensor,
  279. target: Tensor,
  280. num_classes: int,
  281. normalize: Optional[Literal["true", "pred", "all", "none"]] = None,
  282. ignore_index: Optional[int] = None,
  283. validate_args: bool = True,
  284. ) -> Tensor:
  285. r"""Compute the `confusion matrix`_ for multiclass tasks.
  286. Accepts the following input tensors:
  287. - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point
  288. we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into
  289. an int tensor.
  290. - ``target`` (int tensor): ``(N, ...)``
  291. Additional dimension ``...`` will be flattened into the batch dimension.
  292. Args:
  293. preds: Tensor with predictions
  294. target: Tensor with true labels
  295. num_classes: Integer specifying the number of classes
  296. normalize: Normalization mode for confusion matrix. Choose from:
  297. - ``None`` or ``'none'``: no normalization (default)
  298. - ``'true'``: normalization over the targets (most commonly used)
  299. - ``'pred'``: normalization over the predictions
  300. - ``'all'``: normalization over the whole matrix
  301. ignore_index:
  302. Specifies a target value that is ignored and does not contribute to the metric calculation
  303. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  304. Set to ``False`` for faster computations.
  305. Returns:
  306. A ``[num_classes, num_classes]`` tensor
  307. Example (pred is integer tensor):
  308. >>> from torch import tensor
  309. >>> from torchmetrics.functional.classification import multiclass_confusion_matrix
  310. >>> target = tensor([2, 1, 0, 0])
  311. >>> preds = tensor([2, 1, 0, 1])
  312. >>> multiclass_confusion_matrix(preds, target, num_classes=3)
  313. tensor([[1, 1, 0],
  314. [0, 1, 0],
  315. [0, 0, 1]])
  316. Example (pred is float tensor):
  317. >>> from torchmetrics.functional.classification import multiclass_confusion_matrix
  318. >>> target = tensor([2, 1, 0, 0])
  319. >>> preds = tensor([[0.16, 0.26, 0.58],
  320. ... [0.22, 0.61, 0.17],
  321. ... [0.71, 0.09, 0.20],
  322. ... [0.05, 0.82, 0.13]])
  323. >>> multiclass_confusion_matrix(preds, target, num_classes=3)
  324. tensor([[1, 1, 0],
  325. [0, 1, 0],
  326. [0, 0, 1]])
  327. """
  328. if validate_args:
  329. _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize)
  330. _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index)
  331. preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index)
  332. confmat = _multiclass_confusion_matrix_update(preds, target, num_classes)
  333. return _multiclass_confusion_matrix_compute(confmat, normalize)
  334. def _multilabel_confusion_matrix_arg_validation(
  335. num_labels: int,
  336. threshold: float = 0.5,
  337. ignore_index: Optional[int] = None,
  338. normalize: Optional[Literal["true", "pred", "all", "none"]] = None,
  339. ) -> None:
  340. """Validate non tensor input.
  341. - ``num_labels`` should be an int larger than 1
  342. - ``threshold`` has to be a float in the [0,1] range
  343. - ``ignore_index`` has to be None or int
  344. - ``normalize`` has to be "true" | "pred" | "all" | "none" | None
  345. """
  346. if not isinstance(num_labels, int) or num_labels < 2:
  347. raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_labels}")
  348. if not (isinstance(threshold, float) and (0 <= threshold <= 1)):
  349. raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.")
  350. if ignore_index is not None and not isinstance(ignore_index, int):
  351. raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}")
  352. allowed_normalize = ("true", "pred", "all", "none", None)
  353. if normalize not in allowed_normalize:
  354. raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.")
  355. def _multilabel_confusion_matrix_tensor_validation(
  356. preds: Tensor, target: Tensor, num_labels: int, ignore_index: Optional[int] = None
  357. ) -> None:
  358. """Validate tensor input.
  359. - tensors have to be of same shape
  360. - the second dimension of both tensors need to be equal to the number of labels
  361. - all values in target tensor that are not ignored have to be in {0, 1}
  362. - if pred tensor is not floating point, then all values also have to be in {0, 1}
  363. """
  364. # Check that they have same shape
  365. _check_same_shape(preds, target)
  366. if preds.shape[1] != num_labels:
  367. raise ValueError(
  368. "Expected both `target.shape[1]` and `preds.shape[1]` to be equal to the number of labels"
  369. f" but got {preds.shape[1]} and expected {num_labels}"
  370. )
  371. # Check that target only contains [0,1] values or value in ignore_index
  372. unique_values = torch.unique(target, dim=None)
  373. if ignore_index is None:
  374. check = torch.any((unique_values != 0) & (unique_values != 1))
  375. else:
  376. check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index))
  377. if check:
  378. raise RuntimeError(
  379. f"Detected the following values in `target`: {unique_values} but expected only"
  380. f" the following values {[0, 1] if ignore_index is None else [ignore_index]}."
  381. )
  382. # If preds is label tensor, also check that it only contains [0,1] values
  383. if not preds.is_floating_point():
  384. unique_values = torch.unique(preds, dim=None)
  385. if torch.any((unique_values != 0) & (unique_values != 1)):
  386. raise RuntimeError(
  387. f"Detected the following values in `preds`: {unique_values} but expected only"
  388. " the following values [0,1] since preds is a label tensor."
  389. )
  390. def _multilabel_confusion_matrix_format(
  391. preds: Tensor,
  392. target: Tensor,
  393. num_labels: int,
  394. threshold: float = 0.5,
  395. ignore_index: Optional[int] = None,
  396. should_threshold: bool = True,
  397. ) -> tuple[Tensor, Tensor]:
  398. """Convert all input to label format.
  399. - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range
  400. - If preds tensor is floating point, thresholds afterwards
  401. - Mask all elements that should be ignored with negative numbers for later filtration
  402. """
  403. if preds.is_floating_point():
  404. preds = normalize_logits_if_needed(preds, "sigmoid")
  405. if should_threshold:
  406. preds = preds > threshold
  407. preds = torch.movedim(preds, 1, -1).reshape(-1, num_labels)
  408. target = torch.movedim(target, 1, -1).reshape(-1, num_labels)
  409. if ignore_index is not None:
  410. preds = preds.clone()
  411. target = target.clone()
  412. # Make sure that when we map, it will always result in a negative number that we can filter away
  413. # Each label correspond to a 2x2 matrix = 4 elements per label
  414. idx = target == ignore_index
  415. preds[idx] = -4 * num_labels
  416. target[idx] = -4 * num_labels
  417. return preds, target
  418. def _multilabel_confusion_matrix_update(preds: Tensor, target: Tensor, num_labels: int) -> Tensor:
  419. """Compute the bins to update the confusion matrix with."""
  420. unique_mapping = ((2 * target + preds) + 4 * torch.arange(num_labels, device=preds.device)).flatten()
  421. unique_mapping = unique_mapping[unique_mapping >= 0]
  422. bins = _bincount(unique_mapping, minlength=4 * num_labels)
  423. return bins.reshape(num_labels, 2, 2)
  424. def _multilabel_confusion_matrix_compute(
  425. confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None
  426. ) -> Tensor:
  427. """Reduces the confusion matrix to it's final form.
  428. Normalization technique can be chosen by ``normalize``.
  429. """
  430. return _confusion_matrix_reduce(confmat, normalize)
  431. def multilabel_confusion_matrix(
  432. preds: Tensor,
  433. target: Tensor,
  434. num_labels: int,
  435. threshold: float = 0.5,
  436. normalize: Optional[Literal["true", "pred", "all", "none"]] = None,
  437. ignore_index: Optional[int] = None,
  438. validate_args: bool = True,
  439. ) -> Tensor:
  440. r"""Compute the `confusion matrix`_ for multilabel tasks.
  441. Accepts the following input tensors:
  442. - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside
  443. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  444. we convert to int tensor with thresholding using the value in ``threshold``.
  445. - ``target`` (int tensor): ``(N, C, ...)``
  446. Additional dimension ``...`` will be flattened into the batch dimension.
  447. Args:
  448. preds: Tensor with predictions
  449. target: Tensor with true labels
  450. num_labels: Integer specifying the number of labels
  451. threshold: Threshold for transforming probability to binary (0,1) predictions
  452. normalize: Normalization mode for confusion matrix. Choose from:
  453. - ``None`` or ``'none'``: no normalization (default)
  454. - ``'true'``: normalization over the targets (most commonly used)
  455. - ``'pred'``: normalization over the predictions
  456. - ``'all'``: normalization over the whole matrix
  457. ignore_index:
  458. Specifies a target value that is ignored and does not contribute to the metric calculation
  459. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  460. Set to ``False`` for faster computations.
  461. Returns:
  462. A ``[num_labels, 2, 2]`` tensor
  463. Example (preds is int tensor):
  464. >>> from torch import tensor
  465. >>> from torchmetrics.functional.classification import multilabel_confusion_matrix
  466. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  467. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  468. >>> multilabel_confusion_matrix(preds, target, num_labels=3)
  469. tensor([[[1, 0], [0, 1]],
  470. [[1, 0], [1, 0]],
  471. [[0, 1], [0, 1]]])
  472. Example (preds is float tensor):
  473. >>> from torchmetrics.functional.classification import multilabel_confusion_matrix
  474. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  475. >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
  476. >>> multilabel_confusion_matrix(preds, target, num_labels=3)
  477. tensor([[[1, 0], [0, 1]],
  478. [[1, 0], [1, 0]],
  479. [[0, 1], [0, 1]]])
  480. """
  481. if validate_args:
  482. _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize)
  483. _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index)
  484. preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index)
  485. confmat = _multilabel_confusion_matrix_update(preds, target, num_labels)
  486. return _multilabel_confusion_matrix_compute(confmat, normalize)
  487. def confusion_matrix(
  488. preds: Tensor,
  489. target: Tensor,
  490. task: Literal["binary", "multiclass", "multilabel"],
  491. threshold: float = 0.5,
  492. num_classes: Optional[int] = None,
  493. num_labels: Optional[int] = None,
  494. normalize: Optional[Literal["true", "pred", "all", "none"]] = None,
  495. ignore_index: Optional[int] = None,
  496. validate_args: bool = True,
  497. ) -> Tensor:
  498. r"""Compute the `confusion matrix`_.
  499. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  500. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  501. :func:`~torchmetrics.functional.classification.binary_confusion_matrix`,
  502. :func:`~torchmetrics.functional.classification.multiclass_confusion_matrix` and
  503. :func:`~torchmetrics.functional.classification.multilabel_confusion_matrix` for
  504. the specific details of each argument influence and examples.
  505. Legacy Example:
  506. >>> from torch import tensor
  507. >>> from torchmetrics.classification import ConfusionMatrix
  508. >>> target = tensor([1, 1, 0, 0])
  509. >>> preds = tensor([0, 1, 0, 0])
  510. >>> confmat = ConfusionMatrix(task="binary")
  511. >>> confmat(preds, target)
  512. tensor([[2, 0],
  513. [1, 1]])
  514. >>> target = tensor([2, 1, 0, 0])
  515. >>> preds = tensor([2, 1, 0, 1])
  516. >>> confmat = ConfusionMatrix(task="multiclass", num_classes=3)
  517. >>> confmat(preds, target)
  518. tensor([[1, 1, 0],
  519. [0, 1, 0],
  520. [0, 0, 1]])
  521. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  522. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  523. >>> confmat = ConfusionMatrix(task="multilabel", num_labels=3)
  524. >>> confmat(preds, target)
  525. tensor([[[1, 0], [0, 1]],
  526. [[1, 0], [1, 0]],
  527. [[0, 1], [0, 1]]])
  528. """
  529. task = ClassificationTask.from_str(task)
  530. if task == ClassificationTask.BINARY:
  531. return binary_confusion_matrix(preds, target, threshold, normalize, ignore_index, validate_args)
  532. if task == ClassificationTask.MULTICLASS:
  533. if not isinstance(num_classes, int):
  534. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  535. return multiclass_confusion_matrix(preds, target, num_classes, normalize, ignore_index, validate_args)
  536. if task == ClassificationTask.MULTILABEL:
  537. if not isinstance(num_labels, int):
  538. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  539. return multilabel_confusion_matrix(preds, target, num_labels, threshold, normalize, ignore_index, validate_args)
  540. raise ValueError(f"Task {task} not supported.")