summary_metrics.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from warnings import simplefilter
  2. import numpy as np
  3. import sklearn
  4. import wandb
  5. from wandb.integration.sklearn import utils
  6. # ignore all future warnings
  7. simplefilter(action="ignore", category=FutureWarning)
  8. def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None): # noqa: N803
  9. """Calculate summary metrics for both regressors and classifiers.
  10. Called by plot_summary_metrics to visualize metrics. Please use the function
  11. plot_summary_metrics() if you wish to visualize your summary metrics.
  12. """
  13. y, y_test = np.asarray(y), np.asarray(y_test)
  14. metrics = {}
  15. model_name = model.__class__.__name__
  16. y_pred = model.predict(X_test)
  17. if sklearn.base.is_classifier(model):
  18. accuracy_score = sklearn.metrics.accuracy_score(y_test, y_pred)
  19. metrics["accuracy_score"] = accuracy_score
  20. precision = sklearn.metrics.precision_score(y_test, y_pred, average="weighted")
  21. metrics["precision"] = precision
  22. recall = sklearn.metrics.recall_score(y_test, y_pred, average="weighted")
  23. metrics["recall"] = recall
  24. f1_score = sklearn.metrics.f1_score(y_test, y_pred, average="weighted")
  25. metrics["f1_score"] = f1_score
  26. elif sklearn.base.is_regressor(model):
  27. mae = sklearn.metrics.mean_absolute_error(y_test, y_pred)
  28. metrics["mae"] = mae
  29. mse = sklearn.metrics.mean_squared_error(y_test, y_pred)
  30. metrics["mse"] = mse
  31. r2_score = sklearn.metrics.r2_score(y_test, y_pred)
  32. metrics["r2_score"] = r2_score
  33. metrics = {name: utils.round_2(metric) for name, metric in metrics.items()}
  34. table = make_table(metrics, model_name)
  35. chart = wandb.visualize("wandb/metrics/v1", table)
  36. return chart
  37. def make_table(metrics, model_name):
  38. columns = ["metric_name", "metric_value", "model_name"]
  39. table_content = [[name, value, model_name] for name, value in metrics.items()]
  40. table = wandb.Table(columns=columns, data=table_content)
  41. return table