calibration_curves.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from warnings import simplefilter
  2. import numpy as np
  3. import sklearn
  4. from sklearn import model_selection, naive_bayes
  5. from sklearn.calibration import CalibratedClassifierCV
  6. from sklearn.linear_model import LogisticRegression
  7. import wandb
  8. from wandb.integration.sklearn import utils
  9. # ignore all future warnings
  10. simplefilter(action="ignore", category=FutureWarning)
  11. def calibration_curves(clf, X, y, clf_name): # noqa: N803
  12. # ComplementNB (introduced in 0.20.0) requires non-negative features
  13. if int(sklearn.__version__.split(".")[1]) >= 20 and isinstance(
  14. clf, naive_bayes.ComplementNB
  15. ):
  16. X = X - X.min() # noqa:N806
  17. # Calibrated with isotonic calibration
  18. isotonic = CalibratedClassifierCV(clf, cv=2, method="isotonic")
  19. # Calibrated with sigmoid calibration
  20. sigmoid = CalibratedClassifierCV(clf, cv=2, method="sigmoid")
  21. # Logistic regression with no calibration as baseline
  22. lr = LogisticRegression(C=1.0)
  23. model_column = [] # color
  24. frac_positives_column = [] # y axis
  25. mean_pred_value_column = [] # x axis
  26. hist_column = [] # barchart y
  27. edge_column = [] # barchart x
  28. # Add curve for perfectly calibrated model
  29. # format: model, fraction_of_positives, mean_predicted_value
  30. model_column.append("Perfectly calibrated")
  31. frac_positives_column.append(0)
  32. mean_pred_value_column.append(0)
  33. hist_column.append(0)
  34. edge_column.append(0)
  35. model_column.append("Perfectly calibrated")
  36. hist_column.append(0)
  37. edge_column.append(0)
  38. frac_positives_column.append(1)
  39. mean_pred_value_column.append(1)
  40. x_train, x_test, y_train, y_test = model_selection.train_test_split(
  41. X, y, test_size=0.9, random_state=42
  42. )
  43. # Add curve for LogisticRegression baseline and other models
  44. models = [lr, isotonic, sigmoid]
  45. names = ["Logistic", f"{clf_name} Isotonic", f"{clf_name} Sigmoid"]
  46. for model, name in zip(models, names):
  47. model.fit(x_train, y_train)
  48. if hasattr(model, "predict_proba"):
  49. prob_pos = model.predict_proba(x_test)[:, 1]
  50. else: # use decision function
  51. prob_pos = model.decision_function(x_test)
  52. prob_pos = (prob_pos - prob_pos.min()) / (prob_pos.max() - prob_pos.min())
  53. hist, edges = np.histogram(prob_pos, bins=10, density=False)
  54. frac_positives, mean_pred_value = sklearn.calibration.calibration_curve(
  55. y_test, prob_pos, n_bins=10
  56. )
  57. # format: model, fraction_of_positives, mean_predicted_value
  58. num_entries = len(frac_positives)
  59. for i in range(num_entries):
  60. hist_column.append(hist[i])
  61. edge_column.append(edges[i])
  62. model_column.append(name)
  63. frac_positives_column.append(utils.round_3(frac_positives[i]))
  64. mean_pred_value_column.append(utils.round_3(mean_pred_value[i]))
  65. if utils.check_against_limit(
  66. i,
  67. "calibration_curve",
  68. utils.chart_limit - 2,
  69. ):
  70. break
  71. table = make_table(
  72. model_column,
  73. frac_positives_column,
  74. mean_pred_value_column,
  75. hist_column,
  76. edge_column,
  77. )
  78. chart = wandb.visualize("wandb/calibration/v1", table)
  79. return chart
  80. def make_table(
  81. model_column,
  82. frac_positives_column,
  83. mean_pred_value_column,
  84. hist_column,
  85. edge_column,
  86. ):
  87. columns = [
  88. "model",
  89. "fraction_of_positives",
  90. "mean_predicted_value",
  91. "hist_dict",
  92. "edge_dict",
  93. ]
  94. data = list(
  95. zip(
  96. model_column,
  97. frac_positives_column,
  98. mean_pred_value_column,
  99. hist_column,
  100. edge_column,
  101. )
  102. )
  103. return wandb.Table(columns=columns, data=data)