| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- """Shared utilities for the modules in wandb.sklearn."""
- from collections.abc import Iterable, Sequence
- import numpy as np
- import pandas as pd
- import scipy
- import sklearn
- import wandb
- chart_limit = 1000
- def check_against_limit(count, chart, limit=None):
- if limit is None:
- limit = chart_limit
- if count > limit:
- warn_chart_limit(limit, chart)
- return True
- else:
- return False
- def warn_chart_limit(limit, chart):
- warning = f"using only the first {limit} datapoints to create chart {chart}"
- wandb.termwarn(warning)
- def encode_labels(df):
- le = sklearn.preprocessing.LabelEncoder()
- # apply le on categorical feature columns
- categorical_cols = df.select_dtypes(
- exclude=["int", "float", "float64", "float32", "int32", "int64"]
- ).columns
- df[categorical_cols] = df[categorical_cols].apply(lambda col: le.fit_transform(col))
- def test_types(**kwargs):
- test_passed = True
- for k, v in kwargs.items():
- # check for incorrect types
- if (
- (k == "X")
- or (k == "X_test")
- or (k == "y")
- or (k == "y_test")
- or (k == "y_true")
- or (k == "y_probas")
- ) and not isinstance(
- v,
- (
- Sequence,
- Iterable,
- np.ndarray,
- np.generic,
- pd.DataFrame,
- pd.Series,
- list,
- ),
- ):
- # FIXME: do this individually
- wandb.termerror(f"{k} is not an array. Please try again.")
- test_passed = False
- # check for classifier types
- if k == "model":
- if (not sklearn.base.is_classifier(v)) and (
- not sklearn.base.is_regressor(v)
- ):
- wandb.termerror(
- f"{k} is not a classifier or regressor. Please try again."
- )
- test_passed = False
- elif k == "clf" or k == "binary_clf":
- if not (sklearn.base.is_classifier(v)):
- wandb.termerror(f"{k} is not a classifier. Please try again.")
- test_passed = False
- elif k == "regressor":
- if not sklearn.base.is_regressor(v):
- wandb.termerror(f"{k} is not a regressor. Please try again.")
- test_passed = False
- elif k == "clusterer" and getattr(v, "_estimator_type", None) != "clusterer":
- wandb.termerror(f"{k} is not a clusterer. Please try again.")
- test_passed = False
- return test_passed
- def test_fitted(model):
- try:
- model.predict(np.zeros((7, 3)))
- except sklearn.exceptions.NotFittedError:
- wandb.termerror("Please fit the model before passing it in.")
- return False
- except AttributeError:
- # Some clustering models (LDA, PCA, Agglomerative) don't implement ``predict``
- try:
- sklearn.utils.validation.check_is_fitted(
- model,
- [
- "coef_",
- "estimator_",
- "labels_",
- "n_clusters_",
- "children_",
- "components_",
- "n_components_",
- "n_iter_",
- "n_batch_iter_",
- "explained_variance_",
- "singular_values_",
- "mean_",
- ],
- all_or_any=any,
- )
- except sklearn.exceptions.NotFittedError:
- wandb.termerror("Please fit the model before passing it in.")
- return False
- else:
- return True
- except Exception:
- # Assume it's fitted, since ``NotFittedError`` wasn't raised
- return True
- # Test Asummptions for plotting parameters and datasets
- def test_missing(**kwargs):
- test_passed = True
- for k, v in kwargs.items():
- # Missing/empty params/datapoint arrays
- if v is None:
- wandb.termerror(f"{k} is None. Please try again.")
- test_passed = False
- if (k == "X") or (k == "X_test"):
- if isinstance(v, scipy.sparse.csr.csr_matrix):
- v = v.toarray()
- elif isinstance(v, (pd.DataFrame, pd.Series)):
- v = v.to_numpy()
- elif isinstance(v, list):
- v = np.asarray(v)
- # Warn the user about missing values
- missing = 0
- missing = np.count_nonzero(pd.isnull(v))
- if missing > 0:
- wandb.termwarn(f"{k} contains {missing} missing values. ")
- test_passed = False
- # Ensure the dataset contains only integers
- non_nums = 0
- if v.ndim == 1:
- non_nums = sum(
- 1
- for val in v
- if (
- not isinstance(val, (int, float, complex))
- and not isinstance(val, np.number)
- )
- )
- else:
- non_nums = sum(
- 1
- for sl in v
- for val in sl
- if (
- not isinstance(val, (int, float, complex))
- and not isinstance(val, np.number)
- )
- )
- if non_nums > 0:
- wandb.termerror(
- f"{k} contains values that are not numbers. Please vectorize, label encode or one hot encode {k} "
- "and call the plotting function again."
- )
- test_passed = False
- return test_passed
- def round_3(n):
- return round(n, 3)
- def round_2(n):
- return round(n, 2)
|