tqc.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. """
  2. TQC (Truncated Quantile Critics) Algorithm.
  3. Paper: https://arxiv.org/abs/2005.04269
  4. "Controlling Overestimation Bias with Truncated Mixture of Continuous
  5. Distributional Quantile Critics"
  6. TQC extends SAC by using distributional RL with quantile regression to
  7. control overestimation bias in the Q-function.
  8. """
  9. import logging
  10. from typing import Optional, Type, Union
  11. from ray.rllib.algorithms.algorithm import Algorithm
  12. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
  13. from ray.rllib.algorithms.sac.sac import SAC, SACConfig
  14. from ray.rllib.core.learner import Learner
  15. from ray.rllib.core.rl_module.rl_module import RLModuleSpec
  16. from ray.rllib.utils.annotations import override
  17. from ray.rllib.utils.typing import RLModuleSpecType
  18. logger = logging.getLogger(__name__)
  19. class TQCConfig(SACConfig):
  20. """Configuration for the TQC algorithm.
  21. TQC extends SAC with distributional critics using quantile regression.
  22. Example:
  23. >>> from ray.rllib.algorithms.tqc import TQCConfig
  24. >>> config = (
  25. ... TQCConfig()
  26. ... .environment("Pendulum-v1")
  27. ... .training(
  28. ... n_quantiles=25,
  29. ... n_critics=2,
  30. ... top_quantiles_to_drop_per_net=2,
  31. ... )
  32. ... )
  33. >>> algo = config.build()
  34. """
  35. def __init__(self, algo_class=None):
  36. """Initializes a TQCConfig instance."""
  37. super().__init__(algo_class=algo_class or TQC)
  38. # TQC-specific parameters
  39. self.n_quantiles = 25
  40. self.n_critics = 2
  41. self.top_quantiles_to_drop_per_net = 2
  42. @override(SACConfig)
  43. def training(
  44. self,
  45. *,
  46. n_quantiles: Optional[int] = NotProvided,
  47. n_critics: Optional[int] = NotProvided,
  48. top_quantiles_to_drop_per_net: Optional[int] = NotProvided,
  49. **kwargs,
  50. ):
  51. """Sets the training-related configuration.
  52. Args:
  53. n_quantiles: Number of quantiles for each critic network.
  54. Default is 25.
  55. n_critics: Number of critic networks. Default is 2.
  56. top_quantiles_to_drop_per_net: Number of quantiles to drop per
  57. network when computing the target Q-value. This controls
  58. the overestimation bias. Default is 2.
  59. **kwargs: Additional arguments passed to SACConfig.training().
  60. Returns:
  61. This updated TQCConfig object.
  62. """
  63. super().training(**kwargs)
  64. if n_quantiles is not NotProvided:
  65. self.n_quantiles = n_quantiles
  66. if n_critics is not NotProvided:
  67. self.n_critics = n_critics
  68. if top_quantiles_to_drop_per_net is not NotProvided:
  69. self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net
  70. return self
  71. @override(AlgorithmConfig)
  72. def validate(self) -> None:
  73. """Validates the TQC configuration."""
  74. super().validate()
  75. # Validate TQC-specific parameters
  76. if self.n_quantiles < 1:
  77. raise ValueError(f"`n_quantiles` must be >= 1, got {self.n_quantiles}")
  78. if self.n_critics < 1:
  79. raise ValueError(f"`n_critics` must be >= 1, got {self.n_critics}")
  80. # Ensure top_quantiles_to_drop_per_net is non-negative
  81. if self.top_quantiles_to_drop_per_net < 0:
  82. raise ValueError(
  83. f"`top_quantiles_to_drop_per_net` must be >= 0, got "
  84. f"{self.top_quantiles_to_drop_per_net}"
  85. )
  86. # Ensure we don't drop more quantiles than we have
  87. total_quantiles = self.n_quantiles * self.n_critics
  88. quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.n_critics
  89. if quantiles_to_drop >= total_quantiles:
  90. raise ValueError(
  91. f"Cannot drop {quantiles_to_drop} quantiles when only "
  92. f"{total_quantiles} total quantiles are available. "
  93. f"Reduce `top_quantiles_to_drop_per_net` or increase "
  94. f"`n_quantiles` or `n_critics`."
  95. )
  96. @override(AlgorithmConfig)
  97. def get_default_rl_module_spec(self) -> RLModuleSpecType:
  98. if self.framework_str == "torch":
  99. from ray.rllib.algorithms.tqc.torch.default_tqc_torch_rl_module import (
  100. DefaultTQCTorchRLModule,
  101. )
  102. return RLModuleSpec(module_class=DefaultTQCTorchRLModule)
  103. else:
  104. raise ValueError(
  105. f"The framework {self.framework_str} is not supported. Use `torch`."
  106. )
  107. @override(AlgorithmConfig)
  108. def get_default_learner_class(self) -> Union[Type["Learner"], str]:
  109. if self.framework_str == "torch":
  110. from ray.rllib.algorithms.tqc.torch.tqc_torch_learner import (
  111. TQCTorchLearner,
  112. )
  113. return TQCTorchLearner
  114. else:
  115. raise ValueError(
  116. f"The framework {self.framework_str} is not supported. Use `torch`."
  117. )
  118. @property
  119. @override(AlgorithmConfig)
  120. def _model_config_auto_includes(self):
  121. return super()._model_config_auto_includes | {
  122. "n_quantiles": self.n_quantiles,
  123. "n_critics": self.n_critics,
  124. "top_quantiles_to_drop_per_net": self.top_quantiles_to_drop_per_net,
  125. }
  126. class TQC(SAC):
  127. """TQC (Truncated Quantile Critics) Algorithm.
  128. TQC extends SAC by using distributional critics with quantile regression
  129. and truncating the top quantiles to control overestimation bias.
  130. Key differences from SAC:
  131. - Uses multiple critic networks, each outputting multiple quantiles
  132. - Computes target Q-values by sorting and truncating top quantiles
  133. - Uses quantile Huber loss for critic training
  134. See the paper for more details:
  135. https://arxiv.org/abs/2005.04269
  136. """
  137. @classmethod
  138. @override(Algorithm)
  139. def get_default_config(cls) -> TQCConfig:
  140. return TQCConfig()