| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- """catboost init."""
- from __future__ import annotations
- from pathlib import Path
- from types import SimpleNamespace
- from catboost import CatBoostClassifier, CatBoostRegressor # type: ignore
- import wandb
- from wandb.sdk.lib import telemetry as wb_telemetry
- class WandbCallback:
- """`WandbCallback` automatically integrates CatBoost with wandb.
- Args:
- - metric_period: (int) if you are passing `metric_period` to your CatBoost model please pass the same value here (default=1).
- Passing `WandbCallback` to CatBoost will:
- - log training and validation metrics at every `metric_period`
- - log iteration at every `metric_period`
- Example:
- ```
- train_pool = Pool(
- train[features], label=train["label"], cat_features=cat_features
- )
- test_pool = Pool(test[features], label=test["label"], cat_features=cat_features)
- model = CatBoostRegressor(
- iterations=100,
- loss_function="Cox",
- eval_metric="Cox",
- )
- model.fit(
- train_pool,
- eval_set=test_pool,
- callbacks=[WandbCallback()],
- )
- ```
- """
- def __init__(self, metric_period: int = 1):
- if wandb.run is None:
- raise wandb.Error("You must call `wandb.init()` before `WandbCallback()`")
- with wb_telemetry.context() as tel:
- tel.feature.catboost_wandb_callback = True
- self.metric_period: int = metric_period
- def after_iteration(self, info: SimpleNamespace) -> bool:
- if info.iteration % self.metric_period == 0:
- for data, metric in info.metrics.items():
- for metric_name, log in metric.items():
- # todo: replace with wandb.run._log once available
- wandb.log({f"{data}-{metric_name}": log[-1]}, commit=False)
- # todo: replace with wandb.run._log once available
- wandb.log({f"iteration@metric-period-{self.metric_period}": info.iteration})
- return True
- def _checkpoint_artifact(
- model: CatBoostClassifier | CatBoostRegressor, aliases: list[str]
- ) -> None:
- """Upload model checkpoint as W&B artifact."""
- if wandb.run is None:
- raise wandb.Error(
- "You must call `wandb.init()` before `_checkpoint_artifact()`"
- )
- model_name = f"model_{wandb.run.id}"
- # save the model in the default `cbm` format
- model_path = Path(wandb.run.dir) / "model"
- model.save_model(model_path)
- model_artifact = wandb.Artifact(name=model_name, type="model")
- model_artifact.add_file(str(model_path))
- wandb.log_artifact(model_artifact, aliases=aliases)
- def _log_feature_importance(
- model: CatBoostClassifier | CatBoostRegressor,
- ) -> None:
- """Log feature importance with default settings."""
- if wandb.run is None:
- raise wandb.Error(
- "You must call `wandb.init()` before `_checkpoint_artifact()`"
- )
- feat_df = model.get_feature_importance(prettified=True)
- fi_data = [
- [feat, feat_imp]
- for feat, feat_imp in zip(feat_df["Feature Id"], feat_df["Importances"])
- ]
- table = wandb.Table(data=fi_data, columns=["Feature", "Importance"])
- # todo: replace with wandb.run._log once available
- wandb.log(
- {
- "Feature Importance": wandb.plot.bar(
- table, "Feature", "Importance", title="Feature Importance"
- )
- },
- commit=False,
- )
- def log_summary(
- model: CatBoostClassifier | CatBoostRegressor,
- log_all_params: bool = True,
- save_model_checkpoint: bool = False,
- log_feature_importance: bool = True,
- ) -> None:
- """`log_summary` logs useful metrics about catboost model after training is done.
- Args:
- model: it can be CatBoostClassifier or CatBoostRegressor.
- log_all_params: (boolean) if True (default) log the model hyperparameters as W&B config.
- save_model_checkpoint: (boolean) if True saves the model upload as W&B artifacts.
- log_feature_importance: (boolean) if True (default) logs feature importance as W&B bar chart using the default setting of `get_feature_importance`.
- Using this along with `wandb_callback` will:
- - save the hyperparameters as W&B config,
- - log `best_iteration` and `best_score` as `wandb.summary`,
- - save and upload your trained model to Weights & Biases Artifacts (when `save_model_checkpoint = True`)
- - log feature importance plot.
- Example:
- ```python
- train_pool = Pool(
- train[features], label=train["label"], cat_features=cat_features
- )
- test_pool = Pool(test[features], label=test["label"], cat_features=cat_features)
- model = CatBoostRegressor(
- iterations=100,
- loss_function="Cox",
- eval_metric="Cox",
- )
- model.fit(
- train_pool,
- eval_set=test_pool,
- callbacks=[WandbCallback()],
- )
- log_summary(model)
- ```
- """
- if wandb.run is None:
- raise wandb.Error("You must call `wandb.init()` before `log_summary()`")
- if not (isinstance(model, (CatBoostClassifier, CatBoostRegressor))):
- raise wandb.Error(
- "Model should be an instance of CatBoostClassifier or CatBoostRegressor"
- )
- with wb_telemetry.context() as tel:
- tel.feature.catboost_log_summary = True
- # log configs
- params = model.get_all_params()
- if log_all_params:
- wandb.config.update(params)
- # log best score and iteration
- wandb.run.summary["best_iteration"] = model.get_best_iteration()
- wandb.run.summary["best_score"] = model.get_best_score()
- # log model
- if save_model_checkpoint:
- aliases = ["best"] if params["use_best_model"] else ["last"]
- _checkpoint_artifact(model, aliases=aliases)
- # Feature importance
- if log_feature_importance:
- _log_feature_importance(model)
|