"""W&B callback for lightgbm. Really simple callback to get logging for each tree Example usage: param_list = [("eta", 0.08), ("max_depth", 6), ("subsample", 0.8), ("colsample_bytree", 0.8), ("alpha", 8), ("num_class", 10)] config.update(dict(param_list)) lgb = lgb.train(param_list, d_train, callbacks=[wandb_callback()]) """ from pathlib import Path from typing import TYPE_CHECKING, Callable import lightgbm # type: ignore from lightgbm import Booster import wandb from wandb.sdk.lib import telemetry as wb_telemetry MINIMIZE_METRICS = [ "l1", "l2", "rmse", "mape", "huber", "fair", "poisson", "gamma", "binary_logloss", ] MAXIMIZE_METRICS = ["map", "auc", "average_precision"] if TYPE_CHECKING: from typing import Any, NamedTuple, Union # Note: upstream lightgbm has this defined incorrectly _EvalResultTuple = Union[ tuple[str, str, float, bool], tuple[str, str, float, bool, float] ] class CallbackEnv(NamedTuple): model: Any params: dict iteration: int begin_interation: int end_iteration: int evaluation_result_list: list[_EvalResultTuple] def _define_metric(data: str, metric_name: str) -> None: """Capture model performance at the best step. instead of the last step, of training in your `wandb.summary` """ if "loss" in str.lower(metric_name): wandb.define_metric(f"{data}_{metric_name}", summary="min") elif str.lower(metric_name) in MINIMIZE_METRICS: wandb.define_metric(f"{data}_{metric_name}", summary="min") elif str.lower(metric_name) in MAXIMIZE_METRICS: wandb.define_metric(f"{data}_{metric_name}", summary="max") def _checkpoint_artifact( model: "Booster", iteration: int, aliases: "list[str]" ) -> None: """Upload model checkpoint as W&B artifact.""" # NOTE: type ignore required because wandb.run is improperly inferred as None type model_name = f"model_{wandb.run.id}" # type: ignore model_path = Path(wandb.run.dir) / f"model_ckpt_{iteration}.txt" # type: ignore model.save_model(model_path, num_iteration=iteration) model_artifact = wandb.Artifact(name=model_name, type="model") model_artifact.add_file(str(model_path)) wandb.log_artifact(model_artifact, aliases=aliases) def _log_feature_importance(model: "Booster") -> None: """Log feature importance.""" feat_imps = model.feature_importance() feats = model.feature_name() fi_data = [[feat, feat_imp] for feat, feat_imp in zip(feats, feat_imps)] table = wandb.Table(data=fi_data, columns=["Feature", "Importance"]) wandb.log( { "Feature Importance": wandb.plot.bar( table, "Feature", "Importance", title="Feature Importance" ) }, commit=False, ) class _WandbCallback: """Internal class to handle `wandb_callback` logic. This callback is adapted form the LightGBM's `_RecordEvaluationCallback`. """ def __init__(self, log_params: bool = True, define_metric: bool = True) -> None: self.order = 20 self.before_iteration = False self.log_params = log_params self.define_metric_bool = define_metric def _init(self, env: "CallbackEnv") -> None: with wb_telemetry.context() as tel: tel.feature.lightgbm_wandb_callback = True # log the params as W&B config. if self.log_params: wandb.config.update(env.params) # use `define_metric` to set the wandb summary to the best metric value. for item in env.evaluation_result_list: if self.define_metric_bool: if len(item) == 4: data_name, eval_name = item[:2] _define_metric(data_name, eval_name) else: data_name, eval_name = item[1].split() _define_metric(data_name, f"{eval_name}-mean") _define_metric(data_name, f"{eval_name}-stdv") def __call__(self, env: "CallbackEnv") -> None: if env.iteration == env.begin_iteration: # type: ignore self._init(env) for item in env.evaluation_result_list: if len(item) == 4: data_name, eval_name, result = item[:3] wandb.log( {data_name + "_" + eval_name: result}, commit=False, ) else: data_name, eval_name = item[1].split() res_mean = item[2] res_stdv = item[4] wandb.log( { data_name + "_" + eval_name + "-mean": res_mean, data_name + "_" + eval_name + "-stdv": res_stdv, }, commit=False, ) # call `commit=True` to log the data as a single W&B step. wandb.log({"iteration": env.iteration}, commit=True) def wandb_callback(log_params: bool = True, define_metric: bool = True) -> Callable: """Automatically integrates LightGBM with wandb. Args: log_params: (boolean) if True (default) logs params passed to lightgbm.train as W&B config define_metric: (boolean) if True (default) capture model performance at the best step, instead of the last step, of training in your `wandb.summary` Passing `wandb_callback` to LightGBM will: - log params passed to lightgbm.train as W&B config (default). - log evaluation metrics collected by LightGBM, such as rmse, accuracy etc to Weights & Biases - Capture the best metric in `wandb.summary` when `define_metric=True` (default). Use `log_summary` as an extension of this callback. Example: ```python params = { "boosting_type": "gbdt", "objective": "regression", } gbm = lgb.train( params, lgb_train, num_boost_round=10, valid_sets=lgb_eval, valid_names=("validation"), callbacks=[wandb_callback()], ) ``` """ return _WandbCallback(log_params, define_metric) def log_summary( model: Booster, feature_importance: bool = True, save_model_checkpoint: bool = False ) -> None: """Log useful metrics about lightgbm model after training is done. Args: model: (Booster) is an instance of lightgbm.basic.Booster. feature_importance: (boolean) if True (default), logs the feature importance plot. save_model_checkpoint: (boolean) if True saves the best model and upload as W&B artifacts. Using this along with `wandb_callback` will: - log `best_iteration` and `best_score` as `wandb.summary`. - log feature importance plot. - save and upload your best trained model to Weights & Biases Artifacts (when `save_model_checkpoint = True`) Example: ```python params = { "boosting_type": "gbdt", "objective": "regression", } gbm = lgb.train( params, lgb_train, num_boost_round=10, valid_sets=lgb_eval, valid_names=("validation"), callbacks=[wandb_callback()], ) log_summary(gbm) ``` """ if wandb.run is None: raise wandb.Error("You must call wandb.init() before WandbCallback()") if not isinstance(model, Booster): raise wandb.Error("Model should be an instance of lightgbm.basic.Booster") wandb.run.summary["best_iteration"] = model.best_iteration wandb.run.summary["best_score"] = model.best_score # Log feature importance if feature_importance: _log_feature_importance(model) if save_model_checkpoint: _checkpoint_artifact(model, model.best_iteration, aliases=["best"]) with wb_telemetry.context() as tel: tel.feature.lightgbm_log_summary = True