cql.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. import logging
  2. from typing import Optional, Type, Union
  3. from typing_extensions import Self
  4. from ray._common.deprecation import (
  5. DEPRECATED_VALUE,
  6. deprecation_warning,
  7. )
  8. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
  9. from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
  10. from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
  11. from ray.rllib.algorithms.sac.sac import (
  12. SAC,
  13. SACConfig,
  14. )
  15. from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
  16. AddObservationsFromEpisodesToBatch,
  17. )
  18. from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
  19. AddNextObservationsFromEpisodesToTrainBatch,
  20. )
  21. from ray.rllib.core.learner.learner import Learner
  22. from ray.rllib.core.rl_module.rl_module import RLModuleSpec
  23. from ray.rllib.execution.rollout_ops import (
  24. synchronous_parallel_sample,
  25. )
  26. from ray.rllib.execution.train_ops import (
  27. multi_gpu_train_one_step,
  28. train_one_step,
  29. )
  30. from ray.rllib.policy.policy import Policy
  31. from ray.rllib.utils.annotations import OldAPIStack, override
  32. from ray.rllib.utils.framework import try_import_tf, try_import_tfp
  33. from ray.rllib.utils.metrics import (
  34. LAST_TARGET_UPDATE_TS,
  35. LEARNER_RESULTS,
  36. LEARNER_UPDATE_TIMER,
  37. NUM_AGENT_STEPS_SAMPLED,
  38. NUM_AGENT_STEPS_TRAINED,
  39. NUM_ENV_STEPS_SAMPLED,
  40. NUM_ENV_STEPS_TRAINED,
  41. NUM_TARGET_UPDATES,
  42. OFFLINE_SAMPLING_TIMER,
  43. SAMPLE_TIMER,
  44. SYNCH_WORKER_WEIGHTS_TIMER,
  45. TARGET_NET_UPDATE_TIMER,
  46. TIMERS,
  47. )
  48. from ray.rllib.utils.typing import ResultDict, RLModuleSpecType
  49. tf1, tf, tfv = try_import_tf()
  50. tfp = try_import_tfp()
  51. logger = logging.getLogger(__name__)
  52. class CQLConfig(SACConfig):
  53. """Defines a configuration class from which a CQL can be built.
  54. .. testcode::
  55. :skipif: True
  56. from ray.rllib.algorithms.cql import CQLConfig
  57. config = CQLConfig().training(gamma=0.9, lr=0.01)
  58. config = config.resources(num_gpus=0)
  59. config = config.env_runners(num_env_runners=4)
  60. print(config.to_dict())
  61. # Build a Algorithm object from the config and run 1 training iteration.
  62. algo = config.build(env="CartPole-v1")
  63. algo.train()
  64. """
  65. def __init__(self, algo_class=None):
  66. super().__init__(algo_class=algo_class or CQL)
  67. # fmt: off
  68. # __sphinx_doc_begin__
  69. # CQL-specific config settings:
  70. self.bc_iters = 20000
  71. self.temperature = 1.0
  72. self.num_actions = 10
  73. self.lagrangian = False
  74. self.lagrangian_thresh = 5.0
  75. self.min_q_weight = 5.0
  76. self.deterministic_backup = True
  77. self.lr = 3e-4
  78. # Note, the new stack defines learning rates for each component.
  79. # The base learning rate `lr` has to be set to `None`, if using
  80. # the new stack.
  81. self.actor_lr = 1e-4
  82. self.critic_lr = 1e-3
  83. self.alpha_lr = 1e-3
  84. self.replay_buffer_config = {
  85. "_enable_replay_buffer_api": True,
  86. "type": "MultiAgentPrioritizedReplayBuffer",
  87. "capacity": int(1e6),
  88. # If True prioritized replay buffer will be used.
  89. "prioritized_replay": False,
  90. "prioritized_replay_alpha": 0.6,
  91. "prioritized_replay_beta": 0.4,
  92. "prioritized_replay_eps": 1e-6,
  93. # Whether to compute priorities already on the remote worker side.
  94. "worker_side_prioritization": False,
  95. }
  96. # Changes to Algorithm's/SACConfig's default:
  97. # .reporting()
  98. self.min_sample_timesteps_per_iteration = 0
  99. self.min_train_timesteps_per_iteration = 100
  100. # fmt: on
  101. # __sphinx_doc_end__
  102. self.timesteps_per_iteration = DEPRECATED_VALUE
  103. @override(SACConfig)
  104. def training(
  105. self,
  106. *,
  107. bc_iters: Optional[int] = NotProvided,
  108. temperature: Optional[float] = NotProvided,
  109. num_actions: Optional[int] = NotProvided,
  110. lagrangian: Optional[bool] = NotProvided,
  111. lagrangian_thresh: Optional[float] = NotProvided,
  112. min_q_weight: Optional[float] = NotProvided,
  113. deterministic_backup: Optional[bool] = NotProvided,
  114. **kwargs,
  115. ) -> Self:
  116. """Sets the training-related configuration.
  117. Args:
  118. bc_iters: Number of iterations with Behavior Cloning pretraining.
  119. temperature: CQL loss temperature.
  120. num_actions: Number of actions to sample for CQL loss
  121. lagrangian: Whether to use the Lagrangian for Alpha Prime (in CQL loss).
  122. lagrangian_thresh: Lagrangian threshold.
  123. min_q_weight: in Q weight multiplier.
  124. deterministic_backup: If the target in the Bellman update should have an
  125. entropy backup. Defaults to `True`.
  126. Returns:
  127. This updated AlgorithmConfig object.
  128. """
  129. # Pass kwargs onto super's `training()` method.
  130. super().training(**kwargs)
  131. if bc_iters is not NotProvided:
  132. self.bc_iters = bc_iters
  133. if temperature is not NotProvided:
  134. self.temperature = temperature
  135. if num_actions is not NotProvided:
  136. self.num_actions = num_actions
  137. if lagrangian is not NotProvided:
  138. self.lagrangian = lagrangian
  139. if lagrangian_thresh is not NotProvided:
  140. self.lagrangian_thresh = lagrangian_thresh
  141. if min_q_weight is not NotProvided:
  142. self.min_q_weight = min_q_weight
  143. if deterministic_backup is not NotProvided:
  144. self.deterministic_backup = deterministic_backup
  145. return self
  146. @override(AlgorithmConfig)
  147. def offline_data(self, **kwargs) -> Self:
  148. super().offline_data(**kwargs)
  149. # Check, if the passed in class incorporates the `OfflinePreLearner`
  150. # interface.
  151. if "prelearner_class" in kwargs:
  152. from ray.rllib.offline.offline_data import OfflinePreLearner
  153. if not issubclass(kwargs.get("prelearner_class"), OfflinePreLearner):
  154. raise ValueError(
  155. f"`prelearner_class` {kwargs.get('prelearner_class')} is not a "
  156. "subclass of `OfflinePreLearner`. Any class passed to "
  157. "`prelearner_class` needs to implement the interface given by "
  158. "`OfflinePreLearner`."
  159. )
  160. return self
  161. @override(SACConfig)
  162. def get_default_learner_class(self) -> Union[Type["Learner"], str]:
  163. if self.framework_str == "torch":
  164. from ray.rllib.algorithms.cql.torch.cql_torch_learner import CQLTorchLearner
  165. return CQLTorchLearner
  166. else:
  167. raise ValueError(
  168. f"The framework {self.framework_str} is not supported. "
  169. "Use `'torch'` instead."
  170. )
  171. @override(AlgorithmConfig)
  172. def build_learner_connector(
  173. self,
  174. input_observation_space,
  175. input_action_space,
  176. device=None,
  177. ):
  178. pipeline = super().build_learner_connector(
  179. input_observation_space=input_observation_space,
  180. input_action_space=input_action_space,
  181. device=device,
  182. )
  183. # Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
  184. # after the corresponding "add-OBS-..." default piece).
  185. pipeline.insert_after(
  186. AddObservationsFromEpisodesToBatch,
  187. AddNextObservationsFromEpisodesToTrainBatch(),
  188. )
  189. return pipeline
  190. @override(SACConfig)
  191. def validate(self) -> None:
  192. # First check, whether old `timesteps_per_iteration` is used.
  193. if self.timesteps_per_iteration != DEPRECATED_VALUE:
  194. deprecation_warning(
  195. old="timesteps_per_iteration",
  196. new="min_train_timesteps_per_iteration",
  197. error=True,
  198. )
  199. # Call super's validation method.
  200. super().validate()
  201. # CQL-torch performs the optimizer steps inside the loss function.
  202. # Using the multi-GPU optimizer will therefore not work (see multi-GPU
  203. # check above) and we must use the simple optimizer for now.
  204. if self.simple_optimizer is not True and self.framework_str == "torch":
  205. self.simple_optimizer = True
  206. if self.framework_str in ["tf", "tf2"] and tfp is None:
  207. logger.warning(
  208. "You need `tensorflow_probability` in order to run CQL! "
  209. "Install it via `pip install tensorflow_probability`. Your "
  210. f"tf.__version__={tf.__version__ if tf else None}."
  211. "Trying to import tfp results in the following error:"
  212. )
  213. try_import_tfp(error=True)
  214. # Assert that for a local learner the number of iterations is 1. Note,
  215. # this is needed because we have no iterators, but instead a single
  216. # batch returned directly from the `OfflineData.sample` method.
  217. if (
  218. self.num_learners == 0
  219. and not self.dataset_num_iters_per_learner
  220. and self.enable_rl_module_and_learner
  221. ):
  222. self._value_error(
  223. "When using a single local learner the number of iterations "
  224. "per learner, `dataset_num_iters_per_learner` has to be defined. "
  225. "Set this hyperparameter in the `AlgorithmConfig.offline_data`."
  226. )
  227. @override(SACConfig)
  228. def get_default_rl_module_spec(self) -> RLModuleSpecType:
  229. if self.framework_str == "torch":
  230. from ray.rllib.algorithms.cql.torch.default_cql_torch_rl_module import (
  231. DefaultCQLTorchRLModule,
  232. )
  233. return RLModuleSpec(module_class=DefaultCQLTorchRLModule)
  234. else:
  235. raise ValueError(
  236. f"The framework {self.framework_str} is not supported. Use `torch`."
  237. )
  238. @property
  239. def _model_config_auto_includes(self):
  240. return super()._model_config_auto_includes | {
  241. "num_actions": self.num_actions,
  242. }
  243. class CQL(SAC):
  244. """CQL (derived from SAC)."""
  245. @classmethod
  246. @override(SAC)
  247. def get_default_config(cls) -> CQLConfig:
  248. return CQLConfig()
  249. @classmethod
  250. @override(SAC)
  251. def get_default_policy_class(
  252. cls, config: AlgorithmConfig
  253. ) -> Optional[Type[Policy]]:
  254. if config["framework"] == "torch":
  255. return CQLTorchPolicy
  256. else:
  257. return CQLTFPolicy
  258. @override(SAC)
  259. def training_step(self) -> None:
  260. # Old API stack (Policy, RolloutWorker, Connector).
  261. if not self.config.enable_env_runner_and_connector_v2:
  262. return self._training_step_old_api_stack()
  263. # Sampling from offline data.
  264. with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
  265. # If we should use an iterator in the learner(s). Note, in case of
  266. # multiple learners we must always return a list of iterators.
  267. return_iterator = return_iterator = (
  268. self.config.num_learners > 0
  269. or self.config.dataset_num_iters_per_learner != 1
  270. )
  271. # Return an iterator in case we are using remote learners.
  272. batch_or_iterator = self.offline_data.sample(
  273. num_samples=self.config.train_batch_size_per_learner,
  274. num_shards=self.config.num_learners,
  275. # Return an iterator, if a `Learner` should update
  276. # multiple times per RLlib iteration.
  277. return_iterator=return_iterator,
  278. )
  279. # Updating the policy.
  280. with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
  281. learner_results = self.learner_group.update(
  282. data_iterators=batch_or_iterator,
  283. minibatch_size=self.config.train_batch_size_per_learner,
  284. num_iters=self.config.dataset_num_iters_per_learner,
  285. )
  286. # Log training results.
  287. self.metrics.aggregate(learner_results, key=LEARNER_RESULTS)
  288. @OldAPIStack
  289. def _training_step_old_api_stack(self) -> ResultDict:
  290. # Collect SampleBatches from sample workers.
  291. with self._timers[SAMPLE_TIMER]:
  292. train_batch = synchronous_parallel_sample(worker_set=self.env_runner_group)
  293. train_batch = train_batch.as_multi_agent()
  294. self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
  295. self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
  296. # Postprocess batch before we learn on it.
  297. post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
  298. train_batch = post_fn(train_batch, self.env_runner_group, self.config)
  299. # Learn on training batch.
  300. # Use simple optimizer (only for multi-agent or tf-eager; all other
  301. # cases should use the multi-GPU optimizer, even if only using 1 GPU)
  302. if self.config.get("simple_optimizer") is True:
  303. train_results = train_one_step(self, train_batch)
  304. else:
  305. train_results = multi_gpu_train_one_step(self, train_batch)
  306. # Update target network every `target_network_update_freq` training steps.
  307. cur_ts = self._counters[
  308. NUM_AGENT_STEPS_TRAINED
  309. if self.config.count_steps_by == "agent_steps"
  310. else NUM_ENV_STEPS_TRAINED
  311. ]
  312. last_update = self._counters[LAST_TARGET_UPDATE_TS]
  313. if cur_ts - last_update >= self.config.target_network_update_freq:
  314. with self._timers[TARGET_NET_UPDATE_TIMER]:
  315. to_update = self.env_runner.get_policies_to_train()
  316. self.env_runner.foreach_policy_to_train(
  317. lambda p, pid: pid in to_update and p.update_target()
  318. )
  319. self._counters[NUM_TARGET_UPDATES] += 1
  320. self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
  321. # Update remote workers's weights after learning on local worker
  322. # (only those policies that were actually trained).
  323. if self.env_runner_group.num_remote_workers() > 0:
  324. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
  325. self.env_runner_group.sync_weights(policies=list(train_results.keys()))
  326. # Return all collected metrics for the iteration.
  327. return train_results