common.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954
  1. import inspect
  2. import logging
  3. import os
  4. import pickle
  5. import threading
  6. import uuid
  7. from collections import OrderedDict
  8. from concurrent.futures import Future
  9. from dataclasses import dataclass
  10. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  11. import grpc
  12. import ray._raylet as raylet
  13. import ray.core.generated.ray_client_pb2 as ray_client_pb2
  14. import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
  15. from ray._common.signature import extract_signature, get_signature
  16. from ray._private import ray_constants
  17. from ray._private.inspect_util import (
  18. is_class_method,
  19. is_cython,
  20. is_function_or_method,
  21. is_static_method,
  22. )
  23. from ray._private.utils import check_oversized_function
  24. from ray.util.client import ray
  25. from ray.util.client.options import validate_options
  26. from ray.util.common import INT32_MAX
  27. logger = logging.getLogger(__name__)
  28. # gRPC status codes that the client shouldn't attempt to recover from
  29. # Resource exhausted: Server is low on resources, or has hit the max number
  30. # of client connections
  31. # Invalid argument: Reserved for application errors
  32. # Not found: Set if the client is attempting to reconnect to a session that
  33. # does not exist
  34. # Failed precondition: Reserverd for application errors
  35. # Aborted: Set when an error is serialized into the details of the context,
  36. # signals that error should be deserialized on the client side
  37. GRPC_UNRECOVERABLE_ERRORS = (
  38. grpc.StatusCode.RESOURCE_EXHAUSTED,
  39. grpc.StatusCode.INVALID_ARGUMENT,
  40. grpc.StatusCode.NOT_FOUND,
  41. grpc.StatusCode.FAILED_PRECONDITION,
  42. grpc.StatusCode.ABORTED,
  43. )
  44. # TODO: Instead of just making the max message size large, the right thing to
  45. # do is to split up the bytes representation of serialized data into multiple
  46. # messages and reconstruct them on either end. That said, since clients are
  47. # drivers and really just feed initial things in and final results out, (when
  48. # not going to S3 or similar) then a large limit will suffice for many use
  49. # cases.
  50. #
  51. # Currently, this is 2GiB, the max for a signed int.
  52. GRPC_MAX_MESSAGE_SIZE = (2 * 1024 * 1024 * 1024) - 1
  53. # 30 seconds because ELB timeout is 60 seconds
  54. GRPC_KEEPALIVE_TIME_MS = 1000 * 30
  55. # Long timeout because we do not want gRPC ending a connection.
  56. GRPC_KEEPALIVE_TIMEOUT_MS = 1000 * 600
  57. GRPC_OPTIONS = [
  58. *ray_constants.GLOBAL_GRPC_OPTIONS,
  59. ("grpc.max_send_message_length", GRPC_MAX_MESSAGE_SIZE),
  60. ("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_SIZE),
  61. ("grpc.keepalive_time_ms", GRPC_KEEPALIVE_TIME_MS),
  62. ("grpc.keepalive_timeout_ms", GRPC_KEEPALIVE_TIMEOUT_MS),
  63. ("grpc.keepalive_permit_without_calls", 1),
  64. # Send an infinite number of pings
  65. ("grpc.http2.max_pings_without_data", 0),
  66. ("grpc.http2.min_ping_interval_without_data_ms", GRPC_KEEPALIVE_TIME_MS - 50),
  67. # Allow many strikes
  68. ("grpc.http2.max_ping_strikes", 0),
  69. ]
  70. CLIENT_SERVER_MAX_THREADS = float(os.getenv("RAY_CLIENT_SERVER_MAX_THREADS", 100))
  71. # Large objects are chunked into 5 MiB messages, ref PR #35025
  72. OBJECT_TRANSFER_CHUNK_SIZE = 5 * 2**20
  73. # Warn the user if the object being transferred is larger than 2 GiB
  74. OBJECT_TRANSFER_WARNING_SIZE = 2 * 2**30
  75. class ClientObjectRef(raylet.ObjectRef):
  76. def __init__(self, id: Union[bytes, Future]):
  77. self._mutex = threading.Lock()
  78. self._worker = ray.get_context().client_worker
  79. self._id_future = None
  80. if isinstance(id, bytes):
  81. self._set_id(id)
  82. elif isinstance(id, Future):
  83. self._id_future = id
  84. else:
  85. raise TypeError("Unexpected type for id {}".format(id))
  86. def __del__(self):
  87. if self._worker is not None and self._worker.is_connected():
  88. try:
  89. if not self.is_nil():
  90. self._worker.call_release(self.id)
  91. except Exception:
  92. logger.info(
  93. "Exception in ObjectRef is ignored in destructor. "
  94. "To receive this exception in application code, call "
  95. "a method on the actor reference before its destructor "
  96. "is run."
  97. )
  98. def binary(self):
  99. self._wait_for_id()
  100. return super().binary()
  101. def hex(self):
  102. self._wait_for_id()
  103. return super().hex()
  104. def is_nil(self):
  105. self._wait_for_id()
  106. return super().is_nil()
  107. def __hash__(self):
  108. self._wait_for_id()
  109. return hash(self.id)
  110. def task_id(self):
  111. self._wait_for_id()
  112. return super().task_id()
  113. @property
  114. def id(self):
  115. return self.binary()
  116. def future(self) -> Future:
  117. fut = Future()
  118. def set_future(data: Any) -> None:
  119. """Schedules a callback to set the exception or result
  120. in the Future."""
  121. if isinstance(data, Exception):
  122. fut.set_exception(data)
  123. else:
  124. fut.set_result(data)
  125. self._on_completed(set_future)
  126. # Prevent this object ref from being released.
  127. fut.object_ref = self
  128. return fut
  129. def _on_completed(self, py_callback: Callable[[Any], None]) -> None:
  130. """Register a callback that will be called after Object is ready.
  131. If the ObjectRef is already ready, the callback will be called soon.
  132. The callback should take the result as the only argument. The result
  133. can be an exception object in case of task error.
  134. """
  135. def deserialize_obj(
  136. resp: Union[ray_client_pb2.DataResponse, Exception]
  137. ) -> None:
  138. from ray.util.client.client_pickler import loads_from_server
  139. if isinstance(resp, Exception):
  140. data = resp
  141. elif isinstance(resp, bytearray):
  142. data = loads_from_server(resp)
  143. else:
  144. obj = resp.get
  145. data = None
  146. if not obj.valid:
  147. data = loads_from_server(resp.get.error)
  148. else:
  149. data = loads_from_server(resp.get.data)
  150. py_callback(data)
  151. self._worker.register_callback(self, deserialize_obj)
  152. def _set_id(self, id):
  153. super()._set_id(id)
  154. self._worker.call_retain(id)
  155. def _wait_for_id(self, timeout=None):
  156. if self._id_future:
  157. with self._mutex:
  158. if self._id_future:
  159. self._set_id(self._id_future.result(timeout=timeout))
  160. self._id_future = None
  161. class ClientActorRef(raylet.ActorID):
  162. def __init__(
  163. self,
  164. id: Union[bytes, Future],
  165. weak_ref: Optional[bool] = False,
  166. ):
  167. self._weak_ref = weak_ref
  168. self._mutex = threading.Lock()
  169. self._worker = ray.get_context().client_worker
  170. if isinstance(id, bytes):
  171. self._set_id(id)
  172. self._id_future = None
  173. elif isinstance(id, Future):
  174. self._id_future = id
  175. else:
  176. raise TypeError("Unexpected type for id {}".format(id))
  177. def __del__(self):
  178. if self._weak_ref:
  179. return
  180. if self._worker is not None and self._worker.is_connected():
  181. try:
  182. if not self.is_nil():
  183. self._worker.call_release(self.id)
  184. except Exception:
  185. logger.debug(
  186. "Exception from actor creation is ignored in destructor. "
  187. "To receive this exception in application code, call "
  188. "a method on the actor reference before its destructor "
  189. "is run."
  190. )
  191. def binary(self):
  192. self._wait_for_id()
  193. return super().binary()
  194. def hex(self):
  195. self._wait_for_id()
  196. return super().hex()
  197. def is_nil(self):
  198. self._wait_for_id()
  199. return super().is_nil()
  200. def __hash__(self):
  201. self._wait_for_id()
  202. return hash(self.id)
  203. @property
  204. def id(self):
  205. return self.binary()
  206. def _set_id(self, id):
  207. super()._set_id(id)
  208. self._worker.call_retain(id)
  209. def _wait_for_id(self, timeout=None):
  210. if self._id_future:
  211. with self._mutex:
  212. if self._id_future:
  213. self._set_id(self._id_future.result(timeout=timeout))
  214. self._id_future = None
  215. class ClientStub:
  216. pass
  217. class ClientRemoteFunc(ClientStub):
  218. """A stub created on the Ray Client to represent a remote
  219. function that can be exectued on the cluster.
  220. This class is allowed to be passed around between remote functions.
  221. Args:
  222. _func: The actual function to execute remotely
  223. _name: The original name of the function
  224. _ref: The ClientObjectRef of the pickled code of the function, _func
  225. """
  226. def __init__(self, f, options=None):
  227. self._lock = threading.Lock()
  228. self._func = f
  229. self._name = f.__name__
  230. self._signature = get_signature(f)
  231. self._ref = None
  232. self._client_side_ref = ClientSideRefID.generate_id()
  233. self._options = validate_options(options)
  234. def __call__(self, *args, **kwargs):
  235. raise TypeError(
  236. "Remote function cannot be called directly. "
  237. f"Use {self._name}.remote method instead"
  238. )
  239. def remote(self, *args, **kwargs):
  240. # Check if supplied parameters match the function signature. Same case
  241. # at the other callsites.
  242. self._signature.bind(*args, **kwargs)
  243. return return_refs(ray.call_remote(self, *args, **kwargs))
  244. def options(self, **kwargs):
  245. return OptionWrapper(self, kwargs)
  246. def _remote(self, args=None, kwargs=None, **option_args):
  247. if args is None:
  248. args = []
  249. if kwargs is None:
  250. kwargs = {}
  251. return self.options(**option_args).remote(*args, **kwargs)
  252. def __repr__(self):
  253. return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
  254. def _ensure_ref(self):
  255. with self._lock:
  256. if self._ref is None:
  257. # While calling ray.put() on our function, if
  258. # our function is recursive, it will attempt to
  259. # encode the ClientRemoteFunc -- itself -- and
  260. # infinitely recurse on _ensure_ref.
  261. #
  262. # So we set the state of the reference to be an
  263. # in-progress self reference value, which
  264. # the encoding can detect and handle correctly.
  265. self._ref = InProgressSentinel()
  266. data = ray.worker._dumps_from_client(self._func)
  267. # Check pickled size before sending it to server, which is more
  268. # efficient and can be done synchronously inside remote() call.
  269. check_oversized_function(data, self._name, "remote function", None)
  270. self._ref = ray.worker._put_pickled(
  271. data, client_ref_id=self._client_side_ref.id
  272. )
  273. def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
  274. self._ensure_ref()
  275. task = ray_client_pb2.ClientTask()
  276. task.type = ray_client_pb2.ClientTask.FUNCTION
  277. task.name = self._name
  278. task.payload_id = self._ref.id
  279. set_task_options(task, self._options, "baseline_options")
  280. return task
  281. def _num_returns(self) -> int:
  282. if not self._options:
  283. return None
  284. return self._options.get("num_returns")
  285. class ClientActorClass(ClientStub):
  286. """A stub created on the Ray Client to represent an actor class.
  287. It is wrapped by ray.remote and can be executed on the cluster.
  288. Args:
  289. actor_cls: The actual class to execute remotely
  290. _name: The original name of the class
  291. _ref: The ClientObjectRef of the pickled `actor_cls`
  292. """
  293. def __init__(self, actor_cls, options=None):
  294. self.actor_cls = actor_cls
  295. self._lock = threading.Lock()
  296. self._name = actor_cls.__name__
  297. self._init_signature = inspect.Signature(
  298. parameters=extract_signature(actor_cls.__init__, ignore_first=True)
  299. )
  300. self._ref = None
  301. self._client_side_ref = ClientSideRefID.generate_id()
  302. self._options = validate_options(options)
  303. def __call__(self, *args, **kwargs):
  304. raise TypeError(
  305. "Remote actor cannot be instantiated directly. "
  306. f"Use {self._name}.remote() instead"
  307. )
  308. def _ensure_ref(self):
  309. with self._lock:
  310. if self._ref is None:
  311. # As before, set the state of the reference to be an
  312. # in-progress self reference value, which
  313. # the encoding can detect and handle correctly.
  314. self._ref = InProgressSentinel()
  315. data = ray.worker._dumps_from_client(self.actor_cls)
  316. # Check pickled size before sending it to server, which is more
  317. # efficient and can be done synchronously inside remote() call.
  318. check_oversized_function(data, self._name, "actor", None)
  319. self._ref = ray.worker._put_pickled(
  320. data, client_ref_id=self._client_side_ref.id
  321. )
  322. def remote(self, *args, **kwargs) -> "ClientActorHandle":
  323. self._init_signature.bind(*args, **kwargs)
  324. # Actually instantiate the actor
  325. futures = ray.call_remote(self, *args, **kwargs)
  326. assert len(futures) == 1
  327. return ClientActorHandle(ClientActorRef(futures[0]), actor_class=self)
  328. def options(self, **kwargs):
  329. return ActorOptionWrapper(self, kwargs)
  330. def _remote(self, args=None, kwargs=None, **option_args):
  331. if args is None:
  332. args = []
  333. if kwargs is None:
  334. kwargs = {}
  335. return self.options(**option_args).remote(*args, **kwargs)
  336. def __repr__(self):
  337. return "ClientActorClass(%s, %s)" % (self._name, self._ref)
  338. def __getattr__(self, key):
  339. if key not in self.__dict__:
  340. raise AttributeError("Not a class attribute")
  341. raise NotImplementedError("static methods")
  342. def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
  343. self._ensure_ref()
  344. task = ray_client_pb2.ClientTask()
  345. task.type = ray_client_pb2.ClientTask.ACTOR
  346. task.name = self._name
  347. task.payload_id = self._ref.id
  348. set_task_options(task, self._options, "baseline_options")
  349. return task
  350. @staticmethod
  351. def _num_returns() -> int:
  352. return 1
  353. class ClientActorHandle(ClientStub):
  354. """Client-side stub for instantiated actor.
  355. A stub created on the Ray Client to represent a remote actor that
  356. has been started on the cluster. This class is allowed to be passed
  357. around between remote functions.
  358. Args:
  359. actor_ref: A reference to the running actor given to the client. This
  360. is a serialized version of the actual handle as an opaque token.
  361. """
  362. def __init__(
  363. self,
  364. actor_ref: ClientActorRef,
  365. actor_class: Optional[ClientActorClass] = None,
  366. ):
  367. self.actor_ref = actor_ref
  368. self._dir: Optional[List[str]] = None
  369. if actor_class is not None:
  370. self._method_num_returns = {}
  371. self._method_signatures = {}
  372. for method_name, method_obj in inspect.getmembers(
  373. actor_class.actor_cls, is_function_or_method
  374. ):
  375. self._method_num_returns[method_name] = getattr(
  376. method_obj, "__ray_num_returns__", None
  377. )
  378. self._method_signatures[method_name] = inspect.Signature(
  379. parameters=extract_signature(
  380. method_obj,
  381. ignore_first=(
  382. not (
  383. is_class_method(method_obj)
  384. or is_static_method(actor_class.actor_cls, method_name)
  385. )
  386. ),
  387. )
  388. )
  389. else:
  390. self._method_num_returns = None
  391. self._method_signatures = None
  392. def __dir__(self) -> List[str]:
  393. if self._method_num_returns is not None:
  394. return self._method_num_returns.keys()
  395. if ray.is_connected():
  396. self._init_class_info()
  397. return self._method_num_returns.keys()
  398. return super().__dir__()
  399. # For compatibility with core worker ActorHandle._actor_id which returns
  400. # ActorID
  401. @property
  402. def _actor_id(self) -> ClientActorRef:
  403. return self.actor_ref
  404. def __hash__(self) -> int:
  405. return hash(self._actor_id)
  406. def __eq__(self, __value) -> bool:
  407. return hash(self) == hash(__value)
  408. def __getattr__(self, key):
  409. if key == "_method_num_returns":
  410. # We need to explicitly handle this value since it is used below,
  411. # otherwise we may end up infinitely recursing when deserializing.
  412. # This can happen after unpickling an object but before
  413. # _method_num_returns is correctly populated.
  414. raise AttributeError(f"ClientActorRef has no attribute '{key}'")
  415. if self._method_num_returns is None:
  416. self._init_class_info()
  417. if key not in self._method_signatures:
  418. raise AttributeError(f"ClientActorRef has no attribute '{key}'")
  419. return ClientRemoteMethod(
  420. self,
  421. key,
  422. self._method_num_returns.get(key),
  423. self._method_signatures.get(key),
  424. )
  425. def __repr__(self):
  426. return "ClientActorHandle(%s)" % (self.actor_ref.id.hex())
  427. def _init_class_info(self):
  428. # TODO: fetch Ray method decorators
  429. @ray.remote(num_cpus=0)
  430. def get_class_info(x):
  431. return x._ray_method_num_returns, x._ray_method_signatures
  432. self._method_num_returns, method_parameters = ray.get(
  433. get_class_info.remote(self)
  434. )
  435. self._method_signatures = {}
  436. for method, parameters in method_parameters.items():
  437. self._method_signatures[method] = inspect.Signature(parameters=parameters)
  438. class ClientRemoteMethod(ClientStub):
  439. """A stub for a method on a remote actor.
  440. Can be annotated with execution options.
  441. Args:
  442. actor_handle: A reference to the ClientActorHandle that generated
  443. this method and will have this method called upon it.
  444. method_name: The name of this method
  445. """
  446. def __init__(
  447. self,
  448. actor_handle: ClientActorHandle,
  449. method_name: str,
  450. num_returns: int,
  451. signature: inspect.Signature,
  452. ):
  453. self._actor_handle = actor_handle
  454. self._method_name = method_name
  455. self._method_num_returns = num_returns
  456. self._signature = signature
  457. def __call__(self, *args, **kwargs):
  458. raise TypeError(
  459. "Actor methods cannot be called directly. Instead "
  460. f"of running 'object.{self._method_name}()', try "
  461. f"'object.{self._method_name}.remote()'."
  462. )
  463. def remote(self, *args, **kwargs):
  464. self._signature.bind(*args, **kwargs)
  465. return return_refs(ray.call_remote(self, *args, **kwargs))
  466. def __repr__(self):
  467. return "ClientRemoteMethod(%s, %s, %s)" % (
  468. self._method_name,
  469. self._actor_handle,
  470. self._method_num_returns,
  471. )
  472. def options(self, **kwargs):
  473. return OptionWrapper(self, kwargs)
  474. def _remote(self, args=None, kwargs=None, **option_args):
  475. if args is None:
  476. args = []
  477. if kwargs is None:
  478. kwargs = {}
  479. return self.options(**option_args).remote(*args, **kwargs)
  480. def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
  481. task = ray_client_pb2.ClientTask()
  482. task.type = ray_client_pb2.ClientTask.METHOD
  483. task.name = self._method_name
  484. task.payload_id = self._actor_handle.actor_ref.id
  485. return task
  486. def _num_returns(self) -> int:
  487. return self._method_num_returns
  488. class OptionWrapper:
  489. def __init__(self, stub: ClientStub, options: Optional[Dict[str, Any]]):
  490. self._remote_stub = stub
  491. self._options = validate_options(options)
  492. def remote(self, *args, **kwargs):
  493. self._remote_stub._signature.bind(*args, **kwargs)
  494. return return_refs(ray.call_remote(self, *args, **kwargs))
  495. def __getattr__(self, key):
  496. return getattr(self._remote_stub, key)
  497. def _prepare_client_task(self):
  498. task = self._remote_stub._prepare_client_task()
  499. set_task_options(task, self._options)
  500. return task
  501. def _num_returns(self) -> int:
  502. if self._options:
  503. num = self._options.get("num_returns")
  504. if num is not None:
  505. return num
  506. return self._remote_stub._num_returns()
  507. class ActorOptionWrapper(OptionWrapper):
  508. def remote(self, *args, **kwargs):
  509. self._remote_stub._init_signature.bind(*args, **kwargs)
  510. futures = ray.call_remote(self, *args, **kwargs)
  511. assert len(futures) == 1
  512. actor_class = None
  513. if isinstance(self._remote_stub, ClientActorClass):
  514. actor_class = self._remote_stub
  515. return ClientActorHandle(ClientActorRef(futures[0]), actor_class=actor_class)
  516. def set_task_options(
  517. task: ray_client_pb2.ClientTask,
  518. options: Optional[Dict[str, Any]],
  519. field: str = "options",
  520. ) -> None:
  521. if options is None:
  522. task.ClearField(field)
  523. return
  524. getattr(task, field).pickled_options = pickle.dumps(options)
  525. def return_refs(
  526. futures: List[Future],
  527. ) -> Union[None, ClientObjectRef, List[ClientObjectRef]]:
  528. if not futures:
  529. return None
  530. if len(futures) == 1:
  531. return ClientObjectRef(futures[0])
  532. return [ClientObjectRef(fut) for fut in futures]
  533. class InProgressSentinel:
  534. def __repr__(self) -> str:
  535. return self.__class__.__name__
  536. class ClientSideRefID:
  537. """An ID generated by the client for objects not yet given an ObjectRef"""
  538. def __init__(self, id: bytes):
  539. assert len(id) != 0
  540. self.id = id
  541. @staticmethod
  542. def generate_id() -> "ClientSideRefID":
  543. tid = uuid.uuid4()
  544. return ClientSideRefID(b"\xcc" + tid.bytes)
  545. def remote_decorator(options: Optional[Dict[str, Any]]):
  546. def decorator(function_or_class) -> ClientStub:
  547. if inspect.isfunction(function_or_class) or is_cython(function_or_class):
  548. return ClientRemoteFunc(function_or_class, options=options)
  549. elif inspect.isclass(function_or_class):
  550. return ClientActorClass(function_or_class, options=options)
  551. else:
  552. raise TypeError(
  553. "The @ray.remote decorator must be applied to "
  554. "either a function or to a class."
  555. )
  556. return decorator
  557. @dataclass
  558. class ClientServerHandle:
  559. """Holds the handles to the registered gRPC servicers and their server."""
  560. task_servicer: ray_client_pb2_grpc.RayletDriverServicer
  561. data_servicer: ray_client_pb2_grpc.RayletDataStreamerServicer
  562. logs_servicer: ray_client_pb2_grpc.RayletLogStreamerServicer
  563. grpc_server: grpc.Server
  564. def stop(self, grace: int) -> None:
  565. # The data servicer might be sleeping while waiting for clients to
  566. # reconnect. Signal that they no longer have to sleep and can exit
  567. # immediately, since the RPC server is stopped.
  568. self.grpc_server.stop(grace)
  569. self.data_servicer.stopped.set()
  570. # Add a hook for all the cases that previously
  571. # expected simply a gRPC server
  572. def __getattr__(self, attr):
  573. return getattr(self.grpc_server, attr)
  574. def _get_client_id_from_context(context: Any) -> str:
  575. """
  576. Get `client_id` from gRPC metadata. If the `client_id` is not present,
  577. this function logs an error and sets the status_code.
  578. """
  579. metadata = dict(context.invocation_metadata())
  580. client_id = metadata.get("client_id") or ""
  581. if client_id == "":
  582. logger.error("Client connecting with no client_id")
  583. context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
  584. return client_id
  585. def _propagate_error_in_context(e: Exception, context: Any) -> bool:
  586. """
  587. Encode an error into the context of an RPC response. Returns True
  588. if the error can be recovered from, false otherwise
  589. """
  590. try:
  591. if isinstance(e, grpc.RpcError):
  592. # RPC error, propagate directly by copying details into context
  593. context.set_code(e.code())
  594. context.set_details(e.details())
  595. return e.code() not in GRPC_UNRECOVERABLE_ERRORS
  596. except Exception:
  597. # Extra precaution -- if encoding the RPC directly fails fallback
  598. # to treating it as a regular error
  599. pass
  600. context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
  601. context.set_details(str(e))
  602. return False
  603. def _id_is_newer(id1: int, id2: int) -> bool:
  604. """
  605. We should only replace cache entries with the responses for newer IDs.
  606. Most of the time newer IDs will be the ones with higher value, except when
  607. the req_id counter rolls over. We check for this case by checking the
  608. distance between the two IDs. If the distance is significant, then it's
  609. likely that the req_id counter rolled over, and the smaller id should
  610. still be used to replace the one in cache.
  611. """
  612. diff = abs(id2 - id1)
  613. # Int32 max is also the maximum number of simultaneous in-flight requests.
  614. if diff > (INT32_MAX // 2):
  615. # Rollover likely occurred. In this case the smaller ID is newer
  616. return id1 < id2
  617. return id1 > id2
  618. class ResponseCache:
  619. """
  620. Cache for blocking method calls. Needed to prevent retried requests from
  621. being applied multiple times on the server, for example when the client
  622. disconnects. This is used to cache requests/responses sent through
  623. unary-unary RPCs to the RayletServicer.
  624. Note that no clean up logic is used, the last response for each thread
  625. will always be remembered, so at most the cache will hold N entries,
  626. where N is the number of threads on the client side. This relies on the
  627. assumption that a thread will not make a new blocking request until it has
  628. received a response for a previous one, at which point it's safe to
  629. overwrite the old response.
  630. The high level logic is:
  631. 1. Before making a call, check the cache for the current thread.
  632. 2. If present in the cache, check the request id of the cached
  633. response.
  634. a. If it matches the current request_id, then the request has been
  635. received before and we shouldn't re-attempt the logic. Wait for
  636. the response to become available in the cache, and then return it
  637. b. If it doesn't match, then this is a new request and we can
  638. proceed with calling the real stub. While the response is still
  639. being generated, temporarily keep (req_id, None) in the cache.
  640. Once the call is finished, update the cache entry with the
  641. new (req_id, response) pair. Notify other threads that may
  642. have been waiting for the response to be prepared.
  643. """
  644. def __init__(self):
  645. self.cv = threading.Condition()
  646. self.cache: Dict[int, Tuple[int, Any]] = {}
  647. def check_cache(self, thread_id: int, request_id: int) -> Optional[Any]:
  648. """
  649. Check the cache for a given thread, and see if the entry in the cache
  650. matches the current request_id. Returns None if the request_id has
  651. not been seen yet, otherwise returns the cached result.
  652. Throws an error if the placeholder in the cache doesn't match the
  653. request_id -- this means that a new request evicted the old value in
  654. the cache, and that the RPC for `request_id` is redundant and the
  655. result can be discarded, i.e.:
  656. 1. Request A is sent (A1)
  657. 2. Channel disconnects
  658. 3. Request A is resent (A2)
  659. 4. A1 is received
  660. 5. A2 is received, waits for A1 to finish
  661. 6. A1 finishes and is sent back to client
  662. 7. Request B is sent
  663. 8. Request B overwrites cache entry
  664. 9. A2 wakes up extremely late, but cache is now invalid
  665. In practice this is VERY unlikely to happen, but the error can at
  666. least serve as a sanity check or catch invalid request id's.
  667. """
  668. with self.cv:
  669. if thread_id in self.cache:
  670. cached_request_id, cached_resp = self.cache[thread_id]
  671. if cached_request_id == request_id:
  672. while cached_resp is None:
  673. # The call was started, but the response hasn't yet
  674. # been added to the cache. Let go of the lock and
  675. # wait until the response is ready.
  676. self.cv.wait()
  677. cached_request_id, cached_resp = self.cache[thread_id]
  678. if cached_request_id != request_id:
  679. raise RuntimeError(
  680. "Cached response doesn't match the id of the "
  681. "original request. This might happen if this "
  682. "request was received out of order. The "
  683. "result of the caller is no longer needed. "
  684. f"({request_id} != {cached_request_id})"
  685. )
  686. return cached_resp
  687. if not _id_is_newer(request_id, cached_request_id):
  688. raise RuntimeError(
  689. "Attempting to replace newer cache entry with older "
  690. "one. This might happen if this request was received "
  691. "out of order. The result of the caller is no "
  692. f"longer needed. ({request_id} != {cached_request_id}"
  693. )
  694. self.cache[thread_id] = (request_id, None)
  695. return None
  696. def update_cache(self, thread_id: int, request_id: int, response: Any) -> None:
  697. """
  698. Inserts `response` into the cache for `request_id`.
  699. """
  700. with self.cv:
  701. cached_request_id, cached_resp = self.cache[thread_id]
  702. if cached_request_id != request_id or cached_resp is not None:
  703. # The cache was overwritten by a newer requester between
  704. # our call to check_cache and our call to update it.
  705. # This can't happen if the assumption that the cached requests
  706. # are all blocking on the client side, so if you encounter
  707. # this, check if any async requests are being cached.
  708. raise RuntimeError(
  709. "Attempting to update the cache, but placeholder's "
  710. "do not match the current request_id. This might happen "
  711. "if this request was received out of order. The result "
  712. f"of the caller is no longer needed. ({request_id} != "
  713. f"{cached_request_id})"
  714. )
  715. self.cache[thread_id] = (request_id, response)
  716. self.cv.notify_all()
  717. class OrderedResponseCache:
  718. """
  719. Cache for streaming RPCs, i.e. the DataServicer. Relies on explicit
  720. ack's from the client to determine when it can clean up cache entries.
  721. """
  722. def __init__(self):
  723. self.last_received = 0
  724. self.cv = threading.Condition()
  725. self.cache: Dict[int, Any] = OrderedDict()
  726. def check_cache(self, req_id: int) -> Optional[Any]:
  727. """
  728. Check the cache for a given thread, and see if the entry in the cache
  729. matches the current request_id. Returns None if the request_id has
  730. not been seen yet, otherwise returns the cached result.
  731. """
  732. with self.cv:
  733. if _id_is_newer(self.last_received, req_id) or self.last_received == req_id:
  734. # Request is for an id that has already been cleared from
  735. # cache/acknowledged.
  736. raise RuntimeError(
  737. "Attempting to accesss a cache entry that has already "
  738. "cleaned up. The client has already acknowledged "
  739. f"receiving this response. ({req_id}, "
  740. f"{self.last_received})"
  741. )
  742. if req_id in self.cache:
  743. cached_resp = self.cache[req_id]
  744. while cached_resp is None:
  745. # The call was started, but the response hasn't yet been
  746. # added to the cache. Let go of the lock and wait until
  747. # the response is ready
  748. self.cv.wait()
  749. if req_id not in self.cache:
  750. raise RuntimeError(
  751. "Cache entry was removed. This likely means that "
  752. "the result of this call is no longer needed."
  753. )
  754. cached_resp = self.cache[req_id]
  755. return cached_resp
  756. self.cache[req_id] = None
  757. return None
  758. def update_cache(self, req_id: int, resp: Any) -> None:
  759. """
  760. Inserts `response` into the cache for `request_id`.
  761. """
  762. with self.cv:
  763. self.cv.notify_all()
  764. if req_id not in self.cache:
  765. raise RuntimeError(
  766. "Attempting to update the cache, but placeholder is "
  767. "missing. This might happen on a redundant call to "
  768. f"update_cache. ({req_id})"
  769. )
  770. self.cache[req_id] = resp
  771. def invalidate(self, e: Exception) -> bool:
  772. """
  773. Invalidate any partially populated cache entries, replacing their
  774. placeholders with the passed in exception. Useful to prevent a thread
  775. from waiting indefinitely on a failed call.
  776. Returns True if the cache contains an error, False otherwise
  777. """
  778. with self.cv:
  779. invalid = False
  780. for req_id in self.cache:
  781. if self.cache[req_id] is None:
  782. self.cache[req_id] = e
  783. if isinstance(self.cache[req_id], Exception):
  784. invalid = True
  785. self.cv.notify_all()
  786. return invalid
  787. def cleanup(self, last_received: int) -> None:
  788. """
  789. Cleanup all of the cached requests up to last_received. Assumes that
  790. the cache entries were inserted in ascending order.
  791. """
  792. with self.cv:
  793. if _id_is_newer(last_received, self.last_received):
  794. self.last_received = last_received
  795. to_remove = []
  796. for req_id in self.cache:
  797. if _id_is_newer(last_received, req_id) or last_received == req_id:
  798. to_remove.append(req_id)
  799. else:
  800. break
  801. for req_id in to_remove:
  802. del self.cache[req_id]
  803. self.cv.notify_all()