| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- """Define plots for clustering models built with scikit-learn."""
- from warnings import simplefilter
- import pandas as pd
- import sklearn
- import wandb
- from wandb.integration.sklearn import calculate, utils
- # ignore all future warnings
- simplefilter(action="ignore", category=FutureWarning)
- def clusterer(model, X_train, cluster_labels, labels=None, model_name="Clusterer"): # noqa: N803
- """Generates all sklearn clusterer plots supported by W&B.
- The following plots are generated:
- elbow curve, silhouette plot.
- Should only be called with a fitted clusterer (otherwise an error is thrown).
- Args:
- model: (clusterer) Takes in a fitted clusterer.
- X_train: (arr) Training set features.
- cluster_labels: (list) Names for cluster labels. Makes plots easier to read
- by replacing cluster indexes with corresponding names.
- labels: (list) Named labels for target variable (y). Makes plots easier to
- read by replacing target values with corresponding index.
- For example if `labels=['dog', 'cat', 'owl']` all 0s are
- replaced by dog, 1s by cat.
- model_name: (str) Model name. Defaults to 'Clusterer'
- Returns:
- None: To see plots, go to your W&B run page then expand the 'media' tab
- under 'auto visualizations'.
- Example:
- ```python
- wandb.sklearn.plot_clusterer(kmeans, X, cluster_labels, labels, "KMeans")
- ```
- """
- wandb.termlog(f"\nPlotting {model_name}.")
- if isinstance(model, sklearn.cluster.KMeans):
- elbow_curve(model, X_train)
- wandb.termlog("Logged elbow curve.")
- silhouette(model, X_train, cluster_labels, labels=labels, kmeans=True)
- else:
- silhouette(model, X_train, cluster_labels, kmeans=False)
- wandb.termlog("Logged silhouette plot.")
- def elbow_curve(
- clusterer=None,
- X=None, # noqa: N803
- cluster_ranges=None,
- n_jobs=1,
- show_cluster_time=True,
- ):
- """Measures and plots variance explained as a function of the number of clusters.
- Useful in picking the optimal number of clusters.
- Should only be called with a fitted clusterer (otherwise an error is thrown).
- Please note this function fits the model on the training set when called.
- Args:
- model: (clusterer) Takes in a fitted clusterer.
- X: (arr) Training set features.
- Returns:
- None: To see plots, go to your W&B run page then expand the 'media' tab
- under 'auto visualizations'.
- Example:
- ```python
- wandb.sklearn.plot_elbow_curve(model, X_train)
- ```
- """
- if not hasattr(clusterer, "n_clusters"):
- wandb.termlog(
- "n_clusters attribute not in classifier. Cannot plot elbow method."
- )
- return
- not_missing = utils.test_missing(clusterer=clusterer)
- correct_types = utils.test_types
- is_fitted = utils.test_fitted(clusterer)
- if not_missing and correct_types and is_fitted:
- elbow_curve_chart = calculate.elbow_curve(
- clusterer, X, cluster_ranges, n_jobs, show_cluster_time
- )
- wandb.log({"elbow_curve": elbow_curve_chart})
- def silhouette(
- clusterer=None,
- X=None, # noqa: N803
- cluster_labels=None,
- labels=None,
- metric="euclidean",
- kmeans=True,
- ):
- """Measures & plots silhouette coefficients.
- Silhouette coefficients near +1 indicate that the sample is far away from
- the neighboring clusters. A value near 0 indicates that the sample is on or
- very close to the decision boundary between two neighboring clusters and
- negative values indicate that the samples might have been assigned to the wrong cluster.
- Should only be called with a fitted clusterer (otherwise an error is thrown).
- Please note this function fits the model on the training set when called.
- Args:
- model: (clusterer) Takes in a fitted clusterer.
- X: (arr) Training set features.
- cluster_labels: (list) Names for cluster labels. Makes plots easier to read
- by replacing cluster indexes with corresponding names.
- Returns:
- None: To see plots, go to your W&B run page then expand the 'media' tab
- under 'auto visualizations'.
- Example:
- ```python
- wandb.sklearn.plot_silhouette(model, X_train, ["spam", "not spam"])
- ```
- """
- not_missing = utils.test_missing(clusterer=clusterer)
- correct_types = utils.test_types(clusterer=clusterer)
- is_fitted = utils.test_fitted(clusterer)
- if not_missing and correct_types and is_fitted:
- if isinstance(X, (pd.DataFrame)):
- X = X.values # noqa: N806
- silhouette_chart = calculate.silhouette(
- clusterer, X, cluster_labels, labels, metric, kmeans
- )
- wandb.log({"silhouette_plot": silhouette_chart})
|