learning_curve.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from warnings import simplefilter
  2. import numpy as np
  3. from sklearn import model_selection
  4. import wandb
  5. from wandb.integration.sklearn import utils
  6. # ignore all future warnings
  7. simplefilter(action="ignore", category=FutureWarning)
  8. def learning_curve(
  9. model,
  10. X, # noqa: N803
  11. y,
  12. cv=None,
  13. shuffle=False,
  14. random_state=None,
  15. train_sizes=None,
  16. n_jobs=1,
  17. scoring=None,
  18. ):
  19. """Train model on datasets of varying size and generates plot of score vs size.
  20. Called by plot_learning_curve to visualize learning curve. Please use the function
  21. plot_learning_curve() if you wish to visualize your learning curves.
  22. """
  23. train_sizes, train_scores, test_scores = model_selection.learning_curve(
  24. model,
  25. X,
  26. y,
  27. cv=cv,
  28. n_jobs=n_jobs,
  29. train_sizes=train_sizes,
  30. scoring=scoring,
  31. shuffle=shuffle,
  32. random_state=random_state,
  33. )
  34. train_scores_mean = np.mean(train_scores, axis=1)
  35. test_scores_mean = np.mean(test_scores, axis=1)
  36. table = make_table(train_scores_mean, test_scores_mean, train_sizes)
  37. chart = wandb.visualize("wandb/learning_curve/v1", table)
  38. return chart
  39. def make_table(train, test, train_sizes):
  40. data = []
  41. for i in range(len(train)):
  42. if utils.check_against_limit(
  43. i,
  44. "learning_curve",
  45. utils.chart_limit / 2,
  46. ):
  47. break
  48. train_set = ["train", utils.round_2(train[i]), train_sizes[i]]
  49. test_set = ["test", utils.round_2(test[i]), train_sizes[i]]
  50. data.append(train_set)
  51. data.append(test_set)
  52. table = wandb.Table(columns=["dataset", "score", "train_size"], data=data)
  53. return table