policy_client.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import logging
  2. import threading
  3. import time
  4. from typing import Optional, Union
  5. import ray.cloudpickle as pickle
  6. # Backward compatibility.
  7. from ray.rllib.env.external.rllink import RLlink as Commands
  8. from ray.rllib.env.external_env import ExternalEnv
  9. from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
  10. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  11. from ray.rllib.policy.sample_batch import MultiAgentBatch
  12. from ray.rllib.utils.annotations import OldAPIStack
  13. from ray.rllib.utils.typing import (
  14. EnvActionType,
  15. EnvInfoDict,
  16. EnvObsType,
  17. MultiAgentDict,
  18. )
  19. logger = logging.getLogger(__name__)
  20. try:
  21. import requests # `requests` is not part of stdlib.
  22. except ImportError:
  23. requests = None
  24. logger.warning(
  25. "Couldn't import `requests` library. Be sure to install it on"
  26. " the client side."
  27. )
  28. @OldAPIStack
  29. class PolicyClient:
  30. """REST client to interact with an RLlib policy server."""
  31. def __init__(
  32. self,
  33. address: str,
  34. inference_mode: str = "local",
  35. update_interval: float = 10.0,
  36. session: Optional[requests.Session] = None,
  37. ):
  38. self.address = address
  39. self.session = session
  40. self.env: ExternalEnv = None
  41. if inference_mode == "local":
  42. self.local = True
  43. self._setup_local_rollout_worker(update_interval)
  44. elif inference_mode == "remote":
  45. self.local = False
  46. else:
  47. raise ValueError("inference_mode must be either 'local' or 'remote'")
  48. def start_episode(
  49. self, episode_id: Optional[str] = None, training_enabled: bool = True
  50. ) -> str:
  51. if self.local:
  52. self._update_local_policy()
  53. return self.env.start_episode(episode_id, training_enabled)
  54. return self._send(
  55. {
  56. "episode_id": episode_id,
  57. "command": Commands.START_EPISODE,
  58. "training_enabled": training_enabled,
  59. }
  60. )["episode_id"]
  61. def get_action(
  62. self, episode_id: str, observation: Union[EnvObsType, MultiAgentDict]
  63. ) -> Union[EnvActionType, MultiAgentDict]:
  64. if self.local:
  65. self._update_local_policy()
  66. if isinstance(episode_id, (list, tuple)):
  67. actions = {
  68. eid: self.env.get_action(eid, observation[eid])
  69. for eid in episode_id
  70. }
  71. return actions
  72. else:
  73. return self.env.get_action(episode_id, observation)
  74. else:
  75. return self._send(
  76. {
  77. "command": Commands.GET_ACTION,
  78. "observation": observation,
  79. "episode_id": episode_id,
  80. }
  81. )["action"]
  82. def log_action(
  83. self,
  84. episode_id: str,
  85. observation: Union[EnvObsType, MultiAgentDict],
  86. action: Union[EnvActionType, MultiAgentDict],
  87. ) -> None:
  88. if self.local:
  89. self._update_local_policy()
  90. return self.env.log_action(episode_id, observation, action)
  91. self._send(
  92. {
  93. "command": Commands.LOG_ACTION,
  94. "observation": observation,
  95. "action": action,
  96. "episode_id": episode_id,
  97. }
  98. )
  99. def log_returns(
  100. self,
  101. episode_id: str,
  102. reward: float,
  103. info: Union[EnvInfoDict, MultiAgentDict] = None,
  104. multiagent_done_dict: Optional[MultiAgentDict] = None,
  105. ) -> None:
  106. if self.local:
  107. self._update_local_policy()
  108. if multiagent_done_dict is not None:
  109. assert isinstance(reward, dict)
  110. return self.env.log_returns(
  111. episode_id, reward, info, multiagent_done_dict
  112. )
  113. return self.env.log_returns(episode_id, reward, info)
  114. self._send(
  115. {
  116. "command": Commands.LOG_RETURNS,
  117. "reward": reward,
  118. "info": info,
  119. "episode_id": episode_id,
  120. "done": multiagent_done_dict,
  121. }
  122. )
  123. def end_episode(
  124. self, episode_id: str, observation: Union[EnvObsType, MultiAgentDict]
  125. ) -> None:
  126. if self.local:
  127. self._update_local_policy()
  128. return self.env.end_episode(episode_id, observation)
  129. self._send(
  130. {
  131. "command": Commands.END_EPISODE,
  132. "observation": observation,
  133. "episode_id": episode_id,
  134. }
  135. )
  136. def update_policy_weights(self) -> None:
  137. """Query the server for new policy weights, if local inference is enabled."""
  138. self._update_local_policy(force=True)
  139. def _send(self, data):
  140. payload = pickle.dumps(data)
  141. if self.session is None:
  142. response = requests.post(self.address, data=payload)
  143. else:
  144. response = self.session.post(self.address, data=payload)
  145. if response.status_code != 200:
  146. logger.error("Request failed {}: {}".format(response.text, data))
  147. response.raise_for_status()
  148. parsed = pickle.loads(response.content)
  149. return parsed
  150. def _setup_local_rollout_worker(self, update_interval):
  151. self.update_interval = update_interval
  152. self.last_updated = 0
  153. logger.info("Querying server for rollout worker settings.")
  154. kwargs = self._send(
  155. {
  156. "command": Commands.GET_WORKER_ARGS,
  157. }
  158. )["worker_args"]
  159. (self.rollout_worker, self.inference_thread) = _create_embedded_rollout_worker(
  160. kwargs, self._send
  161. )
  162. self.env = self.rollout_worker.env
  163. def _update_local_policy(self, force=False):
  164. assert self.inference_thread.is_alive()
  165. if (
  166. self.update_interval
  167. and time.time() - self.last_updated > self.update_interval
  168. ) or force:
  169. logger.info("Querying server for new policy weights.")
  170. resp = self._send(
  171. {
  172. "command": Commands.GET_WEIGHTS,
  173. }
  174. )
  175. weights = resp["weights"]
  176. global_vars = resp["global_vars"]
  177. logger.info(
  178. "Updating rollout worker weights and global vars {}.".format(
  179. global_vars
  180. )
  181. )
  182. self.rollout_worker.set_weights(weights, global_vars)
  183. self.last_updated = time.time()
  184. @OldAPIStack
  185. class _LocalInferenceThread(threading.Thread):
  186. def __init__(self, rollout_worker, send_fn):
  187. super().__init__()
  188. self.daemon = True
  189. self.rollout_worker = rollout_worker
  190. self.send_fn = send_fn
  191. def run(self):
  192. try:
  193. while True:
  194. logger.info("Generating new batch of experiences.")
  195. samples = self.rollout_worker.sample()
  196. metrics = self.rollout_worker.get_metrics()
  197. if isinstance(samples, MultiAgentBatch):
  198. logger.info(
  199. "Sending batch of {} env steps ({} agent steps) to "
  200. "server.".format(samples.env_steps(), samples.agent_steps())
  201. )
  202. else:
  203. logger.info(
  204. "Sending batch of {} steps back to server.".format(
  205. samples.count
  206. )
  207. )
  208. self.send_fn(
  209. {
  210. "command": Commands.REPORT_SAMPLES,
  211. "samples": samples,
  212. "metrics": metrics,
  213. }
  214. )
  215. except Exception as e:
  216. logger.error("Error: inference worker thread died!", e)
  217. @OldAPIStack
  218. def _auto_wrap_external(real_env_creator):
  219. def wrapped_creator(env_config):
  220. real_env = real_env_creator(env_config)
  221. if not isinstance(real_env, (ExternalEnv, ExternalMultiAgentEnv)):
  222. logger.info(
  223. "The env you specified is not a supported (sub-)type of "
  224. "ExternalEnv. Attempting to convert it automatically to "
  225. "ExternalEnv."
  226. )
  227. if isinstance(real_env, MultiAgentEnv):
  228. external_cls = ExternalMultiAgentEnv
  229. else:
  230. external_cls = ExternalEnv
  231. class _ExternalEnvWrapper(external_cls):
  232. def __init__(self, real_env):
  233. super().__init__(
  234. observation_space=real_env.observation_space,
  235. action_space=real_env.action_space,
  236. )
  237. def run(self):
  238. # Since we are calling methods on this class in the
  239. # client, run doesn't need to do anything.
  240. time.sleep(999999)
  241. return _ExternalEnvWrapper(real_env)
  242. return real_env
  243. return wrapped_creator
  244. @OldAPIStack
  245. def _create_embedded_rollout_worker(kwargs, send_fn):
  246. # Since the server acts as an input datasource, we have to reset the
  247. # input config to the default, which runs env rollouts.
  248. kwargs = kwargs.copy()
  249. kwargs["config"] = kwargs["config"].copy(copy_frozen=False)
  250. config = kwargs["config"]
  251. config.output = None
  252. config.input_ = "sampler"
  253. config.input_config = {}
  254. # If server has no env (which is the expected case):
  255. # Generate a dummy ExternalEnv here using RandomEnv and the
  256. # given observation/action spaces.
  257. if config.env is None:
  258. from ray.rllib.examples.envs.classes.random_env import (
  259. RandomEnv,
  260. RandomMultiAgentEnv,
  261. )
  262. env_config = {
  263. "action_space": config.action_space,
  264. "observation_space": config.observation_space,
  265. }
  266. is_ma = config.is_multi_agent
  267. kwargs["env_creator"] = _auto_wrap_external(
  268. lambda _: (RandomMultiAgentEnv if is_ma else RandomEnv)(env_config)
  269. )
  270. # kwargs["config"].env = True
  271. # Otherwise, use the env specified by the server args.
  272. else:
  273. real_env_creator = kwargs["env_creator"]
  274. kwargs["env_creator"] = _auto_wrap_external(real_env_creator)
  275. logger.info("Creating rollout worker with kwargs={}".format(kwargs))
  276. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  277. rollout_worker = RolloutWorker(**kwargs)
  278. inference_thread = _LocalInferenceThread(rollout_worker, send_fn)
  279. inference_thread.start()
  280. return rollout_worker, inference_thread