sac.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. import logging
  2. from typing import Any, Dict, Optional, Tuple, Type, Union
  3. from typing_extensions import Self
  4. from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning
  5. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
  6. from ray.rllib.algorithms.dqn.dqn import DQN
  7. from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy
  8. from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
  9. AddObservationsFromEpisodesToBatch,
  10. )
  11. from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
  12. AddNextObservationsFromEpisodesToTrainBatch,
  13. )
  14. from ray.rllib.core.learner import Learner
  15. from ray.rllib.core.rl_module.rl_module import RLModuleSpec
  16. from ray.rllib.policy.policy import Policy
  17. from ray.rllib.utils import deep_update
  18. from ray.rllib.utils.annotations import override
  19. from ray.rllib.utils.framework import try_import_tf, try_import_tfp
  20. from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer
  21. from ray.rllib.utils.typing import LearningRateOrSchedule, RLModuleSpecType
  22. tf1, tf, tfv = try_import_tf()
  23. tfp = try_import_tfp()
  24. logger = logging.getLogger(__name__)
  25. class SACConfig(AlgorithmConfig):
  26. """Defines a configuration class from which an SAC Algorithm can be built.
  27. .. testcode::
  28. config = (
  29. SACConfig()
  30. .environment("Pendulum-v1")
  31. .env_runners(num_env_runners=1)
  32. .training(
  33. gamma=0.9,
  34. actor_lr=0.001,
  35. critic_lr=0.002,
  36. train_batch_size_per_learner=32,
  37. )
  38. )
  39. # Build the SAC algo object from the config and run 1 training iteration.
  40. algo = config.build()
  41. algo.train()
  42. """
  43. def __init__(self, algo_class=None):
  44. self.exploration_config = {
  45. # The Exploration class to use. In the simplest case, this is the name
  46. # (str) of any class present in the `rllib.utils.exploration` package.
  47. # You can also provide the python class directly or the full location
  48. # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
  49. # EpsilonGreedy").
  50. "type": "StochasticSampling",
  51. # Add constructor kwargs here (if any).
  52. }
  53. super().__init__(algo_class=algo_class or SAC)
  54. # fmt: off
  55. # __sphinx_doc_begin__
  56. # SAC-specific config settings.
  57. # `.training()`
  58. self.twin_q = True
  59. self.q_model_config = {
  60. "fcnet_hiddens": [256, 256],
  61. "fcnet_activation": "relu",
  62. "post_fcnet_hiddens": [],
  63. "post_fcnet_activation": None,
  64. "custom_model": None, # Use this to define custom Q-model(s).
  65. "custom_model_config": {},
  66. }
  67. self.policy_model_config = {
  68. "fcnet_hiddens": [256, 256],
  69. "fcnet_activation": "relu",
  70. "post_fcnet_hiddens": [],
  71. "post_fcnet_activation": None,
  72. "custom_model": None, # Use this to define a custom policy model.
  73. "custom_model_config": {},
  74. }
  75. self.clip_actions = False
  76. self.tau = 5e-3
  77. self.initial_alpha = 1.0
  78. self.target_entropy = "auto"
  79. self.n_step = 1
  80. # Replay buffer configuration.
  81. self.replay_buffer_config = {
  82. "type": "PrioritizedEpisodeReplayBuffer",
  83. # Size of the replay buffer. Note that if async_updates is set,
  84. # then each worker will have a replay buffer of this size.
  85. "capacity": int(1e6),
  86. "alpha": 0.6,
  87. # Beta parameter for sampling from prioritized replay buffer.
  88. "beta": 0.4,
  89. }
  90. self.store_buffer_in_checkpoints = False
  91. self.training_intensity = None
  92. self.optimization = {
  93. "actor_learning_rate": 3e-4,
  94. "critic_learning_rate": 3e-4,
  95. "entropy_learning_rate": 3e-4,
  96. }
  97. self.actor_lr = 3e-5
  98. self.critic_lr = 3e-4
  99. self.alpha_lr = 3e-4
  100. # Set `lr` parameter to `None` and ensure it is not used.
  101. self.lr = None
  102. self.grad_clip = None
  103. self.target_network_update_freq = 0
  104. # .env_runners()
  105. # Set to `self.n_step`, if 'auto'.
  106. self.rollout_fragment_length = "auto"
  107. # .training()
  108. self.train_batch_size_per_learner = 256
  109. self.train_batch_size = 256 # @OldAPIstack
  110. self.num_steps_sampled_before_learning_starts = 1500
  111. # .reporting()
  112. self.min_time_s_per_iteration = 1
  113. self.min_sample_timesteps_per_iteration = 100
  114. # __sphinx_doc_end__
  115. # fmt: on
  116. self._deterministic_loss = False
  117. self._use_beta_distribution = False
  118. self.use_state_preprocessor = DEPRECATED_VALUE
  119. self.worker_side_prioritization = DEPRECATED_VALUE
  120. @override(AlgorithmConfig)
  121. def training(
  122. self,
  123. *,
  124. twin_q: Optional[bool] = NotProvided,
  125. q_model_config: Optional[Dict[str, Any]] = NotProvided,
  126. policy_model_config: Optional[Dict[str, Any]] = NotProvided,
  127. tau: Optional[float] = NotProvided,
  128. initial_alpha: Optional[float] = NotProvided,
  129. target_entropy: Optional[Union[str, float]] = NotProvided,
  130. n_step: Optional[Union[int, Tuple[int, int]]] = NotProvided,
  131. store_buffer_in_checkpoints: Optional[bool] = NotProvided,
  132. replay_buffer_config: Optional[Dict[str, Any]] = NotProvided,
  133. training_intensity: Optional[float] = NotProvided,
  134. clip_actions: Optional[bool] = NotProvided,
  135. grad_clip: Optional[float] = NotProvided,
  136. optimization_config: Optional[Dict[str, Any]] = NotProvided,
  137. actor_lr: Optional[LearningRateOrSchedule] = NotProvided,
  138. critic_lr: Optional[LearningRateOrSchedule] = NotProvided,
  139. alpha_lr: Optional[LearningRateOrSchedule] = NotProvided,
  140. target_network_update_freq: Optional[int] = NotProvided,
  141. _deterministic_loss: Optional[bool] = NotProvided,
  142. _use_beta_distribution: Optional[bool] = NotProvided,
  143. num_steps_sampled_before_learning_starts: Optional[int] = NotProvided,
  144. **kwargs,
  145. ) -> Self:
  146. """Sets the training related configuration.
  147. Args:
  148. twin_q: Use two Q-networks (instead of one) for action-value estimation.
  149. Note: Each Q-network will have its own target network.
  150. q_model_config: Model configs for the Q network(s). These will override
  151. MODEL_DEFAULTS. This is treated just as the top-level `model` dict in
  152. setting up the Q-network(s) (2 if twin_q=True).
  153. That means, you can do for different observation spaces:
  154. `obs=Box(1D)` -> `Tuple(Box(1D) + Action)` -> `concat` -> `post_fcnet`
  155. obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action
  156. -> post_fcnet
  157. obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action)
  158. -> vision-net -> concat w/ Box(1D) and action -> post_fcnet
  159. You can also have SAC use your custom_model as Q-model(s), by simply
  160. specifying the `custom_model` sub-key in below dict (just like you would
  161. do in the top-level `model` dict.
  162. policy_model_config: Model options for the policy function (see
  163. `q_model_config` above for details). The difference to `q_model_config`
  164. above is that no action concat'ing is performed before the post_fcnet
  165. stack.
  166. tau: Update the target by \tau * policy + (1-\tau) * target_policy.
  167. initial_alpha: Initial value to use for the entropy weight alpha.
  168. target_entropy: Target entropy lower bound. If "auto", will be set
  169. to `-|A|` (e.g. -2.0 for Discrete(2), -3.0 for Box(shape=(3,))).
  170. This is the inverse of reward scale, and will be optimized
  171. automatically.
  172. n_step: N-step target updates. If >1, sars' tuples in trajectories will be
  173. postprocessed to become sa[discounted sum of R][s t+n] tuples. An
  174. integer will be interpreted as a fixed n-step value. If a tuple of 2
  175. ints is provided here, the n-step value will be drawn for each sample(!)
  176. in the train batch from a uniform distribution over the closed interval
  177. defined by `[n_step[0], n_step[1]]`.
  178. store_buffer_in_checkpoints: Set this to True, if you want the contents of
  179. your buffer(s) to be stored in any saved checkpoints as well.
  180. Warnings will be created if:
  181. - This is True AND restoring from a checkpoint that contains no buffer
  182. data.
  183. - This is False AND restoring from a checkpoint that does contain
  184. buffer data.
  185. replay_buffer_config: Replay buffer config.
  186. Examples:
  187. {
  188. "_enable_replay_buffer_api": True,
  189. "type": "MultiAgentReplayBuffer",
  190. "capacity": 50000,
  191. "replay_batch_size": 32,
  192. "replay_sequence_length": 1,
  193. }
  194. - OR -
  195. {
  196. "_enable_replay_buffer_api": True,
  197. "type": "MultiAgentPrioritizedReplayBuffer",
  198. "capacity": 50000,
  199. "prioritized_replay_alpha": 0.6,
  200. "prioritized_replay_beta": 0.4,
  201. "prioritized_replay_eps": 1e-6,
  202. "replay_sequence_length": 1,
  203. }
  204. - Where -
  205. prioritized_replay_alpha: Alpha parameter controls the degree of
  206. prioritization in the buffer. In other words, when a buffer sample has
  207. a higher temporal-difference error, with how much more probability
  208. should it drawn to use to update the parametrized Q-network. 0.0
  209. corresponds to uniform probability. Setting much above 1.0 may quickly
  210. result as the sampling distribution could become heavily “pointy” with
  211. low entropy.
  212. prioritized_replay_beta: Beta parameter controls the degree of
  213. importance sampling which suppresses the influence of gradient updates
  214. from samples that have higher probability of being sampled via alpha
  215. parameter and the temporal-difference error.
  216. prioritized_replay_eps: Epsilon parameter sets the baseline probability
  217. for sampling so that when the temporal-difference error of a sample is
  218. zero, there is still a chance of drawing the sample.
  219. training_intensity: The intensity with which to update the model (vs
  220. collecting samples from the env).
  221. If None, uses "natural" values of:
  222. `train_batch_size` / (`rollout_fragment_length` x `num_env_runners` x
  223. `num_envs_per_env_runner`).
  224. If not None, will make sure that the ratio between timesteps inserted
  225. into and sampled from th buffer matches the given values.
  226. Example:
  227. training_intensity=1000.0
  228. train_batch_size=250
  229. rollout_fragment_length=1
  230. num_env_runners=1 (or 0)
  231. num_envs_per_env_runner=1
  232. -> natural value = 250 / 1 = 250.0
  233. -> will make sure that replay+train op will be executed 4x asoften as
  234. rollout+insert op (4 * 250 = 1000).
  235. See: rllib/algorithms/dqn/dqn.py::calculate_rr_weights for further
  236. details.
  237. clip_actions: Whether to clip actions. If actions are already normalized,
  238. this should be set to False.
  239. grad_clip: If not None, clip gradients during optimization at this value.
  240. optimization_config: Config dict for optimization. Set the supported keys
  241. `actor_learning_rate`, `critic_learning_rate`, and
  242. `entropy_learning_rate` in here.
  243. actor_lr: The learning rate (float) or learning rate schedule for the
  244. policy in the format of
  245. [[timestep, lr-value], [timestep, lr-value], ...] In case of a
  246. schedule, intermediary timesteps will be assigned to linearly
  247. interpolated learning rate values. A schedule config's first entry
  248. must start with timestep 0, i.e.: [[0, initial_value], [...]].
  249. Note: It is common practice (two-timescale approach) to use a smaller
  250. learning rate for the policy than for the critic to ensure that the
  251. critic gives adequate values for improving the policy.
  252. Note: If you require a) more than one optimizer (per RLModule),
  253. b) optimizer types that are not Adam, c) a learning rate schedule that
  254. is not a linearly interpolated, piecewise schedule as described above,
  255. or d) specifying c'tor arguments of the optimizer that are not the
  256. learning rate (e.g. Adam's epsilon), then you must override your
  257. Learner's `configure_optimizer_for_module()` method and handle
  258. lr-scheduling yourself.
  259. The default value is 3e-5, one decimal less than the respective
  260. learning rate of the critic (see `critic_lr`).
  261. critic_lr: The learning rate (float) or learning rate schedule for the
  262. critic in the format of
  263. [[timestep, lr-value], [timestep, lr-value], ...] In case of a
  264. schedule, intermediary timesteps will be assigned to linearly
  265. interpolated learning rate values. A schedule config's first entry
  266. must start with timestep 0, i.e.: [[0, initial_value], [...]].
  267. Note: It is common practice (two-timescale approach) to use a smaller
  268. learning rate for the policy than for the critic to ensure that the
  269. critic gives adequate values for improving the policy.
  270. Note: If you require a) more than one optimizer (per RLModule),
  271. b) optimizer types that are not Adam, c) a learning rate schedule that
  272. is not a linearly interpolated, piecewise schedule as described above,
  273. or d) specifying c'tor arguments of the optimizer that are not the
  274. learning rate (e.g. Adam's epsilon), then you must override your
  275. Learner's `configure_optimizer_for_module()` method and handle
  276. lr-scheduling yourself.
  277. The default value is 3e-4, one decimal higher than the respective
  278. learning rate of the actor (policy) (see `actor_lr`).
  279. alpha_lr: The learning rate (float) or learning rate schedule for the
  280. hyperparameter alpha in the format of
  281. [[timestep, lr-value], [timestep, lr-value], ...] In case of a
  282. schedule, intermediary timesteps will be assigned to linearly
  283. interpolated learning rate values. A schedule config's first entry
  284. must start with timestep 0, i.e.: [[0, initial_value], [...]].
  285. Note: If you require a) more than one optimizer (per RLModule),
  286. b) optimizer types that are not Adam, c) a learning rate schedule that
  287. is not a linearly interpolated, piecewise schedule as described above,
  288. or d) specifying c'tor arguments of the optimizer that are not the
  289. learning rate (e.g. Adam's epsilon), then you must override your
  290. Learner's `configure_optimizer_for_module()` method and handle
  291. lr-scheduling yourself.
  292. The default value is 3e-4, identical to the critic learning rate (`lr`).
  293. target_network_update_freq: Update the target network every
  294. `target_network_update_freq` steps.
  295. num_steps_sampled_before_learning_starts: Number of timesteps (int)
  296. that we collect from the runners before we start sampling the
  297. replay buffers for learning. Whether we count this in agent steps
  298. or environment steps depends on the value of
  299. `config.multi_agent(count_steps_by=...)`.
  300. _deterministic_loss: Whether the loss should be calculated deterministically
  301. (w/o the stochastic action sampling step). True only useful for
  302. continuous actions and for debugging.
  303. _use_beta_distribution: Use a Beta-distribution instead of a
  304. `SquashedGaussian` for bounded, continuous action spaces (not
  305. recommended; for debugging only).
  306. Returns:
  307. This updated AlgorithmConfig object.
  308. """
  309. # Pass kwargs onto super's `training()` method.
  310. super().training(**kwargs)
  311. if twin_q is not NotProvided:
  312. self.twin_q = twin_q
  313. if q_model_config is not NotProvided:
  314. self.q_model_config.update(q_model_config)
  315. if policy_model_config is not NotProvided:
  316. self.policy_model_config.update(policy_model_config)
  317. if tau is not NotProvided:
  318. self.tau = tau
  319. if initial_alpha is not NotProvided:
  320. self.initial_alpha = initial_alpha
  321. if target_entropy is not NotProvided:
  322. self.target_entropy = target_entropy
  323. if n_step is not NotProvided:
  324. self.n_step = n_step
  325. if store_buffer_in_checkpoints is not NotProvided:
  326. self.store_buffer_in_checkpoints = store_buffer_in_checkpoints
  327. if replay_buffer_config is not NotProvided:
  328. # Override entire `replay_buffer_config` if `type` key changes.
  329. # Update, if `type` key remains the same or is not specified.
  330. new_replay_buffer_config = deep_update(
  331. {"replay_buffer_config": self.replay_buffer_config},
  332. {"replay_buffer_config": replay_buffer_config},
  333. False,
  334. ["replay_buffer_config"],
  335. ["replay_buffer_config"],
  336. )
  337. self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"]
  338. if training_intensity is not NotProvided:
  339. self.training_intensity = training_intensity
  340. if clip_actions is not NotProvided:
  341. self.clip_actions = clip_actions
  342. if grad_clip is not NotProvided:
  343. self.grad_clip = grad_clip
  344. if optimization_config is not NotProvided:
  345. self.optimization = optimization_config
  346. if actor_lr is not NotProvided:
  347. self.actor_lr = actor_lr
  348. if critic_lr is not NotProvided:
  349. self.critic_lr = critic_lr
  350. if alpha_lr is not NotProvided:
  351. self.alpha_lr = alpha_lr
  352. if target_network_update_freq is not NotProvided:
  353. self.target_network_update_freq = target_network_update_freq
  354. if _deterministic_loss is not NotProvided:
  355. self._deterministic_loss = _deterministic_loss
  356. if _use_beta_distribution is not NotProvided:
  357. self._use_beta_distribution = _use_beta_distribution
  358. if num_steps_sampled_before_learning_starts is not NotProvided:
  359. self.num_steps_sampled_before_learning_starts = (
  360. num_steps_sampled_before_learning_starts
  361. )
  362. return self
  363. @override(AlgorithmConfig)
  364. def validate(self) -> None:
  365. # Call super's validation method.
  366. super().validate()
  367. # Check rollout_fragment_length to be compatible with n_step.
  368. if isinstance(self.n_step, tuple):
  369. min_rollout_fragment_length = self.n_step[1]
  370. else:
  371. min_rollout_fragment_length = self.n_step
  372. if (
  373. not self.in_evaluation
  374. and self.rollout_fragment_length != "auto"
  375. and self.rollout_fragment_length
  376. < min_rollout_fragment_length # (self.n_step or 1)
  377. ):
  378. raise ValueError(
  379. f"Your `rollout_fragment_length` ({self.rollout_fragment_length}) is "
  380. f"smaller than needed for `n_step` ({self.n_step})! If `n_step` is "
  381. f"an integer try setting `rollout_fragment_length={self.n_step}`. If "
  382. "`n_step` is a tuple, try setting "
  383. f"`rollout_fragment_length={self.n_step[1]}`."
  384. )
  385. if self.use_state_preprocessor != DEPRECATED_VALUE:
  386. deprecation_warning(
  387. old="config['use_state_preprocessor']",
  388. error=False,
  389. )
  390. self.use_state_preprocessor = DEPRECATED_VALUE
  391. if self.grad_clip is not None and self.grad_clip <= 0.0:
  392. raise ValueError("`grad_clip` value must be > 0.0!")
  393. if self.framework in ["tf", "tf2"] and tfp is None:
  394. logger.warning(
  395. "You need `tensorflow_probability` in order to run SAC! "
  396. "Install it via `pip install tensorflow_probability`. Your "
  397. f"tf.__version__={tf.__version__ if tf else None}."
  398. "Trying to import tfp results in the following error:"
  399. )
  400. try_import_tfp(error=True)
  401. # Validate that we use the corresponding `EpisodeReplayBuffer` when using
  402. # episodes.
  403. if (
  404. self.enable_env_runner_and_connector_v2
  405. and self.replay_buffer_config["type"]
  406. not in [
  407. "EpisodeReplayBuffer",
  408. "PrioritizedEpisodeReplayBuffer",
  409. "MultiAgentEpisodeReplayBuffer",
  410. "MultiAgentPrioritizedEpisodeReplayBuffer",
  411. ]
  412. and not (
  413. # TODO (simon): Set up an indicator `is_offline_new_stack` that
  414. # includes all these variable checks.
  415. self.input_
  416. and (
  417. isinstance(self.input_, str)
  418. or (
  419. isinstance(self.input_, list)
  420. and isinstance(self.input_[0], str)
  421. )
  422. )
  423. and self.input_ != "sampler"
  424. and self.enable_rl_module_and_learner
  425. )
  426. ):
  427. raise ValueError(
  428. "When using the new `EnvRunner API` the replay buffer must be of type "
  429. "`EpisodeReplayBuffer`."
  430. )
  431. elif not self.enable_env_runner_and_connector_v2 and (
  432. (
  433. isinstance(self.replay_buffer_config["type"], str)
  434. and "Episode" in self.replay_buffer_config["type"]
  435. )
  436. or (
  437. isinstance(self.replay_buffer_config["type"], type)
  438. and issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer)
  439. )
  440. ):
  441. raise ValueError(
  442. "When using the old API stack the replay buffer must not be of type "
  443. "`EpisodeReplayBuffer`! We suggest you use the following config to run "
  444. "SAC on the old API stack: `config.training(replay_buffer_config={"
  445. "'type': 'MultiAgentPrioritizedReplayBuffer', "
  446. "'prioritized_replay_alpha': [alpha], "
  447. "'prioritized_replay_beta': [beta], "
  448. "'prioritized_replay_eps': [eps], "
  449. "})`."
  450. )
  451. if self.enable_rl_module_and_learner:
  452. if self.lr is not None:
  453. raise ValueError(
  454. "Basic learning rate parameter `lr` is not `None`. For SAC "
  455. "use the specific learning rate parameters `actor_lr`, `critic_lr` "
  456. "and `alpha_lr`, for the actor, critic, and the hyperparameter "
  457. "`alpha`, respectively and set `config.lr` to None."
  458. )
  459. # Warn about new API stack on by default.
  460. logger.warning(
  461. "You are running SAC on the new API stack! This is the new default "
  462. "behavior for this algorithm. If you don't want to use the new API "
  463. "stack, set `config.api_stack(enable_rl_module_and_learner=False, "
  464. "enable_env_runner_and_connector_v2=False)`. For a detailed "
  465. "migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa
  466. )
  467. @override(AlgorithmConfig)
  468. def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
  469. if self.rollout_fragment_length == "auto":
  470. return (
  471. self.n_step[1]
  472. if isinstance(self.n_step, (tuple, list))
  473. else self.n_step
  474. )
  475. else:
  476. return self.rollout_fragment_length
  477. @override(AlgorithmConfig)
  478. def get_default_rl_module_spec(self) -> RLModuleSpecType:
  479. if self.framework_str == "torch":
  480. from ray.rllib.algorithms.sac.torch.default_sac_torch_rl_module import (
  481. DefaultSACTorchRLModule,
  482. )
  483. return RLModuleSpec(module_class=DefaultSACTorchRLModule)
  484. else:
  485. raise ValueError(
  486. f"The framework {self.framework_str} is not supported. Use `torch`."
  487. )
  488. @override(AlgorithmConfig)
  489. def get_default_learner_class(self) -> Union[Type["Learner"], str]:
  490. if self.framework_str == "torch":
  491. from ray.rllib.algorithms.sac.torch.sac_torch_learner import SACTorchLearner
  492. return SACTorchLearner
  493. else:
  494. raise ValueError(
  495. f"The framework {self.framework_str} is not supported. Use `torch`."
  496. )
  497. @override(AlgorithmConfig)
  498. def build_learner_connector(
  499. self,
  500. input_observation_space,
  501. input_action_space,
  502. device=None,
  503. ):
  504. pipeline = super().build_learner_connector(
  505. input_observation_space=input_observation_space,
  506. input_action_space=input_action_space,
  507. device=device,
  508. )
  509. # Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
  510. # after the corresponding "add-OBS-..." default piece).
  511. pipeline.insert_after(
  512. AddObservationsFromEpisodesToBatch,
  513. AddNextObservationsFromEpisodesToTrainBatch(),
  514. )
  515. return pipeline
  516. @property
  517. def _model_config_auto_includes(self):
  518. return super()._model_config_auto_includes | {"twin_q": self.twin_q}
  519. class SAC(DQN):
  520. """Soft Actor Critic (SAC) Algorithm class.
  521. This file defines the distributed Algorithm class for the soft actor critic
  522. algorithm.
  523. See `sac_[tf|torch]_policy.py` for the definition of the policy loss.
  524. Detailed documentation:
  525. https://docs.ray.io/en/master/rllib-algorithms.html#sac
  526. """
  527. def __init__(self, *args, **kwargs):
  528. self._allow_unknown_subkeys += ["policy_model_config", "q_model_config"]
  529. super().__init__(*args, **kwargs)
  530. @classmethod
  531. @override(DQN)
  532. def get_default_config(cls) -> SACConfig:
  533. return SACConfig()
  534. @classmethod
  535. @override(DQN)
  536. def get_default_policy_class(
  537. cls, config: AlgorithmConfig
  538. ) -> Optional[Type[Policy]]:
  539. if config["framework"] == "torch":
  540. from ray.rllib.algorithms.sac.sac_torch_policy import SACTorchPolicy
  541. return SACTorchPolicy
  542. else:
  543. return SACTFPolicy