catboost.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. """catboost init."""
  2. from __future__ import annotations
  3. from pathlib import Path
  4. from types import SimpleNamespace
  5. from catboost import CatBoostClassifier, CatBoostRegressor # type: ignore
  6. import wandb
  7. from wandb.sdk.lib import telemetry as wb_telemetry
  8. class WandbCallback:
  9. """`WandbCallback` automatically integrates CatBoost with wandb.
  10. Args:
  11. - metric_period: (int) if you are passing `metric_period` to your CatBoost model please pass the same value here (default=1).
  12. Passing `WandbCallback` to CatBoost will:
  13. - log training and validation metrics at every `metric_period`
  14. - log iteration at every `metric_period`
  15. Example:
  16. ```
  17. train_pool = Pool(
  18. train[features], label=train["label"], cat_features=cat_features
  19. )
  20. test_pool = Pool(test[features], label=test["label"], cat_features=cat_features)
  21. model = CatBoostRegressor(
  22. iterations=100,
  23. loss_function="Cox",
  24. eval_metric="Cox",
  25. )
  26. model.fit(
  27. train_pool,
  28. eval_set=test_pool,
  29. callbacks=[WandbCallback()],
  30. )
  31. ```
  32. """
  33. def __init__(self, metric_period: int = 1):
  34. if wandb.run is None:
  35. raise wandb.Error("You must call `wandb.init()` before `WandbCallback()`")
  36. with wb_telemetry.context() as tel:
  37. tel.feature.catboost_wandb_callback = True
  38. self.metric_period: int = metric_period
  39. def after_iteration(self, info: SimpleNamespace) -> bool:
  40. if info.iteration % self.metric_period == 0:
  41. for data, metric in info.metrics.items():
  42. for metric_name, log in metric.items():
  43. # todo: replace with wandb.run._log once available
  44. wandb.log({f"{data}-{metric_name}": log[-1]}, commit=False)
  45. # todo: replace with wandb.run._log once available
  46. wandb.log({f"iteration@metric-period-{self.metric_period}": info.iteration})
  47. return True
  48. def _checkpoint_artifact(
  49. model: CatBoostClassifier | CatBoostRegressor, aliases: list[str]
  50. ) -> None:
  51. """Upload model checkpoint as W&B artifact."""
  52. if wandb.run is None:
  53. raise wandb.Error(
  54. "You must call `wandb.init()` before `_checkpoint_artifact()`"
  55. )
  56. model_name = f"model_{wandb.run.id}"
  57. # save the model in the default `cbm` format
  58. model_path = Path(wandb.run.dir) / "model"
  59. model.save_model(model_path)
  60. model_artifact = wandb.Artifact(name=model_name, type="model")
  61. model_artifact.add_file(str(model_path))
  62. wandb.log_artifact(model_artifact, aliases=aliases)
  63. def _log_feature_importance(
  64. model: CatBoostClassifier | CatBoostRegressor,
  65. ) -> None:
  66. """Log feature importance with default settings."""
  67. if wandb.run is None:
  68. raise wandb.Error(
  69. "You must call `wandb.init()` before `_checkpoint_artifact()`"
  70. )
  71. feat_df = model.get_feature_importance(prettified=True)
  72. fi_data = [
  73. [feat, feat_imp]
  74. for feat, feat_imp in zip(feat_df["Feature Id"], feat_df["Importances"])
  75. ]
  76. table = wandb.Table(data=fi_data, columns=["Feature", "Importance"])
  77. # todo: replace with wandb.run._log once available
  78. wandb.log(
  79. {
  80. "Feature Importance": wandb.plot.bar(
  81. table, "Feature", "Importance", title="Feature Importance"
  82. )
  83. },
  84. commit=False,
  85. )
  86. def log_summary(
  87. model: CatBoostClassifier | CatBoostRegressor,
  88. log_all_params: bool = True,
  89. save_model_checkpoint: bool = False,
  90. log_feature_importance: bool = True,
  91. ) -> None:
  92. """`log_summary` logs useful metrics about catboost model after training is done.
  93. Args:
  94. model: it can be CatBoostClassifier or CatBoostRegressor.
  95. log_all_params: (boolean) if True (default) log the model hyperparameters as W&B config.
  96. save_model_checkpoint: (boolean) if True saves the model upload as W&B artifacts.
  97. log_feature_importance: (boolean) if True (default) logs feature importance as W&B bar chart using the default setting of `get_feature_importance`.
  98. Using this along with `wandb_callback` will:
  99. - save the hyperparameters as W&B config,
  100. - log `best_iteration` and `best_score` as `wandb.summary`,
  101. - save and upload your trained model to Weights & Biases Artifacts (when `save_model_checkpoint = True`)
  102. - log feature importance plot.
  103. Example:
  104. ```python
  105. train_pool = Pool(
  106. train[features], label=train["label"], cat_features=cat_features
  107. )
  108. test_pool = Pool(test[features], label=test["label"], cat_features=cat_features)
  109. model = CatBoostRegressor(
  110. iterations=100,
  111. loss_function="Cox",
  112. eval_metric="Cox",
  113. )
  114. model.fit(
  115. train_pool,
  116. eval_set=test_pool,
  117. callbacks=[WandbCallback()],
  118. )
  119. log_summary(model)
  120. ```
  121. """
  122. if wandb.run is None:
  123. raise wandb.Error("You must call `wandb.init()` before `log_summary()`")
  124. if not (isinstance(model, (CatBoostClassifier, CatBoostRegressor))):
  125. raise wandb.Error(
  126. "Model should be an instance of CatBoostClassifier or CatBoostRegressor"
  127. )
  128. with wb_telemetry.context() as tel:
  129. tel.feature.catboost_log_summary = True
  130. # log configs
  131. params = model.get_all_params()
  132. if log_all_params:
  133. wandb.config.update(params)
  134. # log best score and iteration
  135. wandb.run.summary["best_iteration"] = model.get_best_iteration()
  136. wandb.run.summary["best_score"] = model.get_best_score()
  137. # log model
  138. if save_model_checkpoint:
  139. aliases = ["best"] if params["use_best_model"] else ["last"]
  140. _checkpoint_artifact(model, aliases=aliases)
  141. # Feature importance
  142. if log_feature_importance:
  143. _log_feature_importance(model)