regressor.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. """Define plots for regression models built with scikit-learn."""
  2. from warnings import simplefilter
  3. import numpy as np
  4. import wandb
  5. from wandb.integration.sklearn import calculate, utils
  6. from . import shared
  7. # ignore all future warnings
  8. simplefilter(action="ignore", category=FutureWarning)
  9. def regressor(model, X_train, X_test, y_train, y_test, model_name="Regressor"): # noqa: N803
  10. """Generates all sklearn regressor plots supported by W&B.
  11. The following plots are generated:
  12. learning curve, summary metrics, residuals plot, outlier candidates.
  13. Should only be called with a fitted regressor (otherwise an error is thrown).
  14. Args:
  15. model: (regressor) Takes in a fitted regressor.
  16. X_train: (arr) Training set features.
  17. y_train: (arr) Training set labels.
  18. X_test: (arr) Test set features.
  19. y_test: (arr) Test set labels.
  20. model_name: (str) Model name. Defaults to 'Regressor'
  21. Returns:
  22. None: To see plots, go to your W&B run page then expand the 'media' tab
  23. under 'auto visualizations'.
  24. Example:
  25. ```python
  26. wandb.sklearn.plot_regressor(reg, X_train, X_test, y_train, y_test, "Ridge")
  27. ```
  28. """
  29. wandb.termlog(f"\nPlotting {model_name}.")
  30. shared.summary_metrics(model, X_train, y_train, X_test, y_test)
  31. wandb.termlog("Logged summary metrics.")
  32. shared.learning_curve(model, X_train, y_train)
  33. wandb.termlog("Logged learning curve.")
  34. outlier_candidates(model, X_train, y_train)
  35. wandb.termlog("Logged outlier candidates.")
  36. residuals(model, X_train, y_train)
  37. wandb.termlog("Logged residuals.")
  38. def outlier_candidates(regressor=None, X=None, y=None): # noqa: N803
  39. """Measures a datapoint's influence on regression model via cook's distance.
  40. Instances with high influences could potentially be outliers.
  41. Should only be called with a fitted regressor (otherwise an error is thrown).
  42. Please note this function fits the model on the training set when called.
  43. Args:
  44. model: (regressor) Takes in a fitted regressor.
  45. X: (arr) Training set features.
  46. y: (arr) Training set labels.
  47. Returns:
  48. None: To see plots, go to your W&B run page then expand the 'media' tab
  49. under 'auto visualizations'.
  50. Example:
  51. ```python
  52. wandb.sklearn.plot_outlier_candidates(model, X, y)
  53. ```
  54. """
  55. is_missing = utils.test_missing(regressor=regressor, X=X, y=y)
  56. correct_types = utils.test_types(regressor=regressor, X=X, y=y)
  57. is_fitted = utils.test_fitted(regressor)
  58. if is_missing and correct_types and is_fitted:
  59. y = np.asarray(y)
  60. outliers_chart = calculate.outlier_candidates(regressor, X, y)
  61. wandb.log({"outlier_candidates": outliers_chart})
  62. def residuals(regressor=None, X=None, y=None): # noqa: N803
  63. """Measures and plots the regressor's predicted value against the residual.
  64. The marginal distribution of residuals is also calculated and plotted.
  65. Should only be called with a fitted regressor (otherwise an error is thrown).
  66. Please note this function fits variations of the model on the training set when called.
  67. Args:
  68. regressor: (regressor) Takes in a fitted regressor.
  69. X: (arr) Training set features.
  70. y: (arr) Training set labels.
  71. Returns:
  72. None: To see plots, go to your W&B run page then expand the 'media' tab
  73. under 'auto visualizations'.
  74. Example:
  75. ```python
  76. wandb.sklearn.plot_residuals(model, X, y)
  77. ```
  78. """
  79. not_missing = utils.test_missing(regressor=regressor, X=X, y=y)
  80. correct_types = utils.test_types(regressor=regressor, X=X, y=y)
  81. is_fitted = utils.test_fitted(regressor)
  82. if not_missing and correct_types and is_fitted:
  83. y = np.asarray(y)
  84. residuals_chart = calculate.residuals(regressor, X, y)
  85. wandb.log({"residuals": residuals_chart})