iql.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from typing import Optional, Type, Union
  2. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
  3. from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
  4. from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
  5. AddObservationsFromEpisodesToBatch,
  6. )
  7. from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
  8. AddNextObservationsFromEpisodesToTrainBatch,
  9. )
  10. from ray.rllib.core.learner.learner import Learner
  11. from ray.rllib.core.rl_module.rl_module import RLModuleSpec
  12. from ray.rllib.utils.annotations import override
  13. from ray.rllib.utils.typing import LearningRateOrSchedule, RLModuleSpecType
  14. class IQLConfig(MARWILConfig):
  15. """Defines a configuration class from which a new IQL Algorithm can be built
  16. .. testcode::
  17. :skipif: True
  18. from ray.rllib.algorithms.iql import IQLConfig
  19. # Run this from the ray directory root.
  20. config = IQLConfig().training(actor_lr=0.00001, gamma=0.99)
  21. config = config.offline_data(
  22. input_="./rllib/offline/tests/data/pendulum/pendulum-v1_enormous")
  23. # Build an Algorithm object from the config and run 1 training iteration.
  24. algo = config.build()
  25. algo.train()
  26. .. testcode::
  27. :skipif: True
  28. from ray.rllib.algorithms.iql import IQLConfig
  29. from ray import tune
  30. config = IQLConfig()
  31. # Print out some default values.
  32. print(config.beta)
  33. # Update the config object.
  34. config.training(
  35. lr=tune.grid_search([0.001, 0.0001]), beta=0.75
  36. )
  37. # Set the config object's data path.
  38. # Run this from the ray directory root.
  39. config.offline_data(
  40. input_="./rllib/offline/tests/data/pendulum/pendulum-v1_enormous"
  41. )
  42. # Set the config object's env, used for evaluation.
  43. config.environment(env="Pendulum-v1")
  44. # Use to_dict() to get the old-style python config dict
  45. # when running with tune.
  46. tune.Tuner(
  47. "IQL",
  48. param_space=config.to_dict(),
  49. ).fit()
  50. """
  51. def __init__(self, algo_class=None):
  52. super().__init__(algo_class=algo_class or IQL)
  53. # fmt: off
  54. # __sphinx_doc_begin__
  55. # The temperature for the actor loss.
  56. self.beta = 0.1
  57. # The expectile to use in expectile regression.
  58. self.expectile = 0.8
  59. # The learning rates for the actor, critic and value network(s).
  60. self.actor_lr = 3e-4
  61. self.critic_lr = 3e-4
  62. self.value_lr = 3e-4
  63. # Set `lr` parameter to `None` and ensure it is not used.
  64. self.lr = None
  65. # If a twin-Q architecture should be used (advisable).
  66. self.twin_q = True
  67. # How often the target network should be updated.
  68. self.target_network_update_freq = 0
  69. # The weight for Polyak averaging.
  70. self.tau = 1.0
  71. # __sphinx_doc_end__
  72. # fmt: on
  73. @override(MARWILConfig)
  74. def training(
  75. self,
  76. *,
  77. twin_q: Optional[bool] = NotProvided,
  78. expectile: Optional[float] = NotProvided,
  79. actor_lr: Optional[LearningRateOrSchedule] = NotProvided,
  80. critic_lr: Optional[LearningRateOrSchedule] = NotProvided,
  81. value_lr: Optional[LearningRateOrSchedule] = NotProvided,
  82. target_network_update_freq: Optional[int] = NotProvided,
  83. tau: Optional[float] = NotProvided,
  84. **kwargs,
  85. ) -> "IQLConfig":
  86. """Sets the training related configuration.
  87. Args:
  88. beta: The temperature to scaling advantages in exponential terms.
  89. Must be >> 0.0. The higher this parameter the less greedy
  90. (exploitative) the policy becomes. It also means that the policy
  91. is fitting less to the best actions in the dataset.
  92. twin_q: If a twin-Q architecture should be used (advisable).
  93. expectile: The expectile to use in expectile regression for the value
  94. function. For high expectiles the value function tries to match
  95. the upper tail of the Q-value distribution.
  96. actor_lr: The learning rate for the actor network. Actor learning rates
  97. greater than critic learning rates work well in experiments.
  98. critic_lr: The learning rate for the Q-network. Critic learning rates
  99. greater than value function learning rates work well in experiments.
  100. value_lr: The learning rate for the value function network.
  101. target_network_update_freq: The number of timesteps in between the target
  102. Q-network is fixed. Note, too high values here could harm convergence.
  103. The target network is updated via Polyak-averaging.
  104. tau: The update parameter for Polyak-averaging of the target Q-network.
  105. The higher this value the faster the weights move towards the actual
  106. Q-network.
  107. Return:
  108. This updated `AlgorithmConfig` object.
  109. """
  110. super().training(**kwargs)
  111. if twin_q is not NotProvided:
  112. self.twin_q = twin_q
  113. if expectile is not NotProvided:
  114. self.expectile = expectile
  115. if actor_lr is not NotProvided:
  116. self.actor_lr = actor_lr
  117. if critic_lr is not NotProvided:
  118. self.critic_lr = critic_lr
  119. if value_lr is not NotProvided:
  120. self.value_lr = value_lr
  121. if target_network_update_freq is not NotProvided:
  122. self.target_network_update_freq = target_network_update_freq
  123. if tau is not NotProvided:
  124. self.tau = tau
  125. return self
  126. @override(MARWILConfig)
  127. def get_default_learner_class(self) -> Union[Type["Learner"], str]:
  128. if self.framework_str == "torch":
  129. from ray.rllib.algorithms.iql.torch.iql_torch_learner import IQLTorchLearner
  130. return IQLTorchLearner
  131. else:
  132. raise ValueError(
  133. f"The framework {self.framework_str} is not supported. "
  134. "Use `'torch'` instead."
  135. )
  136. @override(MARWILConfig)
  137. def get_default_rl_module_spec(self) -> RLModuleSpecType:
  138. if self.framework_str == "torch":
  139. from ray.rllib.algorithms.iql.torch.default_iql_torch_rl_module import (
  140. DefaultIQLTorchRLModule,
  141. )
  142. return RLModuleSpec(module_class=DefaultIQLTorchRLModule)
  143. else:
  144. raise ValueError(
  145. f"The framework {self.framework_str} is not supported. "
  146. "Use `torch` instead."
  147. )
  148. @override(MARWILConfig)
  149. def build_learner_connector(
  150. self,
  151. input_observation_space,
  152. input_action_space,
  153. device=None,
  154. ):
  155. pipeline = super().build_learner_connector(
  156. input_observation_space=input_observation_space,
  157. input_action_space=input_action_space,
  158. device=device,
  159. )
  160. # Remove unneeded connectors from the MARWIL connector pipeline.
  161. pipeline.remove("AddOneTsToEpisodesAndTruncate")
  162. pipeline.remove("GeneralAdvantageEstimation")
  163. # Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
  164. # after the corresponding "add-OBS-..." default piece).
  165. pipeline.insert_after(
  166. AddObservationsFromEpisodesToBatch,
  167. AddNextObservationsFromEpisodesToTrainBatch(),
  168. )
  169. return pipeline
  170. @override(MARWILConfig)
  171. def validate(self) -> None:
  172. # Call super's validation method.
  173. super().validate()
  174. # Ensure hyperparameters are meaningful.
  175. if self.beta <= 0.0:
  176. self._value_error(
  177. "For meaningful results, `beta` (temperature) parameter must be >> 0.0!"
  178. )
  179. if not 0.0 < self.expectile < 1.0:
  180. self._value_error(
  181. "For meaningful results, `expectile` parameter must be in (0, 1)."
  182. )
  183. @property
  184. def _model_config_auto_includes(self):
  185. return super()._model_config_auto_includes | {"twin_q": self.twin_q}
  186. class IQL(MARWIL):
  187. """Implicit Q-learning (derived from MARWIL).
  188. Uses MARWIL training step.
  189. """
  190. @classmethod
  191. @override(MARWIL)
  192. def get_default_config(cls) -> AlgorithmConfig:
  193. return IQLConfig()