| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- from warnings import simplefilter
- import numpy as np
- from sklearn import model_selection
- import wandb
- from wandb.integration.sklearn import utils
- # ignore all future warnings
- simplefilter(action="ignore", category=FutureWarning)
- def learning_curve(
- model,
- X, # noqa: N803
- y,
- cv=None,
- shuffle=False,
- random_state=None,
- train_sizes=None,
- n_jobs=1,
- scoring=None,
- ):
- """Train model on datasets of varying size and generates plot of score vs size.
- Called by plot_learning_curve to visualize learning curve. Please use the function
- plot_learning_curve() if you wish to visualize your learning curves.
- """
- train_sizes, train_scores, test_scores = model_selection.learning_curve(
- model,
- X,
- y,
- cv=cv,
- n_jobs=n_jobs,
- train_sizes=train_sizes,
- scoring=scoring,
- shuffle=shuffle,
- random_state=random_state,
- )
- train_scores_mean = np.mean(train_scores, axis=1)
- test_scores_mean = np.mean(test_scores, axis=1)
- table = make_table(train_scores_mean, test_scores_mean, train_sizes)
- chart = wandb.visualize("wandb/learning_curve/v1", table)
- return chart
- def make_table(train, test, train_sizes):
- data = []
- for i in range(len(train)):
- if utils.check_against_limit(
- i,
- "learning_curve",
- utils.chart_limit / 2,
- ):
- break
- train_set = ["train", utils.round_2(train[i]), train_sizes[i]]
- test_set = ["test", utils.round_2(test[i]), train_sizes[i]]
- data.append(train_set)
- data.append(test_set)
- table = wandb.Table(columns=["dataset", "score", "train_size"], data=data)
- return table
|