proxier.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936
  1. import atexit
  2. import json
  3. import logging
  4. import socket
  5. import sys
  6. import time
  7. import traceback
  8. import urllib
  9. from concurrent import futures
  10. from dataclasses import dataclass
  11. from itertools import chain
  12. from threading import Event, Lock, RLock, Thread
  13. from typing import Callable, Dict, List, Optional, Tuple
  14. from urllib.parse import urlparse, urlunparse
  15. import grpc
  16. import ray
  17. import ray.core.generated.ray_client_pb2 as ray_client_pb2
  18. import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
  19. import ray.core.generated.runtime_env_agent_pb2 as runtime_env_agent_pb2
  20. from ray._common.network_utils import (
  21. build_address,
  22. is_ipv6,
  23. is_localhost,
  24. )
  25. from ray._private.authentication.http_token_authentication import (
  26. format_authentication_http_error,
  27. get_auth_headers_if_auth_enabled,
  28. )
  29. from ray._private.client_mode_hook import disable_client_hook
  30. from ray._private.grpc_utils import init_grpc_channel
  31. from ray._private.parameter import RayParams
  32. from ray._private.runtime_env.context import RuntimeEnvContext
  33. from ray._private.services import (
  34. ProcessInfo,
  35. get_node_with_retry,
  36. start_ray_client_server,
  37. )
  38. from ray._private.tls_utils import add_port_to_grpc_server
  39. from ray._private.utils import detect_fate_sharing_support
  40. from ray._raylet import GcsClient
  41. from ray.cloudpickle.compat import pickle
  42. from ray.exceptions import AuthenticationError
  43. from ray.job_config import JobConfig
  44. from ray.util.client.common import (
  45. CLIENT_SERVER_MAX_THREADS,
  46. GRPC_OPTIONS,
  47. ClientServerHandle,
  48. _get_client_id_from_context,
  49. _propagate_error_in_context,
  50. )
  51. from ray.util.client.server.dataservicer import _get_reconnecting_from_context
  52. # Import psutil after ray so the packaged version is used.
  53. import psutil
  54. logger = logging.getLogger(__name__)
  55. CHECK_PROCESS_INTERVAL_S = 30
  56. MIN_SPECIFIC_SERVER_PORT = 23000
  57. MAX_SPECIFIC_SERVER_PORT = 24000
  58. CHECK_CHANNEL_TIMEOUT_S = 30
  59. LOGSTREAM_RETRIES = 5
  60. LOGSTREAM_RETRY_INTERVAL_SEC = 2
  61. @dataclass
  62. class SpecificServer:
  63. port: int
  64. process_handle_future: futures.Future
  65. channel: "grpc._channel.Channel"
  66. def is_ready(self) -> bool:
  67. """Check if the server is ready or not (doesn't block)."""
  68. return self.process_handle_future.done()
  69. def wait_ready(self, timeout: Optional[float] = None) -> None:
  70. """
  71. Wait for the server to actually start up.
  72. """
  73. res = self.process_handle_future.result(timeout=timeout)
  74. if res is None:
  75. # This is only set to none when server creation specifically fails.
  76. raise RuntimeError("Server startup failed.")
  77. def poll(self) -> Optional[int]:
  78. """Check if the process has exited."""
  79. try:
  80. proc = self.process_handle_future.result(timeout=0.1)
  81. if proc is not None:
  82. return proc.process.poll()
  83. except futures.TimeoutError:
  84. return
  85. def kill(self) -> None:
  86. """Try to send a KILL signal to the process."""
  87. try:
  88. proc = self.process_handle_future.result(timeout=0.1)
  89. if proc is not None:
  90. proc.process.kill()
  91. except futures.TimeoutError:
  92. # Server has not been started yet.
  93. pass
  94. def set_result(self, proc: Optional[ProcessInfo]) -> None:
  95. """Set the result of the internal future if it is currently unset."""
  96. if not self.is_ready():
  97. self.process_handle_future.set_result(proc)
  98. def _match_running_client_server(command: List[str]) -> bool:
  99. """
  100. Detects if the main process in the given command is the RayClient Server.
  101. This works by ensuring that the command is of the form:
  102. <py_executable> -m ray.util.client.server <args>
  103. """
  104. flattened = " ".join(command)
  105. return "-m ray.util.client.server" in flattened
  106. class ProxyManager:
  107. def __init__(
  108. self,
  109. address: Optional[str],
  110. runtime_env_agent_address: str,
  111. *,
  112. session_dir: Optional[str] = None,
  113. redis_username: Optional[str] = None,
  114. redis_password: Optional[str] = None,
  115. node_id: Optional[str] = None,
  116. ):
  117. self.servers: Dict[str, SpecificServer] = dict()
  118. self.server_lock = RLock()
  119. self._address = address
  120. self._redis_username = redis_username
  121. self._redis_password = redis_password
  122. self._free_ports: List[int] = list(
  123. range(MIN_SPECIFIC_SERVER_PORT, MAX_SPECIFIC_SERVER_PORT)
  124. )
  125. if runtime_env_agent_address:
  126. parsed = urlparse(runtime_env_agent_address)
  127. # runtime env agent self-assigns a free port, fetch it from GCS
  128. if parsed.port is None or parsed.port == 0:
  129. if node_id is None:
  130. raise ValueError(
  131. "node_id is required when runtime_env_agent_address "
  132. "has no port specified"
  133. )
  134. node_info = get_node_with_retry(address, node_id)
  135. runtime_env_agent_address = urlunparse(
  136. parsed._replace(
  137. netloc=f"{parsed.hostname}:{node_info['runtime_env_agent_port']}"
  138. )
  139. )
  140. self._runtime_env_agent_address = runtime_env_agent_address
  141. self._check_thread = Thread(target=self._check_processes, daemon=True)
  142. self._check_thread.start()
  143. self.fate_share = bool(detect_fate_sharing_support())
  144. self._node: Optional[ray._private.node.Node] = None
  145. atexit.register(self._cleanup)
  146. def _get_unused_port(self, family: int = socket.AF_INET) -> int:
  147. """
  148. Search for a port in _free_ports that is unused.
  149. """
  150. with self.server_lock:
  151. num_ports = len(self._free_ports)
  152. for _ in range(num_ports):
  153. port = self._free_ports.pop(0)
  154. s = socket.socket(family, socket.SOCK_STREAM)
  155. try:
  156. s.bind(("", port))
  157. except OSError:
  158. self._free_ports.append(port)
  159. continue
  160. finally:
  161. s.close()
  162. return port
  163. raise RuntimeError("Unable to succeed in selecting a random port.")
  164. @property
  165. def address(self) -> str:
  166. """
  167. Returns the provided Ray bootstrap address, or creates a new cluster.
  168. """
  169. if self._address:
  170. return self._address
  171. # Start a new, locally scoped cluster.
  172. connection_tuple = ray.init()
  173. self._address = connection_tuple["address"]
  174. self._session_dir = connection_tuple["session_dir"]
  175. return self._address
  176. @property
  177. def node(self) -> ray._private.node.Node:
  178. """Gets a 'ray.Node' object for this node (the head node).
  179. If it does not already exist, one is created using the bootstrap
  180. address.
  181. """
  182. if self._node:
  183. return self._node
  184. ray_params = RayParams(gcs_address=self.address)
  185. self._node = ray._private.node.Node(
  186. ray_params,
  187. head=False,
  188. shutdown_at_exit=False,
  189. spawn_reaper=False,
  190. connect_only=True,
  191. )
  192. return self._node
  193. def create_specific_server(self, client_id: str) -> SpecificServer:
  194. """
  195. Create, but not start a SpecificServer for a given client. This
  196. method must be called once per client.
  197. """
  198. with self.server_lock:
  199. assert (
  200. self.servers.get(client_id) is None
  201. ), f"Server already created for Client: {client_id}"
  202. host = "127.0.0.1"
  203. port = self._get_unused_port(
  204. socket.AF_INET6 if is_ipv6(host) else socket.AF_INET
  205. )
  206. server = SpecificServer(
  207. port=port,
  208. process_handle_future=futures.Future(),
  209. channel=init_grpc_channel(
  210. build_address(host, port), options=GRPC_OPTIONS
  211. ),
  212. )
  213. self.servers[client_id] = server
  214. return server
  215. def _create_runtime_env(
  216. self,
  217. serialized_runtime_env: str,
  218. runtime_env_config: str,
  219. specific_server: SpecificServer,
  220. ):
  221. """Increase the runtime_env reference by sending an RPC to the agent.
  222. Includes retry logic to handle the case when the agent is
  223. temporarily unreachable (e.g., hasn't been started up yet).
  224. """
  225. logger.info(
  226. f"Increasing runtime env reference for "
  227. f"ray_client_server_{specific_server.port}."
  228. f"Serialized runtime env is {serialized_runtime_env}."
  229. )
  230. assert (
  231. len(self._runtime_env_agent_address) > 0
  232. ), "runtime_env_agent_address not set"
  233. create_env_request = runtime_env_agent_pb2.GetOrCreateRuntimeEnvRequest(
  234. serialized_runtime_env=serialized_runtime_env,
  235. runtime_env_config=runtime_env_config,
  236. job_id=f"ray_client_server_{specific_server.port}".encode("utf-8"),
  237. source_process="client_server",
  238. )
  239. retries = 0
  240. max_retries = 5
  241. wait_time_s = 0.5
  242. last_exception = None
  243. while retries <= max_retries:
  244. try:
  245. url = urllib.parse.urljoin(
  246. self._runtime_env_agent_address, "/get_or_create_runtime_env"
  247. )
  248. data = create_env_request.SerializeToString()
  249. headers = {"Content-Type": "application/octet-stream"}
  250. headers.update(**get_auth_headers_if_auth_enabled(headers))
  251. req = urllib.request.Request(
  252. url, data=data, method="POST", headers=headers
  253. )
  254. response = urllib.request.urlopen(req, timeout=None)
  255. response_data = response.read()
  256. r = runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply()
  257. r.ParseFromString(response_data)
  258. if r.status == runtime_env_agent_pb2.AgentRpcStatus.AGENT_RPC_STATUS_OK:
  259. return r.serialized_runtime_env_context
  260. elif (
  261. r.status
  262. == runtime_env_agent_pb2.AgentRpcStatus.AGENT_RPC_STATUS_FAILED
  263. ):
  264. raise RuntimeError(
  265. "Failed to create runtime_env for Ray client "
  266. f"server, it is caused by:\n{r.error_message}"
  267. )
  268. else:
  269. assert False, f"Unknown status: {r.status}."
  270. except urllib.error.HTTPError as e:
  271. body = ""
  272. try:
  273. body = e.read().decode("utf-8", "ignore")
  274. except Exception:
  275. body = e.reason if hasattr(e, "reason") else str(e)
  276. formatted_error = format_authentication_http_error(e.code, body or "")
  277. if formatted_error:
  278. raise AuthenticationError(formatted_error) from e
  279. # Treat non-auth HTTP errors like URLError (retry with backoff)
  280. last_exception = e
  281. logger.warning(
  282. f"GetOrCreateRuntimeEnv request failed with HTTP {e.code}: {body or e}. "
  283. f"Retrying after {wait_time_s}s. "
  284. f"{max_retries-retries} retries remaining."
  285. )
  286. except urllib.error.URLError as e:
  287. last_exception = e
  288. logger.warning(
  289. f"GetOrCreateRuntimeEnv request failed: {e}. "
  290. f"Retrying after {wait_time_s}s. "
  291. f"{max_retries-retries} retries remaining."
  292. )
  293. # Exponential backoff.
  294. time.sleep(wait_time_s)
  295. retries += 1
  296. wait_time_s *= 2
  297. raise TimeoutError(
  298. f"GetOrCreateRuntimeEnv request failed after {max_retries} attempts."
  299. f" Last exception: {last_exception}"
  300. )
  301. def start_specific_server(self, client_id: str, job_config: JobConfig) -> bool:
  302. """
  303. Start up a RayClient Server for an incoming client to
  304. communicate with. Returns whether creation was successful.
  305. """
  306. specific_server = self._get_server_for_client(client_id)
  307. assert specific_server, f"Server has not been created for: {client_id}"
  308. output, error = self.node.get_log_file_handles(
  309. f"ray_client_server_{specific_server.port}", unique=True
  310. )
  311. serialized_runtime_env = job_config._get_serialized_runtime_env()
  312. runtime_env_config = job_config._get_proto_runtime_env_config()
  313. if not serialized_runtime_env or serialized_runtime_env == "{}":
  314. # TODO(edoakes): can we just remove this case and always send it
  315. # to the agent?
  316. serialized_runtime_env_context = RuntimeEnvContext().serialize()
  317. else:
  318. serialized_runtime_env_context = self._create_runtime_env(
  319. serialized_runtime_env=serialized_runtime_env,
  320. runtime_env_config=runtime_env_config,
  321. specific_server=specific_server,
  322. )
  323. proc = start_ray_client_server(
  324. self.address,
  325. "127.0.0.1",
  326. specific_server.port,
  327. stdout_file=output,
  328. stderr_file=error,
  329. fate_share=self.fate_share,
  330. server_type="specific-server",
  331. serialized_runtime_env_context=serialized_runtime_env_context,
  332. redis_username=self._redis_username,
  333. redis_password=self._redis_password,
  334. )
  335. # Wait for the process being run transitions from the shim process
  336. # to the actual RayClient Server.
  337. pid = proc.process.pid
  338. if sys.platform != "win32":
  339. psutil_proc = psutil.Process(pid)
  340. else:
  341. psutil_proc = None
  342. # Don't use `psutil` on Win32
  343. while psutil_proc is not None:
  344. if proc.process.poll() is not None:
  345. logger.error(f"SpecificServer startup failed for client: {client_id}")
  346. break
  347. cmd = psutil_proc.cmdline()
  348. if _match_running_client_server(cmd):
  349. break
  350. logger.debug("Waiting for Process to reach the actual client server.")
  351. time.sleep(0.5)
  352. specific_server.set_result(proc)
  353. logger.info(
  354. f"SpecificServer started on port: {specific_server.port} "
  355. f"with PID: {pid} for client: {client_id}"
  356. )
  357. return proc.process.poll() is None
  358. def _get_server_for_client(self, client_id: str) -> Optional[SpecificServer]:
  359. with self.server_lock:
  360. client = self.servers.get(client_id)
  361. if client is None:
  362. logger.error(f"Unable to find channel for client: {client_id}")
  363. return client
  364. def has_channel(self, client_id: str) -> bool:
  365. server = self._get_server_for_client(client_id)
  366. if server is None:
  367. return False
  368. return server.is_ready()
  369. def get_channel(
  370. self,
  371. client_id: str,
  372. ) -> Optional["grpc._channel.Channel"]:
  373. """
  374. Find the gRPC Channel for the given client_id. This will block until
  375. the server process has started.
  376. """
  377. server = self._get_server_for_client(client_id)
  378. if server is None:
  379. return None
  380. # Wait for the SpecificServer to become ready.
  381. server.wait_ready()
  382. try:
  383. grpc.channel_ready_future(server.channel).result(
  384. timeout=CHECK_CHANNEL_TIMEOUT_S
  385. )
  386. return server.channel
  387. except grpc.FutureTimeoutError:
  388. logger.exception(f"Timeout waiting for channel for {client_id}")
  389. return None
  390. def _check_processes(self):
  391. """
  392. Keeps the internal servers dictionary up-to-date with running servers.
  393. """
  394. while True:
  395. with self.server_lock:
  396. for client_id, specific_server in list(self.servers.items()):
  397. if specific_server.poll() is not None:
  398. logger.info(
  399. f"Specific server {client_id} is no longer running"
  400. f", freeing its port {specific_server.port}"
  401. )
  402. del self.servers[client_id]
  403. # Port is available to use again.
  404. self._free_ports.append(specific_server.port)
  405. time.sleep(CHECK_PROCESS_INTERVAL_S)
  406. def _cleanup(self) -> None:
  407. """
  408. Forcibly kill all spawned RayClient Servers. This ensures cleanup
  409. for platforms where fate sharing is not supported.
  410. """
  411. for server in self.servers.values():
  412. server.kill()
  413. class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer):
  414. def __init__(self, ray_connect_handler: Callable, proxy_manager: ProxyManager):
  415. self.proxy_manager = proxy_manager
  416. self.ray_connect_handler = ray_connect_handler
  417. def _call_inner_function(
  418. self, request, context, method: str
  419. ) -> Optional[ray_client_pb2_grpc.RayletDriverStub]:
  420. client_id = _get_client_id_from_context(context)
  421. chan = self.proxy_manager.get_channel(client_id)
  422. if not chan:
  423. logger.error(f"Channel for Client: {client_id} not found!")
  424. context.set_code(grpc.StatusCode.NOT_FOUND)
  425. return None
  426. stub = ray_client_pb2_grpc.RayletDriverStub(chan)
  427. try:
  428. metadata = [("client_id", client_id)]
  429. if context:
  430. metadata = context.invocation_metadata()
  431. return getattr(stub, method)(request, metadata=metadata)
  432. except Exception as e:
  433. # Error while proxying -- propagate the error's context to user
  434. logger.exception(f"Proxying call to {method} failed!")
  435. _propagate_error_in_context(e, context)
  436. def _has_channel_for_request(self, context):
  437. client_id = _get_client_id_from_context(context)
  438. return self.proxy_manager.has_channel(client_id)
  439. def Init(self, request, context=None) -> ray_client_pb2.InitResponse:
  440. return self._call_inner_function(request, context, "Init")
  441. def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse:
  442. """Proxies internal_kv.put.
  443. This is used by the working_dir code to upload to the GCS before
  444. ray.init is called. In that case (if we don't have a server yet)
  445. we directly make the internal KV call from the proxier.
  446. Otherwise, we proxy the call to the downstream server as usual.
  447. """
  448. if self._has_channel_for_request(context):
  449. return self._call_inner_function(request, context, "KVPut")
  450. with disable_client_hook():
  451. already_exists = ray.experimental.internal_kv._internal_kv_put(
  452. request.key, request.value, overwrite=request.overwrite
  453. )
  454. return ray_client_pb2.KVPutResponse(already_exists=already_exists)
  455. def KVGet(self, request, context=None) -> ray_client_pb2.KVGetResponse:
  456. """Proxies internal_kv.get.
  457. This is used by the working_dir code to upload to the GCS before
  458. ray.init is called. In that case (if we don't have a server yet)
  459. we directly make the internal KV call from the proxier.
  460. Otherwise, we proxy the call to the downstream server as usual.
  461. """
  462. if self._has_channel_for_request(context):
  463. return self._call_inner_function(request, context, "KVGet")
  464. with disable_client_hook():
  465. value = ray.experimental.internal_kv._internal_kv_get(request.key)
  466. return ray_client_pb2.KVGetResponse(value=value)
  467. def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse:
  468. """Proxies internal_kv.delete.
  469. This is used by the working_dir code to upload to the GCS before
  470. ray.init is called. In that case (if we don't have a server yet)
  471. we directly make the internal KV call from the proxier.
  472. Otherwise, we proxy the call to the downstream server as usual.
  473. """
  474. if self._has_channel_for_request(context):
  475. return self._call_inner_function(request, context, "KVDel")
  476. with disable_client_hook():
  477. ray.experimental.internal_kv._internal_kv_del(request.key)
  478. return ray_client_pb2.KVDelResponse()
  479. def KVList(self, request, context=None) -> ray_client_pb2.KVListResponse:
  480. """Proxies internal_kv.list.
  481. This is used by the working_dir code to upload to the GCS before
  482. ray.init is called. In that case (if we don't have a server yet)
  483. we directly make the internal KV call from the proxier.
  484. Otherwise, we proxy the call to the downstream server as usual.
  485. """
  486. if self._has_channel_for_request(context):
  487. return self._call_inner_function(request, context, "KVList")
  488. with disable_client_hook():
  489. keys = ray.experimental.internal_kv._internal_kv_list(request.prefix)
  490. return ray_client_pb2.KVListResponse(keys=keys)
  491. def KVExists(self, request, context=None) -> ray_client_pb2.KVExistsResponse:
  492. """Proxies internal_kv.exists.
  493. This is used by the working_dir code to upload to the GCS before
  494. ray.init is called. In that case (if we don't have a server yet)
  495. we directly make the internal KV call from the proxier.
  496. Otherwise, we proxy the call to the downstream server as usual.
  497. """
  498. if self._has_channel_for_request(context):
  499. return self._call_inner_function(request, context, "KVExists")
  500. with disable_client_hook():
  501. exists = ray.experimental.internal_kv._internal_kv_exists(request.key)
  502. return ray_client_pb2.KVExistsResponse(exists=exists)
  503. def PinRuntimeEnvURI(
  504. self, request, context=None
  505. ) -> ray_client_pb2.ClientPinRuntimeEnvURIResponse:
  506. """Proxies internal_kv.pin_runtime_env_uri.
  507. This is used by the working_dir code to upload to the GCS before
  508. ray.init is called. In that case (if we don't have a server yet)
  509. we directly make the internal KV call from the proxier.
  510. Otherwise, we proxy the call to the downstream server as usual.
  511. """
  512. if self._has_channel_for_request(context):
  513. return self._call_inner_function(request, context, "PinRuntimeEnvURI")
  514. with disable_client_hook():
  515. ray.experimental.internal_kv._pin_runtime_env_uri(
  516. request.uri, expiration_s=request.expiration_s
  517. )
  518. return ray_client_pb2.ClientPinRuntimeEnvURIResponse()
  519. def ListNamedActors(
  520. self, request, context=None
  521. ) -> ray_client_pb2.ClientListNamedActorsResponse:
  522. return self._call_inner_function(request, context, "ListNamedActors")
  523. def ClusterInfo(self, request, context=None) -> ray_client_pb2.ClusterInfoResponse:
  524. # NOTE: We need to respond to the PING request here to allow the client
  525. # to continue with connecting.
  526. if request.type == ray_client_pb2.ClusterInfoType.PING:
  527. resp = ray_client_pb2.ClusterInfoResponse(json=json.dumps({}))
  528. return resp
  529. return self._call_inner_function(request, context, "ClusterInfo")
  530. def Terminate(self, req, context=None):
  531. return self._call_inner_function(req, context, "Terminate")
  532. def GetObject(self, request, context=None):
  533. try:
  534. yield from self._call_inner_function(request, context, "GetObject")
  535. except Exception as e:
  536. # Error while iterating over response from GetObject stream
  537. logger.exception("Proxying call to GetObject failed!")
  538. _propagate_error_in_context(e, context)
  539. def PutObject(
  540. self, request: ray_client_pb2.PutRequest, context=None
  541. ) -> ray_client_pb2.PutResponse:
  542. return self._call_inner_function(request, context, "PutObject")
  543. def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
  544. return self._call_inner_function(request, context, "WaitObject")
  545. def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
  546. return self._call_inner_function(task, context, "Schedule")
  547. def ray_client_server_env_prep(job_config: JobConfig) -> JobConfig:
  548. return job_config
  549. def prepare_runtime_init_req(
  550. init_request: ray_client_pb2.DataRequest,
  551. ) -> Tuple[ray_client_pb2.DataRequest, JobConfig]:
  552. """
  553. Extract JobConfig and possibly mutate InitRequest before it is passed to
  554. the specific RayClient Server.
  555. """
  556. init_type = init_request.WhichOneof("type")
  557. assert init_type == "init", (
  558. "Received initial message of type " f"{init_type}, not 'init'."
  559. )
  560. req = init_request.init
  561. job_config = JobConfig()
  562. if req.job_config:
  563. job_config = pickle.loads(req.job_config)
  564. new_job_config = ray_client_server_env_prep(job_config)
  565. modified_init_req = ray_client_pb2.InitRequest(
  566. job_config=pickle.dumps(new_job_config),
  567. ray_init_kwargs=init_request.init.ray_init_kwargs,
  568. reconnect_grace_period=init_request.init.reconnect_grace_period,
  569. )
  570. init_request.init.CopyFrom(modified_init_req)
  571. return (init_request, new_job_config)
  572. class RequestIteratorProxy:
  573. def __init__(self, request_iterator):
  574. self.request_iterator = request_iterator
  575. def __iter__(self):
  576. return self
  577. def __next__(self):
  578. try:
  579. return next(self.request_iterator)
  580. except grpc.RpcError as e:
  581. # To stop proxying already CANCLLED request stream gracefully,
  582. # we only translate the exact grpc.RpcError to StopIteration,
  583. # not its subsclasses. ex: grpc._Rendezvous
  584. # https://github.com/grpc/grpc/blob/v1.43.0/src/python/grpcio/grpc/_server.py#L353-L354
  585. # This fixes the https://github.com/ray-project/ray/issues/23865
  586. if type(e) is not grpc.RpcError:
  587. raise e # re-raise other grpc exceptions
  588. logger.exception(
  589. "Stop iterating cancelled request stream with the following exception:"
  590. )
  591. raise StopIteration
  592. class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
  593. def __init__(self, proxy_manager: ProxyManager):
  594. self.num_clients = 0
  595. # dictionary mapping client_id's to the last time they connected
  596. self.clients_last_seen: Dict[str, float] = {}
  597. self.reconnect_grace_periods: Dict[str, float] = {}
  598. self.clients_lock = Lock()
  599. self.proxy_manager = proxy_manager
  600. self.stopped = Event()
  601. def modify_connection_info_resp(
  602. self, init_resp: ray_client_pb2.DataResponse
  603. ) -> ray_client_pb2.DataResponse:
  604. """
  605. Modify the `num_clients` returned the ConnectionInfoResponse because
  606. individual SpecificServers only have **one** client.
  607. """
  608. init_type = init_resp.WhichOneof("type")
  609. if init_type != "connection_info":
  610. return init_resp
  611. modified_resp = ray_client_pb2.DataResponse()
  612. modified_resp.CopyFrom(init_resp)
  613. with self.clients_lock:
  614. modified_resp.connection_info.num_clients = self.num_clients
  615. return modified_resp
  616. def Datapath(self, request_iterator, context):
  617. request_iterator = RequestIteratorProxy(request_iterator)
  618. cleanup_requested = False
  619. start_time = time.time()
  620. client_id = _get_client_id_from_context(context)
  621. if client_id == "":
  622. return
  623. reconnecting = _get_reconnecting_from_context(context)
  624. if reconnecting:
  625. with self.clients_lock:
  626. if client_id not in self.clients_last_seen:
  627. # Client took too long to reconnect, session has already
  628. # been cleaned up
  629. context.set_code(grpc.StatusCode.NOT_FOUND)
  630. context.set_details(
  631. "Attempted to reconnect a session that has already "
  632. "been cleaned up"
  633. )
  634. return
  635. self.clients_last_seen[client_id] = start_time
  636. server = self.proxy_manager._get_server_for_client(client_id)
  637. channel = self.proxy_manager.get_channel(client_id)
  638. # iterator doesn't need modification on reconnect
  639. new_iter = request_iterator
  640. else:
  641. # Create Placeholder *before* reading the first request.
  642. server = self.proxy_manager.create_specific_server(client_id)
  643. with self.clients_lock:
  644. self.clients_last_seen[client_id] = start_time
  645. self.num_clients += 1
  646. try:
  647. if not reconnecting:
  648. logger.info(f"New data connection from client {client_id}: ")
  649. init_req = next(request_iterator)
  650. with self.clients_lock:
  651. self.reconnect_grace_periods[
  652. client_id
  653. ] = init_req.init.reconnect_grace_period
  654. try:
  655. modified_init_req, job_config = prepare_runtime_init_req(init_req)
  656. if not self.proxy_manager.start_specific_server(
  657. client_id, job_config
  658. ):
  659. logger.error(
  660. f"Server startup failed for client: {client_id}, "
  661. f"using JobConfig: {job_config}!"
  662. )
  663. raise RuntimeError(
  664. "Starting Ray client server failed. See "
  665. f"ray_client_server_{server.port}.err for "
  666. "detailed logs."
  667. )
  668. channel = self.proxy_manager.get_channel(client_id)
  669. if channel is None:
  670. logger.error(f"Channel not found for {client_id}")
  671. raise RuntimeError(
  672. "Proxy failed to Connect to backend! Check "
  673. "`ray_client_server.err` and "
  674. f"`ray_client_server_{server.port}.err` on the "
  675. "head node of the cluster for the relevant logs. "
  676. "By default these are located at "
  677. "/tmp/ray/session_latest/logs."
  678. )
  679. except Exception:
  680. init_resp = ray_client_pb2.DataResponse(
  681. init=ray_client_pb2.InitResponse(
  682. ok=False, msg=traceback.format_exc()
  683. )
  684. )
  685. init_resp.req_id = init_req.req_id
  686. yield init_resp
  687. return None
  688. new_iter = chain([modified_init_req], request_iterator)
  689. stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
  690. metadata = [("client_id", client_id), ("reconnecting", str(reconnecting))]
  691. resp_stream = stub.Datapath(new_iter, metadata=metadata)
  692. for resp in resp_stream:
  693. resp_type = resp.WhichOneof("type")
  694. if resp_type == "connection_cleanup":
  695. # Specific server is skipping cleanup, proxier should too
  696. cleanup_requested = True
  697. yield self.modify_connection_info_resp(resp)
  698. except Exception as e:
  699. logger.exception("Proxying Datapath failed!")
  700. # Propogate error through context
  701. recoverable = _propagate_error_in_context(e, context)
  702. if not recoverable:
  703. # Client shouldn't attempt to recover, clean up connection
  704. cleanup_requested = True
  705. finally:
  706. cleanup_delay = self.reconnect_grace_periods.get(client_id)
  707. if not cleanup_requested and cleanup_delay is not None:
  708. # Delay cleanup, since client may attempt a reconnect
  709. # Wait on stopped event in case the server closes and we
  710. # can clean up earlier
  711. self.stopped.wait(timeout=cleanup_delay)
  712. with self.clients_lock:
  713. if client_id not in self.clients_last_seen:
  714. logger.info(f"{client_id} not found. Skipping clean up.")
  715. # Connection has already been cleaned up
  716. return
  717. last_seen = self.clients_last_seen[client_id]
  718. logger.info(
  719. f"{client_id} last started stream at {last_seen}. Current "
  720. f"stream started at {start_time}."
  721. )
  722. if last_seen > start_time:
  723. logger.info("Client reconnected. Skipping cleanup.")
  724. # Client has reconnected, don't clean up
  725. return
  726. logger.debug(f"Client detached: {client_id}")
  727. self.num_clients -= 1
  728. del self.clients_last_seen[client_id]
  729. if client_id in self.reconnect_grace_periods:
  730. del self.reconnect_grace_periods[client_id]
  731. server.set_result(None)
  732. class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
  733. def __init__(self, proxy_manager: ProxyManager):
  734. super().__init__()
  735. self.proxy_manager = proxy_manager
  736. def Logstream(self, request_iterator, context):
  737. request_iterator = RequestIteratorProxy(request_iterator)
  738. client_id = _get_client_id_from_context(context)
  739. if client_id == "":
  740. return
  741. logger.debug(f"New logstream connection from client {client_id}: ")
  742. channel = None
  743. # We need to retry a few times because the LogClient *may* connect
  744. # Before the DataClient has finished connecting.
  745. for i in range(LOGSTREAM_RETRIES):
  746. channel = self.proxy_manager.get_channel(client_id)
  747. if channel is not None:
  748. break
  749. logger.warning(f"Retrying Logstream connection. {i+1} attempts failed.")
  750. time.sleep(LOGSTREAM_RETRY_INTERVAL_SEC)
  751. if channel is None:
  752. context.set_code(grpc.StatusCode.NOT_FOUND)
  753. context.set_details(
  754. "Logstream proxy failed to connect. Channel for client "
  755. f"{client_id} not found."
  756. )
  757. return None
  758. stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel)
  759. resp_stream = stub.Logstream(
  760. request_iterator, metadata=[("client_id", client_id)]
  761. )
  762. try:
  763. for resp in resp_stream:
  764. yield resp
  765. except Exception:
  766. logger.exception("Proxying Logstream failed!")
  767. def serve_proxier(
  768. host: str,
  769. port: int,
  770. gcs_address: Optional[str],
  771. *,
  772. redis_username: Optional[str] = None,
  773. redis_password: Optional[str] = None,
  774. session_dir: Optional[str] = None,
  775. runtime_env_agent_address: Optional[str] = None,
  776. node_id: Optional[str] = None,
  777. ):
  778. # Initialize internal KV to be used to upload and download working_dir
  779. # before calling ray.init within the RayletServicers.
  780. # NOTE(edoakes): redis_address and redis_password should only be None in
  781. # tests.
  782. if gcs_address is not None:
  783. gcs_cli = GcsClient(address=gcs_address)
  784. ray.experimental.internal_kv._initialize_internal_kv(gcs_cli)
  785. from ray._private.grpc_utils import create_grpc_server_with_interceptors
  786. server = create_grpc_server_with_interceptors(
  787. max_workers=CLIENT_SERVER_MAX_THREADS,
  788. thread_name_prefix="ray_client_proxier",
  789. options=GRPC_OPTIONS,
  790. asynchronous=False,
  791. )
  792. proxy_manager = ProxyManager(
  793. gcs_address,
  794. session_dir=session_dir,
  795. redis_username=redis_username,
  796. redis_password=redis_password,
  797. runtime_env_agent_address=runtime_env_agent_address,
  798. node_id=node_id,
  799. )
  800. task_servicer = RayletServicerProxy(None, proxy_manager)
  801. data_servicer = DataServicerProxy(proxy_manager)
  802. logs_servicer = LogstreamServicerProxy(proxy_manager)
  803. ray_client_pb2_grpc.add_RayletDriverServicer_to_server(task_servicer, server)
  804. ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(data_servicer, server)
  805. ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(logs_servicer, server)
  806. if not is_localhost(host):
  807. add_port_to_grpc_server(server, f"127.0.0.1:{port}")
  808. add_port_to_grpc_server(server, f"{host}:{port}")
  809. server.start()
  810. return ClientServerHandle(
  811. task_servicer=task_servicer,
  812. data_servicer=data_servicer,
  813. logs_servicer=logs_servicer,
  814. grpc_server=server,
  815. )