residuals.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from warnings import simplefilter
  2. from sklearn import model_selection
  3. import wandb
  4. from wandb.integration.sklearn import utils
  5. # ignore all future warnings
  6. simplefilter(action="ignore", category=FutureWarning)
  7. def residuals(regressor, X, y): # noqa: N803
  8. # Create the train and test splits
  9. x_train, x_test, y_train, y_test = model_selection.train_test_split(
  10. X, y, test_size=0.2
  11. )
  12. # Store labels and colors for the legend ordered by call
  13. regressor.fit(x_train, y_train)
  14. train_score_ = regressor.score(x_train, y_train)
  15. test_score_ = regressor.score(x_test, y_test)
  16. y_pred_train = regressor.predict(x_train)
  17. residuals_train = y_pred_train - y_train
  18. y_pred_test = regressor.predict(x_test)
  19. residuals_test = y_pred_test - y_test
  20. table = make_table(
  21. y_pred_train,
  22. residuals_train,
  23. y_pred_test,
  24. residuals_test,
  25. train_score_,
  26. test_score_,
  27. )
  28. chart = wandb.visualize("wandb/residuals_plot/v1", table)
  29. return chart
  30. def make_table(
  31. y_pred_train,
  32. residuals_train,
  33. y_pred_test,
  34. residuals_test,
  35. train_score_,
  36. test_score_,
  37. ):
  38. y_pred_column, dataset_column, residuals_column = [], [], []
  39. datapoints, max_datapoints_train = 0, 100
  40. for pred, residual in zip(y_pred_train, residuals_train):
  41. # add class counts from training set
  42. y_pred_column.append(pred)
  43. dataset_column.append("train")
  44. residuals_column.append(residual)
  45. datapoints += 1
  46. if utils.check_against_limit(datapoints, "residuals", max_datapoints_train):
  47. break
  48. datapoints = 0
  49. for pred, residual in zip(y_pred_test, residuals_test):
  50. # add class counts from training set
  51. y_pred_column.append(pred)
  52. dataset_column.append("test")
  53. residuals_column.append(residual)
  54. datapoints += 1
  55. if utils.check_against_limit(datapoints, "residuals", max_datapoints_train):
  56. break
  57. columns = ["dataset", "y_pred", "residuals", "train_score", "test_score"]
  58. data = [
  59. [
  60. dataset_column[i],
  61. y_pred_column[i],
  62. residuals_column[i],
  63. train_score_,
  64. test_score_,
  65. ]
  66. for i in range(len(y_pred_column))
  67. ]
  68. table = wandb.Table(columns=columns, data=data)
  69. return table