worker.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968
  1. """This file includes the Worker class which sits on the client side.
  2. It implements the Ray API functions that are forwarded through grpc calls
  3. to the server.
  4. """
  5. import base64
  6. import json
  7. import logging
  8. import os
  9. import queue
  10. import tempfile
  11. import threading
  12. import time
  13. import uuid
  14. import warnings
  15. from collections import defaultdict
  16. from concurrent.futures import Future
  17. from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
  18. import grpc
  19. import ray.cloudpickle as cloudpickle
  20. import ray.core.generated.ray_client_pb2 as ray_client_pb2
  21. import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
  22. from ray._private.ray_constants import (
  23. DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD,
  24. env_float,
  25. env_integer,
  26. )
  27. from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
  28. from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
  29. # Use cloudpickle's version of pickle for UnpicklingError
  30. from ray.cloudpickle.compat import pickle
  31. from ray.exceptions import GetTimeoutError
  32. from ray.job_config import JobConfig
  33. from ray.util.client.client_pickler import dumps_from_client, loads_from_server
  34. from ray.util.client.common import (
  35. GRPC_OPTIONS,
  36. GRPC_UNRECOVERABLE_ERRORS,
  37. INT32_MAX,
  38. OBJECT_TRANSFER_WARNING_SIZE,
  39. ClientActorClass,
  40. ClientActorHandle,
  41. ClientActorRef,
  42. ClientObjectRef,
  43. ClientRemoteFunc,
  44. ClientStub,
  45. )
  46. from ray.util.client.dataclient import DataClient
  47. from ray.util.client.logsclient import LogstreamClient
  48. from ray.util.debug import log_once
  49. if TYPE_CHECKING:
  50. from ray.actor import ActorClass
  51. from ray.remote_function import RemoteFunction
  52. logger = logging.getLogger(__name__)
  53. INITIAL_TIMEOUT_SEC = env_integer("RAY_CLIENT_INITIAL_CONNECTION_TIMEOUT_S", 5)
  54. MAX_TIMEOUT_SEC = env_integer("RAY_CLIENT_MAX_CONNECTION_TIMEOUT_S", 30)
  55. # The max amount of time an operation can run blocking in the server. This
  56. # allows for Ctrl-C of the client to work without explicitly cancelling server
  57. # operations.
  58. MAX_BLOCKING_OPERATION_TIME_S: float = env_float(
  59. "RAY_CLIENT_MAX_BLOCKING_OPERATION_TIME_S", 2.0
  60. )
  61. # If the total size (bytes) of all outbound messages to schedule tasks since
  62. # the connection began exceeds this value, a warning should be raised
  63. MESSAGE_SIZE_THRESHOLD = 10 * 2**20 # 10 MB
  64. # Links to the Ray Design Pattern doc to use in the task overhead warning
  65. # message
  66. DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = "https://docs.ray.io/en/latest/ray-core/patterns/too-fine-grained-tasks.html" # noqa E501
  67. DESIGN_PATTERN_LARGE_OBJECTS_LINK = "https://docs.ray.io/en/latest/ray-core/patterns/closure-capture-large-objects.html" # noqa E501
  68. def backoff(timeout: int) -> int:
  69. timeout = timeout + 5
  70. if timeout > MAX_TIMEOUT_SEC:
  71. timeout = MAX_TIMEOUT_SEC
  72. return timeout
  73. class Worker:
  74. def __init__(
  75. self,
  76. conn_str: str = "",
  77. secure: bool = False,
  78. metadata: List[Tuple[str, str]] = None,
  79. connection_retries: int = 3,
  80. _credentials: Optional[grpc.ChannelCredentials] = None,
  81. ):
  82. """Initializes the worker side grpc client.
  83. Args:
  84. conn_str: The host:port connection string for the ray server.
  85. secure: whether to use SSL secure channel or not.
  86. metadata: additional metadata passed in the grpc request headers.
  87. connection_retries: Number of times to attempt to reconnect to the
  88. ray server if it doesn't respond immediately. Setting to 0 tries
  89. at least once. For infinite retries, catch the ConnectionError
  90. exception.
  91. _credentials: gprc channel credentials. Default ones will be used
  92. if None.
  93. """
  94. self._client_id = make_client_id()
  95. self.metadata = [("client_id", self._client_id)] + (
  96. metadata if metadata else []
  97. )
  98. self.channel = None
  99. self.server = None
  100. self._conn_state = grpc.ChannelConnectivity.IDLE
  101. self._converted: Dict[str, ClientStub] = {}
  102. self._secure = secure or os.environ.get("RAY_USE_TLS", "0").lower() in (
  103. "1",
  104. "true",
  105. )
  106. self._conn_str = conn_str
  107. self._connection_retries = connection_retries
  108. if _credentials is not None:
  109. self._credentials = _credentials
  110. self._secure = True
  111. else:
  112. self._credentials = None
  113. self._reconnect_grace_period = DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD
  114. if "RAY_CLIENT_RECONNECT_GRACE_PERIOD" in os.environ:
  115. # Use value in environment variable if available
  116. self._reconnect_grace_period = int(
  117. os.environ["RAY_CLIENT_RECONNECT_GRACE_PERIOD"]
  118. )
  119. # Disable retries if grace period is set to 0
  120. self._reconnect_enabled = self._reconnect_grace_period != 0
  121. # Set to True when the connection cannot be recovered and reconnect
  122. # attempts should be stopped
  123. self._in_shutdown = False
  124. # Set to True after initial connection succeeds
  125. self._has_connected = False
  126. self._connect_channel()
  127. self._has_connected = True
  128. # Has Ray been initialized on the server?
  129. self._serverside_ray_initialized = False
  130. # Initialize the streams to finish protocol negotiation.
  131. self.data_client = DataClient(self, self._client_id, self.metadata)
  132. self.reference_count: Dict[bytes, int] = defaultdict(int)
  133. self.log_client = LogstreamClient(self, self.metadata)
  134. self.log_client.set_logstream_level(logging.INFO)
  135. self.closed = False
  136. # Track this value to raise a warning if a lot of data are transferred.
  137. self.total_outbound_message_size_bytes = 0
  138. # Used to create unique IDs for RPCs to the RayletServicer
  139. self._req_id_lock = threading.Lock()
  140. self._req_id = 0
  141. # ReleaseObject grabs a lock, so it should not be called directly from
  142. # __del__ methods that may be executed at any time on the Python main thread.
  143. self._release_queue = queue.SimpleQueue()
  144. self._release_thread = threading.Thread(
  145. target=self._release_server_worker, daemon=True
  146. )
  147. self._release_thread.start()
  148. def _connect_channel(self, reconnecting=False) -> None:
  149. """
  150. Attempts to connect to the server specified by conn_str. If
  151. reconnecting after an RPC error, cleans up the old channel and
  152. continues to attempt to connect until the grace period is over.
  153. """
  154. if self.channel is not None:
  155. self.channel.unsubscribe(self._on_channel_state_change)
  156. self.channel.close()
  157. from ray._private.grpc_utils import init_grpc_channel
  158. # Prepare credentials if secure connection is requested
  159. credentials = None
  160. if self._secure:
  161. if self._credentials is not None:
  162. credentials = self._credentials
  163. elif os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"):
  164. # init_grpc_channel will handle this via load_certs_from_env()
  165. credentials = None
  166. else:
  167. # Default SSL credentials (no specific certs)
  168. credentials = grpc.ssl_channel_credentials()
  169. # Create channel with auth interceptors via helper
  170. # This automatically adds auth interceptors when token auth is enabled
  171. self.channel = init_grpc_channel(
  172. self._conn_str,
  173. options=GRPC_OPTIONS,
  174. asynchronous=False,
  175. credentials=credentials,
  176. )
  177. self.channel.subscribe(self._on_channel_state_change)
  178. # Retry the connection until the channel responds to something
  179. # looking like a gRPC connection, though it may be a proxy.
  180. start_time = time.time()
  181. conn_attempts = 0
  182. timeout = INITIAL_TIMEOUT_SEC
  183. service_ready = False
  184. while conn_attempts < max(self._connection_retries, 1) or reconnecting:
  185. conn_attempts += 1
  186. if self._in_shutdown:
  187. # User manually closed the worker before connection finished
  188. break
  189. elapsed_time = time.time() - start_time
  190. if reconnecting and elapsed_time > self._reconnect_grace_period:
  191. self._in_shutdown = True
  192. raise ConnectionError(
  193. "Failed to reconnect within the reconnection grace period "
  194. f"({self._reconnect_grace_period}s)"
  195. )
  196. try:
  197. # Let gRPC wait for us to see if the channel becomes ready.
  198. # If it throws, we couldn't connect.
  199. grpc.channel_ready_future(self.channel).result(timeout=timeout)
  200. # The HTTP2 channel is ready. Wrap the channel with the
  201. # RayletDriverStub, allowing for unary requests.
  202. self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
  203. service_ready = bool(self.ping_server())
  204. if service_ready:
  205. break
  206. # Ray is not ready yet, wait a timeout
  207. time.sleep(timeout)
  208. except grpc.FutureTimeoutError:
  209. logger.debug(f"Couldn't connect channel in {timeout} seconds, retrying")
  210. # Note that channel_ready_future constitutes its own timeout,
  211. # which is why we do not sleep here.
  212. except grpc.RpcError as e:
  213. logger.debug(
  214. f"Ray client server unavailable, retrying in {timeout}s..."
  215. )
  216. logger.debug(f"Received when checking init: {e.details()}")
  217. # Ray is not ready yet, wait a timeout.
  218. time.sleep(timeout)
  219. # Fallthrough, backoff, and retry at the top of the loop
  220. logger.debug(
  221. f"Waiting for Ray to become ready on the server, retry in {timeout}s..."
  222. )
  223. if not reconnecting:
  224. # Don't increase backoff when trying to reconnect --
  225. # we already know the server exists, attempt to reconnect
  226. # as soon as we can
  227. timeout = backoff(timeout)
  228. # If we made it through the loop without service_ready
  229. # it means we've used up our retries and
  230. # should error back to the user.
  231. if not service_ready:
  232. self._in_shutdown = True
  233. if log_once("ray_client_security_groups"):
  234. warnings.warn(
  235. "Ray Client connection timed out. Ensure that "
  236. "the Ray Client port on the head node is reachable "
  237. "from your local machine. See https://docs.ray.io/en"
  238. "/latest/cluster/ray-client.html#step-2-check-ports for "
  239. "more information."
  240. )
  241. raise ConnectionError("ray client connection timeout")
  242. def _can_reconnect(self, e: grpc.RpcError) -> bool:
  243. """
  244. Returns True if the RPC error can be recovered from and a retry is
  245. appropriate, false otherwise.
  246. """
  247. if not self._reconnect_enabled:
  248. return False
  249. if self._in_shutdown:
  250. # Channel is being shutdown, don't try to reconnect
  251. return False
  252. if e.code() in GRPC_UNRECOVERABLE_ERRORS:
  253. # Unrecoverable error -- These errors are specifically raised
  254. # by the server's application logic
  255. return False
  256. if e.code() == grpc.StatusCode.INTERNAL:
  257. details = e.details()
  258. if details == "Exception serializing request!":
  259. # The client failed tried to send a bad request (for example,
  260. # passing "None" instead of a valid grpc message). Don't
  261. # try to reconnect/retry.
  262. return False
  263. # All other errors can be treated as recoverable
  264. return True
  265. def _call_stub(self, stub_name: str, *args, **kwargs) -> Any:
  266. """
  267. Calls the stub specified by stub_name (Schedule, WaitObject, etc...).
  268. If a recoverable error occurrs while calling the stub, attempts to
  269. retry the RPC.
  270. """
  271. while not self._in_shutdown:
  272. try:
  273. return getattr(self.server, stub_name)(*args, **kwargs)
  274. except grpc.RpcError as e:
  275. if self._can_reconnect(e):
  276. time.sleep(0.5)
  277. continue
  278. raise
  279. except ValueError:
  280. # Trying to use the stub on a cancelled channel will raise
  281. # ValueError. This should only happen when the data client
  282. # is attempting to reset the connection -- sleep and try
  283. # again.
  284. time.sleep(0.5)
  285. continue
  286. raise ConnectionError("Client is shutting down.")
  287. def _get_object_iterator(
  288. self, req: ray_client_pb2.GetRequest, *args, **kwargs
  289. ) -> Any:
  290. """
  291. Calls the stub for GetObject on the underlying server stub. If a
  292. recoverable error occurs while streaming the response, attempts
  293. to retry the get starting from the first chunk that hasn't been
  294. received.
  295. """
  296. last_seen_chunk = -1
  297. while not self._in_shutdown:
  298. # If we disconnect partway through, restart the get request
  299. # at the first chunk we haven't seen
  300. req.start_chunk_id = last_seen_chunk + 1
  301. try:
  302. for chunk in self.server.GetObject(req, *args, **kwargs):
  303. if chunk.chunk_id <= last_seen_chunk:
  304. # Ignore repeat chunks
  305. logger.debug(
  306. f"Received a repeated chunk {chunk.chunk_id} "
  307. f"from request {req.req_id}."
  308. )
  309. continue
  310. if last_seen_chunk + 1 != chunk.chunk_id:
  311. raise RuntimeError(
  312. f"Received chunk {chunk.chunk_id} when we expected "
  313. f"{self.last_seen_chunk + 1}"
  314. )
  315. last_seen_chunk = chunk.chunk_id
  316. yield chunk
  317. if last_seen_chunk == chunk.total_chunks - 1:
  318. # We've yielded the last chunk, exit early
  319. return
  320. return
  321. except grpc.RpcError as e:
  322. if self._can_reconnect(e):
  323. time.sleep(0.5)
  324. continue
  325. raise
  326. except ValueError:
  327. # Trying to use the stub on a cancelled channel will raise
  328. # ValueError. This should only happen when the data client
  329. # is attempting to reset the connection -- sleep and try
  330. # again.
  331. time.sleep(0.5)
  332. continue
  333. raise ConnectionError("Client is shutting down.")
  334. def _add_ids_to_metadata(self, metadata: Any):
  335. """
  336. Adds a unique req_id and the current thread's identifier to the
  337. metadata. These values are useful for preventing mutating operations
  338. from being replayed on the server side in the event that the client
  339. must retry a requsest.
  340. Args:
  341. metadata: the gRPC metadata to append the IDs to
  342. """
  343. if not self._reconnect_enabled:
  344. # IDs not needed if the reconnects are disabled
  345. return metadata
  346. thread_id = str(threading.get_ident())
  347. with self._req_id_lock:
  348. self._req_id += 1
  349. if self._req_id > INT32_MAX:
  350. self._req_id = 1
  351. req_id = str(self._req_id)
  352. return metadata + [("thread_id", thread_id), ("req_id", req_id)]
  353. def _on_channel_state_change(self, conn_state: grpc.ChannelConnectivity):
  354. logger.debug(f"client gRPC channel state change: {conn_state}")
  355. self._conn_state = conn_state
  356. def connection_info(self):
  357. try:
  358. data = self.data_client.ConnectionInfo()
  359. except grpc.RpcError as e:
  360. raise decode_exception(e)
  361. return {
  362. "num_clients": data.num_clients,
  363. "python_version": data.python_version,
  364. "ray_version": data.ray_version,
  365. "ray_commit": data.ray_commit,
  366. }
  367. def register_callback(
  368. self,
  369. ref: ClientObjectRef,
  370. callback: Callable[[ray_client_pb2.DataResponse], None],
  371. ) -> None:
  372. req = ray_client_pb2.GetRequest(ids=[ref.id], asynchronous=True)
  373. self.data_client.RegisterGetCallback(req, callback)
  374. def get(self, vals, *, timeout: Optional[float] = None) -> Any:
  375. if isinstance(vals, list):
  376. if not vals:
  377. return []
  378. to_get = vals
  379. elif isinstance(vals, ClientObjectRef):
  380. to_get = [vals]
  381. else:
  382. raise Exception(
  383. "Can't get something that's not a "
  384. "list of IDs or just an ID: %s" % type(vals)
  385. )
  386. if timeout is None:
  387. deadline = None
  388. else:
  389. deadline = time.monotonic() + timeout
  390. while True:
  391. if deadline:
  392. op_timeout = min(
  393. MAX_BLOCKING_OPERATION_TIME_S,
  394. max(deadline - time.monotonic(), 0.001),
  395. )
  396. else:
  397. op_timeout = MAX_BLOCKING_OPERATION_TIME_S
  398. try:
  399. res = self._get(to_get, op_timeout)
  400. break
  401. except GetTimeoutError:
  402. if deadline and time.monotonic() > deadline:
  403. raise
  404. logger.debug("Internal retry for get {}".format(to_get))
  405. if len(to_get) != len(res):
  406. raise Exception(
  407. "Mismatched number of items in request ({}) and response ({})".format(
  408. len(to_get), len(res)
  409. )
  410. )
  411. if isinstance(vals, ClientObjectRef):
  412. res = res[0]
  413. return res
  414. def _get(self, ref: List[ClientObjectRef], timeout: float):
  415. req = ray_client_pb2.GetRequest(ids=[r.id for r in ref], timeout=timeout)
  416. data = bytearray()
  417. try:
  418. resp = self._get_object_iterator(req, metadata=self.metadata)
  419. for chunk in resp:
  420. if not chunk.valid:
  421. try:
  422. err = cloudpickle.loads(chunk.error)
  423. except (pickle.UnpicklingError, TypeError):
  424. logger.exception("Failed to deserialize {}".format(chunk.error))
  425. raise
  426. raise err
  427. if chunk.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once(
  428. "client_object_transfer_size_warning"
  429. ):
  430. size_gb = chunk.total_size / 2**30
  431. warnings.warn(
  432. "Ray Client is attempting to retrieve a "
  433. f"{size_gb:.2f} GiB object over the network, which may "
  434. "be slow. Consider serializing the object to a file "
  435. "and using S3 or rsync instead.",
  436. UserWarning,
  437. stacklevel=5,
  438. )
  439. data.extend(chunk.data)
  440. except grpc.RpcError as e:
  441. raise decode_exception(e)
  442. return loads_from_server(data)
  443. def put(
  444. self,
  445. val,
  446. *,
  447. client_ref_id: bytes = None,
  448. _owner: Optional[ClientActorHandle] = None,
  449. ):
  450. if isinstance(val, ClientObjectRef):
  451. raise TypeError(
  452. "Calling 'put' on an ObjectRef is not allowed "
  453. "(similarly, returning an ObjectRef from a remote "
  454. "function is not allowed). If you really want to "
  455. "do this, you can wrap the ObjectRef in a list and "
  456. "call 'put' on it (or return it)."
  457. )
  458. data = dumps_from_client(val, self._client_id)
  459. return self._put_pickled(data, client_ref_id, _owner)
  460. def _put_pickled(
  461. self, data, client_ref_id: bytes, owner: Optional[ClientActorHandle] = None
  462. ):
  463. req = ray_client_pb2.PutRequest(data=data)
  464. if client_ref_id is not None:
  465. req.client_ref_id = client_ref_id
  466. if owner is not None:
  467. req.owner_id = owner.actor_ref.id
  468. resp = self.data_client.PutObject(req)
  469. if not resp.valid:
  470. try:
  471. raise cloudpickle.loads(resp.error)
  472. except (pickle.UnpicklingError, TypeError):
  473. logger.exception("Failed to deserialize {}".format(resp.error))
  474. raise
  475. return ClientObjectRef(resp.id)
  476. # TODO(ekl) respect MAX_BLOCKING_OPERATION_TIME_S for wait too
  477. def wait(
  478. self,
  479. object_refs: List[ClientObjectRef],
  480. *,
  481. num_returns: int = 1,
  482. timeout: float = None,
  483. fetch_local: bool = True,
  484. ) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
  485. if not isinstance(object_refs, list):
  486. raise TypeError(
  487. f"wait() expected a list of ClientObjectRef, got {type(object_refs)}"
  488. )
  489. for ref in object_refs:
  490. if not isinstance(ref, ClientObjectRef):
  491. raise TypeError(
  492. "wait() expected a list of ClientObjectRef, "
  493. f"got list containing {type(ref)}"
  494. )
  495. data = {
  496. "object_ids": [object_ref.id for object_ref in object_refs],
  497. "num_returns": num_returns,
  498. "timeout": timeout if (timeout is not None) else -1,
  499. "client_id": self._client_id,
  500. }
  501. req = ray_client_pb2.WaitRequest(**data)
  502. resp = self._call_stub("WaitObject", req, metadata=self.metadata)
  503. if not resp.valid:
  504. # TODO(ameer): improve error/exceptions messages.
  505. raise Exception("Client Wait request failed. Reference invalid?")
  506. client_ready_object_ids = [
  507. ClientObjectRef(ref) for ref in resp.ready_object_ids
  508. ]
  509. client_remaining_object_ids = [
  510. ClientObjectRef(ref) for ref in resp.remaining_object_ids
  511. ]
  512. return (client_ready_object_ids, client_remaining_object_ids)
  513. def call_remote(self, instance, *args, **kwargs) -> List[Future]:
  514. task = instance._prepare_client_task()
  515. # data is serialized tuple of (args, kwargs)
  516. task.data = dumps_from_client((args, kwargs), self._client_id)
  517. num_returns = instance._num_returns()
  518. if num_returns == "dynamic":
  519. num_returns = -1
  520. if num_returns == "streaming":
  521. raise RuntimeError(
  522. 'Streaming actor methods (num_returns="streaming") '
  523. "are not currently supported when using Ray Client."
  524. )
  525. return self._call_schedule_for_task(task, num_returns)
  526. def _call_schedule_for_task(
  527. self, task: ray_client_pb2.ClientTask, num_returns: Optional[int]
  528. ) -> List[Future]:
  529. logger.debug(f"Scheduling task {task.name} {task.type} {task.payload_id}")
  530. task.client_id = self._client_id
  531. if num_returns is None:
  532. num_returns = 1
  533. num_return_refs = num_returns
  534. if num_return_refs == -1:
  535. num_return_refs = 1
  536. id_futures = [Future() for _ in range(num_return_refs)]
  537. def populate_ids(resp: Union[ray_client_pb2.DataResponse, Exception]) -> None:
  538. if isinstance(resp, Exception):
  539. if isinstance(resp, grpc.RpcError):
  540. resp = decode_exception(resp)
  541. for future in id_futures:
  542. future.set_exception(resp)
  543. return
  544. ticket = resp.task_ticket
  545. if not ticket.valid:
  546. try:
  547. ex = cloudpickle.loads(ticket.error)
  548. except (pickle.UnpicklingError, TypeError) as e_new:
  549. ex = e_new
  550. for future in id_futures:
  551. future.set_exception(ex)
  552. return
  553. if len(ticket.return_ids) != num_return_refs:
  554. exc = ValueError(
  555. f"Expected {num_return_refs} returns but received "
  556. f"{len(ticket.return_ids)}"
  557. )
  558. for future, raw_id in zip(id_futures, ticket.return_ids):
  559. future.set_exception(exc)
  560. return
  561. for future, raw_id in zip(id_futures, ticket.return_ids):
  562. future.set_result(raw_id)
  563. self.data_client.Schedule(task, populate_ids)
  564. self.total_outbound_message_size_bytes += task.ByteSize()
  565. if (
  566. self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD
  567. and log_once("client_communication_overhead_warning")
  568. ):
  569. warnings.warn(
  570. "More than 10MB of messages have been created to schedule "
  571. "tasks on the server. This can be slow on Ray Client due to "
  572. "communication overhead over the network. If you're running "
  573. "many fine-grained tasks, consider running them inside a "
  574. 'single remote function. See the section on "Too '
  575. 'fine-grained tasks" in the Ray Design Patterns document for '
  576. f"more details: {DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK}. If "
  577. "your functions frequently use large objects, consider "
  578. "storing the objects remotely with ray.put. An example of "
  579. 'this is shown in the "Closure capture of large / '
  580. 'unserializable object" section of the Ray Design Patterns '
  581. "document, available here: "
  582. f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}",
  583. UserWarning,
  584. )
  585. return id_futures
  586. def call_release(self, id: bytes) -> None:
  587. if self.closed:
  588. return
  589. self.reference_count[id] -= 1
  590. if self.reference_count[id] == 0:
  591. self._release_server(id)
  592. del self.reference_count[id]
  593. def _release_server(self, id: bytes) -> None:
  594. if self.data_client is not None:
  595. logger.debug(f"Put {id.hex()} to release queue")
  596. self._release_queue.put(id)
  597. def _release_server_worker(self):
  598. """Background thread to release objects from the server.
  599. Runs forever until a sentinel is received.
  600. """
  601. while not self.closed:
  602. try:
  603. id = self._release_queue.get(timeout=1)
  604. if id is None: # Sentinel value for shutdown
  605. logger.debug("Received sentinel, will stop release thread.")
  606. break
  607. if self.data_client is not None:
  608. logger.debug(f"Releasing {id.hex()}")
  609. try:
  610. self.data_client.ReleaseObject(
  611. ray_client_pb2.ReleaseRequest(ids=[id])
  612. )
  613. except Exception as e:
  614. # Log the error but continue processing
  615. # This prevents the release thread from crashing
  616. logger.warning(
  617. f"Failed to release object {id.hex()}: {e}. "
  618. "This is expected if the connection is closed."
  619. )
  620. except queue.Empty:
  621. continue
  622. logger.debug("Release thread finished.")
  623. def call_retain(self, id: bytes) -> None:
  624. logger.debug(f"Retaining {id.hex()}")
  625. self.reference_count[id] += 1
  626. def close(self):
  627. self._in_shutdown = True
  628. self._release_queue.put(None) # Sentinel
  629. timeout = 5
  630. self._release_thread.join(timeout=timeout)
  631. if self._release_thread.is_alive():
  632. logger.warning(f"The release thread failed to join in {timeout}s.")
  633. self.closed = True
  634. self.data_client.close()
  635. self.log_client.close()
  636. self.server = None
  637. if self.channel:
  638. self.channel.close()
  639. self.channel = None
  640. def get_actor(
  641. self, name: str, namespace: Optional[str] = None
  642. ) -> ClientActorHandle:
  643. task = ray_client_pb2.ClientTask()
  644. task.type = ray_client_pb2.ClientTask.NAMED_ACTOR
  645. task.name = name
  646. task.namespace = namespace or ""
  647. # Populate task.data with empty args and kwargs
  648. task.data = dumps_from_client(([], {}), self._client_id)
  649. futures = self._call_schedule_for_task(task, 1)
  650. assert len(futures) == 1
  651. handle = ClientActorHandle(ClientActorRef(futures[0], weak_ref=True))
  652. # `actor_ref.is_nil()` waits until the underlying ID is resolved.
  653. # This is needed because `get_actor` is often used to check the
  654. # existence of an actor.
  655. if handle.actor_ref.is_nil():
  656. raise ValueError(f"ActorID for {name} is empty")
  657. return handle
  658. def terminate_actor(self, actor: ClientActorHandle, no_restart: bool) -> None:
  659. if not isinstance(actor, ClientActorHandle):
  660. raise ValueError(
  661. "ray.kill() only supported for actors. Got: {}.".format(type(actor))
  662. )
  663. term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
  664. term_actor.id = actor.actor_ref.id
  665. term_actor.no_restart = no_restart
  666. term = ray_client_pb2.TerminateRequest(actor=term_actor)
  667. term.client_id = self._client_id
  668. try:
  669. self.data_client.Terminate(term)
  670. except grpc.RpcError as e:
  671. raise decode_exception(e)
  672. def terminate_task(
  673. self, obj: ClientObjectRef, force: bool, recursive: bool
  674. ) -> None:
  675. if not isinstance(obj, ClientObjectRef):
  676. raise TypeError(
  677. "ray.cancel() only supported for non-actor object refs. "
  678. f"Got: {type(obj)}."
  679. )
  680. term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate()
  681. term_object.id = obj.id
  682. term_object.force = force
  683. term_object.recursive = recursive
  684. term = ray_client_pb2.TerminateRequest(task_object=term_object)
  685. term.client_id = self._client_id
  686. try:
  687. self.data_client.Terminate(term)
  688. except grpc.RpcError as e:
  689. raise decode_exception(e)
  690. def get_cluster_info(
  691. self,
  692. req_type: ray_client_pb2.ClusterInfoType.TypeEnum,
  693. timeout: Optional[float] = None,
  694. ):
  695. req = ray_client_pb2.ClusterInfoRequest()
  696. req.type = req_type
  697. resp = self.server.ClusterInfo(req, timeout=timeout, metadata=self.metadata)
  698. if resp.WhichOneof("response_type") == "resource_table":
  699. # translate from a proto map to a python dict
  700. output_dict = dict(resp.resource_table.table)
  701. return output_dict
  702. elif resp.WhichOneof("response_type") == "runtime_context":
  703. return resp.runtime_context
  704. return json.loads(resp.json)
  705. def internal_kv_get(self, key: bytes, namespace: Optional[bytes]) -> bytes:
  706. req = ray_client_pb2.KVGetRequest(key=key, namespace=namespace)
  707. try:
  708. resp = self._call_stub("KVGet", req, metadata=self.metadata)
  709. except grpc.RpcError as e:
  710. raise decode_exception(e)
  711. if resp.HasField("value"):
  712. return resp.value
  713. # Value is None when the key does not exist in the KV.
  714. return None
  715. def internal_kv_exists(self, key: bytes, namespace: Optional[bytes]) -> bool:
  716. req = ray_client_pb2.KVExistsRequest(key=key, namespace=namespace)
  717. try:
  718. resp = self._call_stub("KVExists", req, metadata=self.metadata)
  719. except grpc.RpcError as e:
  720. raise decode_exception(e)
  721. return resp.exists
  722. def internal_kv_put(
  723. self, key: bytes, value: bytes, overwrite: bool, namespace: Optional[bytes]
  724. ) -> bool:
  725. req = ray_client_pb2.KVPutRequest(
  726. key=key, value=value, overwrite=overwrite, namespace=namespace
  727. )
  728. metadata = self._add_ids_to_metadata(self.metadata)
  729. try:
  730. resp = self._call_stub("KVPut", req, metadata=metadata)
  731. except grpc.RpcError as e:
  732. raise decode_exception(e)
  733. return resp.already_exists
  734. def internal_kv_del(
  735. self, key: bytes, del_by_prefix: bool, namespace: Optional[bytes]
  736. ) -> int:
  737. req = ray_client_pb2.KVDelRequest(
  738. key=key, del_by_prefix=del_by_prefix, namespace=namespace
  739. )
  740. metadata = self._add_ids_to_metadata(self.metadata)
  741. try:
  742. resp = self._call_stub("KVDel", req, metadata=metadata)
  743. except grpc.RpcError as e:
  744. raise decode_exception(e)
  745. return resp.deleted_num
  746. def internal_kv_list(
  747. self, prefix: bytes, namespace: Optional[bytes]
  748. ) -> List[bytes]:
  749. try:
  750. req = ray_client_pb2.KVListRequest(prefix=prefix, namespace=namespace)
  751. return self._call_stub("KVList", req, metadata=self.metadata).keys
  752. except grpc.RpcError as e:
  753. raise decode_exception(e)
  754. def pin_runtime_env_uri(self, uri: str, expiration_s: int) -> None:
  755. req = ray_client_pb2.ClientPinRuntimeEnvURIRequest(
  756. uri=uri, expiration_s=expiration_s
  757. )
  758. self._call_stub("PinRuntimeEnvURI", req, metadata=self.metadata)
  759. def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]:
  760. req = ray_client_pb2.ClientListNamedActorsRequest(all_namespaces=all_namespaces)
  761. return json.loads(self.data_client.ListNamedActors(req).actors_json)
  762. def is_initialized(self) -> bool:
  763. if not self.is_connected() or self.server is None:
  764. return False
  765. if not self._serverside_ray_initialized:
  766. # We only check that Ray is initialized on the server once to
  767. # avoid making an RPC every time this function is called. This is
  768. # safe to do because Ray only 'un-initializes' on the server when
  769. # the Client connection is torn down.
  770. self._serverside_ray_initialized = self.get_cluster_info(
  771. ray_client_pb2.ClusterInfoType.IS_INITIALIZED
  772. )
  773. return self._serverside_ray_initialized
  774. def ping_server(self, timeout=None) -> bool:
  775. """Simple health check.
  776. Piggybacks the IS_INITIALIZED call to check if the server provides
  777. an actual response.
  778. """
  779. if self.server is not None:
  780. logger.debug("Pinging server.")
  781. result = self.get_cluster_info(
  782. ray_client_pb2.ClusterInfoType.PING, timeout=timeout
  783. )
  784. return result is not None
  785. return False
  786. def is_connected(self) -> bool:
  787. return not self._in_shutdown and self._has_connected
  788. def _server_init(
  789. self, job_config: JobConfig, ray_init_kwargs: Optional[Dict[str, Any]] = None
  790. ):
  791. """Initialize the server"""
  792. if ray_init_kwargs is None:
  793. ray_init_kwargs = {}
  794. try:
  795. if job_config is None:
  796. serialized_job_config = None
  797. else:
  798. with tempfile.TemporaryDirectory() as tmp_dir:
  799. from ray._private.ray_constants import (
  800. RAY_RUNTIME_ENV_IGNORE_GITIGNORE,
  801. )
  802. runtime_env = job_config.runtime_env or {}
  803. # Determine whether to respect .gitignore files based on environment variable
  804. # Default is True (respect .gitignore). Set to False if env var is "1".
  805. include_gitignore = (
  806. os.environ.get(RAY_RUNTIME_ENV_IGNORE_GITIGNORE, "0") != "1"
  807. )
  808. runtime_env = upload_py_modules_if_needed(
  809. runtime_env,
  810. scratch_dir=tmp_dir,
  811. include_gitignore=include_gitignore,
  812. logger=logger,
  813. )
  814. runtime_env = upload_working_dir_if_needed(
  815. runtime_env,
  816. scratch_dir=tmp_dir,
  817. include_gitignore=include_gitignore,
  818. logger=logger,
  819. )
  820. # Remove excludes, it isn't relevant after the upload step.
  821. runtime_env.pop("excludes", None)
  822. job_config.set_runtime_env(runtime_env, validate=True)
  823. serialized_job_config = pickle.dumps(job_config)
  824. response = self.data_client.Init(
  825. ray_client_pb2.InitRequest(
  826. job_config=serialized_job_config,
  827. ray_init_kwargs=json.dumps(ray_init_kwargs),
  828. reconnect_grace_period=self._reconnect_grace_period,
  829. )
  830. )
  831. if not response.ok:
  832. raise ConnectionAbortedError(
  833. f"Initialization failure from server:\n{response.msg}"
  834. )
  835. except grpc.RpcError as e:
  836. raise decode_exception(e)
  837. def _convert_actor(self, actor: "ActorClass") -> str:
  838. """Register a ClientActorClass for the ActorClass and return a UUID"""
  839. key = uuid.uuid4().hex
  840. cls = actor.__ray_metadata__.modified_class
  841. self._converted[key] = ClientActorClass(cls, options=actor._default_options)
  842. return key
  843. def _convert_function(self, func: "RemoteFunction") -> str:
  844. """Register a ClientRemoteFunc for the ActorClass and return a UUID"""
  845. key = uuid.uuid4().hex
  846. self._converted[key] = ClientRemoteFunc(
  847. func._function, options=func._default_options
  848. )
  849. return key
  850. def _get_converted(self, key: str) -> "ClientStub":
  851. """Given a UUID, return the converted object"""
  852. return self._converted[key]
  853. def _converted_key_exists(self, key: str) -> bool:
  854. """Check if a key UUID is present in the store of converted objects."""
  855. return key in self._converted
  856. def _dumps_from_client(self, val) -> bytes:
  857. return dumps_from_client(val, self._client_id)
  858. def make_client_id() -> str:
  859. id = uuid.uuid4()
  860. return id.hex
  861. def decode_exception(e: grpc.RpcError) -> Exception:
  862. if e.code() != grpc.StatusCode.ABORTED:
  863. # The ABORTED status code is used by the server when an application
  864. # error is serialized into the exception details. If the code
  865. # isn't ABORTED, then return the original error since there's no
  866. # serialized error to decode.
  867. # See server.py::return_exception_in_context for details
  868. return ConnectionError(f"GRPC connection failed: {e}")
  869. data = base64.standard_b64decode(e.details())
  870. return loads_from_server(data)