dataservicer.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. import logging
  2. import sys
  3. import time
  4. from collections import defaultdict
  5. from queue import Queue
  6. from threading import Event, Lock, Thread
  7. from typing import TYPE_CHECKING, Any, Dict, Iterator, Union
  8. import grpc
  9. import ray
  10. import ray.core.generated.ray_client_pb2 as ray_client_pb2
  11. import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
  12. from ray._private.client_mode_hook import disable_client_hook
  13. from ray.util.client.common import (
  14. CLIENT_SERVER_MAX_THREADS,
  15. OrderedResponseCache,
  16. _propagate_error_in_context,
  17. )
  18. from ray.util.client.server.server_pickler import loads_from_client
  19. from ray.util.debug import log_once
  20. if TYPE_CHECKING:
  21. from ray.util.client.server.server import RayletServicer
  22. logger = logging.getLogger(__name__)
  23. QUEUE_JOIN_SECONDS = 10
  24. def _get_reconnecting_from_context(context: Any) -> bool:
  25. """
  26. Get `reconnecting` from gRPC metadata, or False if missing.
  27. """
  28. metadata = dict(context.invocation_metadata())
  29. val = metadata.get("reconnecting")
  30. if val is None or val not in ("True", "False"):
  31. logger.error(
  32. f'Client connecting with invalid value for "reconnecting": {val}, '
  33. "This may be because you have a mismatched client and server "
  34. "version."
  35. )
  36. return False
  37. return val == "True"
  38. def _should_cache(req: ray_client_pb2.DataRequest) -> bool:
  39. """
  40. Returns True if the response should to the given request should be cached,
  41. false otherwise. At the moment the only requests we do not cache are:
  42. - asynchronous gets: These arrive out of order. Skipping caching here
  43. is fine, since repeating an async get is idempotent
  44. - acks: Repeating acks is idempotent
  45. - clean up requests: Also idempotent, and client has likely already
  46. wrapped up the data connection by this point.
  47. - puts: We should only cache when we receive the final chunk, since
  48. any earlier chunks won't generate a response
  49. - tasks: We should only cache when we receive the final chunk,
  50. since any earlier chunks won't generate a response
  51. """
  52. req_type = req.WhichOneof("type")
  53. if req_type == "get" and req.get.asynchronous:
  54. return False
  55. if req_type == "put":
  56. return req.put.chunk_id == req.put.total_chunks - 1
  57. if req_type == "task":
  58. return req.task.chunk_id == req.task.total_chunks - 1
  59. return req_type not in ("acknowledge", "connection_cleanup")
  60. def fill_queue(
  61. grpc_input_generator: Iterator[ray_client_pb2.DataRequest],
  62. output_queue: "Queue[Union[ray_client_pb2.DataRequest, ray_client_pb2.DataResponse]]", # noqa: E501
  63. ) -> None:
  64. """
  65. Pushes incoming requests to a shared output_queue.
  66. """
  67. try:
  68. for req in grpc_input_generator:
  69. output_queue.put(req)
  70. except grpc.RpcError as e:
  71. logger.debug(
  72. "closing dataservicer reader thread "
  73. f"grpc error reading request_iterator: {e}"
  74. )
  75. finally:
  76. # Set the sentinel value for the output_queue
  77. output_queue.put(None)
  78. class ChunkCollector:
  79. """
  80. Helper class for collecting chunks from PutObject or ClientTask messages
  81. """
  82. def __init__(self):
  83. self.curr_req_id = None
  84. self.last_seen_chunk_id = -1
  85. self.data = bytearray()
  86. def add_chunk(
  87. self,
  88. req: ray_client_pb2.DataRequest,
  89. chunk: Union[ray_client_pb2.PutRequest, ray_client_pb2.ClientTask],
  90. ):
  91. if self.curr_req_id is not None and self.curr_req_id != req.req_id:
  92. raise RuntimeError(
  93. "Expected to receive a chunk from request with id "
  94. f"{self.curr_req_id}, but found {req.req_id} instead."
  95. )
  96. self.curr_req_id = req.req_id
  97. next_chunk = self.last_seen_chunk_id + 1
  98. if chunk.chunk_id < next_chunk:
  99. # Repeated chunk, ignore
  100. return
  101. if chunk.chunk_id > next_chunk:
  102. raise RuntimeError(
  103. f"A chunk {chunk.chunk_id} of request {req.req_id} was "
  104. "received out of order."
  105. )
  106. elif chunk.chunk_id == self.last_seen_chunk_id + 1:
  107. self.data.extend(chunk.data)
  108. self.last_seen_chunk_id = chunk.chunk_id
  109. return chunk.chunk_id + 1 == chunk.total_chunks
  110. def reset(self):
  111. self.curr_req_id = None
  112. self.last_seen_chunk_id = -1
  113. self.data = bytearray()
  114. class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
  115. def __init__(self, basic_service: "RayletServicer"):
  116. self.basic_service = basic_service
  117. self.clients_lock = Lock()
  118. self.num_clients = 0 # guarded by self.clients_lock
  119. # dictionary mapping client_id's to the last time they connected
  120. self.client_last_seen: Dict[str, float] = {}
  121. # dictionary mapping client_id's to their reconnect grace periods
  122. self.reconnect_grace_periods: Dict[str, float] = {}
  123. # dictionary mapping client_id's to their response cache
  124. self.response_caches: Dict[str, OrderedResponseCache] = defaultdict(
  125. OrderedResponseCache
  126. )
  127. # stopped event, useful for signals that the server is shut down
  128. self.stopped = Event()
  129. # Helper for collecting chunks from PutObject calls. Assumes that
  130. # that put requests from different objects aren't interleaved.
  131. self.put_request_chunk_collector = ChunkCollector()
  132. # Helper for collecting chunks from ClientTask calls. Assumes that
  133. # schedule requests from different remote calls aren't interleaved.
  134. self.client_task_chunk_collector = ChunkCollector()
  135. def Datapath(self, request_iterator, context):
  136. start_time = time.time()
  137. # set to True if client shuts down gracefully
  138. cleanup_requested = False
  139. metadata = dict(context.invocation_metadata())
  140. client_id = metadata.get("client_id")
  141. if client_id is None:
  142. logger.error("Client connecting with no client_id")
  143. return
  144. logger.debug(f"New data connection from client {client_id}: ")
  145. accepted_connection = self._init(client_id, context, start_time)
  146. response_cache = self.response_caches[client_id]
  147. # Set to False if client requests a reconnect grace period of 0
  148. reconnect_enabled = True
  149. if not accepted_connection:
  150. return
  151. try:
  152. request_queue = Queue()
  153. queue_filler_thread = Thread(
  154. target=fill_queue, daemon=True, args=(request_iterator, request_queue)
  155. )
  156. queue_filler_thread.start()
  157. """For non `async get` requests, this loop yields immediately
  158. For `async get` requests, this loop:
  159. 1) does not yield, it just continues
  160. 2) When the result is ready, it yields
  161. """
  162. for req in iter(request_queue.get, None):
  163. if isinstance(req, ray_client_pb2.DataResponse):
  164. # Early shortcut if this is the result of an async get.
  165. yield req
  166. continue
  167. assert isinstance(req, ray_client_pb2.DataRequest)
  168. if _should_cache(req) and reconnect_enabled:
  169. cached_resp = response_cache.check_cache(req.req_id)
  170. if isinstance(cached_resp, Exception):
  171. # Cache state is invalid, raise exception
  172. raise cached_resp
  173. if cached_resp is not None:
  174. yield cached_resp
  175. continue
  176. resp = None
  177. req_type = req.WhichOneof("type")
  178. if req_type == "init":
  179. resp_init = self.basic_service.Init(req.init)
  180. resp = ray_client_pb2.DataResponse(
  181. init=resp_init,
  182. )
  183. with self.clients_lock:
  184. self.reconnect_grace_periods[
  185. client_id
  186. ] = req.init.reconnect_grace_period
  187. if req.init.reconnect_grace_period == 0:
  188. reconnect_enabled = False
  189. elif req_type == "get":
  190. if req.get.asynchronous:
  191. get_resp = self.basic_service._async_get_object(
  192. req.get, client_id, req.req_id, request_queue
  193. )
  194. if get_resp is None:
  195. # Skip sending a response for this request and
  196. # continue to the next requst. The response for
  197. # this request will be sent when the object is
  198. # ready.
  199. continue
  200. else:
  201. get_resp = self.basic_service._get_object(req.get, client_id)
  202. resp = ray_client_pb2.DataResponse(get=get_resp)
  203. elif req_type == "put":
  204. if not self.put_request_chunk_collector.add_chunk(req, req.put):
  205. # Put request still in progress
  206. continue
  207. put_resp = self.basic_service._put_object(
  208. self.put_request_chunk_collector.data,
  209. req.put.client_ref_id,
  210. client_id,
  211. req.put.owner_id,
  212. )
  213. self.put_request_chunk_collector.reset()
  214. resp = ray_client_pb2.DataResponse(put=put_resp)
  215. elif req_type == "release":
  216. released = []
  217. for rel_id in req.release.ids:
  218. rel = self.basic_service.release(client_id, rel_id)
  219. released.append(rel)
  220. resp = ray_client_pb2.DataResponse(
  221. release=ray_client_pb2.ReleaseResponse(ok=released)
  222. )
  223. elif req_type == "connection_info":
  224. resp = ray_client_pb2.DataResponse(
  225. connection_info=self._build_connection_response()
  226. )
  227. elif req_type == "prep_runtime_env":
  228. with self.clients_lock:
  229. resp_prep = self.basic_service.PrepRuntimeEnv(
  230. req.prep_runtime_env
  231. )
  232. resp = ray_client_pb2.DataResponse(prep_runtime_env=resp_prep)
  233. elif req_type == "connection_cleanup":
  234. cleanup_requested = True
  235. cleanup_resp = ray_client_pb2.ConnectionCleanupResponse()
  236. resp = ray_client_pb2.DataResponse(connection_cleanup=cleanup_resp)
  237. elif req_type == "acknowledge":
  238. # Clean up acknowledged cache entries
  239. response_cache.cleanup(req.acknowledge.req_id)
  240. continue
  241. elif req_type == "task":
  242. with self.clients_lock:
  243. task = req.task
  244. if not self.client_task_chunk_collector.add_chunk(req, task):
  245. # Not all serialized arguments have arrived
  246. continue
  247. arglist, kwargs = loads_from_client(
  248. self.client_task_chunk_collector.data, self.basic_service
  249. )
  250. self.client_task_chunk_collector.reset()
  251. resp_ticket = self.basic_service.Schedule(
  252. req.task, arglist, kwargs, context
  253. )
  254. resp = ray_client_pb2.DataResponse(task_ticket=resp_ticket)
  255. del arglist
  256. del kwargs
  257. elif req_type == "terminate":
  258. with self.clients_lock:
  259. response = self.basic_service.Terminate(req.terminate, context)
  260. resp = ray_client_pb2.DataResponse(terminate=response)
  261. elif req_type == "list_named_actors":
  262. with self.clients_lock:
  263. response = self.basic_service.ListNamedActors(
  264. req.list_named_actors
  265. )
  266. resp = ray_client_pb2.DataResponse(list_named_actors=response)
  267. else:
  268. raise Exception(
  269. f"Unreachable code: Request type "
  270. f"{req_type} not handled in Datapath"
  271. )
  272. resp.req_id = req.req_id
  273. if _should_cache(req) and reconnect_enabled:
  274. response_cache.update_cache(req.req_id, resp)
  275. yield resp
  276. except Exception as e:
  277. logger.exception("Error in data channel:")
  278. recoverable = _propagate_error_in_context(e, context)
  279. invalid_cache = response_cache.invalidate(e)
  280. if not recoverable or invalid_cache:
  281. context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
  282. # Connection isn't recoverable, skip cleanup
  283. cleanup_requested = True
  284. finally:
  285. logger.debug(f"Stream is broken with client {client_id}")
  286. queue_filler_thread.join(QUEUE_JOIN_SECONDS)
  287. if queue_filler_thread.is_alive():
  288. logger.error(
  289. "Queue filler thread failed to join before timeout: {}".format(
  290. QUEUE_JOIN_SECONDS
  291. )
  292. )
  293. cleanup_delay = self.reconnect_grace_periods.get(client_id)
  294. if not cleanup_requested and cleanup_delay is not None:
  295. logger.debug(
  296. "Cleanup wasn't requested, delaying cleanup by"
  297. f"{cleanup_delay} seconds."
  298. )
  299. # Delay cleanup, since client may attempt a reconnect
  300. # Wait on the "stopped" event in case the grpc server is
  301. # stopped and we can clean up earlier.
  302. self.stopped.wait(timeout=cleanup_delay)
  303. else:
  304. logger.debug("Cleanup was requested, cleaning up immediately.")
  305. with self.clients_lock:
  306. if client_id not in self.client_last_seen:
  307. logger.debug("Connection already cleaned up.")
  308. # Some other connection has already cleaned up this
  309. # this client's session. This can happen if the client
  310. # reconnects and then gracefully shut's down immediately.
  311. return
  312. last_seen = self.client_last_seen[client_id]
  313. if last_seen > start_time:
  314. # The client successfully reconnected and updated
  315. # last seen some time during the grace period
  316. logger.debug("Client reconnected, skipping cleanup")
  317. return
  318. # Either the client shut down gracefully, or the client
  319. # failed to reconnect within the grace period. Clean up
  320. # the connection.
  321. self.basic_service.release_all(client_id)
  322. del self.client_last_seen[client_id]
  323. if client_id in self.reconnect_grace_periods:
  324. del self.reconnect_grace_periods[client_id]
  325. if client_id in self.response_caches:
  326. del self.response_caches[client_id]
  327. self.num_clients -= 1
  328. logger.debug(
  329. f"Removed client {client_id}, " f"remaining={self.num_clients}"
  330. )
  331. # It's important to keep the Ray shutdown
  332. # within this locked context or else Ray could hang.
  333. # NOTE: it is strange to start ray in server.py but shut it
  334. # down here. Consider consolidating ray lifetime management.
  335. with disable_client_hook():
  336. if self.num_clients == 0:
  337. logger.debug("Shutting down ray.")
  338. ray.shutdown()
  339. def _init(self, client_id: str, context: Any, start_time: float):
  340. """
  341. Checks if resources allow for another client.
  342. Returns a boolean indicating if initialization was successful.
  343. """
  344. with self.clients_lock:
  345. reconnecting = _get_reconnecting_from_context(context)
  346. threshold = int(CLIENT_SERVER_MAX_THREADS / 2)
  347. if self.num_clients >= threshold:
  348. logger.warning(
  349. f"[Data Servicer]: Num clients {self.num_clients} "
  350. f"has reached the threshold {threshold}. "
  351. f"Rejecting client: {client_id}. "
  352. )
  353. if log_once("client_threshold"):
  354. logger.warning(
  355. "You can configure the client connection "
  356. "threshold by setting the "
  357. "RAY_CLIENT_SERVER_MAX_THREADS env var "
  358. f"(currently set to {CLIENT_SERVER_MAX_THREADS})."
  359. )
  360. context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
  361. return False
  362. if reconnecting and client_id not in self.client_last_seen:
  363. # Client took too long to reconnect, session has been
  364. # cleaned up.
  365. context.set_code(grpc.StatusCode.NOT_FOUND)
  366. context.set_details(
  367. "Attempted to reconnect to a session that has already "
  368. "been cleaned up."
  369. )
  370. return False
  371. if client_id in self.client_last_seen:
  372. logger.debug(f"Client {client_id} has reconnected.")
  373. else:
  374. self.num_clients += 1
  375. logger.debug(
  376. f"Accepted data connection from {client_id}. "
  377. f"Total clients: {self.num_clients}"
  378. )
  379. self.client_last_seen[client_id] = start_time
  380. return True
  381. def _build_connection_response(self):
  382. with self.clients_lock:
  383. cur_num_clients = self.num_clients
  384. return ray_client_pb2.ConnectionInfoResponse(
  385. num_clients=cur_num_clients,
  386. python_version="{}.{}.{}".format(
  387. sys.version_info[0], sys.version_info[1], sys.version_info[2]
  388. ),
  389. ray_version=ray.__version__,
  390. ray_commit=ray.__commit__,
  391. )