confusion_matrix.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Optional
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.classification.base import _ClassificationTaskWrapper
  19. from torchmetrics.functional.classification.confusion_matrix import (
  20. _binary_confusion_matrix_arg_validation,
  21. _binary_confusion_matrix_compute,
  22. _binary_confusion_matrix_format,
  23. _binary_confusion_matrix_tensor_validation,
  24. _binary_confusion_matrix_update,
  25. _multiclass_confusion_matrix_arg_validation,
  26. _multiclass_confusion_matrix_compute,
  27. _multiclass_confusion_matrix_format,
  28. _multiclass_confusion_matrix_tensor_validation,
  29. _multiclass_confusion_matrix_update,
  30. _multilabel_confusion_matrix_arg_validation,
  31. _multilabel_confusion_matrix_compute,
  32. _multilabel_confusion_matrix_format,
  33. _multilabel_confusion_matrix_tensor_validation,
  34. _multilabel_confusion_matrix_update,
  35. )
  36. from torchmetrics.metric import Metric
  37. from torchmetrics.utilities.enums import ClassificationTask
  38. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  39. from torchmetrics.utilities.plot import _AX_TYPE, _CMAP_TYPE, _PLOT_OUT_TYPE, plot_confusion_matrix
  40. if not _MATPLOTLIB_AVAILABLE:
  41. __doctest_skip__ = [
  42. "BinaryConfusionMatrix.plot",
  43. "MulticlassConfusionMatrix.plot",
  44. "MultilabelConfusionMatrix.plot",
  45. ]
  46. class BinaryConfusionMatrix(Metric):
  47. r"""Compute the `confusion matrix`_ for binary tasks.
  48. The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations
  49. known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix
  50. correspond to the true class labels and column indices correspond to the predicted class labels.
  51. For binary tasks, the confusion matrix is a 2x2 matrix with the following structure:
  52. - :math:`C_{0, 0}`: True negatives
  53. - :math:`C_{0, 1}`: False positives
  54. - :math:`C_{1, 0}`: False negatives
  55. - :math:`C_{1, 1}`: True positives
  56. As input to ``forward`` and ``update`` the metric accepts the following input:
  57. - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point
  58. tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per
  59. element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
  60. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``.
  61. As output to ``forward`` and ``compute`` the metric returns the following output:
  62. - ``confusion_matrix`` (:class:`~torch.Tensor`): A tensor containing a ``(2, 2)`` matrix
  63. Additional dimension ``...`` will be flattened into the batch dimension.
  64. Args:
  65. threshold: Threshold for transforming probability to binary (0,1) predictions
  66. ignore_index:
  67. Specifies a target value that is ignored and does not contribute to the metric calculation
  68. normalize: Normalization mode for confusion matrix. Choose from:
  69. - ``None`` or ``'none'``: no normalization (default)
  70. - ``'true'``: normalization over the targets (most commonly used)
  71. - ``'pred'``: normalization over the predictions
  72. - ``'all'``: normalization over the whole matrix
  73. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  74. Set to ``False`` for faster computations.
  75. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  76. Example (preds is int tensor):
  77. >>> from torchmetrics.classification import BinaryConfusionMatrix
  78. >>> target = torch.tensor([1, 1, 0, 0])
  79. >>> preds = torch.tensor([0, 1, 0, 0])
  80. >>> bcm = BinaryConfusionMatrix()
  81. >>> bcm(preds, target)
  82. tensor([[2, 0],
  83. [1, 1]])
  84. Example (preds is float tensor):
  85. >>> from torchmetrics.classification import BinaryConfusionMatrix
  86. >>> target = torch.tensor([1, 1, 0, 0])
  87. >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01])
  88. >>> bcm = BinaryConfusionMatrix()
  89. >>> bcm(preds, target)
  90. tensor([[2, 0],
  91. [1, 1]])
  92. """
  93. is_differentiable: bool = False
  94. higher_is_better: Optional[bool] = None
  95. full_state_update: bool = False
  96. confmat: Tensor
  97. def __init__(
  98. self,
  99. threshold: float = 0.5,
  100. ignore_index: Optional[int] = None,
  101. normalize: Optional[Literal["true", "pred", "all", "none"]] = None,
  102. validate_args: bool = True,
  103. **kwargs: Any,
  104. ) -> None:
  105. super().__init__(**kwargs)
  106. if validate_args:
  107. _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize)
  108. self.threshold = threshold
  109. self.ignore_index = ignore_index
  110. self.normalize = normalize
  111. self.validate_args = validate_args
  112. self.add_state("confmat", torch.zeros(2, 2, dtype=torch.long), dist_reduce_fx="sum")
  113. def update(self, preds: Tensor, target: Tensor) -> None:
  114. """Update state with predictions and targets."""
  115. if self.validate_args:
  116. _binary_confusion_matrix_tensor_validation(preds, target, self.ignore_index)
  117. preds, target = _binary_confusion_matrix_format(preds, target, self.threshold, self.ignore_index)
  118. confmat = _binary_confusion_matrix_update(preds, target)
  119. self.confmat += confmat
  120. def compute(self) -> Tensor:
  121. """Compute confusion matrix."""
  122. return _binary_confusion_matrix_compute(self.confmat, self.normalize)
  123. def plot(
  124. self,
  125. val: Optional[Tensor] = None,
  126. ax: Optional[_AX_TYPE] = None,
  127. add_text: bool = True,
  128. labels: Optional[list[str]] = None,
  129. cmap: Optional[_CMAP_TYPE] = None,
  130. ) -> _PLOT_OUT_TYPE:
  131. """Plot a single or multiple values from the metric.
  132. Args:
  133. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  134. If no value is provided, will automatically call `metric.compute` and plot that result.
  135. ax: An matplotlib axis object. If provided will add plot to that axis
  136. add_text: if the value of each cell should be added to the plot
  137. labels: a list of strings, if provided will be added to the plot to indicate the different classes
  138. cmap: matplotlib colormap to use for the confusion matrix
  139. https://matplotlib.org/stable/users/explain/colors/colormaps.html
  140. Returns:
  141. Figure and Axes object
  142. Raises:
  143. ModuleNotFoundError:
  144. If `matplotlib` is not installed
  145. .. plot::
  146. :scale: 75
  147. >>> from torch import randint
  148. >>> from torchmetrics.classification import MulticlassConfusionMatrix
  149. >>> metric = MulticlassConfusionMatrix(num_classes=5)
  150. >>> metric.update(randint(5, (20,)), randint(5, (20,)))
  151. >>> fig_, ax_ = metric.plot()
  152. """
  153. val = val if val is not None else self.compute()
  154. if not isinstance(val, Tensor):
  155. raise TypeError(f"Expected val to be a single tensor but got {val}")
  156. fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
  157. return fig, ax
  158. class MulticlassConfusionMatrix(Metric):
  159. r"""Compute the `confusion matrix`_ for multiclass tasks.
  160. The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations
  161. known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix
  162. correspond to the true class labels and column indices correspond to the predicted class labels.
  163. For multiclass tasks, the confusion matrix is a NxN matrix, where:
  164. - :math:`C_{i, i}` represents the number of true positives for class :math:`i`
  165. - :math:`\sum_{j=1, j\neq i}^N C_{i, j}` represents the number of false negatives for class :math:`i`
  166. - :math:`\sum_{j=1, j\neq i}^N C_{j, i}` represents the number of false positives for class :math:`i`
  167. - the sum of the remaining cells in the matrix represents the number of true negatives for class :math:`i`
  168. As input to ``forward`` and ``update`` the metric accepts the following input:
  169. - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point
  170. tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per
  171. element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
  172. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``.
  173. As output to ``forward`` and ``compute`` the metric returns the following output:
  174. - ``confusion_matrix``: [num_classes, num_classes] matrix
  175. Args:
  176. num_classes: Integer specifying the number of classes
  177. ignore_index:
  178. Specifies a target value that is ignored and does not contribute to the metric calculation
  179. normalize: Normalization mode for confusion matrix. Choose from:
  180. - ``None`` or ``'none'``: no normalization (default)
  181. - ``'true'``: normalization over the targets (most commonly used)
  182. - ``'pred'``: normalization over the predictions
  183. - ``'all'``: normalization over the whole matrix
  184. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  185. Set to ``False`` for faster computations.
  186. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  187. Example (pred is integer tensor):
  188. >>> from torch import tensor
  189. >>> from torchmetrics.classification import MulticlassConfusionMatrix
  190. >>> target = tensor([2, 1, 0, 0])
  191. >>> preds = tensor([2, 1, 0, 1])
  192. >>> metric = MulticlassConfusionMatrix(num_classes=3)
  193. >>> metric(preds, target)
  194. tensor([[1, 1, 0],
  195. [0, 1, 0],
  196. [0, 0, 1]])
  197. Example (pred is float tensor):
  198. >>> from torchmetrics.classification import MulticlassConfusionMatrix
  199. >>> target = tensor([2, 1, 0, 0])
  200. >>> preds = tensor([[0.16, 0.26, 0.58],
  201. ... [0.22, 0.61, 0.17],
  202. ... [0.71, 0.09, 0.20],
  203. ... [0.05, 0.82, 0.13]])
  204. >>> metric = MulticlassConfusionMatrix(num_classes=3)
  205. >>> metric(preds, target)
  206. tensor([[1, 1, 0],
  207. [0, 1, 0],
  208. [0, 0, 1]])
  209. """
  210. is_differentiable: bool = False
  211. higher_is_better: Optional[bool] = None
  212. full_state_update: bool = False
  213. confmat: Tensor
  214. def __init__(
  215. self,
  216. num_classes: int,
  217. ignore_index: Optional[int] = None,
  218. normalize: Optional[Literal["none", "true", "pred", "all"]] = None,
  219. validate_args: bool = True,
  220. **kwargs: Any,
  221. ) -> None:
  222. super().__init__(**kwargs)
  223. if validate_args:
  224. _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize)
  225. self.num_classes = num_classes
  226. self.ignore_index = ignore_index
  227. self.normalize = normalize
  228. self.validate_args = validate_args
  229. self.add_state("confmat", torch.zeros(num_classes, num_classes, dtype=torch.long), dist_reduce_fx="sum")
  230. def update(self, preds: Tensor, target: Tensor) -> None:
  231. """Update state with predictions and targets."""
  232. if self.validate_args:
  233. _multiclass_confusion_matrix_tensor_validation(preds, target, self.num_classes, self.ignore_index)
  234. preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index)
  235. confmat = _multiclass_confusion_matrix_update(preds, target, self.num_classes)
  236. self.confmat += confmat
  237. def compute(self) -> Tensor:
  238. """Compute confusion matrix."""
  239. return _multiclass_confusion_matrix_compute(self.confmat, self.normalize)
  240. def plot(
  241. self,
  242. val: Optional[Tensor] = None,
  243. ax: Optional[_AX_TYPE] = None,
  244. add_text: bool = True,
  245. labels: Optional[list[str]] = None,
  246. cmap: Optional[_CMAP_TYPE] = None,
  247. ) -> _PLOT_OUT_TYPE:
  248. """Plot a single or multiple values from the metric.
  249. Args:
  250. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  251. If no value is provided, will automatically call `metric.compute` and plot that result.
  252. ax: An matplotlib axis object. If provided will add plot to that axis
  253. add_text: if the value of each cell should be added to the plot
  254. labels: a list of strings, if provided will be added to the plot to indicate the different classes
  255. cmap: matplotlib colormap to use for the confusion matrix
  256. https://matplotlib.org/stable/users/explain/colors/colormaps.html
  257. Returns:
  258. Figure and Axes object
  259. Raises:
  260. ModuleNotFoundError:
  261. If `matplotlib` is not installed
  262. .. plot::
  263. :scale: 75
  264. >>> from torch import randint
  265. >>> from torchmetrics.classification import MulticlassConfusionMatrix
  266. >>> metric = MulticlassConfusionMatrix(num_classes=5)
  267. >>> metric.update(randint(5, (20,)), randint(5, (20,)))
  268. >>> fig_, ax_ = metric.plot()
  269. """
  270. val = val if val is not None else self.compute()
  271. if not isinstance(val, Tensor):
  272. raise TypeError(f"Expected val to be a single tensor but got {val}")
  273. fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
  274. return fig, ax
  275. class MultilabelConfusionMatrix(Metric):
  276. r"""Compute the `confusion matrix`_ for multilabel tasks.
  277. The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations
  278. known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix
  279. correspond to the true class labels and column indices correspond to the predicted class labels.
  280. For multilabel tasks, the confusion matrix is a Nx2x2 tensor, where each 2x2 matrix corresponds to the confusion
  281. for that label. The structure of each 2x2 matrix is as follows:
  282. - :math:`C_{0, 0}`: True negatives
  283. - :math:`C_{0, 1}`: False positives
  284. - :math:`C_{1, 0}`: False negatives
  285. - :math:`C_{1, 1}`: True positives
  286. As input to 'update' the metric accepts the following input:
  287. - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside
  288. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  289. we convert to int tensor with thresholding using the value in ``threshold``.
  290. - ``target`` (int tensor): ``(N, C, ...)``
  291. As output of 'compute' the metric returns the following output:
  292. - ``confusion matrix``: [num_labels,2,2] matrix
  293. Args:
  294. num_classes: Integer specifying the number of labels
  295. threshold: Threshold for transforming probability to binary (0,1) predictions
  296. ignore_index:
  297. Specifies a target value that is ignored and does not contribute to the metric calculation
  298. normalize: Normalization mode for confusion matrix. Choose from:
  299. - ``None`` or ``'none'``: no normalization (default)
  300. - ``'true'``: normalization over the targets (most commonly used)
  301. - ``'pred'``: normalization over the predictions
  302. - ``'all'``: normalization over the whole matrix
  303. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  304. Set to ``False`` for faster computations.
  305. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  306. Example (preds is int tensor):
  307. >>> from torch import tensor
  308. >>> from torchmetrics.classification import MultilabelConfusionMatrix
  309. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  310. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  311. >>> metric = MultilabelConfusionMatrix(num_labels=3)
  312. >>> metric(preds, target)
  313. tensor([[[1, 0], [0, 1]],
  314. [[1, 0], [1, 0]],
  315. [[0, 1], [0, 1]]])
  316. Example (preds is float tensor):
  317. >>> from torchmetrics.classification import MultilabelConfusionMatrix
  318. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  319. >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
  320. >>> metric = MultilabelConfusionMatrix(num_labels=3)
  321. >>> metric(preds, target)
  322. tensor([[[1, 0], [0, 1]],
  323. [[1, 0], [1, 0]],
  324. [[0, 1], [0, 1]]])
  325. """
  326. is_differentiable: bool = False
  327. higher_is_better: Optional[bool] = None
  328. full_state_update: bool = False
  329. confmat: Tensor
  330. def __init__(
  331. self,
  332. num_labels: int,
  333. threshold: float = 0.5,
  334. ignore_index: Optional[int] = None,
  335. normalize: Optional[Literal["none", "true", "pred", "all"]] = None,
  336. validate_args: bool = True,
  337. **kwargs: Any,
  338. ) -> None:
  339. super().__init__(**kwargs)
  340. if validate_args:
  341. _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize)
  342. self.num_labels = num_labels
  343. self.threshold = threshold
  344. self.ignore_index = ignore_index
  345. self.normalize = normalize
  346. self.validate_args = validate_args
  347. self.add_state("confmat", torch.zeros(num_labels, 2, 2, dtype=torch.long), dist_reduce_fx="sum")
  348. def update(self, preds: Tensor, target: Tensor) -> None:
  349. """Update state with predictions and targets."""
  350. if self.validate_args:
  351. _multilabel_confusion_matrix_tensor_validation(preds, target, self.num_labels, self.ignore_index)
  352. preds, target = _multilabel_confusion_matrix_format(
  353. preds, target, self.num_labels, self.threshold, self.ignore_index
  354. )
  355. confmat = _multilabel_confusion_matrix_update(preds, target, self.num_labels)
  356. self.confmat += confmat
  357. def compute(self) -> Tensor:
  358. """Compute confusion matrix."""
  359. return _multilabel_confusion_matrix_compute(self.confmat, self.normalize)
  360. def plot(
  361. self,
  362. val: Optional[Tensor] = None,
  363. ax: Optional[_AX_TYPE] = None,
  364. add_text: bool = True,
  365. labels: Optional[list[str]] = None,
  366. cmap: Optional[_CMAP_TYPE] = None,
  367. ) -> _PLOT_OUT_TYPE:
  368. """Plot a single or multiple values from the metric.
  369. Args:
  370. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  371. If no value is provided, will automatically call `metric.compute` and plot that result.
  372. ax: An matplotlib axis object. If provided will add plot to that axis
  373. add_text: if the value of each cell should be added to the plot
  374. labels: a list of strings, if provided will be added to the plot to indicate the different classes
  375. cmap: matplotlib colormap to use for the confusion matrix
  376. https://matplotlib.org/stable/users/explain/colors/colormaps.html
  377. Returns:
  378. Figure and Axes object
  379. Raises:
  380. ModuleNotFoundError:
  381. If `matplotlib` is not installed
  382. .. plot::
  383. :scale: 75
  384. >>> from torch import randint
  385. >>> from torchmetrics.classification import MulticlassConfusionMatrix
  386. >>> metric = MulticlassConfusionMatrix(num_classes=5)
  387. >>> metric.update(randint(5, (20,)), randint(5, (20,)))
  388. >>> fig_, ax_ = metric.plot()
  389. """
  390. val = val if val is not None else self.compute()
  391. if not isinstance(val, Tensor):
  392. raise TypeError(f"Expected val to be a single tensor but got {val}")
  393. fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
  394. return fig, ax
  395. class ConfusionMatrix(_ClassificationTaskWrapper):
  396. r"""Compute the `confusion matrix`_.
  397. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  398. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  399. :class:`~torchmetrics.classification.BinaryConfusionMatrix`,
  400. :class:`~torchmetrics.classification.MulticlassConfusionMatrix` and
  401. :class:`~torchmetrics.classification.MultilabelConfusionMatrix` for the specific details of each argument influence
  402. and examples.
  403. Legacy Example:
  404. >>> from torch import tensor
  405. >>> target = tensor([1, 1, 0, 0])
  406. >>> preds = tensor([0, 1, 0, 0])
  407. >>> confmat = ConfusionMatrix(task="binary", num_classes=2)
  408. >>> confmat(preds, target)
  409. tensor([[2, 0],
  410. [1, 1]])
  411. >>> target = tensor([2, 1, 0, 0])
  412. >>> preds = tensor([2, 1, 0, 1])
  413. >>> confmat = ConfusionMatrix(task="multiclass", num_classes=3)
  414. >>> confmat(preds, target)
  415. tensor([[1, 1, 0],
  416. [0, 1, 0],
  417. [0, 0, 1]])
  418. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  419. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  420. >>> confmat = ConfusionMatrix(task="multilabel", num_labels=3)
  421. >>> confmat(preds, target)
  422. tensor([[[1, 0], [0, 1]],
  423. [[1, 0], [1, 0]],
  424. [[0, 1], [0, 1]]])
  425. """
  426. def __new__( # type: ignore[misc]
  427. cls: type["ConfusionMatrix"],
  428. task: Literal["binary", "multiclass", "multilabel"],
  429. threshold: float = 0.5,
  430. num_classes: Optional[int] = None,
  431. num_labels: Optional[int] = None,
  432. normalize: Optional[Literal["true", "pred", "all", "none"]] = None,
  433. ignore_index: Optional[int] = None,
  434. validate_args: bool = True,
  435. **kwargs: Any,
  436. ) -> Metric:
  437. """Initialize task metric."""
  438. task = ClassificationTask.from_str(task)
  439. kwargs.update({"normalize": normalize, "ignore_index": ignore_index, "validate_args": validate_args})
  440. if task == ClassificationTask.BINARY:
  441. return BinaryConfusionMatrix(threshold, **kwargs)
  442. if task == ClassificationTask.MULTICLASS:
  443. if not isinstance(num_classes, int):
  444. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  445. return MulticlassConfusionMatrix(num_classes, **kwargs)
  446. if task == ClassificationTask.MULTILABEL:
  447. if not isinstance(num_labels, int):
  448. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  449. return MultilabelConfusionMatrix(num_labels, threshold, **kwargs)
  450. raise ValueError(f"Task {task} not supported!")