__init__.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. """W&B callback for lightgbm.
  2. Really simple callback to get logging for each tree
  3. Example usage:
  4. param_list = [("eta", 0.08), ("max_depth", 6), ("subsample", 0.8), ("colsample_bytree", 0.8), ("alpha", 8), ("num_class", 10)]
  5. config.update(dict(param_list))
  6. lgb = lgb.train(param_list, d_train, callbacks=[wandb_callback()])
  7. """
  8. from pathlib import Path
  9. from typing import TYPE_CHECKING, Callable
  10. import lightgbm # type: ignore
  11. from lightgbm import Booster
  12. import wandb
  13. from wandb.sdk.lib import telemetry as wb_telemetry
  14. MINIMIZE_METRICS = [
  15. "l1",
  16. "l2",
  17. "rmse",
  18. "mape",
  19. "huber",
  20. "fair",
  21. "poisson",
  22. "gamma",
  23. "binary_logloss",
  24. ]
  25. MAXIMIZE_METRICS = ["map", "auc", "average_precision"]
  26. if TYPE_CHECKING:
  27. from typing import Any, NamedTuple, Union
  28. # Note: upstream lightgbm has this defined incorrectly
  29. _EvalResultTuple = Union[
  30. tuple[str, str, float, bool], tuple[str, str, float, bool, float]
  31. ]
  32. class CallbackEnv(NamedTuple):
  33. model: Any
  34. params: dict
  35. iteration: int
  36. begin_interation: int
  37. end_iteration: int
  38. evaluation_result_list: list[_EvalResultTuple]
  39. def _define_metric(data: str, metric_name: str) -> None:
  40. """Capture model performance at the best step.
  41. instead of the last step, of training in your `wandb.summary`
  42. """
  43. if "loss" in str.lower(metric_name):
  44. wandb.define_metric(f"{data}_{metric_name}", summary="min")
  45. elif str.lower(metric_name) in MINIMIZE_METRICS:
  46. wandb.define_metric(f"{data}_{metric_name}", summary="min")
  47. elif str.lower(metric_name) in MAXIMIZE_METRICS:
  48. wandb.define_metric(f"{data}_{metric_name}", summary="max")
  49. def _checkpoint_artifact(
  50. model: "Booster", iteration: int, aliases: "list[str]"
  51. ) -> None:
  52. """Upload model checkpoint as W&B artifact."""
  53. # NOTE: type ignore required because wandb.run is improperly inferred as None type
  54. model_name = f"model_{wandb.run.id}" # type: ignore
  55. model_path = Path(wandb.run.dir) / f"model_ckpt_{iteration}.txt" # type: ignore
  56. model.save_model(model_path, num_iteration=iteration)
  57. model_artifact = wandb.Artifact(name=model_name, type="model")
  58. model_artifact.add_file(str(model_path))
  59. wandb.log_artifact(model_artifact, aliases=aliases)
  60. def _log_feature_importance(model: "Booster") -> None:
  61. """Log feature importance."""
  62. feat_imps = model.feature_importance()
  63. feats = model.feature_name()
  64. fi_data = [[feat, feat_imp] for feat, feat_imp in zip(feats, feat_imps)]
  65. table = wandb.Table(data=fi_data, columns=["Feature", "Importance"])
  66. wandb.log(
  67. {
  68. "Feature Importance": wandb.plot.bar(
  69. table, "Feature", "Importance", title="Feature Importance"
  70. )
  71. },
  72. commit=False,
  73. )
  74. class _WandbCallback:
  75. """Internal class to handle `wandb_callback` logic.
  76. This callback is adapted form the LightGBM's `_RecordEvaluationCallback`.
  77. """
  78. def __init__(self, log_params: bool = True, define_metric: bool = True) -> None:
  79. self.order = 20
  80. self.before_iteration = False
  81. self.log_params = log_params
  82. self.define_metric_bool = define_metric
  83. def _init(self, env: "CallbackEnv") -> None:
  84. with wb_telemetry.context() as tel:
  85. tel.feature.lightgbm_wandb_callback = True
  86. # log the params as W&B config.
  87. if self.log_params:
  88. wandb.config.update(env.params)
  89. # use `define_metric` to set the wandb summary to the best metric value.
  90. for item in env.evaluation_result_list:
  91. if self.define_metric_bool:
  92. if len(item) == 4:
  93. data_name, eval_name = item[:2]
  94. _define_metric(data_name, eval_name)
  95. else:
  96. data_name, eval_name = item[1].split()
  97. _define_metric(data_name, f"{eval_name}-mean")
  98. _define_metric(data_name, f"{eval_name}-stdv")
  99. def __call__(self, env: "CallbackEnv") -> None:
  100. if env.iteration == env.begin_iteration: # type: ignore
  101. self._init(env)
  102. for item in env.evaluation_result_list:
  103. if len(item) == 4:
  104. data_name, eval_name, result = item[:3]
  105. wandb.log(
  106. {data_name + "_" + eval_name: result},
  107. commit=False,
  108. )
  109. else:
  110. data_name, eval_name = item[1].split()
  111. res_mean = item[2]
  112. res_stdv = item[4]
  113. wandb.log(
  114. {
  115. data_name + "_" + eval_name + "-mean": res_mean,
  116. data_name + "_" + eval_name + "-stdv": res_stdv,
  117. },
  118. commit=False,
  119. )
  120. # call `commit=True` to log the data as a single W&B step.
  121. wandb.log({"iteration": env.iteration}, commit=True)
  122. def wandb_callback(log_params: bool = True, define_metric: bool = True) -> Callable:
  123. """Automatically integrates LightGBM with wandb.
  124. Args:
  125. log_params: (boolean) if True (default) logs params passed to lightgbm.train as W&B config
  126. define_metric: (boolean) if True (default) capture model performance at the best step, instead of the last step, of training in your `wandb.summary`
  127. Passing `wandb_callback` to LightGBM will:
  128. - log params passed to lightgbm.train as W&B config (default).
  129. - log evaluation metrics collected by LightGBM, such as rmse, accuracy etc to Weights & Biases
  130. - Capture the best metric in `wandb.summary` when `define_metric=True` (default).
  131. Use `log_summary` as an extension of this callback.
  132. Example:
  133. ```python
  134. params = {
  135. "boosting_type": "gbdt",
  136. "objective": "regression",
  137. }
  138. gbm = lgb.train(
  139. params,
  140. lgb_train,
  141. num_boost_round=10,
  142. valid_sets=lgb_eval,
  143. valid_names=("validation"),
  144. callbacks=[wandb_callback()],
  145. )
  146. ```
  147. """
  148. return _WandbCallback(log_params, define_metric)
  149. def log_summary(
  150. model: Booster, feature_importance: bool = True, save_model_checkpoint: bool = False
  151. ) -> None:
  152. """Log useful metrics about lightgbm model after training is done.
  153. Args:
  154. model: (Booster) is an instance of lightgbm.basic.Booster.
  155. feature_importance: (boolean) if True (default), logs the feature importance plot.
  156. save_model_checkpoint: (boolean) if True saves the best model and upload as W&B artifacts.
  157. Using this along with `wandb_callback` will:
  158. - log `best_iteration` and `best_score` as `wandb.summary`.
  159. - log feature importance plot.
  160. - save and upload your best trained model to Weights & Biases Artifacts (when `save_model_checkpoint = True`)
  161. Example:
  162. ```python
  163. params = {
  164. "boosting_type": "gbdt",
  165. "objective": "regression",
  166. }
  167. gbm = lgb.train(
  168. params,
  169. lgb_train,
  170. num_boost_round=10,
  171. valid_sets=lgb_eval,
  172. valid_names=("validation"),
  173. callbacks=[wandb_callback()],
  174. )
  175. log_summary(gbm)
  176. ```
  177. """
  178. if wandb.run is None:
  179. raise wandb.Error("You must call wandb.init() before WandbCallback()")
  180. if not isinstance(model, Booster):
  181. raise wandb.Error("Model should be an instance of lightgbm.basic.Booster")
  182. wandb.run.summary["best_iteration"] = model.best_iteration
  183. wandb.run.summary["best_score"] = model.best_score
  184. # Log feature importance
  185. if feature_importance:
  186. _log_feature_importance(model)
  187. if save_model_checkpoint:
  188. _checkpoint_artifact(model, model.best_iteration, aliases=["best"])
  189. with wb_telemetry.context() as tel:
  190. tel.feature.lightgbm_log_summary = True