| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- """
- TQC (Truncated Quantile Critics) Algorithm.
- Paper: https://arxiv.org/abs/2005.04269
- "Controlling Overestimation Bias with Truncated Mixture of Continuous
- Distributional Quantile Critics"
- TQC extends SAC by using distributional RL with quantile regression to
- control overestimation bias in the Q-function.
- """
- import logging
- from typing import Optional, Type, Union
- from ray.rllib.algorithms.algorithm import Algorithm
- from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
- from ray.rllib.algorithms.sac.sac import SAC, SACConfig
- from ray.rllib.core.learner import Learner
- from ray.rllib.core.rl_module.rl_module import RLModuleSpec
- from ray.rllib.utils.annotations import override
- from ray.rllib.utils.typing import RLModuleSpecType
- logger = logging.getLogger(__name__)
- class TQCConfig(SACConfig):
- """Configuration for the TQC algorithm.
- TQC extends SAC with distributional critics using quantile regression.
- Example:
- >>> from ray.rllib.algorithms.tqc import TQCConfig
- >>> config = (
- ... TQCConfig()
- ... .environment("Pendulum-v1")
- ... .training(
- ... n_quantiles=25,
- ... n_critics=2,
- ... top_quantiles_to_drop_per_net=2,
- ... )
- ... )
- >>> algo = config.build()
- """
- def __init__(self, algo_class=None):
- """Initializes a TQCConfig instance."""
- super().__init__(algo_class=algo_class or TQC)
- # TQC-specific parameters
- self.n_quantiles = 25
- self.n_critics = 2
- self.top_quantiles_to_drop_per_net = 2
- @override(SACConfig)
- def training(
- self,
- *,
- n_quantiles: Optional[int] = NotProvided,
- n_critics: Optional[int] = NotProvided,
- top_quantiles_to_drop_per_net: Optional[int] = NotProvided,
- **kwargs,
- ):
- """Sets the training-related configuration.
- Args:
- n_quantiles: Number of quantiles for each critic network.
- Default is 25.
- n_critics: Number of critic networks. Default is 2.
- top_quantiles_to_drop_per_net: Number of quantiles to drop per
- network when computing the target Q-value. This controls
- the overestimation bias. Default is 2.
- **kwargs: Additional arguments passed to SACConfig.training().
- Returns:
- This updated TQCConfig object.
- """
- super().training(**kwargs)
- if n_quantiles is not NotProvided:
- self.n_quantiles = n_quantiles
- if n_critics is not NotProvided:
- self.n_critics = n_critics
- if top_quantiles_to_drop_per_net is not NotProvided:
- self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net
- return self
- @override(AlgorithmConfig)
- def validate(self) -> None:
- """Validates the TQC configuration."""
- super().validate()
- # Validate TQC-specific parameters
- if self.n_quantiles < 1:
- raise ValueError(f"`n_quantiles` must be >= 1, got {self.n_quantiles}")
- if self.n_critics < 1:
- raise ValueError(f"`n_critics` must be >= 1, got {self.n_critics}")
- # Ensure top_quantiles_to_drop_per_net is non-negative
- if self.top_quantiles_to_drop_per_net < 0:
- raise ValueError(
- f"`top_quantiles_to_drop_per_net` must be >= 0, got "
- f"{self.top_quantiles_to_drop_per_net}"
- )
- # Ensure we don't drop more quantiles than we have
- total_quantiles = self.n_quantiles * self.n_critics
- quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.n_critics
- if quantiles_to_drop >= total_quantiles:
- raise ValueError(
- f"Cannot drop {quantiles_to_drop} quantiles when only "
- f"{total_quantiles} total quantiles are available. "
- f"Reduce `top_quantiles_to_drop_per_net` or increase "
- f"`n_quantiles` or `n_critics`."
- )
- @override(AlgorithmConfig)
- def get_default_rl_module_spec(self) -> RLModuleSpecType:
- if self.framework_str == "torch":
- from ray.rllib.algorithms.tqc.torch.default_tqc_torch_rl_module import (
- DefaultTQCTorchRLModule,
- )
- return RLModuleSpec(module_class=DefaultTQCTorchRLModule)
- else:
- raise ValueError(
- f"The framework {self.framework_str} is not supported. Use `torch`."
- )
- @override(AlgorithmConfig)
- def get_default_learner_class(self) -> Union[Type["Learner"], str]:
- if self.framework_str == "torch":
- from ray.rllib.algorithms.tqc.torch.tqc_torch_learner import (
- TQCTorchLearner,
- )
- return TQCTorchLearner
- else:
- raise ValueError(
- f"The framework {self.framework_str} is not supported. Use `torch`."
- )
- @property
- @override(AlgorithmConfig)
- def _model_config_auto_includes(self):
- return super()._model_config_auto_includes | {
- "n_quantiles": self.n_quantiles,
- "n_critics": self.n_critics,
- "top_quantiles_to_drop_per_net": self.top_quantiles_to_drop_per_net,
- }
- class TQC(SAC):
- """TQC (Truncated Quantile Critics) Algorithm.
- TQC extends SAC by using distributional critics with quantile regression
- and truncating the top quantiles to control overestimation bias.
- Key differences from SAC:
- - Uses multiple critic networks, each outputting multiple quantiles
- - Computes target Q-values by sorting and truncating top quantiles
- - Uses quantile Huber loss for critic training
- See the paper for more details:
- https://arxiv.org/abs/2005.04269
- """
- @classmethod
- @override(Algorithm)
- def get_default_config(cls) -> TQCConfig:
- return TQCConfig()
|