shared.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. """Define plots used by multiple sklearn model classes."""
  2. from warnings import simplefilter
  3. import numpy as np
  4. import wandb
  5. from wandb.integration.sklearn import calculate, utils
  6. # ignore all future warnings
  7. simplefilter(action="ignore", category=FutureWarning)
  8. def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None): # noqa: N803
  9. """Logs a chart depicting summary metrics for a model.
  10. Should only be called with a fitted model (otherwise an error is thrown).
  11. Args:
  12. model: (clf or reg) Takes in a fitted regressor or classifier.
  13. X: (arr) Training set features.
  14. y: (arr) Training set labels.
  15. X_test: (arr) Test set features.
  16. y_test: (arr) Test set labels.
  17. Returns:
  18. None: To see plots, go to your W&B run page then expand the 'media' tab
  19. under 'auto visualizations'.
  20. Example:
  21. ```python
  22. wandb.sklearn.plot_summary_metrics(model, X_train, y_train, X_test, y_test)
  23. ```
  24. """
  25. not_missing = utils.test_missing(
  26. model=model, X=X, y=y, X_test=X_test, y_test=y_test
  27. )
  28. correct_types = utils.test_types(
  29. model=model, X=X, y=y, X_test=X_test, y_test=y_test
  30. )
  31. model_fitted = utils.test_fitted(model)
  32. if not_missing and correct_types and model_fitted:
  33. metrics_chart = calculate.summary_metrics(model, X, y, X_test, y_test)
  34. wandb.log({"summary_metrics": metrics_chart})
  35. def learning_curve(
  36. model=None,
  37. X=None, # noqa: N803
  38. y=None,
  39. cv=None,
  40. shuffle=False,
  41. random_state=None,
  42. train_sizes=None,
  43. n_jobs=1,
  44. scoring=None,
  45. ):
  46. """Logs a plot depicting model performance against dataset size.
  47. Please note this function fits the model to datasets of varying sizes when called.
  48. Args:
  49. model: (clf or reg) Takes in a fitted regressor or classifier.
  50. X: (arr) Dataset features.
  51. y: (arr) Dataset labels.
  52. For details on the other keyword arguments, see the documentation for
  53. `sklearn.model_selection.learning_curve`.
  54. Returns:
  55. None: To see plots, go to your W&B run page then expand the 'media' tab
  56. under 'auto visualizations'.
  57. Example:
  58. ```python
  59. wandb.sklearn.plot_learning_curve(model, X, y)
  60. ```
  61. """
  62. not_missing = utils.test_missing(model=model, X=X, y=y)
  63. correct_types = utils.test_types(model=model, X=X, y=y)
  64. if not_missing and correct_types:
  65. if train_sizes is None:
  66. train_sizes = np.linspace(0.1, 1.0, 5)
  67. y = np.asarray(y)
  68. learning_curve_chart = calculate.learning_curve(
  69. model, X, y, cv, shuffle, random_state, train_sizes, n_jobs, scoring
  70. )
  71. wandb.log({"learning_curve": learning_curve_chart})