server.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962
  1. import base64
  2. import functools
  3. import gc
  4. import inspect
  5. import json
  6. import logging
  7. import math
  8. import pickle
  9. import queue
  10. import threading
  11. import time
  12. from collections import defaultdict
  13. from typing import Any, Callable, Dict, List, Optional, Set, Union
  14. import grpc
  15. import ray
  16. import ray._private.state
  17. import ray.core.generated.ray_client_pb2 as ray_client_pb2
  18. import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
  19. from ray import cloudpickle
  20. from ray._common.network_utils import build_address, is_localhost
  21. from ray._private import ray_constants
  22. from ray._private.client_mode_hook import disable_client_hook
  23. from ray._private.ray_constants import env_integer
  24. from ray._private.ray_logging import setup_logger
  25. from ray._private.services import canonicalize_bootstrap_address_or_die
  26. from ray._private.tls_utils import add_port_to_grpc_server
  27. from ray._raylet import GcsClient
  28. from ray.job_config import JobConfig
  29. from ray.util.client.common import (
  30. CLIENT_SERVER_MAX_THREADS,
  31. GRPC_OPTIONS,
  32. OBJECT_TRANSFER_CHUNK_SIZE,
  33. ClientServerHandle,
  34. ResponseCache,
  35. )
  36. from ray.util.client.server.dataservicer import DataServicer
  37. from ray.util.client.server.logservicer import LogstreamServicer
  38. from ray.util.client.server.proxier import serve_proxier
  39. from ray.util.client.server.server_pickler import dumps_from_server, loads_from_client
  40. from ray.util.client.server.server_stubs import current_server
  41. logger = logging.getLogger(__name__)
  42. TIMEOUT_FOR_SPECIFIC_SERVER_S = env_integer("TIMEOUT_FOR_SPECIFIC_SERVER_S", 30)
  43. def _use_response_cache(func):
  44. """
  45. Decorator for gRPC stubs. Before calling the real stubs, checks if there's
  46. an existing entry in the caches. If there is, then return the cached
  47. entry. Otherwise, call the real function and use the real cache
  48. """
  49. @functools.wraps(func)
  50. def wrapper(self, request, context):
  51. metadata = dict(context.invocation_metadata())
  52. expected_ids = ("client_id", "thread_id", "req_id")
  53. if any(i not in metadata for i in expected_ids):
  54. # Missing IDs, skip caching and call underlying stub directly
  55. return func(self, request, context)
  56. # Get relevant IDs to check cache
  57. client_id = metadata["client_id"]
  58. thread_id = metadata["thread_id"]
  59. req_id = int(metadata["req_id"])
  60. # Check if response already cached
  61. response_cache = self.response_caches[client_id]
  62. cached_entry = response_cache.check_cache(thread_id, req_id)
  63. if cached_entry is not None:
  64. if isinstance(cached_entry, Exception):
  65. # Original call errored, propogate error
  66. context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
  67. context.set_details(str(cached_entry))
  68. raise cached_entry
  69. return cached_entry
  70. try:
  71. # Response wasn't cached, call underlying stub and cache result
  72. resp = func(self, request, context)
  73. except Exception as e:
  74. # Unexpected error in underlying stub -- update cache and
  75. # propagate to user through context
  76. response_cache.update_cache(thread_id, req_id, e)
  77. context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
  78. context.set_details(str(e))
  79. raise
  80. response_cache.update_cache(thread_id, req_id, resp)
  81. return resp
  82. return wrapper
  83. class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
  84. def __init__(self, ray_connect_handler: Callable):
  85. """Construct a raylet service
  86. Args:
  87. ray_connect_handler: Function to connect to ray cluster
  88. """
  89. # Stores client_id -> (ref_id -> ObjectRef)
  90. self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict(dict)
  91. # Stores client_id -> (client_ref_id -> ref_id (in self.object_refs))
  92. self.client_side_ref_map: Dict[str, Dict[bytes, bytes]] = defaultdict(dict)
  93. self.function_refs = {}
  94. self.actor_refs: Dict[bytes, ray.ActorHandle] = {}
  95. self.actor_owners: Dict[str, Set[bytes]] = defaultdict(set)
  96. self.registered_actor_classes = {}
  97. self.named_actors = set()
  98. self.state_lock = threading.Lock()
  99. self.ray_connect_handler = ray_connect_handler
  100. self.response_caches: Dict[str, ResponseCache] = defaultdict(ResponseCache)
  101. def Init(
  102. self, request: ray_client_pb2.InitRequest, context=None
  103. ) -> ray_client_pb2.InitResponse:
  104. if request.job_config:
  105. job_config = pickle.loads(request.job_config)
  106. job_config._client_job = True
  107. else:
  108. job_config = None
  109. current_job_config = None
  110. with disable_client_hook():
  111. if ray.is_initialized():
  112. worker = ray._private.worker.global_worker
  113. current_job_config = worker.core_worker.get_job_config()
  114. else:
  115. extra_kwargs = json.loads(request.ray_init_kwargs or "{}")
  116. try:
  117. self.ray_connect_handler(job_config, **extra_kwargs)
  118. except Exception as e:
  119. logger.exception("Running Ray Init failed:")
  120. return ray_client_pb2.InitResponse(
  121. ok=False,
  122. msg=f"Call to `ray.init()` on the server failed with: {e}",
  123. )
  124. if job_config is None:
  125. return ray_client_pb2.InitResponse(ok=True)
  126. # NOTE(edoakes): this code should not be necessary anymore because we
  127. # only allow a single client/job per server. There is an existing test
  128. # that tests the behavior of multiple clients with the same job config
  129. # connecting to one server (test_client_init.py::test_num_clients),
  130. # so I'm leaving it here for now.
  131. job_config = job_config._get_proto_job_config()
  132. # If the server has been initialized, we need to compare whether the
  133. # runtime env is compatible.
  134. if current_job_config:
  135. job_uris = set(job_config.runtime_env_info.uris.working_dir_uri)
  136. job_uris.update(job_config.runtime_env_info.uris.py_modules_uris)
  137. current_job_uris = set(
  138. current_job_config.runtime_env_info.uris.working_dir_uri
  139. )
  140. current_job_uris.update(
  141. current_job_config.runtime_env_info.uris.py_modules_uris
  142. )
  143. if job_uris != current_job_uris and len(job_uris) > 0:
  144. return ray_client_pb2.InitResponse(
  145. ok=False,
  146. msg="Runtime environment doesn't match "
  147. f"request one {job_config.runtime_env_info.uris} "
  148. f"current one {current_job_config.runtime_env_info.uris}",
  149. )
  150. return ray_client_pb2.InitResponse(ok=True)
  151. @_use_response_cache
  152. def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse:
  153. try:
  154. with disable_client_hook():
  155. already_exists = ray.experimental.internal_kv._internal_kv_put(
  156. request.key,
  157. request.value,
  158. overwrite=request.overwrite,
  159. namespace=request.namespace,
  160. )
  161. except Exception as e:
  162. return_exception_in_context(e, context)
  163. already_exists = False
  164. return ray_client_pb2.KVPutResponse(already_exists=already_exists)
  165. def KVGet(self, request, context=None) -> ray_client_pb2.KVGetResponse:
  166. try:
  167. with disable_client_hook():
  168. value = ray.experimental.internal_kv._internal_kv_get(
  169. request.key, namespace=request.namespace
  170. )
  171. except Exception as e:
  172. return_exception_in_context(e, context)
  173. value = b""
  174. return ray_client_pb2.KVGetResponse(value=value)
  175. @_use_response_cache
  176. def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse:
  177. try:
  178. with disable_client_hook():
  179. deleted_num = ray.experimental.internal_kv._internal_kv_del(
  180. request.key,
  181. del_by_prefix=request.del_by_prefix,
  182. namespace=request.namespace,
  183. )
  184. except Exception as e:
  185. return_exception_in_context(e, context)
  186. deleted_num = 0
  187. return ray_client_pb2.KVDelResponse(deleted_num=deleted_num)
  188. def KVList(self, request, context=None) -> ray_client_pb2.KVListResponse:
  189. try:
  190. with disable_client_hook():
  191. keys = ray.experimental.internal_kv._internal_kv_list(
  192. request.prefix, namespace=request.namespace
  193. )
  194. except Exception as e:
  195. return_exception_in_context(e, context)
  196. keys = []
  197. return ray_client_pb2.KVListResponse(keys=keys)
  198. def KVExists(self, request, context=None) -> ray_client_pb2.KVExistsResponse:
  199. try:
  200. with disable_client_hook():
  201. exists = ray.experimental.internal_kv._internal_kv_exists(
  202. request.key, namespace=request.namespace
  203. )
  204. except Exception as e:
  205. return_exception_in_context(e, context)
  206. exists = False
  207. return ray_client_pb2.KVExistsResponse(exists=exists)
  208. def ListNamedActors(
  209. self, request, context=None
  210. ) -> ray_client_pb2.ClientListNamedActorsResponse:
  211. with disable_client_hook():
  212. actors = ray.util.list_named_actors(all_namespaces=request.all_namespaces)
  213. return ray_client_pb2.ClientListNamedActorsResponse(
  214. actors_json=json.dumps(actors)
  215. )
  216. def ClusterInfo(self, request, context=None) -> ray_client_pb2.ClusterInfoResponse:
  217. resp = ray_client_pb2.ClusterInfoResponse()
  218. resp.type = request.type
  219. if request.type == ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES:
  220. with disable_client_hook():
  221. resources = ray.cluster_resources()
  222. # Normalize resources into floats
  223. # (the function may return values that are ints)
  224. float_resources = {k: float(v) for k, v in resources.items()}
  225. resp.resource_table.CopyFrom(
  226. ray_client_pb2.ClusterInfoResponse.ResourceTable(table=float_resources)
  227. )
  228. elif request.type == ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES:
  229. with disable_client_hook():
  230. resources = ray.available_resources()
  231. # Normalize resources into floats
  232. # (the function may return values that are ints)
  233. float_resources = {k: float(v) for k, v in resources.items()}
  234. resp.resource_table.CopyFrom(
  235. ray_client_pb2.ClusterInfoResponse.ResourceTable(table=float_resources)
  236. )
  237. elif request.type == ray_client_pb2.ClusterInfoType.RUNTIME_CONTEXT:
  238. ctx = ray_client_pb2.ClusterInfoResponse.RuntimeContext()
  239. with disable_client_hook():
  240. rtc = ray.get_runtime_context()
  241. ctx.job_id = ray._common.utils.hex_to_binary(rtc.get_job_id())
  242. ctx.node_id = ray._common.utils.hex_to_binary(rtc.get_node_id())
  243. ctx.namespace = rtc.namespace
  244. ctx.capture_client_tasks = (
  245. rtc.should_capture_child_tasks_in_placement_group
  246. )
  247. ctx.gcs_address = rtc.gcs_address
  248. ctx.runtime_env = rtc.get_runtime_env_string()
  249. ctx.session_name = rtc.get_session_name()
  250. resp.runtime_context.CopyFrom(ctx)
  251. else:
  252. with disable_client_hook():
  253. resp.json = self._return_debug_cluster_info(request, context)
  254. return resp
  255. def _return_debug_cluster_info(self, request, context=None) -> str:
  256. """Handle ClusterInfo requests that only return a json blob."""
  257. data = None
  258. if request.type == ray_client_pb2.ClusterInfoType.NODES:
  259. data = ray.nodes()
  260. elif request.type == ray_client_pb2.ClusterInfoType.IS_INITIALIZED:
  261. data = ray.is_initialized()
  262. elif request.type == ray_client_pb2.ClusterInfoType.TIMELINE:
  263. data = ray.timeline()
  264. elif request.type == ray_client_pb2.ClusterInfoType.PING:
  265. data = {}
  266. elif request.type == ray_client_pb2.ClusterInfoType.DASHBOARD_URL:
  267. data = {"dashboard_url": ray._private.worker.get_dashboard_url()}
  268. else:
  269. raise TypeError("Unsupported cluster info type")
  270. return json.dumps(data)
  271. def release(self, client_id: str, id: bytes) -> bool:
  272. with self.state_lock:
  273. if client_id in self.object_refs:
  274. if id in self.object_refs[client_id]:
  275. logger.debug(f"Releasing object {id.hex()} for {client_id}")
  276. del self.object_refs[client_id][id]
  277. return True
  278. if client_id in self.actor_owners:
  279. if id in self.actor_owners[client_id]:
  280. logger.debug(f"Releasing actor {id.hex()} for {client_id}")
  281. self.actor_owners[client_id].remove(id)
  282. if self._can_remove_actor_ref(id):
  283. logger.debug(f"Deleting reference to actor {id.hex()}")
  284. del self.actor_refs[id]
  285. return True
  286. return False
  287. def release_all(self, client_id):
  288. with self.state_lock:
  289. self._release_objects(client_id)
  290. self._release_actors(client_id)
  291. # NOTE: Try to actually dereference the object and actor refs.
  292. # Otherwise dereferencing will happen later, which may run concurrently
  293. # with ray.shutdown() and will crash the process. The crash is a bug
  294. # that should be fixed eventually.
  295. gc.collect()
  296. def _can_remove_actor_ref(self, actor_id_bytes):
  297. no_owner = not any(
  298. actor_id_bytes in actor_list for actor_list in self.actor_owners.values()
  299. )
  300. return no_owner and actor_id_bytes not in self.named_actors
  301. def _release_objects(self, client_id):
  302. if client_id not in self.object_refs:
  303. logger.debug(f"Releasing client with no references: {client_id}")
  304. return
  305. count = len(self.object_refs[client_id])
  306. del self.object_refs[client_id]
  307. if client_id in self.client_side_ref_map:
  308. del self.client_side_ref_map[client_id]
  309. if client_id in self.response_caches:
  310. del self.response_caches[client_id]
  311. logger.debug(f"Released all {count} objects for client {client_id}")
  312. def _release_actors(self, client_id):
  313. if client_id not in self.actor_owners:
  314. logger.debug(f"Releasing client with no actors: {client_id}")
  315. return
  316. count = 0
  317. actors_to_remove = self.actor_owners.pop(client_id)
  318. for id_bytes in actors_to_remove:
  319. count += 1
  320. if self._can_remove_actor_ref(id_bytes):
  321. logger.debug(f"Deleting reference to actor {id_bytes.hex()}")
  322. del self.actor_refs[id_bytes]
  323. logger.debug(f"Released all {count} actors for client: {client_id}")
  324. @_use_response_cache
  325. def Terminate(self, req, context=None):
  326. if req.WhichOneof("terminate_type") == "task_object":
  327. try:
  328. object_ref = self.object_refs[req.client_id][req.task_object.id]
  329. with disable_client_hook():
  330. ray.cancel(
  331. object_ref,
  332. force=req.task_object.force,
  333. recursive=req.task_object.recursive,
  334. )
  335. except Exception as e:
  336. return_exception_in_context(e, context)
  337. elif req.WhichOneof("terminate_type") == "actor":
  338. try:
  339. actor_ref = self.actor_refs[req.actor.id]
  340. with disable_client_hook():
  341. ray.kill(actor_ref, no_restart=req.actor.no_restart)
  342. except Exception as e:
  343. return_exception_in_context(e, context)
  344. else:
  345. raise RuntimeError(
  346. "Client requested termination without providing a valid terminate_type"
  347. )
  348. return ray_client_pb2.TerminateResponse(ok=True)
  349. def _async_get_object(
  350. self,
  351. request: ray_client_pb2.GetRequest,
  352. client_id: str,
  353. req_id: int,
  354. result_queue: queue.Queue,
  355. context=None,
  356. ) -> Optional[ray_client_pb2.GetResponse]:
  357. """Attempts to schedule a callback to push the GetResponse to the
  358. main loop when the desired object is ready. If there is some failure
  359. in scheduling, a GetResponse will be immediately returned.
  360. """
  361. if len(request.ids) != 1:
  362. raise ValueError(
  363. f"Async get() must have exactly 1 Object ID. Actual: {request}"
  364. )
  365. rid = request.ids[0]
  366. ref = self.object_refs[client_id].get(rid, None)
  367. if not ref:
  368. return ray_client_pb2.GetResponse(
  369. valid=False,
  370. error=cloudpickle.dumps(
  371. ValueError(
  372. f"ClientObjectRef with id {rid} not found for "
  373. f"client {client_id}"
  374. )
  375. ),
  376. )
  377. try:
  378. logger.debug("async get: %s" % ref)
  379. with disable_client_hook():
  380. def send_get_response(result: Any) -> None:
  381. """Pushes GetResponses to the main DataPath loop to send
  382. to the client. This is called when the object is ready
  383. on the server side."""
  384. try:
  385. serialized = dumps_from_server(result, client_id, self)
  386. total_size = len(serialized)
  387. assert total_size > 0, "Serialized object cannot be zero bytes"
  388. total_chunks = math.ceil(
  389. total_size / OBJECT_TRANSFER_CHUNK_SIZE
  390. )
  391. for chunk_id in range(request.start_chunk_id, total_chunks):
  392. start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
  393. end = min(
  394. total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE
  395. )
  396. get_resp = ray_client_pb2.GetResponse(
  397. valid=True,
  398. data=serialized[start:end],
  399. chunk_id=chunk_id,
  400. total_chunks=total_chunks,
  401. total_size=total_size,
  402. )
  403. chunk_resp = ray_client_pb2.DataResponse(
  404. get=get_resp, req_id=req_id
  405. )
  406. result_queue.put(chunk_resp)
  407. except Exception as exc:
  408. get_resp = ray_client_pb2.GetResponse(
  409. valid=False, error=cloudpickle.dumps(exc)
  410. )
  411. resp = ray_client_pb2.DataResponse(get=get_resp, req_id=req_id)
  412. result_queue.put(resp)
  413. ref._on_completed(send_get_response)
  414. return None
  415. except Exception as e:
  416. return ray_client_pb2.GetResponse(valid=False, error=cloudpickle.dumps(e))
  417. def GetObject(self, request: ray_client_pb2.GetRequest, context):
  418. metadata = dict(context.invocation_metadata())
  419. client_id = metadata.get("client_id")
  420. if client_id is None:
  421. yield ray_client_pb2.GetResponse(
  422. valid=False,
  423. error=cloudpickle.dumps(
  424. ValueError("client_id is not specified in request metadata")
  425. ),
  426. )
  427. else:
  428. yield from self._get_object(request, client_id)
  429. def _get_object(self, request: ray_client_pb2.GetRequest, client_id: str):
  430. objectrefs = []
  431. for rid in request.ids:
  432. ref = self.object_refs[client_id].get(rid, None)
  433. if ref:
  434. objectrefs.append(ref)
  435. else:
  436. yield ray_client_pb2.GetResponse(
  437. valid=False,
  438. error=cloudpickle.dumps(
  439. ValueError(
  440. f"ClientObjectRef {rid} is not found for client {client_id}"
  441. )
  442. ),
  443. )
  444. return
  445. try:
  446. logger.debug("get: %s" % objectrefs)
  447. with disable_client_hook():
  448. items = ray.get(objectrefs, timeout=request.timeout)
  449. except Exception as e:
  450. yield ray_client_pb2.GetResponse(valid=False, error=cloudpickle.dumps(e))
  451. return
  452. serialized = dumps_from_server(items, client_id, self)
  453. total_size = len(serialized)
  454. assert total_size > 0, "Serialized object cannot be zero bytes"
  455. total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE)
  456. for chunk_id in range(request.start_chunk_id, total_chunks):
  457. start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
  458. end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE)
  459. yield ray_client_pb2.GetResponse(
  460. valid=True,
  461. data=serialized[start:end],
  462. chunk_id=chunk_id,
  463. total_chunks=total_chunks,
  464. total_size=total_size,
  465. )
  466. def PutObject(
  467. self, request: ray_client_pb2.PutRequest, context=None
  468. ) -> ray_client_pb2.PutResponse:
  469. """gRPC entrypoint for unary PutObject"""
  470. return self._put_object(
  471. request.data, request.client_ref_id, "", request.owner_id, context
  472. )
  473. def _put_object(
  474. self,
  475. data: Union[bytes, bytearray],
  476. client_ref_id: bytes,
  477. client_id: str,
  478. owner_id: bytes,
  479. context=None,
  480. ):
  481. """Put an object in the cluster with ray.put() via gRPC.
  482. Args:
  483. data: Pickled data. Can either be bytearray if this is called
  484. from the dataservicer, or bytes if called from PutObject.
  485. client_ref_id: The id associated with this object on the client.
  486. client_id: The client who owns this data, for tracking when to
  487. delete this reference.
  488. owner_id: The owner id of the object.
  489. context: gRPC context.
  490. """
  491. try:
  492. obj = loads_from_client(data, self)
  493. if owner_id:
  494. owner = self.actor_refs[owner_id]
  495. else:
  496. owner = None
  497. with disable_client_hook():
  498. objectref = ray.put(obj, _owner=owner)
  499. except Exception as e:
  500. logger.exception("Put failed:")
  501. return ray_client_pb2.PutResponse(
  502. id=b"", valid=False, error=cloudpickle.dumps(e)
  503. )
  504. self.object_refs[client_id][objectref.binary()] = objectref
  505. if len(client_ref_id) > 0:
  506. self.client_side_ref_map[client_id][client_ref_id] = objectref.binary()
  507. logger.debug("put: %s" % objectref)
  508. return ray_client_pb2.PutResponse(id=objectref.binary(), valid=True)
  509. def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
  510. object_refs = []
  511. for rid in request.object_ids:
  512. if rid not in self.object_refs[request.client_id]:
  513. raise Exception(
  514. "Asking for a ref not associated with this client: %s" % str(rid)
  515. )
  516. object_refs.append(self.object_refs[request.client_id][rid])
  517. num_returns = request.num_returns
  518. timeout = request.timeout
  519. try:
  520. with disable_client_hook():
  521. ready_object_refs, remaining_object_refs = ray.wait(
  522. object_refs,
  523. num_returns=num_returns,
  524. timeout=timeout if timeout != -1 else None,
  525. )
  526. except Exception as e:
  527. # TODO(ameer): improve exception messages.
  528. logger.error(f"Exception {e}")
  529. return ray_client_pb2.WaitResponse(valid=False)
  530. logger.debug(
  531. "wait: %s %s" % (str(ready_object_refs), str(remaining_object_refs))
  532. )
  533. ready_object_ids = [
  534. ready_object_ref.binary() for ready_object_ref in ready_object_refs
  535. ]
  536. remaining_object_ids = [
  537. remaining_object_ref.binary()
  538. for remaining_object_ref in remaining_object_refs
  539. ]
  540. return ray_client_pb2.WaitResponse(
  541. valid=True,
  542. ready_object_ids=ready_object_ids,
  543. remaining_object_ids=remaining_object_ids,
  544. )
  545. def Schedule(
  546. self,
  547. task: ray_client_pb2.ClientTask,
  548. arglist: List[Any],
  549. kwargs: Dict[str, Any],
  550. context=None,
  551. ) -> ray_client_pb2.ClientTaskTicket:
  552. logger.debug(
  553. "schedule: %s %s"
  554. % (task.name, ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
  555. )
  556. try:
  557. with disable_client_hook():
  558. if task.type == ray_client_pb2.ClientTask.FUNCTION:
  559. result = self._schedule_function(task, arglist, kwargs, context)
  560. elif task.type == ray_client_pb2.ClientTask.ACTOR:
  561. result = self._schedule_actor(task, arglist, kwargs, context)
  562. elif task.type == ray_client_pb2.ClientTask.METHOD:
  563. result = self._schedule_method(task, arglist, kwargs, context)
  564. elif task.type == ray_client_pb2.ClientTask.NAMED_ACTOR:
  565. result = self._schedule_named_actor(task, context)
  566. else:
  567. raise NotImplementedError(
  568. "Unimplemented Schedule task type: %s"
  569. % ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)
  570. )
  571. result.valid = True
  572. return result
  573. except Exception as e:
  574. logger.debug("Caught schedule exception", exc_info=True)
  575. return ray_client_pb2.ClientTaskTicket(
  576. valid=False, error=cloudpickle.dumps(e)
  577. )
  578. def _schedule_method(
  579. self,
  580. task: ray_client_pb2.ClientTask,
  581. arglist: List[Any],
  582. kwargs: Dict[str, Any],
  583. context=None,
  584. ) -> ray_client_pb2.ClientTaskTicket:
  585. actor_handle = self.actor_refs.get(task.payload_id)
  586. if actor_handle is None:
  587. raise Exception("Can't run an actor the server doesn't have a handle for")
  588. method = getattr(actor_handle, task.name)
  589. opts = decode_options(task.options)
  590. if opts is not None:
  591. method = method.options(**opts)
  592. output = method.remote(*arglist, **kwargs)
  593. ids = self.unify_and_track_outputs(output, task.client_id)
  594. return ray_client_pb2.ClientTaskTicket(return_ids=ids)
  595. def _schedule_actor(
  596. self,
  597. task: ray_client_pb2.ClientTask,
  598. arglist: List[Any],
  599. kwargs: Dict[str, Any],
  600. context=None,
  601. ) -> ray_client_pb2.ClientTaskTicket:
  602. remote_class = self.lookup_or_register_actor(
  603. task.payload_id, task.client_id, decode_options(task.baseline_options)
  604. )
  605. opts = decode_options(task.options)
  606. if opts is not None:
  607. remote_class = remote_class.options(**opts)
  608. with current_server(self):
  609. actor = remote_class.remote(*arglist, **kwargs)
  610. self.actor_refs[actor._actor_id.binary()] = actor
  611. self.actor_owners[task.client_id].add(actor._actor_id.binary())
  612. return ray_client_pb2.ClientTaskTicket(return_ids=[actor._actor_id.binary()])
  613. def _schedule_function(
  614. self,
  615. task: ray_client_pb2.ClientTask,
  616. arglist: List[Any],
  617. kwargs: Dict[str, Any],
  618. context=None,
  619. ) -> ray_client_pb2.ClientTaskTicket:
  620. remote_func = self.lookup_or_register_func(
  621. task.payload_id, task.client_id, decode_options(task.baseline_options)
  622. )
  623. opts = decode_options(task.options)
  624. if opts is not None:
  625. remote_func = remote_func.options(**opts)
  626. with current_server(self):
  627. output = remote_func.remote(*arglist, **kwargs)
  628. ids = self.unify_and_track_outputs(output, task.client_id)
  629. return ray_client_pb2.ClientTaskTicket(return_ids=ids)
  630. def _schedule_named_actor(
  631. self, task: ray_client_pb2.ClientTask, context=None
  632. ) -> ray_client_pb2.ClientTaskTicket:
  633. assert len(task.payload_id) == 0
  634. # Convert empty string back to None.
  635. actor = ray.get_actor(task.name, task.namespace or None)
  636. bin_actor_id = actor._actor_id.binary()
  637. if bin_actor_id not in self.actor_refs:
  638. self.actor_refs[bin_actor_id] = actor
  639. self.actor_owners[task.client_id].add(bin_actor_id)
  640. self.named_actors.add(bin_actor_id)
  641. return ray_client_pb2.ClientTaskTicket(return_ids=[actor._actor_id.binary()])
  642. def lookup_or_register_func(
  643. self, id: bytes, client_id: str, options: Optional[Dict]
  644. ) -> ray.remote_function.RemoteFunction:
  645. with disable_client_hook():
  646. if id not in self.function_refs:
  647. funcref = self.object_refs[client_id][id]
  648. func = ray.get(funcref)
  649. if not inspect.isfunction(func):
  650. raise Exception(
  651. "Attempting to register function that isn't a function."
  652. )
  653. if options is None or len(options) == 0:
  654. self.function_refs[id] = ray.remote(func)
  655. else:
  656. self.function_refs[id] = ray.remote(**options)(func)
  657. return self.function_refs[id]
  658. def lookup_or_register_actor(
  659. self, id: bytes, client_id: str, options: Optional[Dict]
  660. ):
  661. with disable_client_hook():
  662. if id not in self.registered_actor_classes:
  663. actor_class_ref = self.object_refs[client_id][id]
  664. actor_class = ray.get(actor_class_ref)
  665. if not inspect.isclass(actor_class):
  666. raise Exception("Attempting to schedule actor that isn't a class.")
  667. if options is None or len(options) == 0:
  668. reg_class = ray.remote(actor_class)
  669. else:
  670. reg_class = ray.remote(**options)(actor_class)
  671. self.registered_actor_classes[id] = reg_class
  672. return self.registered_actor_classes[id]
  673. def unify_and_track_outputs(self, output, client_id):
  674. if output is None:
  675. outputs = []
  676. elif isinstance(output, list):
  677. outputs = output
  678. else:
  679. outputs = [output]
  680. for out in outputs:
  681. if out.binary() in self.object_refs[client_id]:
  682. logger.warning(f"Already saw object_ref {out}")
  683. self.object_refs[client_id][out.binary()] = out
  684. return [out.binary() for out in outputs]
  685. def return_exception_in_context(err, context):
  686. if context is not None:
  687. context.set_details(encode_exception(err))
  688. # Note: https://grpc.github.io/grpc/core/md_doc_statuscodes.html
  689. # ABORTED used here since it should never be generated by the
  690. # grpc lib -- this way we know the error was generated by ray logic
  691. context.set_code(grpc.StatusCode.ABORTED)
  692. def encode_exception(exception) -> str:
  693. data = cloudpickle.dumps(exception)
  694. return base64.standard_b64encode(data).decode()
  695. def decode_options(options: ray_client_pb2.TaskOptions) -> Optional[Dict[str, Any]]:
  696. if not options.pickled_options:
  697. return None
  698. opts = pickle.loads(options.pickled_options)
  699. assert isinstance(opts, dict)
  700. return opts
  701. def serve(host: str, port: int, ray_connect_handler=None):
  702. def default_connect_handler(
  703. job_config: JobConfig = None, **ray_init_kwargs: Dict[str, Any]
  704. ):
  705. with disable_client_hook():
  706. if not ray.is_initialized():
  707. return ray.init(job_config=job_config, **ray_init_kwargs)
  708. from ray._private.grpc_utils import create_grpc_server_with_interceptors
  709. ray_connect_handler = ray_connect_handler or default_connect_handler
  710. server = create_grpc_server_with_interceptors(
  711. max_workers=CLIENT_SERVER_MAX_THREADS,
  712. thread_name_prefix="ray_client_server",
  713. options=GRPC_OPTIONS,
  714. asynchronous=False,
  715. )
  716. task_servicer = RayletServicer(ray_connect_handler)
  717. data_servicer = DataServicer(task_servicer)
  718. logs_servicer = LogstreamServicer()
  719. ray_client_pb2_grpc.add_RayletDriverServicer_to_server(task_servicer, server)
  720. ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(data_servicer, server)
  721. ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(logs_servicer, server)
  722. if not is_localhost(host):
  723. add_port_to_grpc_server(server, f"127.0.0.1:{port}")
  724. add_port_to_grpc_server(server, f"{host}:{port}")
  725. current_handle = ClientServerHandle(
  726. task_servicer=task_servicer,
  727. data_servicer=data_servicer,
  728. logs_servicer=logs_servicer,
  729. grpc_server=server,
  730. )
  731. server.start()
  732. return current_handle
  733. def init_and_serve(host: str, port: int, *args, **kwargs):
  734. with disable_client_hook():
  735. # Disable client mode inside the worker's environment
  736. info = ray.init(*args, **kwargs)
  737. def ray_connect_handler(job_config=None, **ray_init_kwargs):
  738. # Ray client will disconnect from ray when
  739. # num_clients == 0.
  740. if ray.is_initialized():
  741. return info
  742. else:
  743. return ray.init(job_config=job_config, *args, **kwargs)
  744. server_handle = serve(host, port, ray_connect_handler=ray_connect_handler)
  745. return (server_handle, info)
  746. def shutdown_with_server(server, _exiting_interpreter=False):
  747. server.stop(1)
  748. with disable_client_hook():
  749. ray.shutdown(_exiting_interpreter)
  750. def create_ray_handler(address, redis_password, redis_username=None):
  751. def ray_connect_handler(job_config: JobConfig = None, **ray_init_kwargs):
  752. if address:
  753. if redis_password:
  754. ray.init(
  755. address=address,
  756. _redis_username=redis_username,
  757. _redis_password=redis_password,
  758. job_config=job_config,
  759. **ray_init_kwargs,
  760. )
  761. else:
  762. ray.init(address=address, job_config=job_config, **ray_init_kwargs)
  763. else:
  764. ray.init(job_config=job_config, **ray_init_kwargs)
  765. return ray_connect_handler
  766. def try_create_gcs_client(address: Optional[str]) -> Optional[GcsClient]:
  767. """
  768. Try to create a gcs client based on the command line args or by
  769. autodetecting a running Ray cluster.
  770. """
  771. address = canonicalize_bootstrap_address_or_die(address)
  772. return GcsClient(address=address)
  773. def main():
  774. import argparse
  775. parser = argparse.ArgumentParser()
  776. parser.add_argument(
  777. "--host", type=str, default="0.0.0.0", help="Host IP to bind to"
  778. )
  779. parser.add_argument("-p", "--port", type=int, default=10001, help="Port to bind to")
  780. parser.add_argument(
  781. "--mode",
  782. type=str,
  783. choices=["proxy", "legacy", "specific-server"],
  784. default="proxy",
  785. )
  786. parser.add_argument(
  787. "--address", required=False, type=str, help="Address to use to connect to Ray"
  788. )
  789. parser.add_argument(
  790. "--redis-username",
  791. required=False,
  792. type=str,
  793. help="username for connecting to Redis",
  794. )
  795. parser.add_argument(
  796. "--redis-password",
  797. required=False,
  798. type=str,
  799. help="Password for connecting to Redis",
  800. )
  801. parser.add_argument(
  802. "--runtime-env-agent-address",
  803. required=False,
  804. type=str,
  805. default=None,
  806. help="The port to use for connecting to the runtime_env_agent.",
  807. )
  808. parser.add_argument(
  809. "--node-id",
  810. required=False,
  811. type=str,
  812. default=None,
  813. help="The hex ID of this node.",
  814. )
  815. args, _ = parser.parse_known_args()
  816. setup_logger(ray_constants.LOGGER_LEVEL, ray_constants.LOGGER_FORMAT)
  817. ray_connect_handler = create_ray_handler(
  818. args.address, args.redis_password, args.redis_username
  819. )
  820. hostport = build_address(args.host, args.port)
  821. args_str = str(args)
  822. if args.redis_password:
  823. args_str = args_str.replace(args.redis_password, "****")
  824. logger.info(f"Starting Ray Client server on {hostport}, args {args_str}")
  825. if args.mode == "proxy":
  826. server = serve_proxier(
  827. args.host,
  828. args.port,
  829. args.address,
  830. redis_username=args.redis_username,
  831. redis_password=args.redis_password,
  832. runtime_env_agent_address=args.runtime_env_agent_address,
  833. node_id=args.node_id,
  834. )
  835. else:
  836. server = serve(args.host, args.port, ray_connect_handler)
  837. try:
  838. idle_checks_remaining = TIMEOUT_FOR_SPECIFIC_SERVER_S
  839. while True:
  840. health_report = {
  841. "time": time.time(),
  842. }
  843. try:
  844. if not ray.experimental.internal_kv._internal_kv_initialized():
  845. gcs_client = try_create_gcs_client(args.address)
  846. ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
  847. ray.experimental.internal_kv._internal_kv_put(
  848. "ray_client_server",
  849. json.dumps(health_report),
  850. namespace=ray_constants.KV_NAMESPACE_HEALTHCHECK,
  851. )
  852. except Exception as e:
  853. logger.error(
  854. f"[{args.mode}] Failed to put health check on {args.address}"
  855. )
  856. logger.exception(e)
  857. time.sleep(1)
  858. if args.mode == "specific-server":
  859. if server.data_servicer.num_clients > 0:
  860. idle_checks_remaining = TIMEOUT_FOR_SPECIFIC_SERVER_S
  861. else:
  862. idle_checks_remaining -= 1
  863. if idle_checks_remaining == 0:
  864. raise KeyboardInterrupt()
  865. if (
  866. idle_checks_remaining % 5 == 0
  867. and idle_checks_remaining != TIMEOUT_FOR_SPECIFIC_SERVER_S
  868. ):
  869. logger.info(f"{idle_checks_remaining} idle checks before shutdown.")
  870. except KeyboardInterrupt:
  871. server.stop(0)
  872. if __name__ == "__main__":
  873. main()