utils.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. """Shared utilities for the modules in wandb.sklearn."""
  2. from collections.abc import Iterable, Sequence
  3. import numpy as np
  4. import pandas as pd
  5. import scipy
  6. import sklearn
  7. import wandb
  8. chart_limit = 1000
  9. def check_against_limit(count, chart, limit=None):
  10. if limit is None:
  11. limit = chart_limit
  12. if count > limit:
  13. warn_chart_limit(limit, chart)
  14. return True
  15. else:
  16. return False
  17. def warn_chart_limit(limit, chart):
  18. warning = f"using only the first {limit} datapoints to create chart {chart}"
  19. wandb.termwarn(warning)
  20. def encode_labels(df):
  21. le = sklearn.preprocessing.LabelEncoder()
  22. # apply le on categorical feature columns
  23. categorical_cols = df.select_dtypes(
  24. exclude=["int", "float", "float64", "float32", "int32", "int64"]
  25. ).columns
  26. df[categorical_cols] = df[categorical_cols].apply(lambda col: le.fit_transform(col))
  27. def test_types(**kwargs):
  28. test_passed = True
  29. for k, v in kwargs.items():
  30. # check for incorrect types
  31. if (
  32. (k == "X")
  33. or (k == "X_test")
  34. or (k == "y")
  35. or (k == "y_test")
  36. or (k == "y_true")
  37. or (k == "y_probas")
  38. ) and not isinstance(
  39. v,
  40. (
  41. Sequence,
  42. Iterable,
  43. np.ndarray,
  44. np.generic,
  45. pd.DataFrame,
  46. pd.Series,
  47. list,
  48. ),
  49. ):
  50. # FIXME: do this individually
  51. wandb.termerror(f"{k} is not an array. Please try again.")
  52. test_passed = False
  53. # check for classifier types
  54. if k == "model":
  55. if (not sklearn.base.is_classifier(v)) and (
  56. not sklearn.base.is_regressor(v)
  57. ):
  58. wandb.termerror(
  59. f"{k} is not a classifier or regressor. Please try again."
  60. )
  61. test_passed = False
  62. elif k == "clf" or k == "binary_clf":
  63. if not (sklearn.base.is_classifier(v)):
  64. wandb.termerror(f"{k} is not a classifier. Please try again.")
  65. test_passed = False
  66. elif k == "regressor":
  67. if not sklearn.base.is_regressor(v):
  68. wandb.termerror(f"{k} is not a regressor. Please try again.")
  69. test_passed = False
  70. elif k == "clusterer" and getattr(v, "_estimator_type", None) != "clusterer":
  71. wandb.termerror(f"{k} is not a clusterer. Please try again.")
  72. test_passed = False
  73. return test_passed
  74. def test_fitted(model):
  75. try:
  76. model.predict(np.zeros((7, 3)))
  77. except sklearn.exceptions.NotFittedError:
  78. wandb.termerror("Please fit the model before passing it in.")
  79. return False
  80. except AttributeError:
  81. # Some clustering models (LDA, PCA, Agglomerative) don't implement ``predict``
  82. try:
  83. sklearn.utils.validation.check_is_fitted(
  84. model,
  85. [
  86. "coef_",
  87. "estimator_",
  88. "labels_",
  89. "n_clusters_",
  90. "children_",
  91. "components_",
  92. "n_components_",
  93. "n_iter_",
  94. "n_batch_iter_",
  95. "explained_variance_",
  96. "singular_values_",
  97. "mean_",
  98. ],
  99. all_or_any=any,
  100. )
  101. except sklearn.exceptions.NotFittedError:
  102. wandb.termerror("Please fit the model before passing it in.")
  103. return False
  104. else:
  105. return True
  106. except Exception:
  107. # Assume it's fitted, since ``NotFittedError`` wasn't raised
  108. return True
  109. # Test Asummptions for plotting parameters and datasets
  110. def test_missing(**kwargs):
  111. test_passed = True
  112. for k, v in kwargs.items():
  113. # Missing/empty params/datapoint arrays
  114. if v is None:
  115. wandb.termerror(f"{k} is None. Please try again.")
  116. test_passed = False
  117. if (k == "X") or (k == "X_test"):
  118. if isinstance(v, scipy.sparse.csr.csr_matrix):
  119. v = v.toarray()
  120. elif isinstance(v, (pd.DataFrame, pd.Series)):
  121. v = v.to_numpy()
  122. elif isinstance(v, list):
  123. v = np.asarray(v)
  124. # Warn the user about missing values
  125. missing = 0
  126. missing = np.count_nonzero(pd.isnull(v))
  127. if missing > 0:
  128. wandb.termwarn(f"{k} contains {missing} missing values. ")
  129. test_passed = False
  130. # Ensure the dataset contains only integers
  131. non_nums = 0
  132. if v.ndim == 1:
  133. non_nums = sum(
  134. 1
  135. for val in v
  136. if (
  137. not isinstance(val, (int, float, complex))
  138. and not isinstance(val, np.number)
  139. )
  140. )
  141. else:
  142. non_nums = sum(
  143. 1
  144. for sl in v
  145. for val in sl
  146. if (
  147. not isinstance(val, (int, float, complex))
  148. and not isinstance(val, np.number)
  149. )
  150. )
  151. if non_nums > 0:
  152. wandb.termerror(
  153. f"{k} contains values that are not numbers. Please vectorize, label encode or one hot encode {k} "
  154. "and call the plotting function again."
  155. )
  156. test_passed = False
  157. return test_passed
  158. def round_3(n):
  159. return round(n, 3)
  160. def round_2(n):
  161. return round(n, 2)