utils.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from collections.abc import Iterable, Sequence
  2. import wandb
  3. from wandb import util
  4. def test_missing(**kwargs):
  5. np = util.get_module("numpy", required="Logging plots requires numpy")
  6. pd = util.get_module("pandas", required="Logging dataframes requires pandas")
  7. scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy")
  8. test_passed = True
  9. for k, v in kwargs.items():
  10. # Missing/empty params/datapoint arrays
  11. if v is None:
  12. wandb.termerror(f"{k} is None. Please try again.")
  13. test_passed = False
  14. if (k == "X") or (k == "X_test"):
  15. if isinstance(v, scipy.sparse.csr.csr_matrix):
  16. v = v.toarray()
  17. elif isinstance(v, (pd.DataFrame, pd.Series)):
  18. v = v.to_numpy()
  19. elif isinstance(v, list):
  20. v = np.asarray(v)
  21. # Warn the user about missing values
  22. missing = 0
  23. missing = np.count_nonzero(pd.isnull(v))
  24. if missing > 0:
  25. wandb.termwarn("%s contains %d missing values. " % (k, missing))
  26. test_passed = False
  27. # Ensure the dataset contains only integers
  28. non_nums = 0
  29. if v.ndim == 1:
  30. non_nums = sum(
  31. 1
  32. for val in v
  33. if (
  34. not isinstance(val, (int, float, complex))
  35. and not isinstance(val, np.number)
  36. )
  37. )
  38. else:
  39. non_nums = sum(
  40. 1
  41. for sl in v
  42. for val in sl
  43. if (
  44. not isinstance(val, (int, float, complex))
  45. and not isinstance(val, np.number)
  46. )
  47. )
  48. if non_nums > 0:
  49. wandb.termerror(
  50. f"{k} contains values that are not numbers. Please vectorize, "
  51. f"label encode or one hot encode {k} and call the plotting function again."
  52. )
  53. test_passed = False
  54. return test_passed
  55. def test_fitted(model):
  56. np = util.get_module("numpy", required="Logging plots requires numpy")
  57. _ = util.get_module("pandas", required="Logging dataframes requires pandas")
  58. _ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
  59. scikit_utils = util.get_module(
  60. "sklearn.utils",
  61. required="roc requires the scikit utils submodule, install with `pip install scikit-learn`",
  62. )
  63. scikit_exceptions = util.get_module(
  64. "sklearn.exceptions",
  65. "roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`",
  66. )
  67. try:
  68. model.predict(np.zeros((7, 3)))
  69. except scikit_exceptions.NotFittedError:
  70. wandb.termerror("Please fit the model before passing it in.")
  71. return False
  72. except AttributeError:
  73. # Some clustering models (LDA, PCA, Agglomerative) don't implement ``predict``
  74. try:
  75. scikit_utils.validation.check_is_fitted(
  76. model,
  77. [
  78. "coef_",
  79. "estimator_",
  80. "labels_",
  81. "n_clusters_",
  82. "children_",
  83. "components_",
  84. "n_components_",
  85. "n_iter_",
  86. "n_batch_iter_",
  87. "explained_variance_",
  88. "singular_values_",
  89. "mean_",
  90. ],
  91. all_or_any=any,
  92. )
  93. except scikit_exceptions.NotFittedError:
  94. wandb.termerror("Please fit the model before passing it in.")
  95. return False
  96. else:
  97. return True
  98. except Exception:
  99. # Assume it's fitted, since ``NotFittedError`` wasn't raised
  100. return True
  101. def encode_labels(df):
  102. _ = util.get_module("pandas", required="Logging dataframes requires pandas")
  103. preprocessing = util.get_module(
  104. "sklearn.preprocessing",
  105. "roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`",
  106. )
  107. le = preprocessing.LabelEncoder()
  108. # apply le on categorical feature columns
  109. categorical_cols = df.select_dtypes(
  110. exclude=["int", "float", "float64", "float32", "int32", "int64"]
  111. ).columns
  112. df[categorical_cols] = df[categorical_cols].apply(lambda col: le.fit_transform(col))
  113. def test_types(**kwargs):
  114. np = util.get_module("numpy", required="Logging plots requires numpy")
  115. pd = util.get_module("pandas", required="Logging dataframes requires pandas")
  116. _ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
  117. base = util.get_module(
  118. "sklearn.base",
  119. "roc requires the scikit base submodule, install with `pip install scikit-learn`",
  120. )
  121. test_passed = True
  122. for k, v in kwargs.items():
  123. # check for incorrect types
  124. if (
  125. (k == "X")
  126. or (k == "X_test")
  127. or (k == "y")
  128. or (k == "y_test")
  129. or (k == "y_true")
  130. or (k == "y_probas")
  131. or (k == "x_labels")
  132. or (k == "y_labels")
  133. or (k == "matrix_values")
  134. ) and not isinstance(
  135. v,
  136. (
  137. Sequence,
  138. Iterable,
  139. np.ndarray,
  140. np.generic,
  141. pd.DataFrame,
  142. pd.Series,
  143. list,
  144. ),
  145. ):
  146. # FIXME: do this individually
  147. wandb.termerror(f"{k} is not an array. Please try again.")
  148. test_passed = False
  149. # check for classifier types
  150. if k == "model":
  151. if (not base.is_classifier(v)) and (not base.is_regressor(v)):
  152. wandb.termerror(
  153. f"{k} is not a classifier or regressor. Please try again."
  154. )
  155. test_passed = False
  156. elif k == "clf" or k == "binary_clf":
  157. if not (base.is_classifier(v)):
  158. wandb.termerror(f"{k} is not a classifier. Please try again.")
  159. test_passed = False
  160. elif k == "regressor":
  161. if not base.is_regressor(v):
  162. wandb.termerror(f"{k} is not a regressor. Please try again.")
  163. test_passed = False
  164. elif k == "clusterer" and getattr(v, "_estimator_type", None) != "clusterer":
  165. wandb.termerror(f"{k} is not a clusterer. Please try again.")
  166. test_passed = False
  167. return test_passed