env_runner_group.py 54 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298
  1. import importlib.util
  2. import logging
  3. import os
  4. from typing import (
  5. TYPE_CHECKING,
  6. Any,
  7. Callable,
  8. Collection,
  9. Dict,
  10. List,
  11. Optional,
  12. Tuple,
  13. Type,
  14. TypeVar,
  15. Union,
  16. )
  17. import gymnasium as gym
  18. import ray
  19. from ray._common.deprecation import (
  20. DEPRECATED_VALUE,
  21. deprecation_warning,
  22. )
  23. from ray.actor import ActorHandle
  24. from ray.exceptions import RayActorError
  25. from ray.rllib.core import (
  26. COMPONENT_ENV_TO_MODULE_CONNECTOR,
  27. COMPONENT_LEARNER,
  28. COMPONENT_MODULE_TO_ENV_CONNECTOR,
  29. COMPONENT_RL_MODULE,
  30. )
  31. from ray.rllib.core.learner import LearnerGroup
  32. from ray.rllib.core.rl_module import validate_module_id
  33. from ray.rllib.core.rl_module.rl_module import RLModuleSpec
  34. from ray.rllib.env.base_env import BaseEnv
  35. from ray.rllib.env.env_context import EnvContext
  36. from ray.rllib.env.env_runner import EnvRunner
  37. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  38. from ray.rllib.offline import get_dataset_and_shards
  39. from ray.rllib.policy.policy import Policy, PolicyState
  40. from ray.rllib.utils.actor_manager import FaultTolerantActorManager
  41. from ray.rllib.utils.annotations import OldAPIStack
  42. from ray.rllib.utils.framework import try_import_tf
  43. from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED_LIFETIME, WEIGHTS_SEQ_NO
  44. from ray.rllib.utils.typing import (
  45. AgentID,
  46. EnvCreator,
  47. EnvType,
  48. EpisodeID,
  49. PartialAlgorithmConfigDict,
  50. PolicyID,
  51. SampleBatchType,
  52. TensorType,
  53. )
  54. from ray.util.annotations import DeveloperAPI
  55. if TYPE_CHECKING:
  56. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  57. tf1, tf, tfv = try_import_tf()
  58. logger = logging.getLogger(__name__)
  59. # Generic type var for foreach_* methods.
  60. T = TypeVar("T")
  61. @DeveloperAPI
  62. class EnvRunnerGroup:
  63. """Set of EnvRunners with n @ray.remote workers and zero or one local worker.
  64. Where: n >= 0.
  65. """
  66. def __init__(
  67. self,
  68. *,
  69. env_creator: Optional[EnvCreator] = None,
  70. validate_env: Optional[Callable[[EnvType], None]] = None,
  71. default_policy_class: Optional[Type[Policy]] = None,
  72. config: Optional["AlgorithmConfig"] = None,
  73. local_env_runner: bool = True,
  74. logdir: Optional[str] = None,
  75. _setup: bool = True,
  76. tune_trial_id: Optional[str] = None,
  77. pg_offset: int = 0,
  78. # Deprecated args.
  79. num_env_runners: Optional[int] = None,
  80. num_workers=DEPRECATED_VALUE,
  81. local_worker=DEPRECATED_VALUE,
  82. ):
  83. """Initializes a EnvRunnerGroup instance.
  84. Args:
  85. env_creator: Function that returns env given env config.
  86. validate_env: Optional callable to validate the generated
  87. environment (only on worker=0). This callable should raise
  88. an exception if the environment is invalid.
  89. default_policy_class: An optional default Policy class to use inside
  90. the (multi-agent) `policies` dict. In case the PolicySpecs in there
  91. have no class defined, use this `default_policy_class`.
  92. If None, PolicySpecs will be using the Algorithm's default Policy
  93. class.
  94. config: Optional AlgorithmConfig (or config dict).
  95. local_env_runner: Whether to create a local (non @ray.remote) EnvRunner
  96. in the returned set as well (default: True). If `num_env_runners`
  97. is 0, always create a local EnvRunner.
  98. logdir: Optional logging directory for workers.
  99. _setup: Whether to actually set up workers. This is only for testing.
  100. tune_trial_id: The Ray Tune trial ID, if this EnvRunnerGroup is part of
  101. an Algorithm run as a Tune trial. None, otherwise.
  102. """
  103. if num_workers != DEPRECATED_VALUE or local_worker != DEPRECATED_VALUE:
  104. deprecation_warning(
  105. old="WorkerSet(num_workers=..., local_worker=...)",
  106. new="EnvRunnerGroup(num_env_runners=..., local_env_runner=...)",
  107. error=True,
  108. )
  109. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  110. # Make sure `config` is an AlgorithmConfig object.
  111. if not config:
  112. config = AlgorithmConfig()
  113. elif isinstance(config, dict):
  114. config = AlgorithmConfig.from_dict(config)
  115. self._env_creator = env_creator
  116. self._policy_class = default_policy_class
  117. self._remote_config = config
  118. self._remote_config_obj_ref = ray.put(self._remote_config)
  119. self._remote_args = {
  120. "num_cpus": self._remote_config.num_cpus_per_env_runner,
  121. "num_gpus": self._remote_config.num_gpus_per_env_runner,
  122. "resources": self._remote_config.custom_resources_per_env_runner,
  123. "max_restarts": (
  124. config.max_num_env_runner_restarts
  125. if config.restart_failed_env_runners
  126. else 0
  127. ),
  128. }
  129. self._tune_trial_id = tune_trial_id
  130. self._pg_offset = pg_offset
  131. # Set the EnvRunner subclass to be used as "workers". Default: RolloutWorker.
  132. self.env_runner_cls = config.env_runner_cls
  133. if self.env_runner_cls is None:
  134. if config.enable_env_runner_and_connector_v2:
  135. # If experiences should be recorded, use the `
  136. # OfflineSingleAgentEnvRunner`.
  137. if config.output:
  138. # No multi-agent support.
  139. if config.is_multi_agent:
  140. raise ValueError("Multi-agent recording is not supported, yet.")
  141. # Otherwise, load the single-agent env runner for
  142. # recording.
  143. else:
  144. from ray.rllib.offline.offline_env_runner import (
  145. OfflineSingleAgentEnvRunner,
  146. )
  147. self.env_runner_cls = OfflineSingleAgentEnvRunner
  148. else:
  149. if config.is_multi_agent:
  150. from ray.rllib.env.multi_agent_env_runner import (
  151. MultiAgentEnvRunner,
  152. )
  153. self.env_runner_cls = MultiAgentEnvRunner
  154. else:
  155. from ray.rllib.env.single_agent_env_runner import (
  156. SingleAgentEnvRunner,
  157. )
  158. self.env_runner_cls = SingleAgentEnvRunner
  159. else:
  160. self.env_runner_cls = RolloutWorker
  161. self._logdir = logdir
  162. self._ignore_ray_errors_on_env_runners = (
  163. config.ignore_env_runner_failures or config.restart_failed_env_runners
  164. )
  165. # Create remote worker manager.
  166. # ID=0 is used by the local worker.
  167. # Starting remote workers from ID=1 to avoid conflicts.
  168. self._worker_manager = FaultTolerantActorManager(
  169. max_remote_requests_in_flight_per_actor=(
  170. config.max_requests_in_flight_per_env_runner
  171. ),
  172. init_id=1,
  173. )
  174. if _setup:
  175. try:
  176. self._setup(
  177. validate_env=validate_env,
  178. config=config,
  179. num_env_runners=(
  180. num_env_runners
  181. if num_env_runners is not None
  182. else config.num_env_runners
  183. ),
  184. local_env_runner=local_env_runner,
  185. )
  186. # EnvRunnerGroup creation possibly fails, if some (remote) workers cannot
  187. # be initialized properly (due to some errors in the EnvRunners's
  188. # constructor).
  189. except RayActorError as e:
  190. # In case of an actor (remote worker) init failure, the remote worker
  191. # may still exist and will be accessible, however, e.g. calling
  192. # its `sample.remote()` would result in strange "property not found"
  193. # errors.
  194. if e.actor_init_failed:
  195. # Raise the original error here that the EnvRunners raised
  196. # during its construction process. This is to enforce transparency
  197. # for the user (better to understand the real reason behind the
  198. # failure).
  199. # - e.args[0]: The RayTaskError (inside the caught RayActorError).
  200. # - e.args[0].args[2]: The original Exception (e.g. a ValueError due
  201. # to a config mismatch) thrown inside the actor.
  202. raise e.args[0].args[2]
  203. # In any other case, raise the RayActorError as-is.
  204. else:
  205. raise e
  206. def _setup(
  207. self,
  208. *,
  209. validate_env: Optional[Callable[[EnvType], None]] = None,
  210. config: Optional["AlgorithmConfig"] = None,
  211. num_env_runners: int = 0,
  212. local_env_runner: bool = True,
  213. ):
  214. """Sets up an EnvRunnerGroup instance.
  215. Args:
  216. validate_env: Optional callable to validate the generated
  217. environment (only on worker=0).
  218. config: Optional dict that extends the common config of
  219. the Algorithm class.
  220. num_env_runners: Number of remote EnvRunner workers to create.
  221. local_env_runner: Whether to create a local (non @ray.remote) EnvRunner
  222. in the returned set as well (default: True). If `num_env_runners`
  223. is 0, always create a local EnvRunner.
  224. """
  225. # Force a local worker if num_env_runners == 0 (no remote workers).
  226. # Otherwise, this EnvRunnerGroup would be empty.
  227. self._local_env_runner = None
  228. if num_env_runners == 0:
  229. local_env_runner = True
  230. # Create a local (learner) version of the config for the local worker.
  231. # The only difference is the tf_session_args, which - for the local worker -
  232. # will be `config.tf_session_args` updated/overridden with
  233. # `config.local_tf_session_args`.
  234. local_tf_session_args = config.tf_session_args.copy()
  235. local_tf_session_args.update(config.local_tf_session_args)
  236. self._local_config = config.copy(copy_frozen=False).framework(
  237. tf_session_args=local_tf_session_args
  238. )
  239. if config.input_ == "dataset":
  240. # Create the set of dataset readers to be shared by all the
  241. # rollout workers.
  242. self._ds, self._ds_shards = get_dataset_and_shards(config, num_env_runners)
  243. else:
  244. self._ds = None
  245. self._ds_shards = None
  246. # Create a number of @ray.remote workers.
  247. self.add_workers(
  248. num_env_runners,
  249. validate=config.validate_env_runners_after_construction,
  250. )
  251. # If num_env_runners > 0 and we don't have an env on the local worker,
  252. # get the observation- and action spaces for each policy from
  253. # the first remote worker (which does have an env).
  254. if (
  255. local_env_runner
  256. and self._worker_manager.num_actors() > 0
  257. and not config.create_env_on_local_worker
  258. and (not config.observation_space or not config.action_space)
  259. ):
  260. spaces = self.get_spaces()
  261. else:
  262. spaces = None
  263. # Create a local worker, if needed.
  264. if local_env_runner:
  265. self._local_env_runner = self._make_worker(
  266. env_creator=self._env_creator,
  267. validate_env=validate_env,
  268. worker_index=0,
  269. num_workers=num_env_runners,
  270. config=self._local_config,
  271. spaces=spaces,
  272. )
  273. def get_spaces(self):
  274. """Infer observation and action spaces from one (local or remote) EnvRunner.
  275. Returns:
  276. A dict mapping from ModuleID to a 2-tuple containing obs- and action-space.
  277. """
  278. # Get ID of the first remote worker.
  279. remote_worker_ids = (
  280. [self._worker_manager.actor_ids()[0]]
  281. if self._worker_manager.actor_ids()
  282. else []
  283. )
  284. spaces = self.foreach_env_runner(
  285. lambda env_runner: env_runner.get_spaces(),
  286. remote_worker_ids=remote_worker_ids,
  287. local_env_runner=not remote_worker_ids,
  288. )[0]
  289. logger.info(
  290. "Inferred observation/action spaces from remote "
  291. f"worker (local worker has no env): {spaces}"
  292. )
  293. return spaces
  294. @property
  295. def local_env_runner(self) -> EnvRunner:
  296. """Returns the local EnvRunner."""
  297. return self._local_env_runner
  298. def healthy_env_runner_ids(self) -> List[int]:
  299. """Returns the list of remote worker IDs."""
  300. return self._worker_manager.healthy_actor_ids()
  301. def healthy_worker_ids(self) -> List[int]:
  302. """Returns the list of remote worker IDs."""
  303. return self.healthy_env_runner_ids()
  304. def num_remote_env_runners(self) -> int:
  305. """Returns the number of remote EnvRunners."""
  306. return self._worker_manager.num_actors()
  307. def num_remote_workers(self) -> int:
  308. """Returns the number of remote EnvRunners."""
  309. return self.num_remote_env_runners()
  310. def num_healthy_remote_env_runners(self) -> int:
  311. """Returns the number of healthy remote workers."""
  312. return self._worker_manager.num_healthy_actors()
  313. def num_healthy_remote_workers(self) -> int:
  314. """Returns the number of healthy remote workers."""
  315. return self.num_healthy_remote_env_runners()
  316. def num_healthy_env_runners(self) -> int:
  317. """Returns the number of all healthy workers, including the local worker."""
  318. return int(bool(self._local_env_runner)) + self.num_healthy_remote_workers()
  319. def num_healthy_workers(self) -> int:
  320. """Returns the number of all healthy workers, including the local worker."""
  321. return self.num_healthy_env_runners()
  322. def num_in_flight_async_reqs(self, tag: Optional[str] = None) -> int:
  323. """Returns the number of in-flight async requests."""
  324. return self._worker_manager.num_outstanding_async_reqs(tag=tag)
  325. def num_remote_worker_restarts(self) -> int:
  326. """Total number of times managed remote workers have been restarted."""
  327. return self._worker_manager.total_num_restarts()
  328. def sync_env_runner_states(
  329. self,
  330. *,
  331. config: "AlgorithmConfig",
  332. from_worker: Optional[EnvRunner] = None,
  333. env_steps_sampled: Optional[int] = None,
  334. connector_states: Optional[List[Dict[str, Any]]] = None,
  335. rl_module_state: Optional[Dict[str, Any]] = None,
  336. env_runner_indices_to_update: Optional[List[int]] = None,
  337. env_to_module=None,
  338. module_to_env=None,
  339. ) -> None:
  340. """Synchronizes the connectors of this EnvRunnerGroup's EnvRunners.
  341. The exact procedure works as follows:
  342. - If `from_worker` is None, set `from_worker=self.local_env_runner`.
  343. - If `config.use_worker_filter_stats` is True, gather all remote EnvRunners'
  344. ConnectorV2 states. Otherwise, only use the ConnectorV2 states of `from_worker`.
  345. - Merge all gathered states into one resulting state.
  346. - Broadcast the resulting state back to all remote EnvRunners AND the local
  347. EnvRunner.
  348. Args:
  349. config: The AlgorithmConfig object to use to determine, in which
  350. direction(s) we need to synch and what the timeouts are.
  351. from_worker: The EnvRunner from which to synch. If None, will use the local
  352. worker of this EnvRunnerGroup.
  353. env_steps_sampled: The total number of env steps taken thus far by all
  354. workers combined. Used to broadcast this number to all remote workers
  355. if `update_worker_filter_stats` is True in `config`.
  356. env_runner_indices_to_update: The indices of those EnvRunners to update
  357. with the merged state. Use None (default) to update all remote
  358. EnvRunners.
  359. """
  360. if env_steps_sampled is not None:
  361. env_steps_sampled = int(env_steps_sampled)
  362. from_worker = from_worker or self.local_env_runner
  363. merge = (
  364. not config.enable_env_runner_and_connector_v2
  365. and config.use_worker_filter_stats
  366. ) or (
  367. config.enable_env_runner_and_connector_v2
  368. and (
  369. config.merge_env_runner_states is True
  370. or (
  371. config.merge_env_runner_states == "training_only"
  372. and not config.in_evaluation
  373. )
  374. )
  375. )
  376. broadcast = (
  377. not config.enable_env_runner_and_connector_v2
  378. and config.update_worker_filter_stats
  379. ) or (
  380. config.enable_env_runner_and_connector_v2
  381. and config.broadcast_env_runner_states
  382. )
  383. # Early out if the number of (healthy) remote workers is 0. In this case, the
  384. # local worker is the only operating worker and thus of course always holds
  385. # the reference connector state.
  386. if self.num_healthy_remote_workers() == 0 and self.local_env_runner:
  387. self.local_env_runner.set_state(
  388. {
  389. **(
  390. {NUM_ENV_STEPS_SAMPLED_LIFETIME: env_steps_sampled}
  391. if env_steps_sampled is not None
  392. else {}
  393. ),
  394. **(rl_module_state or {}),
  395. }
  396. )
  397. return
  398. # Also early out, if we don't merge AND don't broadcast.
  399. if not merge and not broadcast:
  400. return
  401. # Use states from all remote EnvRunners.
  402. if merge:
  403. if connector_states == []:
  404. env_runner_states = {}
  405. else:
  406. if connector_states is None:
  407. connector_states = self.foreach_env_runner(
  408. lambda w: w.get_state(
  409. components=[
  410. COMPONENT_ENV_TO_MODULE_CONNECTOR,
  411. COMPONENT_MODULE_TO_ENV_CONNECTOR,
  412. ]
  413. ),
  414. local_env_runner=False,
  415. timeout_seconds=(
  416. config.sync_filters_on_rollout_workers_timeout_s
  417. ),
  418. )
  419. env_to_module_states = [
  420. s[COMPONENT_ENV_TO_MODULE_CONNECTOR]
  421. for s in connector_states
  422. if COMPONENT_ENV_TO_MODULE_CONNECTOR in s
  423. ]
  424. module_to_env_states = [
  425. s[COMPONENT_MODULE_TO_ENV_CONNECTOR]
  426. for s in connector_states
  427. if COMPONENT_MODULE_TO_ENV_CONNECTOR in s
  428. ]
  429. if (
  430. self.local_env_runner is not None
  431. and hasattr(self.local_env_runner, "_env_to_module")
  432. and hasattr(self.local_env_runner, "_module_to_env")
  433. ):
  434. assert env_to_module is None
  435. env_to_module = self.local_env_runner._env_to_module
  436. assert module_to_env is None
  437. module_to_env = self.local_env_runner._module_to_env
  438. env_runner_states = {}
  439. if env_to_module_states:
  440. env_runner_states.update(
  441. {
  442. COMPONENT_ENV_TO_MODULE_CONNECTOR: (
  443. env_to_module.merge_states(env_to_module_states)
  444. ),
  445. }
  446. )
  447. if module_to_env_states:
  448. env_runner_states.update(
  449. {
  450. COMPONENT_MODULE_TO_ENV_CONNECTOR: (
  451. module_to_env.merge_states(module_to_env_states)
  452. ),
  453. }
  454. )
  455. # Ignore states from remote EnvRunners (use the current `from_worker` states
  456. # only).
  457. else:
  458. if from_worker is None:
  459. env_runner_states = {
  460. COMPONENT_ENV_TO_MODULE_CONNECTOR: env_to_module.get_state(),
  461. COMPONENT_MODULE_TO_ENV_CONNECTOR: module_to_env.get_state(),
  462. }
  463. else:
  464. env_runner_states = from_worker.get_state(
  465. components=[
  466. COMPONENT_ENV_TO_MODULE_CONNECTOR,
  467. COMPONENT_MODULE_TO_ENV_CONNECTOR,
  468. ]
  469. )
  470. # Update the global number of environment steps, if necessary.
  471. if env_steps_sampled is not None:
  472. env_runner_states[NUM_ENV_STEPS_SAMPLED_LIFETIME] = env_steps_sampled
  473. # If we do NOT want remote EnvRunners to get their Connector states updated,
  474. # only update the local worker here (with all state components, except the model
  475. # weights) and then remove the connector components.
  476. if not broadcast:
  477. if self.local_env_runner is not None:
  478. self.local_env_runner.set_state(env_runner_states)
  479. else:
  480. env_to_module.set_state(
  481. env_runner_states.get(COMPONENT_ENV_TO_MODULE_CONNECTOR), {}
  482. )
  483. module_to_env.set_state(
  484. env_runner_states.get(COMPONENT_MODULE_TO_ENV_CONNECTOR), {}
  485. )
  486. env_runner_states.pop(COMPONENT_ENV_TO_MODULE_CONNECTOR, None)
  487. env_runner_states.pop(COMPONENT_MODULE_TO_ENV_CONNECTOR, None)
  488. # If there are components in the state left -> Update remote workers with these
  489. # state components (and maybe the local worker, if it hasn't been updated yet).
  490. if env_runner_states:
  491. # Update the local EnvRunner, but NOT with the weights. If used at all for
  492. # evaluation (through the user calling `self.evaluate`), RLlib would update
  493. # the weights up front either way.
  494. if self.local_env_runner is not None and broadcast:
  495. self.local_env_runner.set_state(env_runner_states)
  496. # Send the model weights only to remote EnvRunners.
  497. # In case the local EnvRunner is ever needed for evaluation,
  498. # RLlib updates its weight right before such an eval step.
  499. if rl_module_state:
  500. env_runner_states.update(rl_module_state)
  501. # Broadcast updated states back to all workers.
  502. # We explicitly don't want to fire and forget here, because this can lead to a lot of in-flight requests.
  503. # When these pile up, object store memory can spike.
  504. self.foreach_env_runner_async_fetch_ready(
  505. func="set_state",
  506. tag="set_state",
  507. kwargs=dict(state=env_runner_states),
  508. remote_worker_ids=env_runner_indices_to_update,
  509. timeout_seconds=0.0,
  510. )
  511. def foreach_env_runner_async_fetch_ready(
  512. self,
  513. func: Union[
  514. Callable[[EnvRunner], T], List[Callable[[EnvRunner], T]], str, List[str]
  515. ],
  516. kwargs: Optional[Dict[str, Any]] = None,
  517. tag: Optional[str] = None,
  518. timeout_seconds: Optional[float] = 0.0,
  519. return_obj_refs: bool = False,
  520. mark_healthy: bool = False,
  521. healthy_only: bool = True,
  522. remote_worker_ids: List[int] = None,
  523. return_actor_ids: bool = False,
  524. ) -> List[Union[Tuple[int, T], T]]:
  525. """Calls the given function asynchronously and returns previous results if any.
  526. This is a convenience function that calls the underlying actor manager's
  527. `foreach_actor_async_fetch_ready()` method.
  528. """
  529. return self._worker_manager.foreach_actor_async_fetch_ready(
  530. func=func,
  531. tag=tag,
  532. kwargs=kwargs,
  533. timeout_seconds=timeout_seconds,
  534. return_obj_refs=return_obj_refs,
  535. mark_healthy=mark_healthy,
  536. healthy_only=healthy_only,
  537. remote_actor_ids=remote_worker_ids,
  538. ignore_ray_errors=self._ignore_ray_errors_on_env_runners,
  539. return_actor_ids=return_actor_ids,
  540. )
  541. def sync_weights(
  542. self,
  543. policies: Optional[List[PolicyID]] = None,
  544. from_worker_or_learner_group: Optional[Union[EnvRunner, "LearnerGroup"]] = None,
  545. to_worker_indices: Optional[List[int]] = None,
  546. global_vars: Optional[Dict[str, TensorType]] = None,
  547. timeout_seconds: Optional[float] = 0.0,
  548. inference_only: Optional[bool] = False,
  549. ) -> None:
  550. """Syncs model weights from the given weight source to all remote workers.
  551. Weight source can be either a (local) rollout worker or a learner_group. It
  552. should just implement a `get_weights` method.
  553. Args:
  554. policies: Optional list of PolicyIDs to sync weights for.
  555. If None (default), sync weights to/from all policies.
  556. from_worker_or_learner_group: Optional (local) EnvRunner instance or
  557. LearnerGroup instance to sync from. If None (default),
  558. sync from this EnvRunnerGroup's local worker.
  559. to_worker_indices: Optional list of worker indices to sync the
  560. weights to. If None (default), sync to all remote workers.
  561. global_vars: An optional global vars dict to set this
  562. worker to. If None, do not update the global_vars.
  563. timeout_seconds: Timeout in seconds to wait for the sync weights
  564. calls to complete. Default is 0.0 (fire-and-forget, do not wait
  565. for any sync calls to finish). Setting this to 0.0 might significantly
  566. improve algorithm performance, depending on the algo's `training_step`
  567. logic.
  568. inference_only: Sync weights with workers that keep inference-only
  569. modules. This is needed for algorithms in the new stack that
  570. use inference-only modules. In this case only a part of the
  571. parameters are synced to the workers. Default is False.
  572. """
  573. if self.local_env_runner is None and from_worker_or_learner_group is None:
  574. raise TypeError(
  575. "No `local_env_runner` in EnvRunnerGroup! Must provide "
  576. "`from_worker_or_learner_group` arg in `sync_weights()`!"
  577. )
  578. # Only sync if we have remote workers or `from_worker_or_trainer` is provided.
  579. rl_module_state = None
  580. if self.num_remote_workers() or from_worker_or_learner_group is not None:
  581. weights_src = (
  582. from_worker_or_learner_group
  583. if from_worker_or_learner_group is not None
  584. else self.local_env_runner
  585. )
  586. if weights_src is None:
  587. raise ValueError(
  588. "`from_worker_or_trainer` is None. In this case, EnvRunnerGroup "
  589. "should have local_env_runner. But local_env_runner is also None."
  590. )
  591. modules = (
  592. [COMPONENT_RL_MODULE + "/" + p for p in policies]
  593. if policies is not None
  594. else [COMPONENT_RL_MODULE]
  595. )
  596. # LearnerGroup has a Learner, which has an RLModule.
  597. if isinstance(weights_src, LearnerGroup):
  598. rl_module_state = weights_src.get_state(
  599. components=[COMPONENT_LEARNER + "/" + m for m in modules],
  600. inference_only=inference_only,
  601. )[COMPONENT_LEARNER]
  602. # EnvRunner (new API stack).
  603. elif self._remote_config.enable_env_runner_and_connector_v2:
  604. # EnvRunner (remote) has an RLModule.
  605. # TODO (sven): Replace this with a new ActorManager API:
  606. # try_remote_request_till_success("get_state") -> tuple(int,
  607. # remoteresult)
  608. # `weights_src` could be the ActorManager, then. Then RLlib would know
  609. # that it has to ping the manager to try all healthy actors until the
  610. # first returns something.
  611. if isinstance(weights_src, ray.actor.ActorHandle):
  612. rl_module_state = ray.get(
  613. weights_src.get_state.remote(
  614. components=modules,
  615. inference_only=inference_only,
  616. )
  617. )
  618. # EnvRunner (local) has an RLModule.
  619. else:
  620. rl_module_state = weights_src.get_state(
  621. components=modules,
  622. inference_only=inference_only,
  623. )
  624. # RolloutWorker (old API stack).
  625. else:
  626. rl_module_state = weights_src.get_weights(
  627. policies=policies,
  628. inference_only=inference_only,
  629. )
  630. if self._remote_config.enable_env_runner_and_connector_v2:
  631. # Make sure `rl_module_state` only contains the weights and the
  632. # weight seq no, nothing else.
  633. rl_module_state = {
  634. k: v
  635. for k, v in rl_module_state.items()
  636. if k in [COMPONENT_RL_MODULE, WEIGHTS_SEQ_NO]
  637. }
  638. # Move weights to the object store to avoid having to make n pickled
  639. # copies of the weights dict for each worker.
  640. rl_module_state_ref = ray.put(rl_module_state)
  641. # Sync to specified remote workers in this EnvRunnerGroup.
  642. # We explicitly don't want to fire and forget here, because this can lead to a lot of in-flight requests.
  643. # When these pile up, object store memory can spike.
  644. self.foreach_env_runner_async_fetch_ready(
  645. func="set_state",
  646. tag="set_state",
  647. kwargs=dict(state=rl_module_state_ref),
  648. remote_worker_ids=to_worker_indices,
  649. timeout_seconds=timeout_seconds,
  650. )
  651. else:
  652. rl_module_state_ref = ray.put(rl_module_state)
  653. def _set_weights(env_runner):
  654. env_runner.set_weights(ray.get(rl_module_state_ref), global_vars)
  655. # Sync to specified remote workers in this EnvRunnerGroup.
  656. self.foreach_env_runner(
  657. func=_set_weights,
  658. local_env_runner=False, # Do not sync back to local worker.
  659. remote_worker_ids=to_worker_indices,
  660. timeout_seconds=timeout_seconds,
  661. )
  662. # If `from_worker_or_learner_group` is provided, also sync to this
  663. # EnvRunnerGroup's local worker.
  664. if self.local_env_runner is not None:
  665. if from_worker_or_learner_group is not None:
  666. if self._remote_config.enable_env_runner_and_connector_v2:
  667. self.local_env_runner.set_state(rl_module_state)
  668. else:
  669. self.local_env_runner.set_weights(rl_module_state)
  670. # If `global_vars` is provided and local worker exists -> Update its
  671. # global_vars.
  672. if global_vars is not None:
  673. self.local_env_runner.set_global_vars(global_vars)
  674. def add_workers(self, num_workers: int, validate: bool = False) -> None:
  675. """Creates and adds a number of remote workers to this worker set.
  676. Can be called several times on the same EnvRunnerGroup to add more
  677. EnvRunners to the set.
  678. Args:
  679. num_workers: The number of remote Workers to add to this
  680. EnvRunnerGroup.
  681. validate: Whether to validate remote workers after their construction
  682. process.
  683. Raises:
  684. RayError: If any of the constructed remote workers is not up and running
  685. properly.
  686. """
  687. old_num_workers = self._worker_manager.num_actors()
  688. new_workers = [
  689. self._make_worker(
  690. env_creator=self._env_creator,
  691. validate_env=None,
  692. worker_index=old_num_workers + i + 1,
  693. num_workers=old_num_workers + num_workers,
  694. # self._remote_config can be large
  695. # and it's best practice to pass it by reference
  696. # instead of value (https://docs.ray.io/en/latest/ray-core/patterns/pass-large-arg-by-value.html)
  697. config=self._remote_config_obj_ref,
  698. )
  699. for i in range(num_workers)
  700. ]
  701. self._worker_manager.add_actors(new_workers)
  702. # Validate here, whether all remote workers have been constructed properly
  703. # and are "up and running". Establish initial states.
  704. if validate:
  705. for result in self._worker_manager.foreach_actor(
  706. lambda w: w.assert_healthy()
  707. ):
  708. # Simiply raise the error, which will get handled by the try-except
  709. # clause around the _setup().
  710. if not result.ok:
  711. e = result.get()
  712. if self._ignore_ray_errors_on_env_runners:
  713. logger.error(f"Validation of EnvRunner failed! Error={str(e)}")
  714. else:
  715. raise e
  716. def reset(self, new_remote_workers: List[ActorHandle]) -> None:
  717. """Hard overrides the remote EnvRunners in this set with the provided ones.
  718. Args:
  719. new_remote_workers: A list of new EnvRunners (as `ActorHandles`) to use as
  720. new remote workers.
  721. """
  722. self._worker_manager.clear()
  723. self._worker_manager.add_actors(new_remote_workers)
  724. def stop(self) -> None:
  725. """Calls `stop` on all EnvRunners (including the local one)."""
  726. try:
  727. # Make sure we stop all EnvRunners, include the ones that were just
  728. # restarted / recovered or that are tagged unhealthy (at least, we should
  729. # try).
  730. self.foreach_env_runner(
  731. lambda w: w.stop(), healthy_only=False, local_env_runner=True
  732. )
  733. except Exception:
  734. logger.exception("Failed to stop workers!")
  735. finally:
  736. self._worker_manager.clear()
  737. def foreach_env_runner(
  738. self,
  739. func: Union[
  740. Callable[[EnvRunner], T], List[Callable[[EnvRunner], T]], str, List[str]
  741. ],
  742. *,
  743. kwargs=None,
  744. local_env_runner: bool = True,
  745. healthy_only: bool = True,
  746. remote_worker_ids: List[int] = None,
  747. timeout_seconds: Optional[float] = None,
  748. return_obj_refs: bool = False,
  749. mark_healthy: bool = False,
  750. ) -> List[T]:
  751. """Calls the given function with each EnvRunner as its argument.
  752. Args:
  753. func: The function to call for each EnvRunners. The only call argument is
  754. the respective EnvRunner instance.
  755. local_env_runner: Whether to apply `func` to local EnvRunner, too.
  756. Default is True.
  757. healthy_only: Apply `func` on known-to-be healthy EnvRunners only.
  758. remote_worker_ids: Apply `func` on a selected set of remote EnvRunners.
  759. Use None (default) for all remote EnvRunners.
  760. timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for
  761. fire-and-forget. Set this to None (default) to wait infinitely (i.e. for
  762. synchronous execution).
  763. return_obj_refs: Whether to return ObjectRef instead of actual results.
  764. Note, for fault tolerance reasons, these returned ObjectRefs should
  765. never be resolved with ray.get() outside of this EnvRunnerGroup.
  766. mark_healthy: Whether to mark all those EnvRunners healthy again that are
  767. currently marked unhealthy AND that returned results from the remote
  768. call (within the given `timeout_seconds`).
  769. Note that EnvRunners are NOT set unhealthy, if they simply time out
  770. (only if they return a RayActorError).
  771. Also note that this setting is ignored if `healthy_only=True` (b/c
  772. `mark_healthy` only affects EnvRunners that are currently tagged as
  773. unhealthy).
  774. Returns:
  775. The list of return values of all calls to `func([worker])`.
  776. """
  777. assert (
  778. not return_obj_refs or not local_env_runner
  779. ), "Can not return ObjectRef from local worker."
  780. local_result = []
  781. if local_env_runner and self.local_env_runner is not None:
  782. assert kwargs is None
  783. if isinstance(func, str):
  784. local_result = [getattr(self.local_env_runner, func)]
  785. else:
  786. local_result = [func(self.local_env_runner)]
  787. if not self._worker_manager.actor_ids():
  788. return local_result
  789. remote_results = self._worker_manager.foreach_actor(
  790. func,
  791. kwargs=kwargs,
  792. healthy_only=healthy_only,
  793. remote_actor_ids=remote_worker_ids,
  794. timeout_seconds=timeout_seconds,
  795. return_obj_refs=return_obj_refs,
  796. mark_healthy=mark_healthy,
  797. )
  798. FaultTolerantActorManager.handle_remote_call_result_errors(
  799. remote_results, ignore_ray_errors=self._ignore_ray_errors_on_env_runners
  800. )
  801. # With application errors handled, return good results.
  802. remote_results = [r.get() for r in remote_results.ignore_errors()]
  803. return local_result + remote_results
  804. def foreach_env_runner_async(
  805. self,
  806. func: Union[
  807. Callable[[EnvRunner], T], List[Callable[[EnvRunner], T]], str, List[str]
  808. ],
  809. tag: Optional[str] = None,
  810. *,
  811. kwargs=None,
  812. healthy_only: bool = True,
  813. remote_worker_ids: List[int] = None,
  814. ) -> int:
  815. """Calls the given function asynchronously with each EnvRunner as the argument.
  816. Does not return results directly. Instead, `fetch_ready_async_reqs()` can be
  817. used to pull results in an async manner whenever they are available.
  818. Args:
  819. func: The function to call for each EnvRunners. The only call argument is
  820. the respective EnvRunner instance.
  821. tag: A tag to identify the results from this async call when fetching with
  822. `fetch_ready_async_reqs()`.
  823. kwargs: An optional kwargs dict to be passed to the remote function calls.
  824. healthy_only: Apply `func` on known-to-be healthy EnvRunners only.
  825. remote_worker_ids: Apply `func` on a selected set of remote EnvRunners.
  826. Returns:
  827. The number of async requests that have actually been made. This is the
  828. length of `remote_worker_ids` (or self.num_remote_workers()` if
  829. `remote_worker_ids` is None) minus the number of requests that were NOT
  830. made b/c a remote EnvRunner already had its
  831. `max_remote_requests_in_flight_per_actor` counter reached for this tag.
  832. """
  833. return self._worker_manager.foreach_actor_async(
  834. func,
  835. tag=tag,
  836. kwargs=kwargs,
  837. healthy_only=healthy_only,
  838. remote_actor_ids=remote_worker_ids,
  839. )
  840. def fetch_ready_async_reqs(
  841. self,
  842. *,
  843. tags: Optional[Union[str, List[str], Tuple[str]]] = None,
  844. timeout_seconds: Optional[float] = 0.0,
  845. return_obj_refs: bool = False,
  846. mark_healthy: bool = False,
  847. ) -> List[Tuple[int, T]]:
  848. """Get results from outstanding asynchronous requests that are ready.
  849. Args:
  850. tags: Tags to identify the results from a specific async call.
  851. If None (default), returns results from all ready async requests.
  852. If a single string, returns results from all ready async requests with that tag.
  853. timeout_seconds: Time to wait for results. Default is 0, meaning
  854. those requests that are already ready.
  855. return_obj_refs: Whether to return ObjectRef instead of actual results.
  856. mark_healthy: Whether to mark all those workers healthy again that are
  857. currently marked unhealthy AND that returned results from the remote
  858. call (within the given `timeout_seconds`).
  859. Note that workers are NOT set unhealthy, if they simply time out
  860. (only if they return a RayActorError).
  861. Also note that this setting is ignored if `healthy_only=True` (b/c
  862. `mark_healthy` only affects workers that are currently tagged as
  863. unhealthy).
  864. Returns:
  865. A list of results successfully returned from outstanding remote calls,
  866. paired with the indices of the callee workers.
  867. """
  868. # Get remote results
  869. remote_results = self._worker_manager.fetch_ready_async_reqs(
  870. tags=tags,
  871. timeout_seconds=timeout_seconds,
  872. return_obj_refs=return_obj_refs,
  873. mark_healthy=mark_healthy,
  874. )
  875. FaultTolerantActorManager.handle_remote_call_result_errors(
  876. remote_results,
  877. ignore_ray_errors=self._ignore_ray_errors_on_env_runners,
  878. )
  879. return [(r.actor_id, r.get()) for r in remote_results.ignore_errors()]
  880. @OldAPIStack
  881. def foreach_env(self, func: Callable[[EnvType], List[T]]) -> List[List[T]]:
  882. """Calls `func` with all workers' sub-environments as args.
  883. An "underlying sub environment" is a single clone of an env within
  884. a vectorized environment.
  885. `func` takes a single underlying sub environment as arg, e.g. a
  886. gym.Env object.
  887. Args:
  888. func: A function - taking an EnvType (normally a gym.Env object)
  889. as arg and returning a list of lists of return values, one
  890. value per underlying sub-environment per each worker.
  891. Returns:
  892. The list (workers) of lists (sub environments) of results.
  893. """
  894. return list(
  895. self.foreach_env_runner(
  896. lambda w: w.foreach_env(func),
  897. local_env_runner=True,
  898. )
  899. )
  900. @OldAPIStack
  901. def foreach_env_with_context(
  902. self, func: Callable[[BaseEnv, EnvContext], List[T]]
  903. ) -> List[List[T]]:
  904. """Calls `func` with all workers' sub-environments and env_ctx as args.
  905. An "underlying sub environment" is a single clone of an env within
  906. a vectorized environment.
  907. `func` takes a single underlying sub environment and the env_context
  908. as args.
  909. Args:
  910. func: A function - taking a BaseEnv object and an EnvContext as
  911. arg - and returning a list of lists of return values over envs
  912. of the worker.
  913. Returns:
  914. The list (1 item per workers) of lists (1 item per sub-environment)
  915. of results.
  916. """
  917. return list(
  918. self.foreach_env_runner(
  919. lambda w: w.foreach_env_with_context(func),
  920. local_env_runner=True,
  921. )
  922. )
  923. def probe_unhealthy_env_runners(self) -> List[int]:
  924. """Checks for unhealthy workers and tries restoring their states.
  925. Returns:
  926. List of IDs of the workers that were restored.
  927. """
  928. return self._worker_manager.probe_unhealthy_actors(
  929. timeout_seconds=self._remote_config.env_runner_health_probe_timeout_s,
  930. mark_healthy=True,
  931. )
  932. @OldAPIStack
  933. def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
  934. """Calls `func` with each worker's (policy, PolicyID) tuple.
  935. Note that in the multi-agent case, each worker may have more than one
  936. policy.
  937. Args:
  938. func: A function - taking a Policy and its ID - that is
  939. called on all workers' Policies.
  940. Returns:
  941. The list of return values of func over all workers' policies. The
  942. length of this list is:
  943. (num_workers + 1 (local-worker)) *
  944. [num policies in the multi-agent config dict].
  945. The local workers' results are first, followed by all remote
  946. workers' results
  947. """
  948. results = []
  949. for r in self.foreach_env_runner(
  950. lambda w: w.foreach_policy(func), local_env_runner=True
  951. ):
  952. results.extend(r)
  953. return results
  954. @OldAPIStack
  955. def foreach_policy_to_train(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
  956. """Apply `func` to all workers' Policies iff in `policies_to_train`.
  957. Args:
  958. func: A function - taking a Policy and its ID - that is
  959. called on all workers' Policies, for which
  960. `worker.is_policy_to_train()` returns True.
  961. Returns:
  962. List[any]: The list of n return values of all
  963. `func([trainable policy], [ID])`-calls.
  964. """
  965. results = []
  966. for r in self.foreach_env_runner(
  967. lambda w: w.foreach_policy_to_train(func), local_env_runner=True
  968. ):
  969. results.extend(r)
  970. return results
  971. @OldAPIStack
  972. def is_policy_to_train(
  973. self, policy_id: PolicyID, batch: Optional[SampleBatchType] = None
  974. ) -> bool:
  975. """Whether given PolicyID (optionally inside some batch) is trainable."""
  976. if self.local_env_runner:
  977. if self.local_env_runner.is_policy_to_train is None:
  978. return True
  979. return self.local_env_runner.is_policy_to_train(policy_id, batch)
  980. else:
  981. raise NotImplementedError
  982. @OldAPIStack
  983. def add_policy(
  984. self,
  985. policy_id: PolicyID,
  986. policy_cls: Optional[Type[Policy]] = None,
  987. policy: Optional[Policy] = None,
  988. *,
  989. observation_space: Optional[gym.spaces.Space] = None,
  990. action_space: Optional[gym.spaces.Space] = None,
  991. config: Optional[Union["AlgorithmConfig", PartialAlgorithmConfigDict]] = None,
  992. policy_state: Optional[PolicyState] = None,
  993. policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
  994. policies_to_train: Optional[
  995. Union[
  996. Collection[PolicyID],
  997. Callable[[PolicyID, Optional[SampleBatchType]], bool],
  998. ]
  999. ] = None,
  1000. module_spec: Optional[RLModuleSpec] = None,
  1001. # Deprecated.
  1002. workers: Optional[List[Union[EnvRunner, ActorHandle]]] = DEPRECATED_VALUE,
  1003. ) -> None:
  1004. """Adds a policy to this EnvRunnerGroup's workers or a specific list of workers.
  1005. Args:
  1006. policy_id: ID of the policy to add.
  1007. policy_cls: The Policy class to use for constructing the new Policy.
  1008. Note: Only one of `policy_cls` or `policy` must be provided.
  1009. policy: The Policy instance to add to this EnvRunnerGroup. If not None, the
  1010. given Policy object will be directly inserted into the
  1011. local worker and clones of that Policy will be created on all remote
  1012. workers.
  1013. Note: Only one of `policy_cls` or `policy` must be provided.
  1014. observation_space: The observation space of the policy to add.
  1015. If None, try to infer this space from the environment.
  1016. action_space: The action space of the policy to add.
  1017. If None, try to infer this space from the environment.
  1018. config: The config object or overrides for the policy to add.
  1019. policy_state: Optional state dict to apply to the new
  1020. policy instance, right after its construction.
  1021. policy_mapping_fn: An optional (updated) policy mapping function
  1022. to use from here on. Note that already ongoing episodes will
  1023. not change their mapping but will use the old mapping till
  1024. the end of the episode.
  1025. policies_to_train: An optional list of policy IDs to be trained
  1026. or a callable taking PolicyID and SampleBatchType and
  1027. returning a bool (trainable or not?).
  1028. If None, will keep the existing setup in place. Policies,
  1029. whose IDs are not in the list (or for which the callable
  1030. returns False) will not be updated.
  1031. module_spec: In the new RLModule API we need to pass in the module_spec for
  1032. the new module that is supposed to be added. Knowing the policy spec is
  1033. not sufficient.
  1034. workers: A list of EnvRunner/ActorHandles (remote
  1035. EnvRunners) to add this policy to. If defined, will only
  1036. add the given policy to these workers.
  1037. Raises:
  1038. KeyError: If the given `policy_id` already exists in this EnvRunnerGroup.
  1039. """
  1040. if self.local_env_runner and policy_id in self.local_env_runner.policy_map:
  1041. raise KeyError(
  1042. f"Policy ID '{policy_id}' already exists in policy map! "
  1043. "Make sure you use a Policy ID that has not been taken yet."
  1044. " Policy IDs that are already in your policy map: "
  1045. f"{list(self.local_env_runner.policy_map.keys())}"
  1046. )
  1047. if workers is not DEPRECATED_VALUE:
  1048. deprecation_warning(
  1049. old="EnvRunnerGroup.add_policy(.., workers=..)",
  1050. help=(
  1051. "The `workers` argument to `EnvRunnerGroup.add_policy()` is "
  1052. "deprecated! Please do not use it anymore."
  1053. ),
  1054. error=True,
  1055. )
  1056. if (policy_cls is None) == (policy is None):
  1057. raise ValueError(
  1058. "Only one of `policy_cls` or `policy` must be provided to "
  1059. "staticmethod: `EnvRunnerGroup.add_policy()`!"
  1060. )
  1061. validate_module_id(policy_id, error=False)
  1062. # Policy instance not provided: Use the information given here.
  1063. if policy_cls is not None:
  1064. new_policy_instance_kwargs = dict(
  1065. policy_id=policy_id,
  1066. policy_cls=policy_cls,
  1067. observation_space=observation_space,
  1068. action_space=action_space,
  1069. config=config,
  1070. policy_state=policy_state,
  1071. policy_mapping_fn=policy_mapping_fn,
  1072. policies_to_train=list(policies_to_train)
  1073. if policies_to_train
  1074. else None,
  1075. module_spec=module_spec,
  1076. )
  1077. # Policy instance provided: Create clones of this very policy on the different
  1078. # workers (copy all its properties here for the calls to add_policy on the
  1079. # remote workers).
  1080. else:
  1081. new_policy_instance_kwargs = dict(
  1082. policy_id=policy_id,
  1083. policy_cls=type(policy),
  1084. observation_space=policy.observation_space,
  1085. action_space=policy.action_space,
  1086. config=policy.config,
  1087. policy_state=policy.get_state(),
  1088. policy_mapping_fn=policy_mapping_fn,
  1089. policies_to_train=list(policies_to_train)
  1090. if policies_to_train
  1091. else None,
  1092. module_spec=module_spec,
  1093. )
  1094. def _create_new_policy_fn(worker):
  1095. # `foreach_env_runner` function: Adds the policy to the worker (and
  1096. # maybe changes its policy_mapping_fn - if provided here).
  1097. worker.add_policy(**new_policy_instance_kwargs)
  1098. if self.local_env_runner is not None:
  1099. # Add policy directly by (already instantiated) object.
  1100. if policy is not None:
  1101. self.local_env_runner.add_policy(
  1102. policy_id=policy_id,
  1103. policy=policy,
  1104. policy_mapping_fn=policy_mapping_fn,
  1105. policies_to_train=policies_to_train,
  1106. module_spec=module_spec,
  1107. )
  1108. # Add policy by constructor kwargs.
  1109. else:
  1110. self.local_env_runner.add_policy(**new_policy_instance_kwargs)
  1111. # Add the policy to all remote workers.
  1112. self.foreach_env_runner(_create_new_policy_fn, local_env_runner=False)
  1113. def _make_worker(
  1114. self,
  1115. *,
  1116. env_creator: EnvCreator,
  1117. validate_env: Optional[Callable[[EnvType], None]],
  1118. worker_index: int,
  1119. num_workers: int,
  1120. recreated_worker: bool = False,
  1121. config: "AlgorithmConfig",
  1122. spaces: Optional[
  1123. Dict[PolicyID, Tuple[gym.spaces.Space, gym.spaces.Space]]
  1124. ] = None,
  1125. ) -> Union[EnvRunner, ActorHandle]:
  1126. kwargs = dict(
  1127. env_creator=env_creator,
  1128. validate_env=validate_env,
  1129. default_policy_class=self._policy_class,
  1130. config=config,
  1131. worker_index=worker_index,
  1132. num_workers=num_workers,
  1133. recreated_worker=recreated_worker,
  1134. log_dir=self._logdir,
  1135. spaces=spaces,
  1136. dataset_shards=self._ds_shards,
  1137. tune_trial_id=self._tune_trial_id,
  1138. )
  1139. if worker_index == 0:
  1140. return self.env_runner_cls(**kwargs)
  1141. pg_bundle_idx = (
  1142. -1
  1143. if ray.util.get_current_placement_group() is None
  1144. else self._pg_offset + worker_index
  1145. )
  1146. return (
  1147. ray.remote(**self._remote_args)(self.env_runner_cls)
  1148. .options(placement_group_bundle_index=pg_bundle_idx)
  1149. .remote(**kwargs)
  1150. )
  1151. @staticmethod
  1152. def _valid_module(class_path):
  1153. if (
  1154. isinstance(class_path, str)
  1155. and not os.path.isfile(class_path)
  1156. and "." in class_path
  1157. ):
  1158. module_path, class_name = class_path.rsplit(".", 1)
  1159. try:
  1160. spec = importlib.util.find_spec(module_path)
  1161. if spec is not None:
  1162. return True
  1163. except (ModuleNotFoundError, ValueError) as e:
  1164. logger.warning(
  1165. f"module {module_path} not found using input {class_path} with error: {e}"
  1166. )
  1167. return False