utils.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import platform
  2. from typing import List
  3. import tree # pip install dm_tree
  4. import ray
  5. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  6. from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
  7. from ray.rllib.utils.actor_manager import FaultAwareApply
  8. from ray.rllib.utils.framework import try_import_torch
  9. from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
  10. from ray.rllib.utils.metrics.ray_metrics import (
  11. DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  12. TimerAndPrometheusLogger,
  13. )
  14. from ray.rllib.utils.typing import EpisodeType
  15. from ray.util.annotations import DeveloperAPI
  16. from ray.util.metrics import Counter, Histogram
  17. torch, _ = try_import_torch()
  18. @DeveloperAPI(stability="alpha")
  19. class AggregatorActor(FaultAwareApply):
  20. """Runs episode lists through ConnectorV2 pipeline and creates train batches.
  21. The actor should be co-located with a Learner worker. Ideally, there should be one
  22. or two aggregator actors per Learner worker (having even more per Learner probably
  23. won't help. Then the main process driving the RL algo can perform the following
  24. execution logic:
  25. - query n EnvRunners to sample the environment and return n lists of episodes as
  26. Ray.ObjectRefs.
  27. - remote call the set of aggregator actors (in round-robin fashion) with these
  28. list[episodes] refs in async fashion.
  29. - gather the results asynchronously, as each actor returns refs pointing to
  30. ready-to-go train batches.
  31. - as soon as we have at least one train batch per Learner, call the LearnerGroup
  32. with the (already sharded) refs.
  33. - an aggregator actor - when receiving p refs to List[EpisodeType] - does:
  34. -- ray.get() the actual p lists and concatenate the p lists into one
  35. List[EpisodeType].
  36. -- pass the lists of episodes through its LearnerConnector pipeline
  37. -- buffer the output batches of this pipeline until enough batches have been
  38. collected for creating one train batch (matching the config's
  39. `train_batch_size_per_learner`).
  40. -- concatenate q batches into a train batch and return that train batch.
  41. - the algo main process then passes the ray.ObjectRef to the ready-to-go train batch
  42. to the LearnerGroup for calling each Learner with one train batch.
  43. """
  44. def __init__(self, config: AlgorithmConfig, rl_module_spec):
  45. self.config = config
  46. # Set device and node.
  47. self._node = platform.node()
  48. self._device = torch.device("cpu")
  49. self.metrics: MetricsLogger = MetricsLogger(
  50. stats_cls_lookup=config.stats_cls_lookup,
  51. root=True,
  52. )
  53. # Create the RLModule.
  54. # TODO (sven): For now, this RLModule (its weights) never gets updated.
  55. # The reason the module is needed is for the connector to know, which
  56. # sub-modules are stateful (and what their initial state tensors are), and
  57. # which IDs the submodules have (to figure out, whether its multi-agent or
  58. # not).
  59. self._module = rl_module_spec.build()
  60. self._module = self._module.as_multi_rl_module()
  61. # Create the Learner connector pipeline.
  62. self._learner_connector = self.config.build_learner_connector(
  63. input_observation_space=None,
  64. input_action_space=None,
  65. device=self._device,
  66. )
  67. # Ray metrics
  68. self._metrics_get_batch_time = Histogram(
  69. name="rllib_utils_aggregator_actor_get_batch_time",
  70. description="Time spent in AggregatorActor.get_batch()",
  71. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  72. tag_keys=("rllib",),
  73. )
  74. self._metrics_get_batch_time.set_default_tags(
  75. {"rllib": self.__class__.__name__}
  76. )
  77. self._metrics_episode_owner_died = Counter(
  78. name="rllib_utils_aggregator_actor_episode_owner_died_counter",
  79. description="N times ray.get() on an episode ref failed ",
  80. tag_keys=("rllib",),
  81. )
  82. self._metrics_episode_owner_died.set_default_tags(
  83. {"rllib": self.__class__.__name__}
  84. )
  85. self._metrics_get_batch_input_episode_refs = Counter(
  86. name="rllib_utils_aggregator_actor_get_batch_input_episode_refs_counter",
  87. description="Number of episode refs received as input to get_batch()",
  88. tag_keys=("rllib",),
  89. )
  90. self._metrics_get_batch_input_episode_refs.set_default_tags(
  91. {"rllib": self.__class__.__name__}
  92. )
  93. self._metrics_get_batch_output_batches = Counter(
  94. name="rllib_utils_aggregator_actor_get_batch_output_batches_counter",
  95. description="Number of policy batches output by get_batch()",
  96. tag_keys=("rllib",),
  97. )
  98. self._metrics_get_batch_output_batches.set_default_tags(
  99. {"rllib": self.__class__.__name__}
  100. )
  101. def get_batch(self, episode_refs: List[ray.ObjectRef]):
  102. with TimerAndPrometheusLogger(self._metrics_get_batch_time):
  103. if len(episode_refs) > 0:
  104. self._metrics_get_batch_input_episode_refs.inc(value=len(episode_refs))
  105. episodes: List[EpisodeType] = []
  106. # It's possible that individual refs are invalid due to the EnvRunner
  107. # that produced the ref has crashed or had its entire node go down.
  108. # In this case, try each ref individually and collect only valid results.
  109. try:
  110. episodes = tree.flatten(ray.get(episode_refs))
  111. except ray.exceptions.OwnerDiedError:
  112. for ref in episode_refs:
  113. try:
  114. episodes.extend(ray.get(ref))
  115. except ray.exceptions.OwnerDiedError:
  116. self._metrics_episode_owner_died.inc(value=1)
  117. env_steps = sum(len(e) for e in episodes)
  118. # If we have enough episodes collected to create a single train batch, pass
  119. # them at once through the connector to receive a single train batch.
  120. batch = self._learner_connector(
  121. episodes=episodes,
  122. rl_module=self._module,
  123. metrics=self.metrics,
  124. )
  125. # Convert to a dict into a `MultiAgentBatch`.
  126. # TODO (sven): Try to get rid of dependency on MultiAgentBatch (once our mini-
  127. # batch iterators support splitting over a dict).
  128. ma_batch = MultiAgentBatch(
  129. policy_batches={
  130. pid: SampleBatch(pol_batch) for pid, pol_batch in batch.items()
  131. },
  132. env_steps=env_steps,
  133. )
  134. self._metrics_get_batch_output_batches.inc(value=1)
  135. return ma_batch
  136. def get_metrics(self):
  137. return self.metrics.reduce()
  138. def _get_env_runner_bundles(config):
  139. return [
  140. {
  141. "CPU": config.num_cpus_per_env_runner,
  142. "GPU": config.num_gpus_per_env_runner,
  143. **config.custom_resources_per_env_runner,
  144. }
  145. for _ in range(config.num_env_runners)
  146. ]
  147. def _get_offline_eval_runner_bundles(config):
  148. return [
  149. {
  150. "CPU": config.num_cpus_per_offline_eval_runner,
  151. "GPU": config.num_gpus_per_offline_eval_runner,
  152. **config.custom_resources_per_offline_eval_runner,
  153. }
  154. for _ in range(config.num_offline_eval_runners)
  155. ]
  156. def _get_learner_bundles(config):
  157. if config.num_learners == 0:
  158. if config.num_aggregator_actors_per_learner > 0:
  159. return [{"CPU": 1} for _ in range(config.num_aggregator_actors_per_learner)]
  160. else:
  161. return []
  162. if config.num_cpus_per_learner != "auto":
  163. num_cpus_per_learner = config.num_cpus_per_learner
  164. elif config.num_gpus_per_learner == 0:
  165. num_cpus_per_learner = 1
  166. else:
  167. num_cpus_per_learner = 0
  168. # aggregator actors are co-located with learners and use 1 CPU each
  169. bundles = [
  170. {
  171. "CPU": num_cpus_per_learner + config.num_aggregator_actors_per_learner,
  172. "GPU": config.num_gpus_per_learner,
  173. }
  174. for _ in range(config.num_learners)
  175. ]
  176. return bundles
  177. def _get_main_process_bundle(config):
  178. if config.num_learners == 0:
  179. if config.num_cpus_per_learner != "auto":
  180. num_cpus_per_learner = config.num_cpus_per_learner
  181. elif config.num_gpus_per_learner == 0:
  182. num_cpus_per_learner = 1
  183. else:
  184. num_cpus_per_learner = 0
  185. bundle = {
  186. "CPU": max(num_cpus_per_learner, config.num_cpus_for_main_process),
  187. "GPU": config.num_gpus_per_learner,
  188. }
  189. else:
  190. bundle = {"CPU": config.num_cpus_for_main_process, "GPU": 0}
  191. return bundle