impala.py 61 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391
  1. import copy
  2. import logging
  3. import queue
  4. from typing import Dict, List, Optional, Set, Tuple, Type, Union
  5. from typing_extensions import Self
  6. import ray
  7. from ray import ObjectRef
  8. from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning
  9. from ray.rllib import SampleBatch
  10. from ray.rllib.algorithms.algorithm import Algorithm
  11. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
  12. from ray.rllib.connectors.learner import AddOneTsToEpisodesAndTruncate, NumpyToTensor
  13. from ray.rllib.core import (
  14. COMPONENT_ENV_TO_MODULE_CONNECTOR,
  15. COMPONENT_MODULE_TO_ENV_CONNECTOR,
  16. )
  17. from ray.rllib.core.learner.training_data import TrainingData
  18. from ray.rllib.core.rl_module.rl_module import RLModuleSpec
  19. from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer
  20. from ray.rllib.execution.learner_thread import LearnerThread
  21. from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
  22. from ray.rllib.policy.policy import Policy
  23. from ray.rllib.policy.sample_batch import concat_samples
  24. from ray.rllib.utils.annotations import OldAPIStack, override
  25. from ray.rllib.utils.metrics import (
  26. AGGREGATOR_ACTOR_RESULTS,
  27. ALL_MODULES,
  28. ENV_RUNNER_RESULTS,
  29. LEARNER_GROUP,
  30. LEARNER_RESULTS,
  31. LEARNER_UPDATE_TIMER,
  32. MEAN_NUM_EPISODE_LISTS_RECEIVED,
  33. MEAN_NUM_LEARNER_GROUP_UPDATE_CALLED,
  34. MEAN_NUM_LEARNER_RESULTS_RECEIVED,
  35. NUM_AGENT_STEPS_SAMPLED,
  36. NUM_AGENT_STEPS_TRAINED,
  37. NUM_ENV_STEPS_SAMPLED,
  38. NUM_ENV_STEPS_SAMPLED_LIFETIME,
  39. NUM_ENV_STEPS_TRAINED,
  40. NUM_ENV_STEPS_TRAINED_LIFETIME,
  41. NUM_SYNCH_WORKER_WEIGHTS,
  42. NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS,
  43. SAMPLE_TIMER,
  44. SYNCH_WORKER_WEIGHTS_TIMER,
  45. TIMERS,
  46. )
  47. from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
  48. from ray.rllib.utils.metrics.ray_metrics import (
  49. DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  50. TimerAndPrometheusLogger,
  51. )
  52. from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode
  53. from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
  54. from ray.rllib.utils.schedules.scheduler import Scheduler
  55. from ray.rllib.utils.typing import (
  56. LearningRateOrSchedule,
  57. PolicyID,
  58. ResultDict,
  59. SampleBatchType,
  60. )
  61. from ray.util.metrics import Counter, Histogram
  62. logger = logging.getLogger(__name__)
  63. LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY = "curr_entropy_coeff"
  64. class IMPALAConfig(AlgorithmConfig):
  65. """Defines a configuration class from which an Impala can be built.
  66. .. testcode::
  67. from ray.rllib.algorithms.impala import IMPALAConfig
  68. config = (
  69. IMPALAConfig()
  70. .environment("CartPole-v1")
  71. .env_runners(num_env_runners=1)
  72. .training(lr=0.0003, train_batch_size_per_learner=512)
  73. .learners(num_learners=1)
  74. )
  75. # Build a Algorithm object from the config and run 1 training iteration.
  76. algo = config.build()
  77. algo.train()
  78. del algo
  79. .. testcode::
  80. from ray.rllib.algorithms.impala import IMPALAConfig
  81. from ray import tune
  82. config = (
  83. IMPALAConfig()
  84. .environment("CartPole-v1")
  85. .env_runners(num_env_runners=1)
  86. .training(lr=tune.grid_search([0.0001, 0.0002]), grad_clip=20.0)
  87. .learners(num_learners=1)
  88. )
  89. # Run with tune.
  90. tune.Tuner(
  91. "IMPALA",
  92. param_space=config,
  93. run_config=tune.RunConfig(stop={"training_iteration": 1}),
  94. ).fit()
  95. """
  96. def __init__(self, algo_class=None):
  97. """Initializes a IMPALAConfig instance."""
  98. self.exploration_config = { # @OldAPIstack
  99. # The Exploration class to use. In the simplest case, this is the name
  100. # (str) of any class present in the `rllib.utils.exploration` package.
  101. # You can also provide the python class directly or the full location
  102. # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
  103. # EpsilonGreedy").
  104. "type": "StochasticSampling",
  105. # Add constructor kwargs here (if any).
  106. }
  107. super().__init__(algo_class=algo_class or IMPALA)
  108. # fmt: off
  109. # __sphinx_doc_begin__
  110. # IMPALA specific settings:
  111. self.vtrace = True
  112. self.vtrace_clip_rho_threshold = 1.0
  113. self.vtrace_clip_pg_rho_threshold = 1.0
  114. self.learner_queue_size = 3
  115. self.timeout_s_sampler_manager = 0.0
  116. self.timeout_s_aggregator_manager = 0.0
  117. self.broadcast_interval = 1
  118. self.num_gpu_loader_threads = 8
  119. self.grad_clip = 40.0
  120. # Note: Only when using enable_rl_module_and_learner=True can the clipping mode
  121. # be configured by the user. On the old API stack, RLlib will always clip by
  122. # global_norm, no matter the value of `grad_clip_by`.
  123. self.grad_clip_by = "global_norm"
  124. self.vf_loss_coeff = 0.5
  125. self.entropy_coeff = 0.01
  126. # Override some of AlgorithmConfig's default values with IMPALA-specific values.
  127. self.num_learners = 1
  128. self.num_aggregator_actors_per_learner = 0
  129. self.rollout_fragment_length = 50
  130. self.train_batch_size = 500 # @OldAPIstack
  131. self.num_env_runners = 2
  132. self.lr = 0.0005
  133. self.min_time_s_per_iteration = 10
  134. # __sphinx_doc_end__
  135. # fmt: on
  136. # IMPALA takes care of its own EnvRunner (weights, connector, metrics) synching.
  137. self._dont_auto_sync_env_runner_states = True
  138. # `.debugging()`
  139. self._env_runners_only = False
  140. self._skip_learners = False
  141. self.lr_schedule = None # @OldAPIStack
  142. self.entropy_coeff_schedule = None # @OldAPIStack
  143. self.num_multi_gpu_tower_stacks = 1 # @OldAPIstack
  144. self.minibatch_buffer_size = 1 # @OldAPIstack
  145. self.replay_proportion = 0.0 # @OldAPIstack
  146. self.replay_buffer_num_slots = 0 # @OldAPIstack
  147. self.learner_queue_timeout = 300 # @OldAPIstack
  148. self.opt_type = "adam" # @OldAPIstack
  149. self.decay = 0.99 # @OldAPIstack
  150. self.momentum = 0.0 # @OldAPIstack
  151. self.epsilon = 0.1 # @OldAPIstack
  152. self._separate_vf_optimizer = False # @OldAPIstack
  153. self._lr_vf = 0.0005 # @OldAPIstack
  154. self.num_gpus = 1 # @OldAPIstack
  155. self._tf_policy_handles_more_than_one_loss = True # @OldAPIstack
  156. # Deprecated settings.
  157. self.num_aggregation_workers = DEPRECATED_VALUE
  158. self.max_requests_in_flight_per_aggregator_worker = DEPRECATED_VALUE
  159. @override(AlgorithmConfig)
  160. def training(
  161. self,
  162. *,
  163. vtrace: Optional[bool] = NotProvided,
  164. vtrace_clip_rho_threshold: Optional[float] = NotProvided,
  165. vtrace_clip_pg_rho_threshold: Optional[float] = NotProvided,
  166. num_gpu_loader_threads: Optional[int] = NotProvided,
  167. num_multi_gpu_tower_stacks: Optional[int] = NotProvided,
  168. minibatch_buffer_size: Optional[int] = NotProvided,
  169. replay_proportion: Optional[float] = NotProvided,
  170. replay_buffer_num_slots: Optional[int] = NotProvided,
  171. learner_queue_size: Optional[int] = NotProvided,
  172. learner_queue_timeout: Optional[float] = NotProvided,
  173. timeout_s_sampler_manager: Optional[float] = NotProvided,
  174. timeout_s_aggregator_manager: Optional[float] = NotProvided,
  175. broadcast_interval: Optional[int] = NotProvided,
  176. grad_clip: Optional[float] = NotProvided,
  177. opt_type: Optional[str] = NotProvided,
  178. lr_schedule: Optional[List[List[Union[int, float]]]] = NotProvided,
  179. decay: Optional[float] = NotProvided,
  180. momentum: Optional[float] = NotProvided,
  181. epsilon: Optional[float] = NotProvided,
  182. vf_loss_coeff: Optional[float] = NotProvided,
  183. entropy_coeff: Optional[LearningRateOrSchedule] = NotProvided,
  184. entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = NotProvided,
  185. _separate_vf_optimizer: Optional[bool] = NotProvided,
  186. _lr_vf: Optional[float] = NotProvided,
  187. # Deprecated args.
  188. num_aggregation_workers=DEPRECATED_VALUE,
  189. max_requests_in_flight_per_aggregator_worker=DEPRECATED_VALUE,
  190. **kwargs,
  191. ) -> Self:
  192. """Sets the training related configuration.
  193. Args:
  194. vtrace: V-trace params (see vtrace_tf/torch.py).
  195. vtrace_clip_rho_threshold:
  196. vtrace_clip_pg_rho_threshold:
  197. num_gpu_loader_threads: The number of GPU-loader threads (per Learner
  198. worker), used to load incoming (CPU) batches to the GPU, if applicable.
  199. The incoming batches are produced by each Learner's LearnerConnector
  200. pipeline. After loading the batches on the GPU, the threads place them
  201. on yet another queue for the Learner thread (only one per Learner
  202. worker) to pick up and perform `forward_train/loss` computations.
  203. num_multi_gpu_tower_stacks: For each stack of multi-GPU towers, how many
  204. slots should we reserve for parallel data loading? Set this to >1 to
  205. load data into GPUs in parallel. This will increase GPU memory usage
  206. proportionally with the number of stacks.
  207. Example:
  208. 2 GPUs and `num_multi_gpu_tower_stacks=3`:
  209. - One tower stack consists of 2 GPUs, each with a copy of the
  210. model/graph.
  211. - Each of the stacks will create 3 slots for batch data on each of its
  212. GPUs, increasing memory requirements on each GPU by 3x.
  213. - This enables us to preload data into these stacks while another stack
  214. is performing gradient calculations.
  215. minibatch_buffer_size: How many train batches should be retained for
  216. minibatching. This conf only has an effect if `num_epochs > 1`.
  217. replay_proportion: Set >0 to enable experience replay. Saved samples will
  218. be replayed with a p:1 proportion to new data samples.
  219. replay_buffer_num_slots: Number of sample batches to store for replay.
  220. The number of transitions saved total will be
  221. (replay_buffer_num_slots * rollout_fragment_length).
  222. learner_queue_size: Max queue size for train batches feeding into the
  223. learner.
  224. learner_queue_timeout: Wait for train batches to be available in minibatch
  225. buffer queue this many seconds. This may need to be increased e.g. when
  226. training with a slow environment.
  227. timeout_s_sampler_manager: The timeout for waiting for sampling results
  228. for workers -- typically if this is too low, the manager won't be able
  229. to retrieve ready sampling results.
  230. timeout_s_aggregator_manager: The timeout for waiting for replay worker
  231. results -- typically if this is too low, the manager won't be able to
  232. retrieve ready replay requests.
  233. broadcast_interval: Number of training step calls before weights are
  234. broadcasted to rollout workers that are sampled during any iteration.
  235. grad_clip: If specified, clip the global norm of gradients by this amount.
  236. opt_type: Either "adam" or "rmsprop".
  237. lr_schedule: Learning rate schedule. In the format of
  238. [[timestep, lr-value], [timestep, lr-value], ...]
  239. Intermediary timesteps will be assigned to interpolated learning rate
  240. values. A schedule should normally start from timestep 0.
  241. decay: Decay setting for the RMSProp optimizer, in case `opt_type=rmsprop`.
  242. momentum: Momentum setting for the RMSProp optimizer, in case
  243. `opt_type=rmsprop`.
  244. epsilon: Epsilon setting for the RMSProp optimizer, in case
  245. `opt_type=rmsprop`.
  246. vf_loss_coeff: Coefficient for the value function term in the loss function.
  247. entropy_coeff: Coefficient for the entropy regularizer term in the loss
  248. function.
  249. entropy_coeff_schedule: Decay schedule for the entropy regularizer.
  250. _separate_vf_optimizer: Set this to true to have two separate optimizers
  251. optimize the policy-and value networks. Only supported for some
  252. algorithms (APPO, IMPALA) on the old API stack.
  253. _lr_vf: If _separate_vf_optimizer is True, define separate learning rate
  254. for the value network.
  255. Returns:
  256. This updated AlgorithmConfig object.
  257. """
  258. if num_aggregation_workers != DEPRECATED_VALUE:
  259. deprecation_warning(
  260. old="config.training(num_aggregation_workers=..)",
  261. help="Aggregator workers are no longer supported on the old API "
  262. "stack! To use aggregation (and GPU pre-loading) on the new API "
  263. "stack, activate the new API stack, then set "
  264. "`config.learners(num_aggregator_actors_per_learner=..)`. Good "
  265. "choices are normally 1 or 2, but this depends on your overall "
  266. "setup, especially your `EnvRunner` throughput.",
  267. error=True,
  268. )
  269. if max_requests_in_flight_per_aggregator_worker != DEPRECATED_VALUE:
  270. deprecation_warning(
  271. old="config.training(max_requests_in_flight_per_aggregator_worker=..)",
  272. help="Aggregator workers are no longer supported on the old API "
  273. "stack! To use aggregation (and GPU pre-loading) on the new API "
  274. "stack, activate the new API stack and THEN set "
  275. "`config.learners(max_requests_in_flight_per_aggregator_actor=..)"
  276. "`.",
  277. error=True,
  278. )
  279. # Pass kwargs onto super's `training()` method.
  280. super().training(**kwargs)
  281. if vtrace is not NotProvided:
  282. self.vtrace = vtrace
  283. if vtrace_clip_rho_threshold is not NotProvided:
  284. self.vtrace_clip_rho_threshold = vtrace_clip_rho_threshold
  285. if vtrace_clip_pg_rho_threshold is not NotProvided:
  286. self.vtrace_clip_pg_rho_threshold = vtrace_clip_pg_rho_threshold
  287. if num_gpu_loader_threads is not NotProvided:
  288. self.num_gpu_loader_threads = num_gpu_loader_threads
  289. if num_multi_gpu_tower_stacks is not NotProvided:
  290. self.num_multi_gpu_tower_stacks = num_multi_gpu_tower_stacks
  291. if minibatch_buffer_size is not NotProvided:
  292. self.minibatch_buffer_size = minibatch_buffer_size
  293. if replay_proportion is not NotProvided:
  294. self.replay_proportion = replay_proportion
  295. if replay_buffer_num_slots is not NotProvided:
  296. self.replay_buffer_num_slots = replay_buffer_num_slots
  297. if learner_queue_size is not NotProvided:
  298. self.learner_queue_size = learner_queue_size
  299. if learner_queue_timeout is not NotProvided:
  300. self.learner_queue_timeout = learner_queue_timeout
  301. if broadcast_interval is not NotProvided:
  302. self.broadcast_interval = broadcast_interval
  303. if timeout_s_sampler_manager is not NotProvided:
  304. self.timeout_s_sampler_manager = timeout_s_sampler_manager
  305. if timeout_s_aggregator_manager is not NotProvided:
  306. self.timeout_s_aggregator_manager = timeout_s_aggregator_manager
  307. if grad_clip is not NotProvided:
  308. self.grad_clip = grad_clip
  309. if opt_type is not NotProvided:
  310. self.opt_type = opt_type
  311. if lr_schedule is not NotProvided:
  312. self.lr_schedule = lr_schedule
  313. if decay is not NotProvided:
  314. self.decay = decay
  315. if momentum is not NotProvided:
  316. self.momentum = momentum
  317. if epsilon is not NotProvided:
  318. self.epsilon = epsilon
  319. if vf_loss_coeff is not NotProvided:
  320. self.vf_loss_coeff = vf_loss_coeff
  321. if entropy_coeff is not NotProvided:
  322. self.entropy_coeff = entropy_coeff
  323. if entropy_coeff_schedule is not NotProvided:
  324. self.entropy_coeff_schedule = entropy_coeff_schedule
  325. if _separate_vf_optimizer is not NotProvided:
  326. self._separate_vf_optimizer = _separate_vf_optimizer
  327. if _lr_vf is not NotProvided:
  328. self._lr_vf = _lr_vf
  329. return self
  330. def debugging(
  331. self,
  332. *,
  333. _env_runners_only: Optional[bool] = NotProvided,
  334. _skip_learners: Optional[bool] = NotProvided,
  335. **kwargs,
  336. ) -> Self:
  337. """Sets the debugging related configuration.
  338. Args:
  339. _env_runners_only: If True, only run (remote) EnvRunner requests, discard
  340. their episode/training data, but log their metrics results. Aggregator-
  341. and Learner actors won't be used.
  342. _skip_learners: If True, no `update` requests are sent to the LearnerGroup
  343. and Learner actors. Only EnvRunners and aggregator actors (if
  344. applicable) are used.
  345. """
  346. super().debugging(**kwargs)
  347. if _env_runners_only is not NotProvided:
  348. self._env_runners_only = _env_runners_only
  349. if _skip_learners is not NotProvided:
  350. self._skip_learners = _skip_learners
  351. return self
  352. @override(AlgorithmConfig)
  353. def validate(self) -> None:
  354. # Call the super class' validation method first.
  355. super().validate()
  356. # IMPALA and APPO need vtrace (A3C Policies no longer exist).
  357. if not self.vtrace:
  358. self._value_error(
  359. "IMPALA and APPO do NOT support vtrace=False anymore! Set "
  360. "`config.training(vtrace=True)`."
  361. )
  362. # New API stack checks.
  363. if self.enable_env_runner_and_connector_v2:
  364. # Does NOT support aggregation workers yet or a mixin replay buffer.
  365. if self.replay_ratio != 0.0:
  366. self._value_error(
  367. "The new API stack in combination with the new EnvRunner API "
  368. "does NOT support a mixin replay buffer yet for "
  369. f"{self} (set `config.replay_proportion` to 0.0)!"
  370. )
  371. # `lr_schedule` checking.
  372. if self.lr_schedule is not None:
  373. self._value_error(
  374. "`lr_schedule` is deprecated and must be None! Use the "
  375. "`lr` setting to setup a schedule."
  376. )
  377. # Entropy coeff schedule checking.
  378. if self.entropy_coeff_schedule is not None:
  379. self._value_error(
  380. "`entropy_coeff_schedule` is deprecated and must be None! Use the "
  381. "`entropy_coeff` setting to setup a schedule."
  382. )
  383. Scheduler.validate(
  384. fixed_value_or_schedule=self.entropy_coeff,
  385. setting_name="entropy_coeff",
  386. description="entropy coefficient",
  387. )
  388. if self.minibatch_size is not None and not (
  389. (self.minibatch_size % self.rollout_fragment_length == 0)
  390. and self.minibatch_size <= self.total_train_batch_size
  391. ):
  392. self._value_error(
  393. f"`minibatch_size` ({self.minibatch_size}) must either be None "
  394. "or a multiple of `rollout_fragment_length` "
  395. f"({self.rollout_fragment_length}) while at the same time smaller "
  396. "than or equal to `total_train_batch_size` "
  397. f"({self.total_train_batch_size})!"
  398. )
  399. # Old API stack checks.
  400. else:
  401. if isinstance(self.entropy_coeff, float) and self.entropy_coeff < 0.0:
  402. self._value_error("`entropy_coeff` must be >= 0.0")
  403. # If two separate optimizers/loss terms used for tf, must also set
  404. # `_tf_policy_handles_more_than_one_loss` to True.
  405. if (
  406. self.framework_str in ["tf", "tf2"]
  407. and self._separate_vf_optimizer is True
  408. and self._tf_policy_handles_more_than_one_loss is False
  409. ):
  410. self._value_error(
  411. "`_tf_policy_handles_more_than_one_loss` must be set to True, for "
  412. "TFPolicy to support more than one loss term/optimizer! Try setting "
  413. "config.training(_tf_policy_handles_more_than_one_loss=True)."
  414. )
  415. @property
  416. def replay_ratio(self) -> float:
  417. """Returns replay ratio (between 0.0 and 1.0) based off self.replay_proportion.
  418. Formula: ratio = 1 / proportion
  419. """
  420. return (1 / self.replay_proportion) if self.replay_proportion > 0 else 0.0
  421. @override(AlgorithmConfig)
  422. def get_default_learner_class(self):
  423. if self.framework_str == "torch":
  424. from ray.rllib.algorithms.impala.torch.impala_torch_learner import (
  425. IMPALATorchLearner,
  426. )
  427. return IMPALATorchLearner
  428. elif self.framework_str in ["tf2", "tf"]:
  429. raise ValueError(
  430. "TensorFlow is no longer supported on the new API stack! "
  431. "Use `framework='torch'`."
  432. )
  433. else:
  434. raise ValueError(
  435. f"The framework {self.framework_str} is not supported. "
  436. "Use `framework='torch'`."
  437. )
  438. @override(AlgorithmConfig)
  439. def get_default_rl_module_spec(self) -> RLModuleSpec:
  440. if self.framework_str == "torch":
  441. from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import (
  442. DefaultPPOTorchRLModule,
  443. )
  444. return RLModuleSpec(module_class=DefaultPPOTorchRLModule)
  445. else:
  446. raise ValueError(
  447. f"The framework {self.framework_str} is not supported. "
  448. "Use either 'torch' or 'tf2'."
  449. )
  450. @override(AlgorithmConfig)
  451. def build_learner_connector(
  452. self,
  453. input_observation_space,
  454. input_action_space,
  455. device=None,
  456. ):
  457. connector = super().build_learner_connector(
  458. input_observation_space,
  459. input_action_space,
  460. device,
  461. )
  462. if self.add_default_connectors_to_learner_pipeline:
  463. # Extend all episodes by one artificial timestep to allow the value function
  464. # net to compute the bootstrap values (and add a mask to the batch to know,
  465. # which slots to mask out).
  466. connector.prepend(AddOneTsToEpisodesAndTruncate())
  467. # Remove the NumpyToTensor connector if we have the GPULoaderThreads.
  468. if self.num_aggregator_actors_per_learner > 0:
  469. connector.remove(NumpyToTensor)
  470. return connector
  471. ImpalaConfig = IMPALAConfig
  472. class IMPALA(Algorithm):
  473. """Importance weighted actor/learner architecture (IMPALA) Algorithm
  474. == Overview of data flow in IMPALA ==
  475. 1. Policy evaluation in parallel across `num_env_runners` actors produces
  476. batches of size `rollout_fragment_length * num_envs_per_env_runner`.
  477. 2. If enabled, the replay buffer stores and produces batches of size
  478. `rollout_fragment_length * num_envs_per_env_runner`.
  479. 3. If enabled, the minibatch ring buffer stores and replays batches of
  480. size `train_batch_size` up to `num_epochs` times per batch.
  481. 4. The learner thread executes data parallel SGD across `num_gpus` GPUs
  482. on batches of size `train_batch_size`.
  483. """
  484. @classmethod
  485. @override(Algorithm)
  486. def get_default_config(cls) -> IMPALAConfig:
  487. return IMPALAConfig()
  488. @classmethod
  489. @override(Algorithm)
  490. def get_default_policy_class(
  491. cls, config: AlgorithmConfig
  492. ) -> Optional[Type[Policy]]:
  493. if config.framework_str == "torch":
  494. from ray.rllib.algorithms.impala.impala_torch_policy import (
  495. ImpalaTorchPolicy,
  496. )
  497. return ImpalaTorchPolicy
  498. elif config.framework_str == "tf":
  499. from ray.rllib.algorithms.impala.impala_tf_policy import (
  500. ImpalaTF1Policy,
  501. )
  502. return ImpalaTF1Policy
  503. else:
  504. from ray.rllib.algorithms.impala.impala_tf_policy import (
  505. ImpalaTF2Policy,
  506. )
  507. return ImpalaTF2Policy
  508. @override(Algorithm)
  509. def setup(self, config: AlgorithmConfig):
  510. super().setup(config)
  511. # Initialize so it does not default to None
  512. self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] = 0
  513. # Queue of data to be sent to the Learner.
  514. self.data_to_place_on_learner = []
  515. self.local_mixin_buffer = None # @OldAPIStack
  516. self._batch_being_built = [] # @OldAPIStack
  517. # Create extra aggregation workers and assign each rollout worker to one of
  518. # them.
  519. self._episode_packs_being_built = []
  520. self._ma_batches_being_built: Dict[int, list] = {
  521. i: [] for i in range(self.config.num_learners or 1)
  522. }
  523. # Create local mixin buffer if on old API stack and replay
  524. # proportion is set.
  525. if not self.config.enable_rl_module_and_learner:
  526. if self.config.replay_proportion > 0.0:
  527. self.local_mixin_buffer = MixInMultiAgentReplayBuffer(
  528. capacity=(
  529. self.config.replay_buffer_num_slots
  530. if self.config.replay_buffer_num_slots > 0
  531. else 1
  532. ),
  533. replay_ratio=self.config.replay_ratio,
  534. replay_mode=ReplayMode.LOCKSTEP,
  535. )
  536. # This variable is used to keep track of the statistics from the most recent
  537. # update of the learner group
  538. self._results = {}
  539. if not self.config.enable_rl_module_and_learner:
  540. # Create and start the learner thread.
  541. self._learner_thread = make_learner_thread(self.env_runner, self.config)
  542. self._learner_thread.start()
  543. @override(Algorithm)
  544. def training_step(self):
  545. with TimerAndPrometheusLogger(self._metrics_impala_training_step_time):
  546. # Old API stack.
  547. if not self.config.enable_rl_module_and_learner:
  548. return self._training_step_old_api_stack()
  549. do_async_updates = self.config.num_learners > 0
  550. # Asynchronously request all EnvRunners to sample and return their current
  551. # (e.g. ConnectorV2) states and sampling metrics/stats.
  552. # Note that each item in `episode_refs` is a reference to a list of Episodes.
  553. with self.metrics.log_time((TIMERS, SAMPLE_TIMER)):
  554. (
  555. episode_refs,
  556. connector_states,
  557. env_runner_metrics,
  558. env_runner_indices_to_update,
  559. ) = self._sample_and_get_connector_states()
  560. # Reduce EnvRunner metrics over the n EnvRunners.
  561. self.metrics.aggregate(
  562. env_runner_metrics,
  563. key=ENV_RUNNER_RESULTS,
  564. )
  565. # Log the average number of sample results (list of episodes) received.
  566. self.metrics.log_value(
  567. (ENV_RUNNER_RESULTS, MEAN_NUM_EPISODE_LISTS_RECEIVED),
  568. len(episode_refs),
  569. )
  570. # Only run EnvRunners, nothing else.
  571. if self.config._env_runners_only:
  572. return
  573. # "Batch" collected episode refs into groups, such that exactly
  574. # `total_train_batch_size` timesteps are sent to
  575. # `LearnerGroup.update()`.
  576. if self.config.num_aggregator_actors_per_learner > 0:
  577. with TimerAndPrometheusLogger(
  578. self._metrics_impala_training_step_aggregator_preprocessing_time
  579. ):
  580. data_packages_for_aggregators = self._pre_queue_episode_refs(
  581. episode_refs,
  582. package_size=self.config.train_batch_size_per_learner,
  583. )
  584. self.metrics.log_value(
  585. (AGGREGATOR_ACTOR_RESULTS, "mean_num_input_packages"),
  586. len(episode_refs),
  587. )
  588. ma_batches_refs_remote_results = (
  589. self._aggregator_actor_manager.fetch_ready_async_reqs(
  590. return_obj_refs=True,
  591. tags="get_batches",
  592. )
  593. )
  594. ma_batches_refs = []
  595. for call_result in ma_batches_refs_remote_results:
  596. ma_batches_refs.append(
  597. (call_result.actor_id, call_result.get())
  598. )
  599. self.metrics.log_value(
  600. (AGGREGATOR_ACTOR_RESULTS, "mean_num_output_batches"),
  601. len(ma_batches_refs),
  602. )
  603. while data_packages_for_aggregators:
  604. num_agg = self.config.num_aggregator_actors_per_learner * (
  605. self.config.num_learners or 1
  606. )
  607. packs, data_packages_for_aggregators = (
  608. data_packages_for_aggregators[:num_agg],
  609. data_packages_for_aggregators[num_agg:],
  610. )
  611. sent = self._aggregator_actor_manager.foreach_actor_async(
  612. func="get_batch",
  613. kwargs=[dict(episode_refs=p) for p in packs],
  614. tag="get_batches",
  615. )
  616. _dropped = self.config.train_batch_size_per_learner * (
  617. len(packs) - sent
  618. )
  619. if _dropped > 0:
  620. self._metrics_impala_training_step_env_steps_dropped.inc(
  621. value=_dropped
  622. )
  623. self.metrics.log_value(
  624. (
  625. AGGREGATOR_ACTOR_RESULTS,
  626. "num_env_steps_dropped_lifetime",
  627. ),
  628. _dropped,
  629. reduce="sum",
  630. )
  631. # Get n lists of m ObjRef[MABatch] (m=num_learners) to perform n calls to
  632. # all learner workers with the already GPU-located batches.
  633. data_packages_for_learner_group = self._pre_queue_batch_refs(
  634. ma_batches_refs
  635. )
  636. if len(data_packages_for_learner_group) > 0:
  637. self._metrics_impala_training_step_input_batches.inc(
  638. value=len(data_packages_for_learner_group)
  639. )
  640. else:
  641. self._metrics_impala_training_step_zero_input_batches.inc(
  642. value=1
  643. )
  644. self.metrics.log_value(
  645. (AGGREGATOR_ACTOR_RESULTS, "num_env_steps_aggregated_lifetime"),
  646. self.config.train_batch_size_per_learner
  647. * (self.config.num_learners or 1)
  648. * len(data_packages_for_learner_group),
  649. reduce="sum",
  650. with_throughput=True,
  651. )
  652. else:
  653. data_packages_for_learner_group = self._pre_queue_episode_refs(
  654. episode_refs, package_size=self.config.total_train_batch_size
  655. )
  656. # Skip Learner update calls.
  657. if self.config._skip_learners:
  658. return
  659. # Call the LearnerGroup's `update()` method.
  660. with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
  661. self.metrics.log_value(
  662. key=MEAN_NUM_LEARNER_GROUP_UPDATE_CALLED,
  663. value=len(data_packages_for_learner_group),
  664. )
  665. rl_module_state = None
  666. num_learner_group_results_received = 0
  667. return_state = (
  668. self._counters[
  669. NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS
  670. ]
  671. >= self.config.broadcast_interval
  672. )
  673. with TimerAndPrometheusLogger(
  674. self._metrics_impala_training_step_learner_group_loop_time
  675. ):
  676. for (
  677. batch_ref_or_episode_list_ref
  678. ) in data_packages_for_learner_group:
  679. timesteps = {
  680. NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
  681. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME),
  682. default=0,
  683. ),
  684. NUM_ENV_STEPS_TRAINED_LIFETIME: self.metrics.peek(
  685. (
  686. LEARNER_RESULTS,
  687. ALL_MODULES,
  688. NUM_ENV_STEPS_TRAINED_LIFETIME,
  689. ),
  690. default=0,
  691. ),
  692. }
  693. # Update from batch refs coming from AggregatorActors.
  694. if self.config.num_aggregator_actors_per_learner > 0:
  695. assert len(batch_ref_or_episode_list_ref) == (
  696. self.config.num_learners or 1
  697. )
  698. training_data = TrainingData(
  699. batch_refs=batch_ref_or_episode_list_ref
  700. )
  701. # Update from episodes refs coming from EnvRunner actors.
  702. else:
  703. training_data = TrainingData(
  704. episodes_refs=batch_ref_or_episode_list_ref
  705. )
  706. learner_results = self.learner_group.update(
  707. training_data=training_data,
  708. async_update=do_async_updates,
  709. return_state=return_state,
  710. timesteps=timesteps,
  711. num_epochs=self.config.num_epochs,
  712. minibatch_size=self.config.minibatch_size,
  713. shuffle_batch_per_epoch=self.config.shuffle_batch_per_epoch,
  714. defer_solve_refs_to_learner=True,
  715. )
  716. # Only request weights from 1st Learner - at most - once per
  717. # `training_step` call.
  718. return_state = False
  719. num_learner_group_results_received += len(learner_results)
  720. # Extract the last (most recent) weights matrix, if available.
  721. for result_from_1_learner in learner_results:
  722. rl_module_state = result_from_1_learner.pop(
  723. "_rl_module_state_after_update", rl_module_state
  724. )
  725. self.metrics.aggregate(
  726. stats_dicts=learner_results,
  727. key=LEARNER_RESULTS,
  728. )
  729. self.metrics.log_value(
  730. key=(LEARNER_GROUP, MEAN_NUM_LEARNER_RESULTS_RECEIVED),
  731. value=num_learner_group_results_received,
  732. )
  733. self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] += 1
  734. self.metrics.log_value(
  735. NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS,
  736. self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS],
  737. reduce="mean",
  738. )
  739. # Update LearnerGroup's own stats.
  740. self.metrics.log_dict(self.learner_group.get_stats(), key=LEARNER_GROUP)
  741. # Figure out, whether we should sync/broadcast the (remote) EnvRunner states.
  742. # Note: `learner_results` is a List of n (num async calls) Lists of m
  743. # (num Learner workers) ResultDicts each.
  744. if rl_module_state is not None:
  745. self._counters[
  746. NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS
  747. ] = 0
  748. self.metrics.log_value(NUM_SYNCH_WORKER_WEIGHTS, 1, reduce="sum")
  749. with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
  750. with TimerAndPrometheusLogger(
  751. self._metrics_impala_training_step_sync_env_runner_state_time
  752. ):
  753. self.env_runner_group.sync_env_runner_states(
  754. config=self.config,
  755. connector_states=connector_states,
  756. rl_module_state=rl_module_state,
  757. env_steps_sampled=self.metrics.peek(
  758. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME),
  759. default=0,
  760. ),
  761. env_to_module=self.env_to_module_connector,
  762. module_to_env=self.module_to_env_connector,
  763. )
  764. def _sample_and_get_connector_states(self):
  765. with TimerAndPrometheusLogger(
  766. self._metrics_impala_sample_and_get_connector_states_time
  767. ):
  768. env_runner_indices_to_update = set()
  769. episode_refs = []
  770. connector_states = []
  771. env_runner_metrics = []
  772. num_healthy_remote_workers = (
  773. self.env_runner_group.num_healthy_remote_workers()
  774. )
  775. # Perform asynchronous sampling on all (healthy) remote rollout workers.
  776. if num_healthy_remote_workers > 0:
  777. async_results = (
  778. self.env_runner_group.foreach_env_runner_async_fetch_ready(
  779. func="sample_get_state_and_metrics",
  780. tag="sample_get_state_and_metrics",
  781. timeout_seconds=self.config.timeout_s_sampler_manager,
  782. return_obj_refs=False,
  783. return_actor_ids=True,
  784. )
  785. )
  786. # Get results from the n different async calls and store those EnvRunner
  787. # indices we should update.
  788. results = []
  789. for r in async_results:
  790. env_runner_indices_to_update.add(r[0])
  791. results.append(r[1])
  792. for (episodes, states, metrics) in results:
  793. episode_refs.append(episodes)
  794. connector_states.append(states)
  795. env_runner_metrics.append(metrics)
  796. # Sample from the local EnvRunner.
  797. else:
  798. episodes = self.env_runner.sample()
  799. env_runner_metrics = [self.env_runner.get_metrics()]
  800. episode_refs = [ray.put(episodes)]
  801. connector_states = [
  802. self.env_runner.get_state(
  803. components=[
  804. COMPONENT_ENV_TO_MODULE_CONNECTOR,
  805. COMPONENT_MODULE_TO_ENV_CONNECTOR,
  806. ]
  807. )
  808. ]
  809. return (
  810. episode_refs,
  811. connector_states,
  812. env_runner_metrics,
  813. env_runner_indices_to_update,
  814. )
  815. def _pre_queue_episode_refs(
  816. self, episode_refs: List[ObjectRef], package_size: int
  817. ) -> List[List[ObjectRef]]:
  818. # Each element in this list is itself a list of ObjRef[Episodes].
  819. # Each ObjRef was returned by one EnvRunner from a single sample() call.
  820. episodes: List[List[ObjectRef]] = []
  821. for ref in episode_refs:
  822. self._episode_packs_being_built.append(ref)
  823. if (
  824. len(self._episode_packs_being_built)
  825. * self.config.num_envs_per_env_runner
  826. * self.config.get_rollout_fragment_length()
  827. >= package_size
  828. ):
  829. episodes.append(self._episode_packs_being_built)
  830. self._episode_packs_being_built = []
  831. return episodes
  832. def _pre_queue_batch_refs(
  833. self, batch_refs: List[Tuple[int, ObjectRef]]
  834. ) -> List[List[ObjectRef]]:
  835. # `batch_refs` is a list of tuple(aggregator_actor_id, ObjRef[MABatch]).
  836. # Each ObjRef[MABatch] was returned by one AggregatorActor from a single
  837. # `get_batch()` call.
  838. # TODO (sven): Add this comment, once valid:
  839. # .. and the underlying MABatch is already located on a particular GPU
  840. # (matching one particular Learner).
  841. for agg_actor_id, ma_batch_ref in batch_refs:
  842. learner_actor_id = self._aggregator_actor_to_learner[agg_actor_id]
  843. self._ma_batches_being_built[learner_actor_id].append(ma_batch_ref)
  844. # Construct an n-group of batches (n=num_learners) as long as we still have
  845. # at least one batch per learner in our queue.
  846. batch_refs_for_learner_group: List[List[ObjectRef]] = []
  847. while all(
  848. learner_list for learner_list in self._ma_batches_being_built.values()
  849. ):
  850. batch_refs_for_learner_group.append(
  851. [
  852. learner_list.pop(0)
  853. for learner_list in self._ma_batches_being_built.values()
  854. ]
  855. )
  856. return batch_refs_for_learner_group
  857. @override(Algorithm)
  858. def _set_up_metrics(self):
  859. super()._set_up_metrics()
  860. self._metrics_impala_training_step_time = Histogram(
  861. name="rllib_algorithms_impala_training_step_time",
  862. description="Time spent in IMPALA.training_step()",
  863. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  864. tag_keys=("rllib",),
  865. )
  866. self._metrics_impala_training_step_time.set_default_tags(
  867. {"rllib": self.__class__.__name__}
  868. )
  869. self._metrics_impala_training_step_aggregator_preprocessing_time = Histogram(
  870. name="rllib_algorithms_impala_training_step_aggregator_preprocessing_time",
  871. description="Time spent preprocessing episodes with aggregator actor in the IMPALA.training_step()",
  872. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  873. tag_keys=("rllib",),
  874. )
  875. self._metrics_impala_training_step_aggregator_preprocessing_time.set_default_tags(
  876. {"rllib": self.__class__.__name__}
  877. )
  878. self._metrics_impala_training_step_learner_group_loop_time = Histogram(
  879. name="rllib_algorithms_impala_training_step_learner_group_loop_time",
  880. description="Time spent in the learner group update calls loop, in the IMPALA.training_step()",
  881. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  882. tag_keys=("rllib",),
  883. )
  884. self._metrics_impala_training_step_learner_group_loop_time.set_default_tags(
  885. {"rllib": self.__class__.__name__}
  886. )
  887. self._metrics_impala_training_step_sync_env_runner_state_time = Histogram(
  888. name="rllib_algorithms_impala_training_step_sync_env_runner_state_time",
  889. description="Time spent on syncing EnvRunner states in the IMPALA.training_step()",
  890. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  891. tag_keys=("rllib",),
  892. )
  893. self._metrics_impala_training_step_sync_env_runner_state_time.set_default_tags(
  894. {"rllib": self.__class__.__name__}
  895. )
  896. self._metrics_impala_sample_and_get_connector_states_time = Histogram(
  897. name="rllib_algorithms_impala_sample_and_get_connector_states_time",
  898. description="Time spent in IMPALA._sample_and_get_connector_states()",
  899. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  900. tag_keys=("rllib",),
  901. )
  902. self._metrics_impala_sample_and_get_connector_states_time.set_default_tags(
  903. {"rllib": self.__class__.__name__}
  904. )
  905. self._metrics_impala_training_step_input_batches = Counter(
  906. name="rllib_algorithms_impala_training_step_input_batches_counter",
  907. description="Number of input batches processed and passed to the learner in the IMPALA.training_step()",
  908. tag_keys=("rllib",),
  909. )
  910. self._metrics_impala_training_step_input_batches.set_default_tags(
  911. {"rllib": self.__class__.__name__}
  912. )
  913. self._metrics_impala_training_step_zero_input_batches = Counter(
  914. name="rllib_algorithms_impala_training_step_zero_input_batches_counter",
  915. description="Number of times zero input batches were ready in the IMPALA.training_step()",
  916. tag_keys=("rllib",),
  917. )
  918. self._metrics_impala_training_step_zero_input_batches.set_default_tags(
  919. {"rllib": self.__class__.__name__}
  920. )
  921. self._metrics_impala_training_step_env_steps_dropped = Counter(
  922. name="rllib_algorithms_impala_training_step_env_steps_dropped_counter",
  923. description="Number of env steps dropped when sending data to the aggregator actors in the IMPALA.training_step()",
  924. tag_keys=("rllib",),
  925. )
  926. self._metrics_impala_training_step_env_steps_dropped.set_default_tags(
  927. {"rllib": self.__class__.__name__}
  928. )
  929. @OldAPIStack
  930. def _training_step_old_api_stack(self):
  931. # First, check, whether our learner thread is still healthy.
  932. if not self._learner_thread.is_alive():
  933. raise RuntimeError("The learner thread died while training!")
  934. # Get sampled SampleBatches from our workers (by ray references if we use
  935. # tree-aggregation).
  936. unprocessed_sample_batches = self._get_samples_from_workers_old_api_stack(
  937. return_object_refs=False,
  938. )
  939. # Tag workers that actually produced ready sample batches this iteration.
  940. # Those workers will have to get updated at the end of the iteration.
  941. workers_that_need_updates = {
  942. worker_id for worker_id, _ in unprocessed_sample_batches
  943. }
  944. # Resolve collected batches here on local process (using the mixin buffer).
  945. batches = self._process_experiences_old_api_stack(unprocessed_sample_batches)
  946. # Increase sampling counters now that we have the actual SampleBatches on
  947. # the local process (and can measure their sizes).
  948. for batch in batches:
  949. self._counters[NUM_ENV_STEPS_SAMPLED] += batch.count
  950. self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
  951. # Concatenate single batches into batches of size `total_train_batch_size`.
  952. self._concatenate_batches_and_pre_queue(batches)
  953. # Move train batches (of size `total_train_batch_size`) onto learner queue.
  954. self._place_processed_samples_on_learner_thread_queue()
  955. # Extract most recent train results from learner thread.
  956. train_results = self._process_trained_results()
  957. # Sync worker weights (only those policies that were actually updated).
  958. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
  959. pids = list(train_results.keys())
  960. self._update_workers_old_api_stack(
  961. workers_that_need_updates=workers_that_need_updates,
  962. policy_ids=pids,
  963. )
  964. # With a training step done, try to bring any aggregators back to life
  965. # if necessary.
  966. # AggregatorActor are stateless, so we do not need to restore any
  967. # state here.
  968. if self._aggregator_actor_manager:
  969. self._aggregator_actor_manager.probe_unhealthy_actors(
  970. timeout_seconds=self.config.env_runner_health_probe_timeout_s,
  971. mark_healthy=True,
  972. )
  973. return train_results
  974. @OldAPIStack
  975. def _get_samples_from_workers_old_api_stack(
  976. self,
  977. return_object_refs: Optional[bool] = False,
  978. ) -> List[Tuple[int, Union[ObjectRef, SampleBatchType]]]:
  979. """Get samples from rollout workers for training.
  980. Args:
  981. return_object_refs: If True, return ObjectRefs instead of the samples
  982. directly. This is useful when using aggregator workers so that data
  983. collected on rollout workers is directly de referenced on the aggregator
  984. workers instead of first in the driver and then on the aggregator
  985. workers.
  986. Returns:
  987. a list of tuples of (worker_index, sample batch or ObjectRef to a sample
  988. batch)
  989. """
  990. with self._timers[SAMPLE_TIMER]:
  991. # Sample from healthy remote workers by default. If there is no healthy
  992. # worker (either because they have all died, or because there was none to
  993. # begin) check if the local_worker exists. If the local worker has an
  994. # env_instance (either because there are no remote workers or
  995. # self.config.create_local_env_runner == True), then sample from the
  996. # local worker. Otherwise just return an empty list.
  997. if self.env_runner_group.num_healthy_remote_workers() > 0:
  998. # Perform asynchronous sampling on all (remote) rollout workers.
  999. self.env_runner_group.foreach_env_runner_async(
  1000. lambda worker: worker.sample()
  1001. )
  1002. sample_batches: List[
  1003. Tuple[int, ObjectRef]
  1004. ] = self.env_runner_group.fetch_ready_async_reqs(
  1005. timeout_seconds=self.config.timeout_s_sampler_manager,
  1006. return_obj_refs=return_object_refs,
  1007. )
  1008. elif self.config.num_env_runners == 0 or (
  1009. self.env_runner and self.env_runner.async_env is not None
  1010. ):
  1011. # Sampling from the local worker
  1012. sample_batch = self.env_runner.sample()
  1013. if return_object_refs:
  1014. sample_batch = ray.put(sample_batch)
  1015. sample_batches = [(0, sample_batch)]
  1016. else:
  1017. # Not much we can do. Return empty list and wait.
  1018. sample_batches = []
  1019. return sample_batches
  1020. @OldAPIStack
  1021. def _process_experiences_old_api_stack(
  1022. self,
  1023. worker_to_sample_batches: List[Tuple[int, SampleBatch]],
  1024. ) -> List[SampleBatchType]:
  1025. """Process sample batches directly on the driver, for training.
  1026. Args:
  1027. worker_to_sample_batches: List of (worker_id, sample_batch) tuples.
  1028. Returns:
  1029. Batches that have been processed by the mixin buffer.
  1030. """
  1031. batches = [b for _, b in worker_to_sample_batches]
  1032. processed_batches = []
  1033. for batch in batches:
  1034. assert not isinstance(
  1035. batch, ObjectRef
  1036. ), "`IMPALA._process_experiences_old_api_stack` can not handle ObjectRefs!"
  1037. batch = batch.decompress_if_needed()
  1038. # Only make a pass through the buffer, if replay proportion is > 0.0 (and
  1039. # we actually have one).
  1040. if self.local_mixin_buffer:
  1041. self.local_mixin_buffer.add(batch)
  1042. batch = self.local_mixin_buffer.replay(_ALL_POLICIES)
  1043. else:
  1044. # TODO(jjyao) somehow deep copy the batch
  1045. # fix a memory leak issue. Need to investigate more
  1046. # to know why.
  1047. batch = batch.copy()
  1048. if batch:
  1049. processed_batches.append(batch)
  1050. return processed_batches
  1051. @OldAPIStack
  1052. def _concatenate_batches_and_pre_queue(self, batches: List[SampleBatch]) -> None:
  1053. """Concatenate batches that are being returned from rollout workers
  1054. Args:
  1055. batches: List of batches of experiences from EnvRunners.
  1056. """
  1057. def aggregate_into_larger_batch():
  1058. if (
  1059. sum(b.count for b in self._batch_being_built)
  1060. >= self.config.total_train_batch_size
  1061. ):
  1062. batch_to_add = concat_samples(self._batch_being_built)
  1063. self.data_to_place_on_learner.append(batch_to_add)
  1064. self._batch_being_built = []
  1065. for batch in batches:
  1066. # TODO (sven): Strange bug after a RolloutWorker crash and proper
  1067. # restart. The bug is related to (old, non-V2) connectors being used and
  1068. # seems to happen inside the AgentCollector's `add_action_reward_next_obs`
  1069. # method, at the end of which the number of vf_preds (and all other
  1070. # extra action outs) in the batch is one smaller than the number of obs/
  1071. # actions/rewards, which then leads to a malformed train batch.
  1072. # IMPALA/APPO crash inside the loss function (during v-trace operations)
  1073. # b/c of the resulting shape mismatch. The following if-block prevents
  1074. # this from happening and it can be removed once we are on the new API
  1075. # stack for good (and use the new connectors and also no longer
  1076. # AgentCollectors, RolloutWorkers, Policies, TrajectoryView API, etc..):
  1077. if (
  1078. self.config.batch_mode == "truncate_episodes"
  1079. and self.config.restart_failed_env_runners
  1080. ):
  1081. if any(
  1082. SampleBatch.VF_PREDS in pb
  1083. and (
  1084. pb[SampleBatch.VF_PREDS].shape[0]
  1085. != pb[SampleBatch.REWARDS].shape[0]
  1086. )
  1087. for pb in batch.policy_batches.values()
  1088. ):
  1089. continue
  1090. self._batch_being_built.append(batch)
  1091. aggregate_into_larger_batch()
  1092. @OldAPIStack
  1093. def _place_processed_samples_on_learner_thread_queue(self) -> None:
  1094. """Place processed samples on the learner queue for training."""
  1095. for i, batch in enumerate(self.data_to_place_on_learner):
  1096. try:
  1097. self._learner_thread.inqueue.put(
  1098. batch,
  1099. # Setting block = True for the very last item in our list prevents
  1100. # the learner thread, this main thread, and the GPU loader threads
  1101. # from thrashing when there are more samples than the learner can
  1102. # reasonably process.
  1103. # see https://github.com/ray-project/ray/pull/26581#issuecomment-1187877674 # noqa
  1104. block=i == len(self.data_to_place_on_learner) - 1,
  1105. )
  1106. self._counters["num_samples_added_to_queue"] += (
  1107. batch.agent_steps()
  1108. if self.config.count_steps_by == "agent_steps"
  1109. else batch.count
  1110. )
  1111. except queue.Full:
  1112. self._counters["num_times_learner_queue_full"] += 1
  1113. self.data_to_place_on_learner.clear()
  1114. @OldAPIStack
  1115. def _process_trained_results(self) -> ResultDict:
  1116. """Process training results that are outputed by the learner thread.
  1117. Returns:
  1118. Aggregated results from the learner thread after an update is completed.
  1119. """
  1120. # Get learner outputs/stats from output queue.
  1121. num_env_steps_trained = 0
  1122. num_agent_steps_trained = 0
  1123. learner_infos = []
  1124. # Loop through output queue and update our counts.
  1125. for _ in range(self._learner_thread.outqueue.qsize()):
  1126. (
  1127. env_steps,
  1128. agent_steps,
  1129. learner_results,
  1130. ) = self._learner_thread.outqueue.get(timeout=0.001)
  1131. num_env_steps_trained += env_steps
  1132. num_agent_steps_trained += agent_steps
  1133. if learner_results:
  1134. learner_infos.append(learner_results)
  1135. # Nothing new happened since last time, use the same learner stats.
  1136. if not learner_infos:
  1137. final_learner_info = copy.deepcopy(self._learner_thread.learner_info)
  1138. # Accumulate learner stats using the `LearnerInfoBuilder` utility.
  1139. else:
  1140. builder = LearnerInfoBuilder()
  1141. for info in learner_infos:
  1142. builder.add_learn_on_batch_results_multi_agent(info)
  1143. final_learner_info = builder.finalize()
  1144. # Update the steps trained counters.
  1145. self._counters[NUM_ENV_STEPS_TRAINED] += num_env_steps_trained
  1146. self._counters[NUM_AGENT_STEPS_TRAINED] += num_agent_steps_trained
  1147. return final_learner_info
  1148. @OldAPIStack
  1149. def _update_workers_old_api_stack(
  1150. self,
  1151. workers_that_need_updates: Set[int],
  1152. policy_ids: Optional[List[PolicyID]] = None,
  1153. ) -> None:
  1154. """Updates all RolloutWorkers that require updating.
  1155. Updates only if NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS has been
  1156. reached and the worker has sent samples in this iteration. Also only updates
  1157. those policies, whose IDs are given via `policies` (if None, update all
  1158. policies).
  1159. Args:
  1160. workers_that_need_updates: Set of worker IDs that need to be updated.
  1161. policy_ids: Optional list of Policy IDs to update. If None, will update all
  1162. policies on the to-be-updated workers.
  1163. """
  1164. # Update global vars of the local worker.
  1165. if self.config.policy_states_are_swappable:
  1166. self.env_runner.lock()
  1167. global_vars = {
  1168. "timestep": self._counters[NUM_AGENT_STEPS_TRAINED],
  1169. "num_grad_updates_per_policy": {
  1170. pid: self.env_runner.policy_map[pid].num_grad_updates
  1171. for pid in policy_ids or []
  1172. },
  1173. }
  1174. self.env_runner.set_global_vars(global_vars, policy_ids=policy_ids)
  1175. if self.config.policy_states_are_swappable:
  1176. self.env_runner.unlock()
  1177. # Only need to update workers if there are remote workers.
  1178. self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] += 1
  1179. if (
  1180. self.env_runner_group.num_remote_workers() > 0
  1181. and self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS]
  1182. >= self.config.broadcast_interval
  1183. and workers_that_need_updates
  1184. ):
  1185. if self.config.policy_states_are_swappable:
  1186. self.env_runner.lock()
  1187. weights = self.env_runner.get_weights(policy_ids)
  1188. if self.config.policy_states_are_swappable:
  1189. self.env_runner.unlock()
  1190. weights_ref = ray.put(weights)
  1191. self._learner_thread.policy_ids_updated.clear()
  1192. self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] = 0
  1193. self._counters[NUM_SYNCH_WORKER_WEIGHTS] += 1
  1194. self.env_runner_group.foreach_env_runner(
  1195. func=lambda w: w.set_weights(ray.get(weights_ref), global_vars),
  1196. local_env_runner=False,
  1197. remote_worker_ids=list(workers_that_need_updates),
  1198. timeout_seconds=0, # Don't wait for the workers to finish.
  1199. )
  1200. @override(Algorithm)
  1201. def _compile_iteration_results_old_api_stack(self, *args, **kwargs):
  1202. result = super()._compile_iteration_results_old_api_stack(*args, **kwargs)
  1203. if not self.config.enable_rl_module_and_learner:
  1204. result = self._learner_thread.add_learner_metrics(
  1205. result, overwrite_learner_info=False
  1206. )
  1207. return result
  1208. Impala = IMPALA
  1209. @OldAPIStack
  1210. def make_learner_thread(local_worker, config):
  1211. if not config["simple_optimizer"]:
  1212. logger.info(
  1213. "Enabling multi-GPU mode, {} GPUs, {} parallel tower-stacks".format(
  1214. config["num_gpus"], config["num_multi_gpu_tower_stacks"]
  1215. )
  1216. )
  1217. num_stacks = config["num_multi_gpu_tower_stacks"]
  1218. buffer_size = config["minibatch_buffer_size"]
  1219. if num_stacks < buffer_size:
  1220. logger.warning(
  1221. "In multi-GPU mode you should have at least as many "
  1222. "multi-GPU tower stacks (to load data into on one device) as "
  1223. "you have stack-index slots in the buffer! You have "
  1224. f"configured {num_stacks} stacks and a buffer of size "
  1225. f"{buffer_size}. Setting "
  1226. f"`minibatch_buffer_size={num_stacks}`."
  1227. )
  1228. config["minibatch_buffer_size"] = num_stacks
  1229. learner_thread = MultiGPULearnerThread(
  1230. local_worker,
  1231. num_gpus=config["num_gpus"],
  1232. lr=config["lr"],
  1233. train_batch_size=config["train_batch_size"],
  1234. num_multi_gpu_tower_stacks=config["num_multi_gpu_tower_stacks"],
  1235. num_sgd_iter=config["num_epochs"],
  1236. learner_queue_size=config["learner_queue_size"],
  1237. learner_queue_timeout=config["learner_queue_timeout"],
  1238. num_data_load_threads=config["num_gpu_loader_threads"],
  1239. )
  1240. else:
  1241. learner_thread = LearnerThread(
  1242. local_worker,
  1243. minibatch_buffer_size=config["minibatch_buffer_size"],
  1244. num_sgd_iter=config["num_epochs"],
  1245. learner_queue_size=config["learner_queue_size"],
  1246. learner_queue_timeout=config["learner_queue_timeout"],
  1247. )
  1248. return learner_thread