xgboost.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. """xgboost init!"""
  2. from __future__ import annotations
  3. import json
  4. import warnings
  5. from pathlib import Path
  6. from typing import TYPE_CHECKING, Callable, NamedTuple, Union, cast
  7. import xgboost as xgb
  8. import xgboost.callback
  9. from typing_extensions import TypeAlias, override
  10. import wandb
  11. from wandb.sdk.lib import telemetry as wb_telemetry
  12. MINIMIZE_METRICS = [
  13. "rmse",
  14. "rmsle",
  15. "mae",
  16. "mape",
  17. "mphe",
  18. "logloss",
  19. "error",
  20. "error@t",
  21. "merror",
  22. ]
  23. MAXIMIZE_METRICS = ["auc", "aucpr", "ndcg", "map", "ndcg@n", "map@n"]
  24. if TYPE_CHECKING:
  25. class CallbackEnv(NamedTuple):
  26. evaluation_result_list: list
  27. # Copied from xgboost's source code. These types are not exported.
  28. _ScoreList = Union[list[float], list[tuple[float, float]]]
  29. _EvalsLog: TypeAlias = dict[str, dict[str, _ScoreList]]
  30. def wandb_callback() -> Callable:
  31. """Old style callback that will be deprecated in favor of WandbCallback. Please try the new logger for more features."""
  32. warnings.warn(
  33. "wandb_callback will be deprecated in favor of WandbCallback. Please use WandbCallback for more features.",
  34. UserWarning,
  35. stacklevel=2,
  36. )
  37. with wb_telemetry.context() as tel:
  38. tel.feature.xgboost_old_wandb_callback = True
  39. def callback(env: CallbackEnv) -> None:
  40. for k, v in env.evaluation_result_list:
  41. wandb.log({k: v}, commit=False)
  42. wandb.log({})
  43. return callback
  44. class WandbCallback(xgboost.callback.TrainingCallback):
  45. """`WandbCallback` automatically integrates XGBoost with wandb.
  46. Args:
  47. log_model: (boolean) if True save and upload the model to Weights & Biases Artifacts
  48. log_feature_importance: (boolean) if True log a feature importance bar plot
  49. importance_type: (str) one of {weight, gain, cover, total_gain, total_cover} for tree model. weight for linear model.
  50. define_metric: (boolean) if True (default) capture model performance at the best step, instead of the last step, of training in your `wandb.summary`.
  51. Passing `WandbCallback` to XGBoost will:
  52. - log the booster model configuration to Weights & Biases
  53. - log evaluation metrics collected by XGBoost, such as rmse, accuracy etc. to Weights & Biases
  54. - log training metric collected by XGBoost (if you provide training data to eval_set)
  55. - log the best score and the best iteration
  56. - save and upload your trained model to Weights & Biases Artifacts (when `log_model = True`)
  57. - log feature importance plot when `log_feature_importance=True` (default).
  58. - Capture the best eval metric in `wandb.summary` when `define_metric=True` (default).
  59. Example:
  60. ```python
  61. bst_params = dict(
  62. objective="reg:squarederror",
  63. colsample_bytree=0.3,
  64. learning_rate=0.1,
  65. max_depth=5,
  66. alpha=10,
  67. n_estimators=10,
  68. tree_method="hist",
  69. callbacks=[WandbCallback()],
  70. )
  71. xg_reg = xgb.XGBRegressor(**bst_params)
  72. xg_reg.fit(
  73. X_train,
  74. y_train,
  75. eval_set=[(X_test, y_test)],
  76. )
  77. ```
  78. """
  79. def __init__(
  80. self,
  81. log_model: bool = False,
  82. log_feature_importance: bool = True,
  83. importance_type: str = "gain",
  84. define_metric: bool = True,
  85. ):
  86. super().__init__()
  87. self.log_model: bool = log_model
  88. self.log_feature_importance: bool = log_feature_importance
  89. self.importance_type: str = importance_type
  90. self.define_metric: bool = define_metric
  91. if wandb.run is None:
  92. raise wandb.Error("You must call wandb.init() before WandbCallback()")
  93. with wb_telemetry.context() as tel:
  94. tel.feature.xgboost_wandb_callback = True
  95. @override
  96. def before_training(self, model: xgb.Booster) -> xgb.Booster:
  97. """Run before training is finished."""
  98. # Update W&B config
  99. config = model.save_config()
  100. wandb.config.update(json.loads(config))
  101. return model
  102. @override
  103. def after_training(self, model: xgb.Booster) -> xgb.Booster:
  104. """Run after training is finished."""
  105. # Log the booster model as artifacts
  106. if self.log_model:
  107. self._log_model_as_artifact(model)
  108. # Plot feature importance
  109. if self.log_feature_importance:
  110. self._log_feature_importance(model)
  111. # Log the best score and best iteration
  112. if model.attr("best_score") is not None:
  113. wandb.log(
  114. {
  115. "best_score": float(cast(str, model.attr("best_score"))),
  116. "best_iteration": int(cast(str, model.attr("best_iteration"))),
  117. }
  118. )
  119. return model
  120. @override
  121. def after_iteration(
  122. self,
  123. model: xgb.Booster,
  124. epoch: int,
  125. evals_log: _EvalsLog,
  126. ) -> bool:
  127. """Run after each iteration. Return True when training should stop."""
  128. # Log metrics
  129. for data, metric in evals_log.items():
  130. for metric_name, log in metric.items():
  131. if self.define_metric:
  132. self._define_metric(data, metric_name)
  133. wandb.log({f"{data}-{metric_name}": log[-1]}, commit=False)
  134. else:
  135. wandb.log({f"{data}-{metric_name}": log[-1]}, commit=False)
  136. wandb.log({"epoch": epoch})
  137. self.define_metric = False
  138. return False
  139. def _log_model_as_artifact(self, model: xgb.Booster) -> None:
  140. model_name = f"{wandb.run.id}_model.json" # type: ignore
  141. model_path = Path(wandb.run.dir) / model_name # type: ignore
  142. model.save_model(str(model_path))
  143. model_artifact = wandb.Artifact(name=model_name, type="model")
  144. model_artifact.add_file(str(model_path))
  145. wandb.log_artifact(model_artifact)
  146. def _log_feature_importance(self, model: xgb.Booster) -> None:
  147. fi = model.get_score(importance_type=self.importance_type)
  148. fi_data = [[k, fi[k]] for k in fi]
  149. table = wandb.Table(data=fi_data, columns=["Feature", "Importance"])
  150. wandb.log(
  151. {
  152. "Feature Importance": wandb.plot.bar(
  153. table, "Feature", "Importance", title="Feature Importance"
  154. )
  155. }
  156. )
  157. def _define_metric(self, data: str, metric_name: str) -> None:
  158. if "loss" in str.lower(metric_name):
  159. wandb.define_metric(f"{data}-{metric_name}", summary="min")
  160. elif str.lower(metric_name) in MINIMIZE_METRICS:
  161. wandb.define_metric(f"{data}-{metric_name}", summary="min")
  162. elif str.lower(metric_name) in MAXIMIZE_METRICS:
  163. wandb.define_metric(f"{data}-{metric_name}", summary="max")
  164. else:
  165. pass