confusion_matrix.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import itertools
  2. from warnings import simplefilter
  3. import numpy as np
  4. from sklearn import metrics
  5. from sklearn.utils.multiclass import unique_labels
  6. import wandb
  7. from .. import utils
  8. # ignore all future warnings
  9. simplefilter(action="ignore", category=FutureWarning)
  10. def validate_labels(*args, **kwargs): # FIXME
  11. raise AssertionError()
  12. def confusion_matrix(
  13. y_true=None,
  14. y_pred=None,
  15. labels=None,
  16. true_labels=None,
  17. pred_labels=None,
  18. normalize=False,
  19. ):
  20. """Compute the confusion matrix to evaluate the performance of a classification.
  21. Called by plot_confusion_matrix to visualize roc curves. Please use the function
  22. plot_confusion_matrix() if you wish to visualize your confusion matrix.
  23. """
  24. cm = metrics.confusion_matrix(y_true, y_pred)
  25. if labels is None:
  26. classes = unique_labels(y_true, y_pred)
  27. else:
  28. classes = np.asarray(labels)
  29. if normalize:
  30. cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
  31. cm = np.around(cm, decimals=2)
  32. cm[np.isnan(cm)] = 0.0
  33. if true_labels is None:
  34. true_classes = classes
  35. else:
  36. validate_labels(classes, true_labels, "true_labels")
  37. true_label_indexes = np.in1d(classes, true_labels)
  38. true_classes = classes[true_label_indexes]
  39. cm = cm[true_label_indexes]
  40. if pred_labels is None:
  41. pred_classes = classes
  42. else:
  43. validate_labels(classes, pred_labels, "pred_labels")
  44. pred_label_indexes = np.in1d(classes, pred_labels)
  45. pred_classes = classes[pred_label_indexes]
  46. cm = cm[:, pred_label_indexes]
  47. table = make_table(cm, pred_classes, true_classes, labels)
  48. chart = wandb.visualize("wandb/confusion_matrix/v1", table)
  49. return chart
  50. def make_table(cm, pred_classes, true_classes, labels):
  51. data, count = [], 0
  52. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  53. if labels is not None and (
  54. isinstance(pred_classes[i], int) or isinstance(pred_classes[0], np.integer)
  55. ):
  56. pred = labels[pred_classes[i]]
  57. true = labels[true_classes[j]]
  58. else:
  59. pred = pred_classes[i]
  60. true = true_classes[j]
  61. data.append([pred, true, cm[i, j]])
  62. count += 1
  63. if utils.check_against_limit(
  64. count,
  65. "confusion_matrix",
  66. utils.chart_limit,
  67. ):
  68. break
  69. table = wandb.Table(columns=["Predicted", "Actual", "Count"], data=data)
  70. return table