ppo.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. """
  2. Proximal Policy Optimization (PPO)
  3. ==================================
  4. This file defines the distributed Algorithm class for proximal policy
  5. optimization.
  6. See `ppo_[tf|torch]_policy.py` for the definition of the policy loss.
  7. Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#ppo
  8. """
  9. import logging
  10. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
  11. from typing_extensions import Self
  12. from ray._common.deprecation import DEPRECATED_VALUE
  13. from ray.rllib.algorithms.algorithm import Algorithm
  14. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
  15. from ray.rllib.core.rl_module.rl_module import RLModuleSpec
  16. from ray.rllib.execution.rollout_ops import (
  17. standardize_fields,
  18. synchronous_parallel_sample,
  19. )
  20. from ray.rllib.execution.train_ops import (
  21. multi_gpu_train_one_step,
  22. train_one_step,
  23. )
  24. from ray.rllib.policy.policy import Policy
  25. from ray.rllib.utils.annotations import OldAPIStack, override
  26. from ray.rllib.utils.metrics import (
  27. ALL_MODULES,
  28. ENV_RUNNER_RESULTS,
  29. ENV_RUNNER_SAMPLING_TIMER,
  30. LEARNER_RESULTS,
  31. LEARNER_UPDATE_TIMER,
  32. NUM_AGENT_STEPS_SAMPLED,
  33. NUM_ENV_STEPS_SAMPLED,
  34. NUM_ENV_STEPS_SAMPLED_LIFETIME,
  35. NUM_MODULE_STEPS_TRAINED_LIFETIME,
  36. SAMPLE_TIMER,
  37. SYNCH_WORKER_WEIGHTS_TIMER,
  38. TIMERS,
  39. )
  40. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  41. from ray.rllib.utils.schedules.scheduler import Scheduler
  42. from ray.rllib.utils.typing import ResultDict
  43. from ray.util.debug import log_once
  44. if TYPE_CHECKING:
  45. from ray.rllib.core.learner.learner import Learner
  46. logger = logging.getLogger(__name__)
  47. LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY = "vf_loss_unclipped"
  48. LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY = "vf_explained_var"
  49. LEARNER_RESULTS_KL_KEY = "mean_kl_loss"
  50. LEARNER_RESULTS_CURR_KL_COEFF_KEY = "curr_kl_coeff"
  51. LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY = "curr_entropy_coeff"
  52. class PPOConfig(AlgorithmConfig):
  53. """Defines a configuration class from which a PPO Algorithm can be built.
  54. .. testcode::
  55. from ray.rllib.algorithms.ppo import PPOConfig
  56. config = PPOConfig()
  57. config.environment("CartPole-v1")
  58. config.env_runners(num_env_runners=1)
  59. config.training(
  60. gamma=0.9, lr=0.01, kl_coeff=0.3, train_batch_size_per_learner=256
  61. )
  62. # Build a Algorithm object from the config and run 1 training iteration.
  63. algo = config.build()
  64. algo.train()
  65. .. testcode::
  66. from ray.rllib.algorithms.ppo import PPOConfig
  67. from ray import tune
  68. config = (
  69. PPOConfig()
  70. # Set the config object's env.
  71. .environment(env="CartPole-v1")
  72. # Update the config object's training parameters.
  73. .training(
  74. lr=0.001, clip_param=0.2
  75. )
  76. )
  77. tune.Tuner(
  78. "PPO",
  79. run_config=tune.RunConfig(stop={"training_iteration": 1}),
  80. param_space=config,
  81. ).fit()
  82. .. testoutput::
  83. :hide:
  84. ...
  85. """
  86. def __init__(self, algo_class=None):
  87. """Initializes a PPOConfig instance."""
  88. self.exploration_config = {
  89. # The Exploration class to use. In the simplest case, this is the name
  90. # (str) of any class present in the `rllib.utils.exploration` package.
  91. # You can also provide the python class directly or the full location
  92. # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
  93. # EpsilonGreedy").
  94. "type": "StochasticSampling",
  95. # Add constructor kwargs here (if any).
  96. }
  97. super().__init__(algo_class=algo_class or PPO)
  98. # fmt: off
  99. # __sphinx_doc_begin__
  100. self.lr = 5e-5
  101. self.rollout_fragment_length = "auto"
  102. self.train_batch_size = 4000
  103. # PPO specific settings:
  104. self.use_critic = True
  105. self.use_gae = True
  106. self.num_epochs = 30
  107. self.minibatch_size = 128
  108. self.shuffle_batch_per_epoch = True
  109. self.lambda_ = 1.0
  110. self.use_kl_loss = True
  111. self.kl_coeff = 0.2
  112. self.kl_target = 0.01
  113. self.vf_loss_coeff = 1.0
  114. self.entropy_coeff = 0.0
  115. self.clip_param = 0.3
  116. self.vf_clip_param = 10.0
  117. self.grad_clip = None
  118. # Override some of AlgorithmConfig's default values with PPO-specific values.
  119. self.num_env_runners = 2
  120. # __sphinx_doc_end__
  121. # fmt: on
  122. self.model["vf_share_layers"] = False # @OldAPIStack
  123. self.entropy_coeff_schedule = None # @OldAPIStack
  124. self.lr_schedule = None # @OldAPIStack
  125. # Deprecated keys.
  126. self.sgd_minibatch_size = DEPRECATED_VALUE
  127. self.vf_share_layers = DEPRECATED_VALUE
  128. @override(AlgorithmConfig)
  129. def get_default_rl_module_spec(self) -> RLModuleSpec:
  130. if self.framework_str == "torch":
  131. from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import (
  132. DefaultPPOTorchRLModule,
  133. )
  134. return RLModuleSpec(module_class=DefaultPPOTorchRLModule)
  135. else:
  136. raise ValueError(
  137. f"The framework {self.framework_str} is not supported. "
  138. "Use either 'torch' or 'tf2'."
  139. )
  140. @override(AlgorithmConfig)
  141. def get_default_learner_class(self) -> Union[Type["Learner"], str]:
  142. if self.framework_str == "torch":
  143. from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import (
  144. PPOTorchLearner,
  145. )
  146. return PPOTorchLearner
  147. elif self.framework_str in ["tf2", "tf"]:
  148. raise ValueError(
  149. "TensorFlow is no longer supported on the new API stack! "
  150. "Use `framework='torch'`."
  151. )
  152. else:
  153. raise ValueError(
  154. f"The framework {self.framework_str} is not supported. "
  155. "Use `framework='torch'`."
  156. )
  157. @override(AlgorithmConfig)
  158. def training(
  159. self,
  160. *,
  161. use_critic: Optional[bool] = NotProvided,
  162. use_gae: Optional[bool] = NotProvided,
  163. lambda_: Optional[float] = NotProvided,
  164. use_kl_loss: Optional[bool] = NotProvided,
  165. kl_coeff: Optional[float] = NotProvided,
  166. kl_target: Optional[float] = NotProvided,
  167. vf_loss_coeff: Optional[float] = NotProvided,
  168. entropy_coeff: Optional[float] = NotProvided,
  169. entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = NotProvided,
  170. clip_param: Optional[float] = NotProvided,
  171. vf_clip_param: Optional[float] = NotProvided,
  172. grad_clip: Optional[float] = NotProvided,
  173. # @OldAPIStack
  174. lr_schedule: Optional[List[List[Union[int, float]]]] = NotProvided,
  175. # Deprecated.
  176. vf_share_layers=DEPRECATED_VALUE,
  177. **kwargs,
  178. ) -> Self:
  179. """Sets the training related configuration.
  180. Args:
  181. use_critic: Should use a critic as a baseline (otherwise don't use value
  182. baseline; required for using GAE).
  183. use_gae: If true, use the Generalized Advantage Estimator (GAE)
  184. with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
  185. lambda_: The lambda parameter for General Advantage Estimation (GAE).
  186. Defines the exponential weight used between actually measured rewards
  187. vs value function estimates over multiple time steps. Specifically,
  188. `lambda_` balances short-term, low-variance estimates against long-term,
  189. high-variance returns. A `lambda_` of 0.0 makes the GAE rely only on
  190. immediate rewards (and vf predictions from there on, reducing variance,
  191. but increasing bias), while a `lambda_` of 1.0 only incorporates vf
  192. predictions at the truncation points of the given episodes or episode
  193. chunks (reducing bias but increasing variance).
  194. use_kl_loss: Whether to use the KL-term in the loss function.
  195. kl_coeff: Initial coefficient for KL divergence.
  196. kl_target: Target value for KL divergence.
  197. vf_loss_coeff: Coefficient of the value function loss. IMPORTANT: you must
  198. tune this if you set vf_share_layers=True inside your model's config.
  199. entropy_coeff: The entropy coefficient (float) or entropy coefficient
  200. schedule in the format of
  201. [[timestep, coeff-value], [timestep, coeff-value], ...]
  202. In case of a schedule, intermediary timesteps will be assigned to
  203. linearly interpolated coefficient values. A schedule config's first
  204. entry must start with timestep 0, i.e.: [[0, initial_value], [...]].
  205. clip_param: The PPO clip parameter.
  206. vf_clip_param: Clip param for the value function. Note that this is
  207. sensitive to the scale of the rewards. If your expected V is large,
  208. increase this.
  209. grad_clip: If specified, clip the global norm of gradients by this amount.
  210. Returns:
  211. This updated AlgorithmConfig object.
  212. """
  213. # Pass kwargs onto super's `training()` method.
  214. super().training(**kwargs)
  215. if use_critic is not NotProvided:
  216. self.use_critic = use_critic
  217. # TODO (Kourosh) This is experimental.
  218. # Don't forget to remove .use_critic from algorithm config.
  219. if use_gae is not NotProvided:
  220. self.use_gae = use_gae
  221. if lambda_ is not NotProvided:
  222. self.lambda_ = lambda_
  223. if use_kl_loss is not NotProvided:
  224. self.use_kl_loss = use_kl_loss
  225. if kl_coeff is not NotProvided:
  226. self.kl_coeff = kl_coeff
  227. if kl_target is not NotProvided:
  228. self.kl_target = kl_target
  229. if vf_loss_coeff is not NotProvided:
  230. self.vf_loss_coeff = vf_loss_coeff
  231. if entropy_coeff is not NotProvided:
  232. self.entropy_coeff = entropy_coeff
  233. if clip_param is not NotProvided:
  234. self.clip_param = clip_param
  235. if vf_clip_param is not NotProvided:
  236. self.vf_clip_param = vf_clip_param
  237. if grad_clip is not NotProvided:
  238. self.grad_clip = grad_clip
  239. # TODO (sven): Remove these once new API stack is only option for PPO.
  240. if lr_schedule is not NotProvided:
  241. self.lr_schedule = lr_schedule
  242. if entropy_coeff_schedule is not NotProvided:
  243. self.entropy_coeff_schedule = entropy_coeff_schedule
  244. return self
  245. @override(AlgorithmConfig)
  246. def validate(self) -> None:
  247. # Call super's validation method.
  248. super().validate()
  249. # Synchronous sampling, on-policy/PPO algos -> Check mismatches between
  250. # `rollout_fragment_length` and `train_batch_size_per_learner` to avoid user
  251. # confusion.
  252. # TODO (sven): Make rollout_fragment_length a property and create a private
  253. # attribute to store (possibly) user provided value (or "auto") in. Deprecate
  254. # `self.get_rollout_fragment_length()`.
  255. self.validate_train_batch_size_vs_rollout_fragment_length()
  256. # SGD minibatch size must be smaller than train_batch_size (b/c
  257. # we subsample a batch of `minibatch_size` from the train-batch for
  258. # each `num_epochs`).
  259. if (
  260. not self.enable_rl_module_and_learner
  261. and self.minibatch_size > self.train_batch_size
  262. ):
  263. self._value_error(
  264. f"`minibatch_size` ({self.minibatch_size}) must be <= "
  265. f"`train_batch_size` ({self.train_batch_size}). In PPO, the train batch"
  266. f" will be split into {self.minibatch_size} chunks, each of which "
  267. f"is iterated over (used for updating the policy) {self.num_epochs} "
  268. "times."
  269. )
  270. elif self.enable_rl_module_and_learner:
  271. mbs = self.minibatch_size
  272. tbs = self.train_batch_size_per_learner or self.train_batch_size
  273. if isinstance(mbs, int) and isinstance(tbs, int) and mbs > tbs:
  274. self._value_error(
  275. f"`minibatch_size` ({mbs}) must be <= "
  276. f"`train_batch_size_per_learner` ({tbs}). In PPO, the train batch"
  277. f" will be split into {mbs} chunks, each of which is iterated over "
  278. f"(used for updating the policy) {self.num_epochs} times."
  279. )
  280. # Episodes may only be truncated (and passed into PPO's
  281. # `postprocessing_fn`), iff generalized advantage estimation is used
  282. # (value function estimate at end of truncated episode to estimate
  283. # remaining value).
  284. if (
  285. not self.in_evaluation
  286. and self.batch_mode == "truncate_episodes"
  287. and not self.use_gae
  288. ):
  289. self._value_error(
  290. "Episode truncation is not supported without a value "
  291. "function (to estimate the return at the end of the truncated"
  292. " trajectory). Consider setting "
  293. "batch_mode=complete_episodes."
  294. )
  295. # New API stack checks.
  296. if self.enable_rl_module_and_learner:
  297. # `lr_schedule` checking.
  298. if self.lr_schedule is not None:
  299. self._value_error(
  300. "`lr_schedule` is deprecated and must be None! Use the "
  301. "`lr` setting to setup a schedule."
  302. )
  303. if self.entropy_coeff_schedule is not None:
  304. self._value_error(
  305. "`entropy_coeff_schedule` is deprecated and must be None! Use the "
  306. "`entropy_coeff` setting to setup a schedule."
  307. )
  308. Scheduler.validate(
  309. fixed_value_or_schedule=self.entropy_coeff,
  310. setting_name="entropy_coeff",
  311. description="entropy coefficient",
  312. )
  313. if isinstance(self.entropy_coeff, float) and self.entropy_coeff < 0.0:
  314. self._value_error("`entropy_coeff` must be >= 0.0")
  315. @property
  316. @override(AlgorithmConfig)
  317. def _model_config_auto_includes(self) -> Dict[str, Any]:
  318. return super()._model_config_auto_includes | {"vf_share_layers": False}
  319. class PPO(Algorithm):
  320. @classmethod
  321. @override(Algorithm)
  322. def get_default_config(cls) -> PPOConfig:
  323. return PPOConfig()
  324. @classmethod
  325. @override(Algorithm)
  326. def get_default_policy_class(
  327. cls, config: AlgorithmConfig
  328. ) -> Optional[Type[Policy]]:
  329. if config["framework"] == "torch":
  330. from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
  331. return PPOTorchPolicy
  332. elif config["framework"] == "tf":
  333. from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
  334. return PPOTF1Policy
  335. else:
  336. from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF2Policy
  337. return PPOTF2Policy
  338. @override(Algorithm)
  339. def training_step(self) -> None:
  340. # Old API stack (Policy, RolloutWorker, Connector).
  341. if not self.config.enable_env_runner_and_connector_v2:
  342. return self._training_step_old_api_stack()
  343. # Collect batches from sample workers until we have a full batch.
  344. with self.metrics.log_time((TIMERS, ENV_RUNNER_SAMPLING_TIMER)):
  345. # Sample in parallel from the workers.
  346. if self.config.count_steps_by == "agent_steps":
  347. episodes, env_runner_results = synchronous_parallel_sample(
  348. worker_set=self.env_runner_group,
  349. max_agent_steps=self.config.total_train_batch_size,
  350. sample_timeout_s=self.config.sample_timeout_s,
  351. _uses_new_env_runners=(
  352. self.config.enable_env_runner_and_connector_v2
  353. ),
  354. _return_metrics=True,
  355. )
  356. else:
  357. episodes, env_runner_results = synchronous_parallel_sample(
  358. worker_set=self.env_runner_group,
  359. max_env_steps=self.config.total_train_batch_size,
  360. sample_timeout_s=self.config.sample_timeout_s,
  361. _uses_new_env_runners=(
  362. self.config.enable_env_runner_and_connector_v2
  363. ),
  364. _return_metrics=True,
  365. )
  366. # Return early if all our workers failed.
  367. if not episodes:
  368. return
  369. # Reduce EnvRunner metrics over the n EnvRunners.
  370. self.metrics.aggregate(env_runner_results, key=ENV_RUNNER_RESULTS)
  371. # Perform a learner update step on the collected episodes.
  372. with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
  373. learner_results = self.learner_group.update(
  374. episodes=episodes,
  375. timesteps={
  376. NUM_ENV_STEPS_SAMPLED_LIFETIME: (
  377. self.metrics.peek(
  378. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME)
  379. )
  380. ),
  381. NUM_MODULE_STEPS_TRAINED_LIFETIME: (
  382. self.metrics.peek(
  383. (
  384. LEARNER_RESULTS,
  385. ALL_MODULES,
  386. NUM_MODULE_STEPS_TRAINED_LIFETIME,
  387. ),
  388. default=0,
  389. )
  390. ),
  391. },
  392. num_epochs=self.config.num_epochs,
  393. minibatch_size=self.config.minibatch_size,
  394. shuffle_batch_per_epoch=self.config.shuffle_batch_per_epoch,
  395. )
  396. self.metrics.aggregate(learner_results, key=LEARNER_RESULTS)
  397. # Update weights - after learning on the local worker - on all remote
  398. # workers.
  399. with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
  400. # The train results's loss keys are ModuleIDs to their loss values.
  401. # But we also return a total_loss key at the same level as the ModuleID
  402. # keys. So we need to subtract that to get the correct set of ModuleIDs to
  403. # update.
  404. # TODO (sven): We should not be using `learner_results` as a messenger
  405. # to infer which modules to update. `policies_to_train` might also NOT work
  406. # as it might be a very large set (100s of Modules) vs a smaller Modules
  407. # set that's present in the current train batch.
  408. modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES}
  409. self.env_runner_group.sync_weights(
  410. # Sync weights from learner_group to all EnvRunners.
  411. from_worker_or_learner_group=self.learner_group,
  412. policies=modules_to_update,
  413. inference_only=True,
  414. )
  415. @OldAPIStack
  416. def _training_step_old_api_stack(self) -> ResultDict:
  417. # Collect batches from sample workers until we have a full batch.
  418. with self._timers[SAMPLE_TIMER]:
  419. if self.config.count_steps_by == "agent_steps":
  420. train_batch = synchronous_parallel_sample(
  421. worker_set=self.env_runner_group,
  422. max_agent_steps=self.config.total_train_batch_size,
  423. sample_timeout_s=self.config.sample_timeout_s,
  424. )
  425. else:
  426. train_batch = synchronous_parallel_sample(
  427. worker_set=self.env_runner_group,
  428. max_env_steps=self.config.total_train_batch_size,
  429. sample_timeout_s=self.config.sample_timeout_s,
  430. )
  431. # Return early if all our workers failed.
  432. if not train_batch:
  433. return {}
  434. train_batch = train_batch.as_multi_agent()
  435. self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
  436. self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
  437. # Standardize advantages.
  438. train_batch = standardize_fields(train_batch, ["advantages"])
  439. if self.config.simple_optimizer:
  440. train_results = train_one_step(self, train_batch)
  441. else:
  442. train_results = multi_gpu_train_one_step(self, train_batch)
  443. policies_to_update = list(train_results.keys())
  444. global_vars = {
  445. "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
  446. # TODO (sven): num_grad_updates per each policy should be
  447. # accessible via `train_results` (and get rid of global_vars).
  448. "num_grad_updates_per_policy": {
  449. pid: self.env_runner.policy_map[pid].num_grad_updates
  450. for pid in policies_to_update
  451. },
  452. }
  453. # Update weights - after learning on the local worker - on all remote
  454. # workers.
  455. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
  456. if self.env_runner_group.num_remote_workers() > 0:
  457. from_worker_or_learner_group = None
  458. self.env_runner_group.sync_weights(
  459. from_worker_or_learner_group=from_worker_or_learner_group,
  460. policies=policies_to_update,
  461. global_vars=global_vars,
  462. )
  463. # For each policy: Update KL scale and warn about possible issues
  464. for policy_id, policy_info in train_results.items():
  465. # Update KL loss with dynamic scaling
  466. # for each (possibly multiagent) policy we are training
  467. kl_divergence = policy_info[LEARNER_STATS_KEY].get("kl")
  468. self.get_policy(policy_id).update_kl(kl_divergence)
  469. # Warn about excessively high value function loss
  470. scaled_vf_loss = (
  471. self.config.vf_loss_coeff * policy_info[LEARNER_STATS_KEY]["vf_loss"]
  472. )
  473. policy_loss = policy_info[LEARNER_STATS_KEY]["policy_loss"]
  474. if (
  475. log_once("ppo_warned_lr_ratio")
  476. and self.config.get("model", {}).get("vf_share_layers")
  477. and scaled_vf_loss > 100
  478. ):
  479. logger.warning(
  480. "The magnitude of your value function loss for policy: {} is "
  481. "extremely large ({}) compared to the policy loss ({}). This "
  482. "can prevent the policy from learning. Consider scaling down "
  483. "the VF loss by reducing vf_loss_coeff, or disabling "
  484. "vf_share_layers.".format(policy_id, scaled_vf_loss, policy_loss)
  485. )
  486. # Warn about bad clipping configs.
  487. train_batch.policy_batches[policy_id].set_get_interceptor(None)
  488. mean_reward = train_batch.policy_batches[policy_id]["rewards"].mean()
  489. if (
  490. log_once("ppo_warned_vf_clip")
  491. and mean_reward > self.config.vf_clip_param
  492. ):
  493. self.warned_vf_clip = True
  494. logger.warning(
  495. f"The mean reward returned from the environment is {mean_reward}"
  496. f" but the vf_clip_param is set to {self.config['vf_clip_param']}."
  497. f" Consider increasing it for policy: {policy_id} to improve"
  498. " value function convergence."
  499. )
  500. # Update global vars on local worker as well.
  501. # TODO (simon): At least in RolloutWorker obsolete I guess as called in
  502. # `sync_weights()` called above if remote workers. Can we call this
  503. # where `set_weights()` is called on the local_worker?
  504. self.env_runner.set_global_vars(global_vars)
  505. return train_results