marwil.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. from typing import Callable, Optional, Type, Union
  2. from typing_extensions import Self
  3. from ray._common.deprecation import deprecation_warning
  4. from ray.rllib.algorithms.algorithm import Algorithm
  5. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
  6. from ray.rllib.connectors.learner import (
  7. AddNextObservationsFromEpisodesToTrainBatch,
  8. AddObservationsFromEpisodesToBatch,
  9. AddOneTsToEpisodesAndTruncate,
  10. GeneralAdvantageEstimation,
  11. )
  12. from ray.rllib.core.learner.learner import Learner
  13. from ray.rllib.core.learner.training_data import TrainingData
  14. from ray.rllib.core.rl_module.rl_module import RLModuleSpec
  15. from ray.rllib.execution.rollout_ops import (
  16. synchronous_parallel_sample,
  17. )
  18. from ray.rllib.execution.train_ops import (
  19. multi_gpu_train_one_step,
  20. train_one_step,
  21. )
  22. from ray.rllib.policy.policy import Policy
  23. from ray.rllib.utils.annotations import OldAPIStack, override
  24. from ray.rllib.utils.metrics import (
  25. LEARNER_RESULTS,
  26. LEARNER_UPDATE_TIMER,
  27. NUM_AGENT_STEPS_SAMPLED,
  28. NUM_ENV_STEPS_SAMPLED,
  29. NUM_ENV_STEPS_SAMPLED_LIFETIME,
  30. OFFLINE_SAMPLING_TIMER,
  31. SAMPLE_TIMER,
  32. SYNCH_WORKER_WEIGHTS_TIMER,
  33. TIMERS,
  34. )
  35. from ray.rllib.utils.typing import (
  36. EnvType,
  37. ResultDict,
  38. RLModuleSpecType,
  39. )
  40. from ray.tune.logger import Logger
  41. class MARWILConfig(AlgorithmConfig):
  42. """Defines a configuration class from which a MARWIL Algorithm can be built.
  43. .. testcode::
  44. import gymnasium as gym
  45. import numpy as np
  46. from pathlib import Path
  47. from ray.rllib.algorithms.marwil import MARWILConfig
  48. # Get the base path (to ray/rllib)
  49. base_path = Path(__file__).parents[2]
  50. # Get the path to the data in rllib folder.
  51. data_path = base_path / "offline/tests/data/cartpole/cartpole-v1_large"
  52. config = MARWILConfig()
  53. # Enable the new API stack.
  54. config.api_stack(
  55. enable_rl_module_and_learner=True,
  56. enable_env_runner_and_connector_v2=True,
  57. )
  58. # Define the environment for which to learn a policy
  59. # from offline data.
  60. config.environment(
  61. observation_space=gym.spaces.Box(
  62. np.array([-4.8, -np.inf, -0.41887903, -np.inf]),
  63. np.array([4.8, np.inf, 0.41887903, np.inf]),
  64. shape=(4,),
  65. dtype=np.float32,
  66. ),
  67. action_space=gym.spaces.Discrete(2),
  68. )
  69. # Set the training parameters.
  70. config.training(
  71. beta=1.0,
  72. lr=1e-5,
  73. gamma=0.99,
  74. # We must define a train batch size for each
  75. # learner (here 1 local learner).
  76. train_batch_size_per_learner=2000,
  77. )
  78. # Define the data source for offline data.
  79. config.offline_data(
  80. input_=[data_path.as_posix()],
  81. # Run exactly one update per training iteration.
  82. dataset_num_iters_per_learner=1,
  83. )
  84. # Build an `Algorithm` object from the config and run 1 training
  85. # iteration.
  86. algo = config.build()
  87. algo.train()
  88. .. testcode::
  89. import gymnasium as gym
  90. import numpy as np
  91. from pathlib import Path
  92. from ray.rllib.algorithms.marwil import MARWILConfig
  93. from ray import tune
  94. # Get the base path (to ray/rllib)
  95. base_path = Path(__file__).parents[2]
  96. # Get the path to the data in rllib folder.
  97. data_path = base_path / "offline/tests/data/cartpole/cartpole-v1_large"
  98. config = MARWILConfig()
  99. # Enable the new API stack.
  100. config.api_stack(
  101. enable_rl_module_and_learner=True,
  102. enable_env_runner_and_connector_v2=True,
  103. )
  104. # Print out some default values
  105. print(f"beta: {config.beta}")
  106. # Update the config object.
  107. config.training(
  108. lr=tune.grid_search([1e-3, 1e-4]),
  109. beta=0.75,
  110. # We must define a train batch size for each
  111. # learner (here 1 local learner).
  112. train_batch_size_per_learner=2000,
  113. )
  114. # Set the config's data path.
  115. config.offline_data(
  116. input_=[data_path.as_posix()],
  117. # Set the number of updates to be run per learner
  118. # per training step.
  119. dataset_num_iters_per_learner=1,
  120. )
  121. # Set the config's environment for evalaution.
  122. config.environment(
  123. observation_space=gym.spaces.Box(
  124. np.array([-4.8, -np.inf, -0.41887903, -np.inf]),
  125. np.array([4.8, np.inf, 0.41887903, np.inf]),
  126. shape=(4,),
  127. dtype=np.float32,
  128. ),
  129. action_space=gym.spaces.Discrete(2),
  130. )
  131. # Set up a tuner to run the experiment.
  132. tuner = tune.Tuner(
  133. "MARWIL",
  134. param_space=config,
  135. run_config=tune.RunConfig(
  136. stop={"training_iteration": 1},
  137. ),
  138. )
  139. # Run the experiment.
  140. tuner.fit()
  141. """
  142. def __init__(self, algo_class=None):
  143. """Initializes a MARWILConfig instance."""
  144. self.exploration_config = {
  145. # The Exploration class to use. In the simplest case, this is the name
  146. # (str) of any class present in the `rllib.utils.exploration` package.
  147. # You can also provide the python class directly or the full location
  148. # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
  149. # EpsilonGreedy").
  150. "type": "StochasticSampling",
  151. # Add constructor kwargs here (if any).
  152. }
  153. super().__init__(algo_class=algo_class or MARWIL)
  154. self._is_online = False
  155. # fmt: off
  156. # __sphinx_doc_begin__
  157. # MARWIL specific settings:
  158. self.beta = 1.0
  159. self.bc_logstd_coeff = 0.0
  160. self.moving_average_sqd_adv_norm_update_rate = 1e-8
  161. self.moving_average_sqd_adv_norm_start = 100.0
  162. self.vf_coeff = 1.0
  163. self.model["vf_share_layers"] = False
  164. self.grad_clip = None
  165. # Override some of AlgorithmConfig's default values with MARWIL-specific values.
  166. # You should override input_ to point to an offline dataset
  167. # (see algorithm.py and algorithm_config.py).
  168. # The dataset may have an arbitrary number of timesteps
  169. # (and even episodes) per line.
  170. # However, each line must only contain consecutive timesteps in
  171. # order for MARWIL to be able to calculate accumulated
  172. # discounted returns. It is ok, though, to have multiple episodes in
  173. # the same line.
  174. self.input_ = "sampler"
  175. self.postprocess_inputs = True
  176. self.lr = 1e-4
  177. self.lambda_ = 1.0
  178. self.train_batch_size = 2000
  179. self.burnin_len = 0
  180. # Materialize only the data in raw format, but not the mapped data b/c
  181. # MARWIL uses a connector to calculate values and therefore the module
  182. # needs to be updated frequently. This updating would not work if we
  183. # map the data once at the beginning.
  184. # TODO (simon, sven): The module is only updated when the OfflinePreLearner
  185. # gets reinitiated, i.e. when the iterator gets reinitiated. This happens
  186. # frequently enough with a small dataset, but with a big one this does not
  187. # update often enough. We might need to put model weigths every couple of
  188. # iterations into the object storage (maybe also connector states).
  189. self.materialize_data = True
  190. self.materialize_mapped_data = False
  191. # __sphinx_doc_end__
  192. # fmt: on
  193. self._set_off_policy_estimation_methods = False
  194. @override(AlgorithmConfig)
  195. def training(
  196. self,
  197. *,
  198. beta: Optional[float] = NotProvided,
  199. bc_logstd_coeff: Optional[float] = NotProvided,
  200. moving_average_sqd_adv_norm_update_rate: Optional[float] = NotProvided,
  201. moving_average_sqd_adv_norm_start: Optional[float] = NotProvided,
  202. vf_coeff: Optional[float] = NotProvided,
  203. grad_clip: Optional[float] = NotProvided,
  204. burnin_len: Optional[int] = NotProvided,
  205. **kwargs,
  206. ) -> Self:
  207. """Sets the training related configuration.
  208. Args:
  209. beta: Scaling of advantages in exponential terms. When beta is 0.0,
  210. MARWIL is reduced to behavior cloning (imitation learning);
  211. see bc.py algorithm in this same directory.
  212. bc_logstd_coeff: A coefficient to encourage higher action distribution
  213. entropy for exploration.
  214. moving_average_sqd_adv_norm_update_rate: The rate for updating the
  215. squared moving average advantage norm (c^2). A higher rate leads
  216. to faster updates of this moving avergage.
  217. moving_average_sqd_adv_norm_start: Starting value for the
  218. squared moving average advantage norm (c^2).
  219. vf_coeff: Balancing value estimation loss and policy optimization loss.
  220. grad_clip: If specified, clip the global norm of gradients by this amount.
  221. burnin_len: Number of initial time steps to "burn in" when using
  222. RNNs. These time steps will not be included in the training loss.
  223. Returns:
  224. This updated AlgorithmConfig object.
  225. """
  226. # Pass kwargs onto super's `training()` method.
  227. super().training(**kwargs)
  228. if beta is not NotProvided:
  229. self.beta = beta
  230. if bc_logstd_coeff is not NotProvided:
  231. self.bc_logstd_coeff = bc_logstd_coeff
  232. if moving_average_sqd_adv_norm_update_rate is not NotProvided:
  233. self.moving_average_sqd_adv_norm_update_rate = (
  234. moving_average_sqd_adv_norm_update_rate
  235. )
  236. if moving_average_sqd_adv_norm_start is not NotProvided:
  237. self.moving_average_sqd_adv_norm_start = moving_average_sqd_adv_norm_start
  238. if vf_coeff is not NotProvided:
  239. self.vf_coeff = vf_coeff
  240. if grad_clip is not NotProvided:
  241. self.grad_clip = grad_clip
  242. if burnin_len is not NotProvided:
  243. self.burnin_len = burnin_len
  244. return self
  245. @override(AlgorithmConfig)
  246. def get_default_rl_module_spec(self) -> RLModuleSpecType:
  247. if self.framework_str == "torch":
  248. from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import (
  249. DefaultPPOTorchRLModule,
  250. )
  251. return RLModuleSpec(module_class=DefaultPPOTorchRLModule)
  252. else:
  253. raise ValueError(
  254. f"The framework {self.framework_str} is not supported. "
  255. "Use 'torch' instead."
  256. )
  257. @override(AlgorithmConfig)
  258. def get_default_learner_class(self) -> Union[Type["Learner"], str]:
  259. if self.framework_str == "torch":
  260. from ray.rllib.algorithms.marwil.torch.marwil_torch_learner import (
  261. MARWILTorchLearner,
  262. )
  263. return MARWILTorchLearner
  264. else:
  265. raise ValueError(
  266. f"The framework {self.framework_str} is not supported. "
  267. "Use 'torch' instead."
  268. )
  269. @override(AlgorithmConfig)
  270. def evaluation(
  271. self,
  272. **kwargs,
  273. ) -> Self:
  274. """Sets the evaluation related configuration.
  275. Returns:
  276. This updated AlgorithmConfig object.
  277. """
  278. # Pass kwargs onto super's `evaluation()` method.
  279. super().evaluation(**kwargs)
  280. if "off_policy_estimation_methods" in kwargs:
  281. # User specified their OPE methods.
  282. self._set_off_policy_estimation_methods = True
  283. return self
  284. @override(AlgorithmConfig)
  285. def offline_data(self, **kwargs) -> Self:
  286. super().offline_data(**kwargs)
  287. # Check, if the passed in class incorporates the `OfflinePreLearner`
  288. # interface.
  289. if "prelearner_class" in kwargs:
  290. from ray.rllib.offline.offline_data import OfflinePreLearner
  291. if not issubclass(kwargs.get("prelearner_class"), OfflinePreLearner):
  292. raise ValueError(
  293. f"`prelearner_class` {kwargs.get('prelearner_class')} is not a "
  294. "subclass of `OfflinePreLearner`. Any class passed to "
  295. "`prelearner_class` needs to implement the interface given by "
  296. "`OfflinePreLearner`."
  297. )
  298. return self
  299. @override(AlgorithmConfig)
  300. def build(
  301. self,
  302. env: Optional[Union[str, EnvType]] = None,
  303. logger_creator: Optional[Callable[[], Logger]] = None,
  304. ) -> "Algorithm":
  305. if not self._set_off_policy_estimation_methods:
  306. deprecation_warning(
  307. old=r"MARWIL used to have off_policy_estimation_methods "
  308. "is and wis by default. This has"
  309. r"changed to off_policy_estimation_methods: \{\}."
  310. "If you want to use an off-policy estimator, specify it in"
  311. ".evaluation(off_policy_estimation_methods=...)",
  312. error=False,
  313. )
  314. return super().build(env, logger_creator)
  315. @override(AlgorithmConfig)
  316. def build_learner_connector(
  317. self,
  318. input_observation_space,
  319. input_action_space,
  320. device=None,
  321. ):
  322. pipeline = super().build_learner_connector(
  323. input_observation_space=input_observation_space,
  324. input_action_space=input_action_space,
  325. device=device,
  326. )
  327. # Before anything, add one ts to each episode (and record this in the loss
  328. # mask, so that the computations at this extra ts are not used to compute
  329. # the loss).
  330. pipeline.prepend(AddOneTsToEpisodesAndTruncate())
  331. # Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
  332. # after the corresponding "add-OBS-..." default piece).
  333. pipeline.insert_after(
  334. AddObservationsFromEpisodesToBatch,
  335. AddNextObservationsFromEpisodesToTrainBatch(),
  336. )
  337. # At the end of the pipeline (when the batch is already completed), add the
  338. # GAE connector, which performs a vf forward pass, then computes the GAE
  339. # computations, and puts the results of this (advantages, value targets)
  340. # directly back in the batch. This is then the batch used for
  341. # `forward_train` and `compute_losses`.
  342. pipeline.append(
  343. GeneralAdvantageEstimation(gamma=self.gamma, lambda_=self.lambda_)
  344. )
  345. return pipeline
  346. @override(AlgorithmConfig)
  347. def validate(self) -> None:
  348. # Call super's validation method.
  349. super().validate()
  350. if self.beta < 0.0 or self.beta > 1.0:
  351. self._value_error("`beta` must be within 0.0 and 1.0!")
  352. if self.postprocess_inputs is False and self.beta > 0.0:
  353. self._value_error(
  354. "`postprocess_inputs` must be True for MARWIL (to "
  355. "calculate accum., discounted returns)! Try setting "
  356. "`config.offline_data(postprocess_inputs=True)`."
  357. )
  358. # Assert that for a local learner the number of iterations is 1. Note,
  359. # this is needed because we have no iterators, but instead a single
  360. # batch returned directly from the `OfflineData.sample` method.
  361. if (
  362. self.num_learners == 0
  363. and not self.dataset_num_iters_per_learner
  364. and self.enable_rl_module_and_learner
  365. ):
  366. self._value_error(
  367. "When using a local Learner (`config.num_learners=0`), the number of "
  368. "iterations per learner (`dataset_num_iters_per_learner`) has to be "
  369. "defined! Set this hyperparameter through `config.offline_data("
  370. "dataset_num_iters_per_learner=...)`."
  371. )
  372. # Assert that burnin_len is smaller than max_seq_len.
  373. if self.burnin_len > 0 and (
  374. self.burnin_len >= self.model.get("max_seq_len", 0)
  375. ):
  376. self._value_error(
  377. "`burnin_len` must be < `model.max_seq_len`! "
  378. f"Got burnin_len={self.burnin_len}, "
  379. f"model.max_seq_len={self.model.get('max_seq_len', 0)}."
  380. )
  381. @property
  382. def _model_auto_keys(self):
  383. return super()._model_auto_keys | {"beta": self.beta, "vf_share_layers": False}
  384. class MARWIL(Algorithm):
  385. @classmethod
  386. @override(Algorithm)
  387. def get_default_config(cls) -> MARWILConfig:
  388. return MARWILConfig()
  389. @classmethod
  390. @override(Algorithm)
  391. def get_default_policy_class(
  392. cls, config: AlgorithmConfig
  393. ) -> Optional[Type[Policy]]:
  394. if config["framework"] == "torch":
  395. from ray.rllib.algorithms.marwil.marwil_torch_policy import (
  396. MARWILTorchPolicy,
  397. )
  398. return MARWILTorchPolicy
  399. elif config["framework"] == "tf":
  400. from ray.rllib.algorithms.marwil.marwil_tf_policy import (
  401. MARWILTF1Policy,
  402. )
  403. return MARWILTF1Policy
  404. else:
  405. from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTF2Policy
  406. return MARWILTF2Policy
  407. @override(Algorithm)
  408. def training_step(self) -> None:
  409. """Implements training logic for the new stack
  410. Note, this includes so far training with the `OfflineData`
  411. class (multi-/single-learner setup) and evaluation on
  412. `EnvRunner`s. Note further, evaluation on the dataset itself
  413. using estimators is not implemented, yet.
  414. """
  415. # Old API stack (Policy, RolloutWorker, Connector).
  416. if not self.config.enable_env_runner_and_connector_v2:
  417. return self._training_step_old_api_stack()
  418. # TODO (simon): Take care of sampler metrics: right
  419. # now all rewards are `nan`, which possibly confuses
  420. # the user that sth. is not right, although it is as
  421. # we do not step the env.
  422. with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
  423. # If we should use an iterator in the learner(s). Note, in case of
  424. # multiple learners we must always return a list of iterators.
  425. return_iterator = (
  426. self.config.num_learners > 0
  427. or self.config.dataset_num_iters_per_learner != 1
  428. )
  429. # Sampling from offline data.
  430. batch_or_iterator = self.offline_data.sample(
  431. num_samples=self.config.train_batch_size_per_learner,
  432. num_shards=self.config.num_learners,
  433. # Return an iterator, if a `Learner` should update
  434. # multiple times per RLlib iteration.
  435. return_iterator=return_iterator,
  436. )
  437. self.metrics.log_value(
  438. key=NUM_ENV_STEPS_SAMPLED_LIFETIME,
  439. value=self.config.train_batch_size_per_learner
  440. * max(1, self.config.num_learners),
  441. reduce="lifetime_sum",
  442. )
  443. if return_iterator:
  444. training_data = TrainingData(data_iterators=batch_or_iterator)
  445. else:
  446. training_data = TrainingData(batch=batch_or_iterator)
  447. with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
  448. # Updating the policy.
  449. learner_results = self.learner_group.update(
  450. training_data=training_data,
  451. minibatch_size=self.config.train_batch_size_per_learner,
  452. num_iters=self.config.dataset_num_iters_per_learner,
  453. timesteps={
  454. NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
  455. NUM_ENV_STEPS_SAMPLED_LIFETIME
  456. )
  457. },
  458. **self.offline_data.iter_batches_kwargs,
  459. )
  460. # Log training results.
  461. self.metrics.aggregate(learner_results, key=LEARNER_RESULTS)
  462. @OldAPIStack
  463. def _training_step_old_api_stack(self) -> ResultDict:
  464. """Implements training step for the old stack.
  465. Note, there is no hybrid stack anymore. If you need to use `RLModule`s,
  466. use the new api stack.
  467. """
  468. # Collect SampleBatches from sample workers.
  469. with self._timers[SAMPLE_TIMER]:
  470. train_batch = synchronous_parallel_sample(worker_set=self.env_runner_group)
  471. train_batch = train_batch.as_multi_agent(
  472. module_id=list(self.config.policies)[0]
  473. )
  474. self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
  475. self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
  476. # Train.
  477. if self.config.simple_optimizer:
  478. train_results = train_one_step(self, train_batch)
  479. else:
  480. train_results = multi_gpu_train_one_step(self, train_batch)
  481. # TODO: Move training steps counter update outside of `train_one_step()` method.
  482. # # Update train step counters.
  483. # self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
  484. # self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
  485. global_vars = {
  486. "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
  487. }
  488. # Update weights - after learning on the local worker - on all remote
  489. # workers (only those policies that were actually trained).
  490. if self.env_runner_group.num_remote_env_runners() > 0:
  491. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
  492. self.env_runner_group.sync_weights(
  493. policies=list(train_results.keys()), global_vars=global_vars
  494. )
  495. # Update global vars on local worker as well.
  496. self.env_runner.set_global_vars(global_vars)
  497. return train_results