| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- from warnings import simplefilter
- from sklearn import model_selection
- import wandb
- from wandb.integration.sklearn import utils
- # ignore all future warnings
- simplefilter(action="ignore", category=FutureWarning)
- def residuals(regressor, X, y): # noqa: N803
- # Create the train and test splits
- x_train, x_test, y_train, y_test = model_selection.train_test_split(
- X, y, test_size=0.2
- )
- # Store labels and colors for the legend ordered by call
- regressor.fit(x_train, y_train)
- train_score_ = regressor.score(x_train, y_train)
- test_score_ = regressor.score(x_test, y_test)
- y_pred_train = regressor.predict(x_train)
- residuals_train = y_pred_train - y_train
- y_pred_test = regressor.predict(x_test)
- residuals_test = y_pred_test - y_test
- table = make_table(
- y_pred_train,
- residuals_train,
- y_pred_test,
- residuals_test,
- train_score_,
- test_score_,
- )
- chart = wandb.visualize("wandb/residuals_plot/v1", table)
- return chart
- def make_table(
- y_pred_train,
- residuals_train,
- y_pred_test,
- residuals_test,
- train_score_,
- test_score_,
- ):
- y_pred_column, dataset_column, residuals_column = [], [], []
- datapoints, max_datapoints_train = 0, 100
- for pred, residual in zip(y_pred_train, residuals_train):
- # add class counts from training set
- y_pred_column.append(pred)
- dataset_column.append("train")
- residuals_column.append(residual)
- datapoints += 1
- if utils.check_against_limit(datapoints, "residuals", max_datapoints_train):
- break
- datapoints = 0
- for pred, residual in zip(y_pred_test, residuals_test):
- # add class counts from training set
- y_pred_column.append(pred)
- dataset_column.append("test")
- residuals_column.append(residual)
- datapoints += 1
- if utils.check_against_limit(datapoints, "residuals", max_datapoints_train):
- break
- columns = ["dataset", "y_pred", "residuals", "train_score", "test_score"]
- data = [
- [
- dataset_column[i],
- y_pred_column[i],
- residuals_column[i],
- train_score_,
- test_score_,
- ]
- for i in range(len(y_pred_column))
- ]
- table = wandb.Table(columns=columns, data=data)
- return table
|