sb3.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. """W&B callback for sb3.
  2. Really simple callback to get logging for each tree
  3. Example usage:
  4. ```python
  5. import gym
  6. from stable_baselines3 import PPO
  7. from stable_baselines3.common.monitor import Monitor
  8. from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
  9. import wandb
  10. from wandb.integration.sb3 import WandbCallback
  11. config = {
  12. "policy_type": "MlpPolicy",
  13. "total_timesteps": 25000,
  14. "env_name": "CartPole-v1",
  15. }
  16. run = wandb.init(
  17. project="sb3",
  18. config=config,
  19. sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
  20. monitor_gym=True, # auto-upload the videos of agents playing the game
  21. save_code=True, # optional
  22. )
  23. def make_env():
  24. env = gym.make(config["env_name"])
  25. env = Monitor(env) # record stats such as returns
  26. return env
  27. env = DummyVecEnv([make_env])
  28. env = VecVideoRecorder(
  29. env, "videos", record_video_trigger=lambda x: x % 2000 == 0, video_length=200
  30. )
  31. model = PPO(config["policy_type"], env, verbose=1, tensorboard_log=f"runs")
  32. model.learn(
  33. total_timesteps=config["total_timesteps"],
  34. callback=WandbCallback(
  35. model_save_path=f"models/{run.id}",
  36. gradient_save_freq=100,
  37. log="all",
  38. ),
  39. )
  40. ```
  41. """
  42. from __future__ import annotations
  43. import logging
  44. import os
  45. from typing import Literal
  46. from stable_baselines3.common.callbacks import BaseCallback # type: ignore
  47. import wandb
  48. from wandb.sdk.lib import telemetry as wb_telemetry
  49. logger = logging.getLogger(__name__)
  50. class WandbCallback(BaseCallback):
  51. """Callback for logging experiments to Weights and Biases.
  52. Log SB3 experiments to Weights and Biases
  53. - Added model tracking and uploading
  54. - Added complete hyperparameters recording
  55. - Added gradient logging
  56. - Note that `wandb.init(...)` must be called before the WandbCallback can be used.
  57. Args:
  58. verbose: The verbosity of sb3 output
  59. model_save_path: Path to the folder where the model will be saved, The default value is `None` so the model is not logged
  60. model_save_freq: Frequency to save the model
  61. gradient_save_freq: Frequency to log gradient. The default value is 0 so the gradients are not logged
  62. log: What to log. One of "gradients", "parameters", or "all".
  63. """
  64. def __init__(
  65. self,
  66. verbose: int = 0,
  67. model_save_path: str | None = None,
  68. model_save_freq: int = 0,
  69. gradient_save_freq: int = 0,
  70. log: Literal["gradients", "parameters", "all"] | None = "all",
  71. ) -> None:
  72. super().__init__(verbose)
  73. if wandb.run is None:
  74. raise wandb.Error("You must call wandb.init() before WandbCallback()")
  75. with wb_telemetry.context() as tel:
  76. tel.feature.sb3 = True
  77. self.model_save_freq = model_save_freq
  78. self.model_save_path = model_save_path
  79. self.gradient_save_freq = gradient_save_freq
  80. if log not in ["gradients", "parameters", "all", None]:
  81. wandb.termwarn(
  82. "`log` must be one of `None`, 'gradients', 'parameters', or 'all', "
  83. "falling back to 'all'"
  84. )
  85. log = "all"
  86. self.log = log
  87. # Create folder if needed
  88. if self.model_save_path is not None:
  89. os.makedirs(self.model_save_path, exist_ok=True)
  90. self.path = os.path.join(self.model_save_path, "model.zip")
  91. else:
  92. assert self.model_save_freq == 0, (
  93. "to use the `model_save_freq` you have to set the `model_save_path` parameter"
  94. )
  95. def _init_callback(self) -> None:
  96. d = {}
  97. if "algo" not in d:
  98. d["algo"] = type(self.model).__name__
  99. for key in self.model.__dict__:
  100. if key in wandb.config:
  101. continue
  102. if type(self.model.__dict__[key]) in [float, int, str]:
  103. d[key] = self.model.__dict__[key]
  104. else:
  105. d[key] = str(self.model.__dict__[key])
  106. if self.gradient_save_freq > 0:
  107. wandb.watch(
  108. self.model.policy,
  109. log_freq=self.gradient_save_freq,
  110. log=self.log,
  111. )
  112. wandb.config.setdefaults(d)
  113. def _on_step(self) -> bool:
  114. if (
  115. self.model_save_freq > 0
  116. and self.model_save_path is not None
  117. and self.n_calls % self.model_save_freq == 0
  118. ):
  119. self.save_model()
  120. return True
  121. def _on_training_end(self) -> None:
  122. if self.model_save_path is not None:
  123. self.save_model()
  124. def save_model(self) -> None:
  125. self.model.save(self.path)
  126. wandb.save(self.path, base_path=self.model_save_path)
  127. if self.verbose > 1:
  128. logger.info(f"Saving model checkpoint to {self.path}")