| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- """Define plots used by multiple sklearn model classes."""
- from warnings import simplefilter
- import numpy as np
- import wandb
- from wandb.integration.sklearn import calculate, utils
- # ignore all future warnings
- simplefilter(action="ignore", category=FutureWarning)
- def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None): # noqa: N803
- """Logs a chart depicting summary metrics for a model.
- Should only be called with a fitted model (otherwise an error is thrown).
- Args:
- model: (clf or reg) Takes in a fitted regressor or classifier.
- X: (arr) Training set features.
- y: (arr) Training set labels.
- X_test: (arr) Test set features.
- y_test: (arr) Test set labels.
- Returns:
- None: To see plots, go to your W&B run page then expand the 'media' tab
- under 'auto visualizations'.
- Example:
- ```python
- wandb.sklearn.plot_summary_metrics(model, X_train, y_train, X_test, y_test)
- ```
- """
- not_missing = utils.test_missing(
- model=model, X=X, y=y, X_test=X_test, y_test=y_test
- )
- correct_types = utils.test_types(
- model=model, X=X, y=y, X_test=X_test, y_test=y_test
- )
- model_fitted = utils.test_fitted(model)
- if not_missing and correct_types and model_fitted:
- metrics_chart = calculate.summary_metrics(model, X, y, X_test, y_test)
- wandb.log({"summary_metrics": metrics_chart})
- def learning_curve(
- model=None,
- X=None, # noqa: N803
- y=None,
- cv=None,
- shuffle=False,
- random_state=None,
- train_sizes=None,
- n_jobs=1,
- scoring=None,
- ):
- """Logs a plot depicting model performance against dataset size.
- Please note this function fits the model to datasets of varying sizes when called.
- Args:
- model: (clf or reg) Takes in a fitted regressor or classifier.
- X: (arr) Dataset features.
- y: (arr) Dataset labels.
- For details on the other keyword arguments, see the documentation for
- `sklearn.model_selection.learning_curve`.
- Returns:
- None: To see plots, go to your W&B run page then expand the 'media' tab
- under 'auto visualizations'.
- Example:
- ```python
- wandb.sklearn.plot_learning_curve(model, X, y)
- ```
- """
- not_missing = utils.test_missing(model=model, X=X, y=y)
- correct_types = utils.test_types(model=model, X=X, y=y)
- if not_missing and correct_types:
- if train_sizes is None:
- train_sizes = np.linspace(0.1, 1.0, 5)
- y = np.asarray(y)
- learning_curve_chart = calculate.learning_curve(
- model, X, y, cv, shuffle, random_state, train_sizes, n_jobs, scoring
- )
- wandb.log({"learning_curve": learning_curve_chart})
|