external_env.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. import queue
  2. import threading
  3. import uuid
  4. from typing import TYPE_CHECKING, Callable, Optional, Tuple
  5. import gymnasium as gym
  6. from ray._common.deprecation import deprecation_warning
  7. from ray.rllib.env.base_env import BaseEnv
  8. from ray.rllib.utils.annotations import OldAPIStack, override
  9. from ray.rllib.utils.typing import (
  10. EnvActionType,
  11. EnvInfoDict,
  12. EnvObsType,
  13. EnvType,
  14. MultiEnvDict,
  15. )
  16. if TYPE_CHECKING:
  17. from ray.rllib.models.preprocessors import Preprocessor
  18. @OldAPIStack
  19. class ExternalEnv(threading.Thread):
  20. """An environment that interfaces with external agents.
  21. Unlike simulator envs, control is inverted: The environment queries the
  22. policy to obtain actions and in return logs observations and rewards for
  23. training. This is in contrast to gym.Env, where the algorithm drives the
  24. simulation through env.step() calls.
  25. You can use ExternalEnv as the backend for policy serving (by serving HTTP
  26. requests in the run loop), for ingesting offline logs data (by reading
  27. offline transitions in the run loop), or other custom use cases not easily
  28. expressed through gym.Env.
  29. ExternalEnv supports both on-policy actions (through self.get_action()),
  30. and off-policy actions (through self.log_action()).
  31. This env is thread-safe, but individual episodes must be executed serially.
  32. .. testcode::
  33. :skipif: True
  34. from ray.tune import register_env
  35. from ray.rllib.algorithms.dqn import DQN
  36. YourExternalEnv = ...
  37. register_env("my_env", lambda config: YourExternalEnv(config))
  38. algo = DQN(env="my_env")
  39. while True:
  40. print(algo.train())
  41. """
  42. def __init__(
  43. self,
  44. action_space: gym.Space,
  45. observation_space: gym.Space,
  46. max_concurrent: int = None,
  47. ):
  48. """Initializes an ExternalEnv instance.
  49. Args:
  50. action_space: Action space of the env.
  51. observation_space: Observation space of the env.
  52. """
  53. threading.Thread.__init__(self)
  54. self.daemon = True
  55. self.action_space = action_space
  56. self.observation_space = observation_space
  57. self._episodes = {}
  58. self._finished = set()
  59. self._results_avail_condition = threading.Condition()
  60. if max_concurrent is not None:
  61. deprecation_warning(
  62. "The `max_concurrent` argument has been deprecated. Please configure"
  63. "the number of episodes using the `rollout_fragment_length` and"
  64. "`batch_mode` arguments. Please raise an issue on the Ray Github if "
  65. "these arguments do not support your expected use case for ExternalEnv",
  66. error=True,
  67. )
  68. def run(self):
  69. """Override this to implement the run loop.
  70. Your loop should continuously:
  71. 1. Call self.start_episode(episode_id)
  72. 2. Call self.[get|log]_action(episode_id, obs, [action]?)
  73. 3. Call self.log_returns(episode_id, reward)
  74. 4. Call self.end_episode(episode_id, obs)
  75. 5. Wait if nothing to do.
  76. Multiple episodes may be started at the same time.
  77. """
  78. raise NotImplementedError
  79. def start_episode(
  80. self, episode_id: Optional[str] = None, training_enabled: bool = True
  81. ) -> str:
  82. """Record the start of an episode.
  83. Args:
  84. episode_id: Unique string id for the episode or
  85. None for it to be auto-assigned and returned.
  86. training_enabled: Whether to use experiences for this
  87. episode to improve the policy.
  88. Returns:
  89. Unique string id for the episode.
  90. """
  91. if episode_id is None:
  92. episode_id = uuid.uuid4().hex
  93. if episode_id in self._finished:
  94. raise ValueError("Episode {} has already completed.".format(episode_id))
  95. if episode_id in self._episodes:
  96. raise ValueError("Episode {} is already started".format(episode_id))
  97. self._episodes[episode_id] = _ExternalEnvEpisode(
  98. episode_id, self._results_avail_condition, training_enabled
  99. )
  100. return episode_id
  101. def get_action(self, episode_id: str, observation: EnvObsType) -> EnvActionType:
  102. """Record an observation and get the on-policy action.
  103. Args:
  104. episode_id: Episode id returned from start_episode().
  105. observation: Current environment observation.
  106. Returns:
  107. Action from the env action space.
  108. """
  109. episode = self._get(episode_id)
  110. return episode.wait_for_action(observation)
  111. def log_action(
  112. self, episode_id: str, observation: EnvObsType, action: EnvActionType
  113. ) -> None:
  114. """Record an observation and (off-policy) action taken.
  115. Args:
  116. episode_id: Episode id returned from start_episode().
  117. observation: Current environment observation.
  118. action: Action for the observation.
  119. """
  120. episode = self._get(episode_id)
  121. episode.log_action(observation, action)
  122. def log_returns(
  123. self, episode_id: str, reward: float, info: Optional[EnvInfoDict] = None
  124. ) -> None:
  125. """Records returns (rewards and infos) from the environment.
  126. The reward will be attributed to the previous action taken by the
  127. episode. Rewards accumulate until the next action. If no reward is
  128. logged before the next action, a reward of 0.0 is assumed.
  129. Args:
  130. episode_id: Episode id returned from start_episode().
  131. reward: Reward from the environment.
  132. info: Optional info dict.
  133. """
  134. episode = self._get(episode_id)
  135. episode.cur_reward += reward
  136. if info:
  137. episode.cur_info = info or {}
  138. def end_episode(self, episode_id: str, observation: EnvObsType) -> None:
  139. """Records the end of an episode.
  140. Args:
  141. episode_id: Episode id returned from start_episode().
  142. observation: Current environment observation.
  143. """
  144. episode = self._get(episode_id)
  145. self._finished.add(episode.episode_id)
  146. episode.done(observation)
  147. def _get(self, episode_id: str) -> "_ExternalEnvEpisode":
  148. """Get a started episode by its ID or raise an error."""
  149. if episode_id in self._finished:
  150. raise ValueError("Episode {} has already completed.".format(episode_id))
  151. if episode_id not in self._episodes:
  152. raise ValueError("Episode {} not found.".format(episode_id))
  153. return self._episodes[episode_id]
  154. def to_base_env(
  155. self,
  156. make_env: Optional[Callable[[int], EnvType]] = None,
  157. num_envs: int = 1,
  158. remote_envs: bool = False,
  159. remote_env_batch_wait_ms: int = 0,
  160. restart_failed_sub_environments: bool = False,
  161. ) -> "BaseEnv":
  162. """Converts an RLlib MultiAgentEnv into a BaseEnv object.
  163. The resulting BaseEnv is always vectorized (contains n
  164. sub-environments) to support batched forward passes, where n may
  165. also be 1. BaseEnv also supports async execution via the `poll` and
  166. `send_actions` methods and thus supports external simulators.
  167. Args:
  168. make_env: A callable taking an int as input (which indicates
  169. the number of individual sub-environments within the final
  170. vectorized BaseEnv) and returning one individual
  171. sub-environment.
  172. num_envs: The number of sub-environments to create in the
  173. resulting (vectorized) BaseEnv. The already existing `env`
  174. will be one of the `num_envs`.
  175. remote_envs: Whether each sub-env should be a @ray.remote
  176. actor. You can set this behavior in your config via the
  177. `remote_worker_envs=True` option.
  178. remote_env_batch_wait_ms: The wait time (in ms) to poll remote
  179. sub-environments for, if applicable. Only used if
  180. `remote_envs` is True.
  181. Returns:
  182. The resulting BaseEnv object.
  183. """
  184. if num_envs != 1:
  185. raise ValueError(
  186. "External(MultiAgent)Env does not currently support "
  187. "num_envs > 1. One way of solving this would be to "
  188. "treat your Env as a MultiAgentEnv hosting only one "
  189. "type of agent but with several copies."
  190. )
  191. env = ExternalEnvWrapper(self)
  192. return env
  193. @OldAPIStack
  194. class _ExternalEnvEpisode:
  195. """Tracked state for each active episode."""
  196. def __init__(
  197. self,
  198. episode_id: str,
  199. results_avail_condition: threading.Condition,
  200. training_enabled: bool,
  201. multiagent: bool = False,
  202. ):
  203. self.episode_id = episode_id
  204. self.results_avail_condition = results_avail_condition
  205. self.training_enabled = training_enabled
  206. self.multiagent = multiagent
  207. self.data_queue = queue.Queue()
  208. self.action_queue = queue.Queue()
  209. if multiagent:
  210. self.new_observation_dict = None
  211. self.new_action_dict = None
  212. self.cur_reward_dict = {}
  213. self.cur_terminated_dict = {"__all__": False}
  214. self.cur_truncated_dict = {"__all__": False}
  215. self.cur_info_dict = {}
  216. else:
  217. self.new_observation = None
  218. self.new_action = None
  219. self.cur_reward = 0.0
  220. self.cur_terminated = False
  221. self.cur_truncated = False
  222. self.cur_info = {}
  223. def get_data(self):
  224. if self.data_queue.empty():
  225. return None
  226. return self.data_queue.get_nowait()
  227. def log_action(self, observation, action):
  228. if self.multiagent:
  229. self.new_observation_dict = observation
  230. self.new_action_dict = action
  231. else:
  232. self.new_observation = observation
  233. self.new_action = action
  234. self._send()
  235. self.action_queue.get(True, timeout=60.0)
  236. def wait_for_action(self, observation):
  237. if self.multiagent:
  238. self.new_observation_dict = observation
  239. else:
  240. self.new_observation = observation
  241. self._send()
  242. return self.action_queue.get(True, timeout=300.0)
  243. def done(self, observation):
  244. if self.multiagent:
  245. self.new_observation_dict = observation
  246. self.cur_terminated_dict = {"__all__": True}
  247. # TODO(sven): External env API does not currently support truncated,
  248. # but we should deprecate external Env anyways in favor of a client-only
  249. # approach.
  250. self.cur_truncated_dict = {"__all__": False}
  251. else:
  252. self.new_observation = observation
  253. self.cur_terminated = True
  254. self.cur_truncated = False
  255. self._send()
  256. def _send(self):
  257. if self.multiagent:
  258. if not self.training_enabled:
  259. for agent_id in self.cur_info_dict:
  260. self.cur_info_dict[agent_id]["training_enabled"] = False
  261. item = {
  262. "obs": self.new_observation_dict,
  263. "reward": self.cur_reward_dict,
  264. "terminated": self.cur_terminated_dict,
  265. "truncated": self.cur_truncated_dict,
  266. "info": self.cur_info_dict,
  267. }
  268. if self.new_action_dict is not None:
  269. item["off_policy_action"] = self.new_action_dict
  270. self.new_observation_dict = None
  271. self.new_action_dict = None
  272. self.cur_reward_dict = {}
  273. else:
  274. item = {
  275. "obs": self.new_observation,
  276. "reward": self.cur_reward,
  277. "terminated": self.cur_terminated,
  278. "truncated": self.cur_truncated,
  279. "info": self.cur_info,
  280. }
  281. if self.new_action is not None:
  282. item["off_policy_action"] = self.new_action
  283. self.new_observation = None
  284. self.new_action = None
  285. self.cur_reward = 0.0
  286. if not self.training_enabled:
  287. item["info"]["training_enabled"] = False
  288. with self.results_avail_condition:
  289. self.data_queue.put_nowait(item)
  290. self.results_avail_condition.notify()
  291. @OldAPIStack
  292. class ExternalEnvWrapper(BaseEnv):
  293. """Internal adapter of ExternalEnv to BaseEnv."""
  294. def __init__(
  295. self, external_env: "ExternalEnv", preprocessor: "Preprocessor" = None
  296. ):
  297. from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
  298. self.external_env = external_env
  299. self.prep = preprocessor
  300. self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
  301. self._action_space = external_env.action_space
  302. if preprocessor:
  303. self._observation_space = preprocessor.observation_space
  304. else:
  305. self._observation_space = external_env.observation_space
  306. external_env.start()
  307. @override(BaseEnv)
  308. def poll(
  309. self,
  310. ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
  311. with self.external_env._results_avail_condition:
  312. results = self._poll()
  313. while len(results[0]) == 0:
  314. self.external_env._results_avail_condition.wait()
  315. results = self._poll()
  316. if not self.external_env.is_alive():
  317. raise Exception("Serving thread has stopped.")
  318. return results
  319. @override(BaseEnv)
  320. def send_actions(self, action_dict: MultiEnvDict) -> None:
  321. from ray.rllib.env.base_env import _DUMMY_AGENT_ID
  322. if self.multiagent:
  323. for env_id, actions in action_dict.items():
  324. self.external_env._episodes[env_id].action_queue.put(actions)
  325. else:
  326. for env_id, action in action_dict.items():
  327. self.external_env._episodes[env_id].action_queue.put(
  328. action[_DUMMY_AGENT_ID]
  329. )
  330. def _poll(
  331. self,
  332. ) -> Tuple[
  333. MultiEnvDict,
  334. MultiEnvDict,
  335. MultiEnvDict,
  336. MultiEnvDict,
  337. MultiEnvDict,
  338. MultiEnvDict,
  339. ]:
  340. from ray.rllib.env.base_env import with_dummy_agent_id
  341. all_obs, all_rewards, all_terminateds, all_truncateds, all_infos = (
  342. {},
  343. {},
  344. {},
  345. {},
  346. {},
  347. )
  348. off_policy_actions = {}
  349. for eid, episode in self.external_env._episodes.copy().items():
  350. data = episode.get_data()
  351. cur_terminated = (
  352. episode.cur_terminated_dict["__all__"]
  353. if self.multiagent
  354. else episode.cur_terminated
  355. )
  356. cur_truncated = (
  357. episode.cur_truncated_dict["__all__"]
  358. if self.multiagent
  359. else episode.cur_truncated
  360. )
  361. if cur_terminated or cur_truncated:
  362. del self.external_env._episodes[eid]
  363. if data:
  364. if self.prep:
  365. all_obs[eid] = self.prep.transform(data["obs"])
  366. else:
  367. all_obs[eid] = data["obs"]
  368. all_rewards[eid] = data["reward"]
  369. all_terminateds[eid] = data["terminated"]
  370. all_truncateds[eid] = data["truncated"]
  371. all_infos[eid] = data["info"]
  372. if "off_policy_action" in data:
  373. off_policy_actions[eid] = data["off_policy_action"]
  374. if self.multiagent:
  375. # Ensure a consistent set of keys
  376. # rely on all_obs having all possible keys for now.
  377. for eid, eid_dict in all_obs.items():
  378. for agent_id in eid_dict.keys():
  379. def fix(d, zero_val):
  380. if agent_id not in d[eid]:
  381. d[eid][agent_id] = zero_val
  382. fix(all_rewards, 0.0)
  383. fix(all_terminateds, False)
  384. fix(all_truncateds, False)
  385. fix(all_infos, {})
  386. return (
  387. all_obs,
  388. all_rewards,
  389. all_terminateds,
  390. all_truncateds,
  391. all_infos,
  392. off_policy_actions,
  393. )
  394. else:
  395. return (
  396. with_dummy_agent_id(all_obs),
  397. with_dummy_agent_id(all_rewards),
  398. with_dummy_agent_id(all_terminateds, "__all__"),
  399. with_dummy_agent_id(all_truncateds, "__all__"),
  400. with_dummy_agent_id(all_infos),
  401. with_dummy_agent_id(off_policy_actions),
  402. )
  403. @property
  404. @override(BaseEnv)
  405. def observation_space(self) -> gym.spaces.Dict:
  406. return self._observation_space
  407. @property
  408. @override(BaseEnv)
  409. def action_space(self) -> gym.Space:
  410. return self._action_space