confusion_matrix.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. from __future__ import annotations
  2. from collections.abc import Sequence
  3. from typing import TYPE_CHECKING, TypeVar
  4. import wandb
  5. from wandb import util
  6. from wandb.plot.custom_chart import plot_table
  7. if TYPE_CHECKING:
  8. from wandb.plot.custom_chart import CustomChart
  9. T = TypeVar("T")
  10. def confusion_matrix(
  11. probs: Sequence[Sequence[float]] | None = None,
  12. y_true: Sequence[T] | None = None,
  13. preds: Sequence[T] | None = None,
  14. class_names: Sequence[str] | None = None,
  15. title: str = "Confusion Matrix Curve",
  16. split_table: bool = False,
  17. ) -> CustomChart:
  18. """Constructs a confusion matrix from a sequence of probabilities or predictions.
  19. Args:
  20. probs: A sequence of predicted probabilities for each
  21. class. The sequence shape should be (N, K) where N is the number of samples
  22. and K is the number of classes. If provided, `preds` should not be provided.
  23. y_true: A sequence of true labels.
  24. preds: A sequence of predicted class labels. If provided,
  25. `probs` should not be provided.
  26. class_names: Sequence of class names. If not
  27. provided, class names will be defined as "Class_1", "Class_2", etc.
  28. title: Title of the confusion matrix chart.
  29. split_table: Whether the table should be split into a separate section
  30. in the W&B UI. If `True`, the table will be displayed in a section named
  31. "Custom Chart Tables". Default is `False`.
  32. Returns:
  33. CustomChart: A custom chart object that can be logged to W&B. To log the
  34. chart, pass it to `wandb.log()`.
  35. Raises:
  36. ValueError: If both `probs` and `preds` are provided or if the number of
  37. predictions and true labels are not equal. If the number of unique
  38. predicted classes exceeds the number of class names or if the number of
  39. unique true labels exceeds the number of class names.
  40. wandb.Error: If numpy is not installed.
  41. Examples:
  42. Logging a confusion matrix with random probabilities for wildlife
  43. classification:
  44. ```python
  45. import numpy as np
  46. import wandb
  47. # Define class names for wildlife
  48. wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"]
  49. # Generate random true labels (0 to 3 for 10 samples)
  50. wildlife_y_true = np.random.randint(0, 4, size=10)
  51. # Generate random probabilities for each class (10 samples x 4 classes)
  52. wildlife_probs = np.random.rand(10, 4)
  53. wildlife_probs = np.exp(wildlife_probs) / np.sum(
  54. np.exp(wildlife_probs),
  55. axis=1,
  56. keepdims=True,
  57. )
  58. # Initialize W&B run and log confusion matrix
  59. with wandb.init(project="wildlife_classification") as run:
  60. confusion_matrix = wandb.plot.confusion_matrix(
  61. probs=wildlife_probs,
  62. y_true=wildlife_y_true,
  63. class_names=wildlife_class_names,
  64. title="Wildlife Classification Confusion Matrix",
  65. )
  66. run.log({"wildlife_confusion_matrix": confusion_matrix})
  67. ```
  68. In this example, random probabilities are used to generate a confusion
  69. matrix.
  70. Logging a confusion matrix with simulated model predictions and 85%
  71. accuracy:
  72. ```python
  73. import numpy as np
  74. import wandb
  75. # Define class names for wildlife
  76. wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"]
  77. # Simulate true labels for 200 animal images (imbalanced distribution)
  78. wildlife_y_true = np.random.choice(
  79. [0, 1, 2, 3],
  80. size=200,
  81. p=[0.2, 0.3, 0.25, 0.25],
  82. )
  83. # Simulate model predictions with 85% accuracy
  84. wildlife_preds = [
  85. y_t
  86. if np.random.rand() < 0.85
  87. else np.random.choice([x for x in range(4) if x != y_t])
  88. for y_t in wildlife_y_true
  89. ]
  90. # Initialize W&B run and log confusion matrix
  91. with wandb.init(project="wildlife_classification") as run:
  92. confusion_matrix = wandb.plot.confusion_matrix(
  93. preds=wildlife_preds,
  94. y_true=wildlife_y_true,
  95. class_names=wildlife_class_names,
  96. title="Simulated Wildlife Classification Confusion Matrix",
  97. )
  98. run.log({"wildlife_confusion_matrix": confusion_matrix})
  99. ```
  100. In this example, predictions are simulated with 85% accuracy to generate a
  101. confusion matrix.
  102. """
  103. np = util.get_module(
  104. "numpy",
  105. required=(
  106. "numpy is required to use wandb.plot.confusion_matrix, "
  107. "install with `pip install numpy`",
  108. ),
  109. )
  110. if probs is not None and preds is not None:
  111. raise ValueError("Only one of `probs` or `preds` should be provided, not both.")
  112. if probs is not None:
  113. preds = np.argmax(probs, axis=1).tolist()
  114. if len(preds) != len(y_true):
  115. raise ValueError("The number of predictions and true labels must be equal.")
  116. if class_names is not None:
  117. n_classes = len(class_names)
  118. class_idx = list(range(n_classes))
  119. if len(set(preds)) > len(class_names):
  120. raise ValueError(
  121. "The number of unique predicted classes exceeds the number of class names."
  122. )
  123. if len(set(y_true)) > len(class_names):
  124. raise ValueError(
  125. "The number of unique true labels exceeds the number of class names."
  126. )
  127. else:
  128. class_idx = set(preds).union(set(y_true))
  129. n_classes = len(class_idx)
  130. class_names = [f"Class_{i + 1}" for i in range(n_classes)]
  131. # Create a mapping from class name to index
  132. class_mapping = {val: i for i, val in enumerate(sorted(list(class_idx)))}
  133. counts = np.zeros((n_classes, n_classes))
  134. for i in range(len(preds)):
  135. counts[class_mapping[y_true[i]], class_mapping[preds[i]]] += 1
  136. data = [
  137. [class_names[i], class_names[j], counts[i, j]]
  138. for i in range(n_classes)
  139. for j in range(n_classes)
  140. ]
  141. return plot_table(
  142. data_table=wandb.Table(
  143. columns=["Actual", "Predicted", "nPredictions"],
  144. data=data,
  145. ),
  146. vega_spec_name="wandb/confusion_matrix/v1",
  147. fields={
  148. "Actual": "Actual",
  149. "Predicted": "Predicted",
  150. "nPredictions": "nPredictions",
  151. },
  152. string_fields={"title": title},
  153. split_table=split_table,
  154. )