confusion_matrix.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import torch
  18. # Inspired by:
  19. # https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py#L68-L73
  20. def confusion_matrix(
  21. pred: torch.Tensor, target: torch.Tensor, num_classes: int, normalized: bool = False
  22. ) -> torch.Tensor:
  23. r"""Compute confusion matrix to evaluate the accuracy of a classification.
  24. Args:
  25. pred: tensor with estimated targets returned by a
  26. classifier. The shape can be :math:`(B, *)` and must contain integer
  27. values between 0 and K-1.
  28. target: tensor with ground truth (correct) target
  29. values. The shape can be :math:`(B, *)` and must contain integer
  30. values between 0 and K-1, where targets are assumed to be provided as
  31. one-hot vectors.
  32. num_classes: total possible number of classes in target.
  33. normalized: whether to return the confusion matrix normalized.
  34. Returns:
  35. a tensor containing the confusion matrix with shape
  36. :math:`(B, K, K)` where K is the number of classes.
  37. Example:
  38. >>> logits = torch.tensor([[0, 1, 0]])
  39. >>> target = torch.tensor([[0, 1, 0]])
  40. >>> confusion_matrix(logits, target, num_classes=3)
  41. tensor([[[2., 0., 0.],
  42. [0., 1., 0.],
  43. [0., 0., 0.]]])
  44. """
  45. if not torch.is_tensor(pred) and pred.dtype is not torch.int64:
  46. raise TypeError(f"Input pred type is not a torch.Tensor with torch.int64 dtype. Got {type(pred)}")
  47. if not torch.is_tensor(target) and target.dtype is not torch.int64:
  48. raise TypeError(f"Input target type is not a torch.Tensor with torch.int64 dtype. Got {type(target)}")
  49. if not pred.shape == target.shape:
  50. raise ValueError(f"Inputs pred and target must have the same shape. Got: {pred.shape} and {target.shape}")
  51. if not pred.device == target.device:
  52. raise ValueError(f"Inputs must be in the same device. Got: {pred.device} - {target.device}")
  53. if not isinstance(num_classes, int) or num_classes < 2:
  54. raise ValueError(f"The number of classes must be an integer bigger than two. Got: {num_classes}")
  55. batch_size: int = pred.shape[0]
  56. # hack for bitcounting 2 arrays together
  57. # NOTE: torch.bincount does not implement batched version
  58. pre_bincount: torch.Tensor = pred + target * num_classes
  59. pre_bincount_vec: torch.Tensor = pre_bincount.view(batch_size, -1)
  60. confusion_list = []
  61. for iter_id in range(batch_size):
  62. pb: torch.Tensor = pre_bincount_vec[iter_id]
  63. bin_count: torch.Tensor = torch.bincount(pb, minlength=num_classes**2)
  64. confusion_list.append(bin_count)
  65. confusion_vec: torch.Tensor = torch.stack(confusion_list)
  66. confusion_mat: torch.Tensor = confusion_vec.view(batch_size, num_classes, num_classes).to(torch.float32) # BxKxK
  67. if normalized:
  68. norm_val: torch.Tensor = torch.sum(confusion_mat, dim=1, keepdim=True)
  69. confusion_mat = confusion_mat / (norm_val + 1e-6)
  70. return confusion_mat