dataclient.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. """This file implements a threaded stream controller to abstract a data stream
  2. back to the ray clientserver.
  3. """
  4. import logging
  5. import math
  6. import queue
  7. import threading
  8. import warnings
  9. from collections import OrderedDict
  10. from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
  11. import grpc
  12. import ray.core.generated.ray_client_pb2 as ray_client_pb2
  13. import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
  14. from ray.util.client.common import (
  15. INT32_MAX,
  16. OBJECT_TRANSFER_CHUNK_SIZE,
  17. OBJECT_TRANSFER_WARNING_SIZE,
  18. )
  19. from ray.util.debug import log_once
  20. if TYPE_CHECKING:
  21. from ray.util.client.worker import Worker
  22. logger = logging.getLogger(__name__)
  23. ResponseCallable = Callable[[Union[ray_client_pb2.DataResponse, Exception]], None]
  24. # Send an acknowledge on every 32nd response received
  25. ACKNOWLEDGE_BATCH_SIZE = 32
  26. def chunk_put(req: ray_client_pb2.DataRequest):
  27. """
  28. Chunks a put request. Doing this lazily is important for large objects,
  29. since taking slices of bytes objects does a copy. This means if we
  30. immediately materialized every chunk of a large object and inserted them
  31. into the result_queue, we would effectively double the memory needed
  32. on the client to handle the put.
  33. """
  34. # When accessing a protobuf field, deserialization is performed, which will
  35. # generate a copy. So we need to avoid accessing the `data` field multiple
  36. # times in the loop
  37. request_data = req.put.data
  38. total_size = len(request_data)
  39. assert total_size > 0, "Cannot chunk object with missing data"
  40. if total_size >= OBJECT_TRANSFER_WARNING_SIZE and log_once(
  41. "client_object_put_size_warning"
  42. ):
  43. size_gb = total_size / 2**30
  44. warnings.warn(
  45. "Ray Client is attempting to send a "
  46. f"{size_gb:.2f} GiB object over the network, which may "
  47. "be slow. Consider serializing the object and using a remote "
  48. "URI to transfer via S3 or Google Cloud Storage instead. "
  49. "Documentation for doing this can be found here: "
  50. "https://docs.ray.io/en/latest/handling-dependencies.html#remote-uris",
  51. UserWarning,
  52. )
  53. total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE)
  54. for chunk_id in range(0, total_chunks):
  55. start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
  56. end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE)
  57. chunk = ray_client_pb2.PutRequest(
  58. client_ref_id=req.put.client_ref_id,
  59. data=request_data[start:end],
  60. chunk_id=chunk_id,
  61. total_chunks=total_chunks,
  62. total_size=total_size,
  63. owner_id=req.put.owner_id,
  64. )
  65. yield ray_client_pb2.DataRequest(req_id=req.req_id, put=chunk)
  66. def chunk_task(req: ray_client_pb2.DataRequest):
  67. """
  68. Chunks a client task. Doing this lazily is important with large arguments,
  69. since taking slices of bytes objects does a copy. This means if we
  70. immediately materialized every chunk of a large argument and inserted them
  71. into the result_queue, we would effectively double the memory needed
  72. on the client to handle the task.
  73. """
  74. # When accessing a protobuf field, deserialization is performed, which will
  75. # generate a copy. So we need to avoid accessing the `data` field multiple
  76. # times in the loop
  77. request_data = req.task.data
  78. total_size = len(request_data)
  79. assert total_size > 0, "Cannot chunk object with missing data"
  80. total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE)
  81. for chunk_id in range(0, total_chunks):
  82. start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
  83. end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE)
  84. chunk = ray_client_pb2.ClientTask(
  85. type=req.task.type,
  86. name=req.task.name,
  87. payload_id=req.task.payload_id,
  88. client_id=req.task.client_id,
  89. options=req.task.options,
  90. baseline_options=req.task.baseline_options,
  91. namespace=req.task.namespace,
  92. data=request_data[start:end],
  93. chunk_id=chunk_id,
  94. total_chunks=total_chunks,
  95. )
  96. yield ray_client_pb2.DataRequest(req_id=req.req_id, task=chunk)
  97. class ChunkCollector:
  98. """
  99. This object collects chunks from async get requests via __call__, and
  100. calls the underlying callback when the object is fully received, or if an
  101. exception while retrieving the object occurs.
  102. This is not used in synchronous gets (synchronous gets interact with the
  103. raylet servicer directly, not through the datapath).
  104. __call__ returns true once the underlying call back has been called.
  105. """
  106. def __init__(self, callback: ResponseCallable, request: ray_client_pb2.DataRequest):
  107. # Bytearray containing data received so far
  108. self.data = bytearray()
  109. # The callback that will be called once all data is received
  110. self.callback = callback
  111. # The id of the last chunk we've received, or -1 if haven't seen any yet
  112. self.last_seen_chunk = -1
  113. # The GetRequest that initiated the transfer. start_chunk_id will be
  114. # updated as chunks are received to avoid re-requesting chunks that
  115. # we've already received.
  116. self.request = request
  117. def __call__(self, response: Union[ray_client_pb2.DataResponse, Exception]) -> bool:
  118. if isinstance(response, Exception):
  119. self.callback(response)
  120. return True
  121. get_resp = response.get
  122. if not get_resp.valid:
  123. self.callback(response)
  124. return True
  125. if get_resp.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once(
  126. "client_object_transfer_size_warning"
  127. ):
  128. size_gb = get_resp.total_size / 2**30
  129. warnings.warn(
  130. "Ray Client is attempting to retrieve a "
  131. f"{size_gb:.2f} GiB object over the network, which may "
  132. "be slow. Consider serializing the object to a file and "
  133. "using rsync or S3 instead.",
  134. UserWarning,
  135. )
  136. chunk_data = get_resp.data
  137. chunk_id = get_resp.chunk_id
  138. if chunk_id == self.last_seen_chunk + 1:
  139. self.data.extend(chunk_data)
  140. self.last_seen_chunk = chunk_id
  141. # If we disconnect partway through, restart the get request
  142. # at the first chunk we haven't seen
  143. self.request.get.start_chunk_id = self.last_seen_chunk + 1
  144. elif chunk_id > self.last_seen_chunk + 1:
  145. # A chunk was skipped. This shouldn't happen in practice since
  146. # grpc guarantees that chunks will arrive in order.
  147. msg = (
  148. f"Received chunk {chunk_id} when we expected "
  149. f"{self.last_seen_chunk + 1} for request {response.req_id}"
  150. )
  151. logger.warning(msg)
  152. self.callback(RuntimeError(msg))
  153. return True
  154. else:
  155. # We received a chunk that've already seen before. Ignore, since
  156. # it should already be appended to self.data.
  157. logger.debug(
  158. f"Received a repeated chunk {chunk_id} "
  159. f"from request {response.req_id}."
  160. )
  161. if get_resp.chunk_id == get_resp.total_chunks - 1:
  162. self.callback(self.data)
  163. return True
  164. else:
  165. # Not done yet
  166. return False
  167. class DataClient:
  168. def __init__(self, client_worker: "Worker", client_id: str, metadata: list):
  169. """Initializes a thread-safe datapath over a Ray Client gRPC channel.
  170. Args:
  171. client_worker: The Ray Client worker that manages this client
  172. client_id: the generated ID representing this client
  173. metadata: metadata to pass to gRPC requests
  174. """
  175. self.client_worker = client_worker
  176. self._client_id = client_id
  177. self._metadata = metadata
  178. self.data_thread = self._start_datathread()
  179. # Track outstanding requests to resend in case of disconnection
  180. self.outstanding_requests: Dict[int, Any] = OrderedDict()
  181. # Serialize access to all mutable internal states: self.request_queue,
  182. # self.ready_data, self.asyncio_waiting_data,
  183. # self._in_shutdown, self._req_id, self.outstanding_requests and
  184. # calling self._next_id()
  185. self.lock = threading.Lock()
  186. # Waiting for response or shutdown.
  187. self.cv = threading.Condition(lock=self.lock)
  188. self.request_queue = self._create_queue()
  189. self.ready_data: Dict[int, Any] = {}
  190. # NOTE: Dictionary insertion is guaranteed to complete before lookup
  191. # and/or removal because of synchronization via the request_queue.
  192. self.asyncio_waiting_data: Dict[int, ResponseCallable] = {}
  193. self._in_shutdown = False
  194. self._req_id = 0
  195. self._last_exception = None
  196. self._acknowledge_counter = 0
  197. self.data_thread.start()
  198. # Must hold self.lock when calling this function.
  199. def _next_id(self) -> int:
  200. assert self.lock.locked()
  201. self._req_id += 1
  202. if self._req_id > INT32_MAX:
  203. self._req_id = 1
  204. # Responses that aren't tracked (like opportunistic releases)
  205. # have req_id=0, so make sure we never mint such an id.
  206. assert self._req_id != 0
  207. return self._req_id
  208. def _start_datathread(self) -> threading.Thread:
  209. return threading.Thread(
  210. target=self._data_main,
  211. name="ray_client_streaming_rpc",
  212. args=(),
  213. daemon=True,
  214. )
  215. # A helper that takes requests from queue. If the request wraps a PutRequest,
  216. # lazily chunks and yields the request. Otherwise, yields the request directly.
  217. def _requests(self):
  218. while True:
  219. req = self.request_queue.get()
  220. if req is None:
  221. # Stop when client signals shutdown.
  222. return
  223. req_type = req.WhichOneof("type")
  224. if req_type == "put":
  225. yield from chunk_put(req)
  226. elif req_type == "task":
  227. yield from chunk_task(req)
  228. else:
  229. yield req
  230. def _data_main(self) -> None:
  231. reconnecting = False
  232. try:
  233. while not self.client_worker._in_shutdown:
  234. stub = ray_client_pb2_grpc.RayletDataStreamerStub(
  235. self.client_worker.channel
  236. )
  237. metadata = self._metadata + [("reconnecting", str(reconnecting))]
  238. resp_stream = stub.Datapath(
  239. self._requests(),
  240. metadata=metadata,
  241. wait_for_ready=True,
  242. )
  243. try:
  244. for response in resp_stream:
  245. self._process_response(response)
  246. return
  247. except grpc.RpcError as e:
  248. reconnecting = self._can_reconnect(e)
  249. if not reconnecting:
  250. self._last_exception = e
  251. return
  252. self._reconnect_channel()
  253. except Exception as e:
  254. self._last_exception = e
  255. finally:
  256. logger.debug("Shutting down data channel.")
  257. self._shutdown()
  258. def _process_response(self, response: Any) -> None:
  259. """
  260. Process responses from the data servicer.
  261. """
  262. if response.req_id == 0:
  263. # This is not being waited for.
  264. logger.debug(f"Got unawaited response {response}")
  265. return
  266. if response.req_id in self.asyncio_waiting_data:
  267. can_remove = True
  268. try:
  269. callback = self.asyncio_waiting_data[response.req_id]
  270. if isinstance(callback, ChunkCollector):
  271. can_remove = callback(response)
  272. elif callback:
  273. callback(response)
  274. if can_remove:
  275. # NOTE: calling del self.asyncio_waiting_data results
  276. # in the destructor of ClientObjectRef running, which
  277. # calls ReleaseObject(). So self.asyncio_waiting_data
  278. # is accessed without holding self.lock. Holding the
  279. # lock shouldn't be necessary either.
  280. del self.asyncio_waiting_data[response.req_id]
  281. except Exception:
  282. logger.exception("Callback error:")
  283. with self.lock:
  284. # Update outstanding requests
  285. if response.req_id in self.outstanding_requests and can_remove:
  286. del self.outstanding_requests[response.req_id]
  287. # Acknowledge response
  288. self._acknowledge(response.req_id)
  289. else:
  290. with self.lock:
  291. self.ready_data[response.req_id] = response
  292. self.cv.notify_all()
  293. def _can_reconnect(self, e: grpc.RpcError) -> bool:
  294. """
  295. Processes RPC errors that occur while reading from data stream.
  296. Returns True if the error can be recovered from, False otherwise.
  297. """
  298. if not self.client_worker._can_reconnect(e):
  299. logger.error("Unrecoverable error in data channel.")
  300. logger.debug(e)
  301. return False
  302. logger.debug("Recoverable error in data channel.")
  303. logger.debug(e)
  304. return True
  305. def _shutdown(self) -> None:
  306. """
  307. Shutdown the data channel
  308. """
  309. with self.lock:
  310. self._in_shutdown = True
  311. self.cv.notify_all()
  312. callbacks = self.asyncio_waiting_data.values()
  313. self.asyncio_waiting_data = {}
  314. if self._last_exception:
  315. # Abort async requests with the error.
  316. err = ConnectionError(
  317. "Failed during this or a previous request. Exception that "
  318. f"broke the connection: {self._last_exception}"
  319. )
  320. else:
  321. err = ConnectionError(
  322. "Request cannot be fulfilled because the data client has "
  323. "disconnected."
  324. )
  325. for callback in callbacks:
  326. if callback:
  327. callback(err)
  328. # Since self._in_shutdown is set to True, no new item
  329. # will be added to self.asyncio_waiting_data
  330. def _acknowledge(self, req_id: int) -> None:
  331. """
  332. Puts an acknowledge request on the request queue periodically.
  333. Lock should be held before calling this. Used when an async or
  334. blocking response is received.
  335. """
  336. if not self.client_worker._reconnect_enabled:
  337. # Skip ACKs if reconnect isn't enabled
  338. return
  339. assert self.lock.locked()
  340. self._acknowledge_counter += 1
  341. if self._acknowledge_counter % ACKNOWLEDGE_BATCH_SIZE == 0:
  342. self.request_queue.put(
  343. ray_client_pb2.DataRequest(
  344. acknowledge=ray_client_pb2.AcknowledgeRequest(req_id=req_id)
  345. )
  346. )
  347. def _reconnect_channel(self) -> None:
  348. """
  349. Attempts to reconnect the gRPC channel and resend outstanding
  350. requests. First, the server is pinged to see if the current channel
  351. still works. If the ping fails, then the current channel is closed
  352. and replaced with a new one.
  353. Once a working channel is available, a new request queue is made
  354. and filled with any outstanding requests to be resent to the server.
  355. """
  356. try:
  357. # Ping the server to see if the current channel is reuseable, for
  358. # example if gRPC reconnected the channel on its own or if the
  359. # RPC error was transient and the channel is still open
  360. ping_succeeded = self.client_worker.ping_server(timeout=5)
  361. except grpc.RpcError:
  362. ping_succeeded = False
  363. if not ping_succeeded:
  364. # Ping failed, try refreshing the data channel
  365. logger.warning(
  366. "Encountered connection issues in the data channel. "
  367. "Attempting to reconnect."
  368. )
  369. try:
  370. self.client_worker._connect_channel(reconnecting=True)
  371. except ConnectionError:
  372. logger.warning("Failed to reconnect the data channel")
  373. raise
  374. logger.debug("Reconnection succeeded!")
  375. # Recreate the request queue, and resend outstanding requests
  376. with self.lock:
  377. self.request_queue = self._create_queue()
  378. for request in self.outstanding_requests.values():
  379. # Resend outstanding requests
  380. self.request_queue.put(request)
  381. # Use SimpleQueue to avoid deadlocks when appending to queue from __del__()
  382. @staticmethod
  383. def _create_queue():
  384. return queue.SimpleQueue()
  385. def close(self) -> None:
  386. thread = None
  387. with self.lock:
  388. self._in_shutdown = True
  389. # Notify blocking operations to fail.
  390. self.cv.notify_all()
  391. # Add sentinel to terminate streaming RPC.
  392. if self.request_queue is not None:
  393. # Intentional shutdown, tell server it can clean up the
  394. # connection immediately and ignore the reconnect grace period.
  395. cleanup_request = ray_client_pb2.DataRequest(
  396. connection_cleanup=ray_client_pb2.ConnectionCleanupRequest()
  397. )
  398. self.request_queue.put(cleanup_request)
  399. self.request_queue.put(None)
  400. if self.data_thread is not None:
  401. thread = self.data_thread
  402. # Wait until streaming RPCs are done.
  403. if thread is not None:
  404. thread.join()
  405. def _blocking_send(
  406. self, req: ray_client_pb2.DataRequest
  407. ) -> ray_client_pb2.DataResponse:
  408. with self.lock:
  409. self._check_shutdown()
  410. req_id = self._next_id()
  411. req.req_id = req_id
  412. self.request_queue.put(req)
  413. self.outstanding_requests[req_id] = req
  414. self.cv.wait_for(lambda: req_id in self.ready_data or self._in_shutdown)
  415. self._check_shutdown()
  416. data = self.ready_data[req_id]
  417. del self.ready_data[req_id]
  418. del self.outstanding_requests[req_id]
  419. self._acknowledge(req_id)
  420. return data
  421. def _async_send(
  422. self,
  423. req: ray_client_pb2.DataRequest,
  424. callback: Optional[ResponseCallable] = None,
  425. ) -> None:
  426. with self.lock:
  427. self._check_shutdown()
  428. req_id = self._next_id()
  429. req.req_id = req_id
  430. self.asyncio_waiting_data[req_id] = callback
  431. self.outstanding_requests[req_id] = req
  432. self.request_queue.put(req)
  433. # Must hold self.lock when calling this function.
  434. def _check_shutdown(self):
  435. assert self.lock.locked()
  436. if not self._in_shutdown:
  437. return
  438. self.lock.release()
  439. # Do not try disconnect() or throw exceptions in self.data_thread.
  440. # Otherwise deadlock can occur.
  441. if threading.current_thread().ident == self.data_thread.ident:
  442. return
  443. from ray.util import disconnect
  444. disconnect()
  445. self.lock.acquire()
  446. if self._last_exception is not None:
  447. msg = (
  448. "Request can't be sent because the Ray client has already "
  449. "been disconnected due to an error. Last exception: "
  450. f"{self._last_exception}"
  451. )
  452. else:
  453. msg = (
  454. "Request can't be sent because the Ray client has already "
  455. "been disconnected."
  456. )
  457. raise ConnectionError(msg)
  458. def Init(
  459. self, request: ray_client_pb2.InitRequest, context=None
  460. ) -> ray_client_pb2.InitResponse:
  461. datareq = ray_client_pb2.DataRequest(
  462. init=request,
  463. )
  464. resp = self._blocking_send(datareq)
  465. return resp.init
  466. def PrepRuntimeEnv(
  467. self, request: ray_client_pb2.PrepRuntimeEnvRequest, context=None
  468. ) -> ray_client_pb2.PrepRuntimeEnvResponse:
  469. datareq = ray_client_pb2.DataRequest(
  470. prep_runtime_env=request,
  471. )
  472. resp = self._blocking_send(datareq)
  473. return resp.prep_runtime_env
  474. def ConnectionInfo(self, context=None) -> ray_client_pb2.ConnectionInfoResponse:
  475. datareq = ray_client_pb2.DataRequest(
  476. connection_info=ray_client_pb2.ConnectionInfoRequest()
  477. )
  478. resp = self._blocking_send(datareq)
  479. return resp.connection_info
  480. def GetObject(
  481. self, request: ray_client_pb2.GetRequest, context=None
  482. ) -> ray_client_pb2.GetResponse:
  483. datareq = ray_client_pb2.DataRequest(
  484. get=request,
  485. )
  486. resp = self._blocking_send(datareq)
  487. return resp.get
  488. def RegisterGetCallback(
  489. self, request: ray_client_pb2.GetRequest, callback: ResponseCallable
  490. ) -> None:
  491. if len(request.ids) != 1:
  492. raise ValueError(
  493. "RegisterGetCallback() must have exactly 1 Object ID. "
  494. f"Actual: {request}"
  495. )
  496. datareq = ray_client_pb2.DataRequest(
  497. get=request,
  498. )
  499. collector = ChunkCollector(callback=callback, request=datareq)
  500. self._async_send(datareq, collector)
  501. # TODO: convert PutObject to async
  502. def PutObject(
  503. self, request: ray_client_pb2.PutRequest, context=None
  504. ) -> ray_client_pb2.PutResponse:
  505. datareq = ray_client_pb2.DataRequest(
  506. put=request,
  507. )
  508. resp = self._blocking_send(datareq)
  509. return resp.put
  510. def ReleaseObject(
  511. self, request: ray_client_pb2.ReleaseRequest, context=None
  512. ) -> None:
  513. datareq = ray_client_pb2.DataRequest(
  514. release=request,
  515. )
  516. self._async_send(datareq)
  517. def Schedule(self, request: ray_client_pb2.ClientTask, callback: ResponseCallable):
  518. datareq = ray_client_pb2.DataRequest(task=request)
  519. self._async_send(datareq, callback)
  520. def Terminate(
  521. self, request: ray_client_pb2.TerminateRequest
  522. ) -> ray_client_pb2.TerminateResponse:
  523. req = ray_client_pb2.DataRequest(
  524. terminate=request,
  525. )
  526. resp = self._blocking_send(req)
  527. return resp.terminate
  528. def ListNamedActors(
  529. self, request: ray_client_pb2.ClientListNamedActorsRequest
  530. ) -> ray_client_pb2.ClientListNamedActorsResponse:
  531. req = ray_client_pb2.DataRequest(
  532. list_named_actors=request,
  533. )
  534. resp = self._blocking_send(req)
  535. return resp.list_named_actors