handle.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. import asyncio
  2. import logging
  3. import multiprocessing
  4. import os
  5. from typing import Optional, Union
  6. import multidict
  7. import ray.dashboard.consts as dashboard_consts
  8. from ray.dashboard.optional_deps import aiohttp
  9. from ray.dashboard.subprocesses.module import (
  10. SubprocessModule,
  11. SubprocessModuleConfig,
  12. run_module,
  13. )
  14. from ray.dashboard.subprocesses.utils import (
  15. ResponseType,
  16. get_http_session_to_module,
  17. module_logging_filename,
  18. )
  19. """
  20. This file contains code run in the parent process. It can start a subprocess and send
  21. messages to it. Requires non-minimal Ray.
  22. """
  23. logger = logging.getLogger(__name__)
  24. def filter_hop_by_hop_headers(
  25. headers: Union[dict[str, str], multidict.CIMultiDictProxy[str]],
  26. ) -> dict[str, str]:
  27. """
  28. Filter out hop-by-hop headers from the headers dict.
  29. """
  30. HOP_BY_HOP_HEADERS = {
  31. "connection",
  32. "keep-alive",
  33. "proxy-authenticate",
  34. "proxy-authorization",
  35. "te",
  36. "trailers",
  37. "transfer-encoding",
  38. "upgrade",
  39. }
  40. if isinstance(headers, multidict.CIMultiDictProxy):
  41. headers = dict(headers)
  42. filtered_headers = {
  43. key: value
  44. for key, value in headers.items()
  45. if key.lower() not in HOP_BY_HOP_HEADERS
  46. }
  47. return filtered_headers
  48. class SubprocessModuleHandle:
  49. """
  50. A handle to a module created as a subprocess. Can send messages to the module and
  51. receive responses. It only acts as a proxy to the aiohttp server running in the
  52. subprocess. On destruction, the subprocess is terminated.
  53. Lifecycle:
  54. 1. In SubprocessModuleHandle creation, the subprocess is started and runs an aiohttp
  55. server.
  56. 2. User must call start_module() and wait_for_module_ready() first.
  57. 3. SubprocessRouteTable.bind(handle)
  58. 4. app.add_routes(routes=SubprocessRouteTable.bound_routes())
  59. 5. Run the app.
  60. Health check (_do_periodic_health_check):
  61. Every 1s, do a health check by _do_once_health_check. If the module is
  62. unhealthy:
  63. 1. log the exception
  64. 2. log the last N lines of the log file
  65. 3. fail all active requests
  66. 4. restart the module
  67. TODO(ryw): define policy for health check:
  68. - check period (Now: 1s)
  69. - define unhealthy. (Now: process exits. TODO: check_health() for event loop hang)
  70. - check number of failures in a row before we deem it unhealthy (Now: N/A)
  71. - "max number of restarts"? (Now: infinite)
  72. """
  73. # Class variable. Force using spawn because Ray C bindings have static variables
  74. # that need to be re-initialized for a new process.
  75. mp_context = multiprocessing.get_context("spawn")
  76. def __init__(
  77. self,
  78. loop: asyncio.AbstractEventLoop,
  79. module_cls: type[SubprocessModule],
  80. config: SubprocessModuleConfig,
  81. ):
  82. self.loop = loop
  83. self.module_cls = module_cls
  84. self.config = config
  85. # Increment this when the module is restarted.
  86. self.incarnation = 0
  87. # Runtime states, set by start_module() and wait_for_module_ready(),
  88. # reset by destroy_module().
  89. self.parent_conn = None
  90. self.process = None
  91. self.http_client_session: Optional[aiohttp.ClientSession] = None
  92. self.health_check_task = None
  93. def str_for_state(self, incarnation: int, pid: Optional[int]):
  94. return f"SubprocessModuleHandle(module_cls={self.module_cls.__name__}, incarnation={incarnation}, pid={pid})"
  95. def __str__(self):
  96. return self.str_for_state(
  97. self.incarnation, self.process.pid if self.process else None
  98. )
  99. def start_module(self):
  100. """
  101. Start the module. Should be non-blocking.
  102. """
  103. self.parent_conn, child_conn = self.mp_context.Pipe()
  104. if not os.path.exists(self.config.socket_dir):
  105. os.makedirs(self.config.socket_dir)
  106. self.process = self.mp_context.Process(
  107. target=run_module,
  108. args=(
  109. self.module_cls,
  110. self.config,
  111. self.incarnation,
  112. child_conn,
  113. ),
  114. daemon=True,
  115. name=f"{self.module_cls.__name__}-{self.incarnation}",
  116. )
  117. self.process.start()
  118. child_conn.close()
  119. def wait_for_module_ready(self):
  120. """
  121. Wait for the module to be ready. This is called after start_module()
  122. and can be blocking.
  123. """
  124. if self.parent_conn.poll(dashboard_consts.SUBPROCESS_MODULE_WAIT_READY_TIMEOUT):
  125. try:
  126. self.parent_conn.recv()
  127. except EOFError:
  128. raise RuntimeError(
  129. f"Module {self.module_cls.__name__} failed to start. "
  130. "Received EOF from pipe."
  131. )
  132. self.parent_conn.close()
  133. self.parent_conn = None
  134. else:
  135. raise RuntimeError(
  136. f"Module {self.module_cls.__name__} failed to start. "
  137. f"Timeout after {dashboard_consts.SUBPROCESS_MODULE_WAIT_READY_TIMEOUT} seconds."
  138. )
  139. module_name = self.module_cls.__name__
  140. self.http_client_session = get_http_session_to_module(
  141. module_name, self.config.socket_dir, self.config.session_name
  142. )
  143. self.health_check_task = self.loop.create_task(self._do_periodic_health_check())
  144. async def destroy_module(self):
  145. """
  146. Destroy the module with complete resource cleanup.
  147. This is called when the module is unhealthy or being shut down.
  148. """
  149. self.incarnation += 1
  150. # 1. Cancel health check task first to avoid race conditions
  151. if self.health_check_task:
  152. # NOTE: destroy_module() can be invoked from within the periodic health
  153. # check task itself (see _do_periodic_health_check()).
  154. # Cancelling the *current* task would raise CancelledError at the next
  155. # await and prevent cleanup + restart from completing.
  156. current_task = asyncio.current_task()
  157. if current_task is None or self.health_check_task is not current_task:
  158. self.health_check_task.cancel()
  159. self.health_check_task = None
  160. # 2. Close parent connection
  161. if self.parent_conn:
  162. self.parent_conn.close()
  163. self.parent_conn = None
  164. # 3. Terminate process gracefully, then forcefully if needed
  165. if self.process:
  166. try:
  167. # First, try graceful termination
  168. if self.process.is_alive():
  169. self.process.terminate()
  170. logger.debug(
  171. f"Terminated process {self.process.pid}, waiting for exit..."
  172. )
  173. # Wait for process to exit (with timeout)
  174. self.process.join(
  175. timeout=dashboard_consts.SUBPROCESS_MODULE_GRACEFUL_SHUTDOWN_TIMEOUT
  176. )
  177. # Force kill if still alive
  178. if self.process.is_alive():
  179. logger.warning(
  180. f"Process {self.process.pid} did not exit gracefully, "
  181. "force killing..."
  182. )
  183. self.process.kill()
  184. self.process.join(
  185. timeout=dashboard_consts.SUBPROCESS_MODULE_JOIN_TIMEOUT
  186. )
  187. else:
  188. # Process already dead, just wait for it
  189. self.process.join(
  190. timeout=dashboard_consts.SUBPROCESS_MODULE_JOIN_TIMEOUT
  191. )
  192. logger.debug(f"Process {self.process.pid} terminated successfully")
  193. except Exception as e:
  194. logger.warning(f"Error terminating process: {e}")
  195. finally:
  196. self.process = None
  197. # 4. Close HTTP client session with proper cleanup
  198. if self.http_client_session:
  199. await self.http_client_session.close()
  200. self.http_client_session = None
  201. async def _health_check(self) -> aiohttp.web.Response:
  202. """
  203. Do internal health check. The module should respond immediately with a 200 OK.
  204. This can be used to measure module responsiveness in RTT, it also indicates
  205. subprocess event loop lag.
  206. Currently you get a 200 OK with body = b'success'. Later if we want we can add more
  207. observability payloads.
  208. """
  209. resp = await self.http_client_session.get("http://localhost/api/healthz")
  210. return aiohttp.web.Response(
  211. status=resp.status,
  212. headers=filter_hop_by_hop_headers(resp.headers),
  213. body=await resp.read(),
  214. )
  215. async def _do_once_health_check(self):
  216. """
  217. Do a health check once. We check for:
  218. 1. if the process exits, it's considered died.
  219. 2. if the health check endpoint returns non-200, it's considered unhealthy.
  220. """
  221. if self.process.exitcode is not None:
  222. raise RuntimeError(f"Process exited with code {self.process.exitcode}")
  223. resp = await self._health_check()
  224. if resp.status != 200:
  225. raise RuntimeError(f"Health check failed: status code is {resp.status}")
  226. async def _do_periodic_health_check(self):
  227. """
  228. Every 1s, do a health check. If the module is unhealthy:
  229. 1. log the exception
  230. 2. log the last N lines of the log file
  231. 3. restart the module
  232. """
  233. while True:
  234. try:
  235. await self._do_once_health_check()
  236. except Exception:
  237. filename = module_logging_filename(
  238. self.module_cls.__name__, self.config.logging_filename
  239. )
  240. logger.exception(
  241. f"Module {self.module_cls.__name__} is unhealthy. Please refer to "
  242. f"{self.config.log_dir}/{filename} for more details. Failing all "
  243. "active requests."
  244. )
  245. await self.destroy_module()
  246. self.start_module()
  247. self.wait_for_module_ready()
  248. return
  249. await asyncio.sleep(1)
  250. async def proxy_request(
  251. self, request: aiohttp.web.Request, resp_type: ResponseType = ResponseType.HTTP
  252. ) -> aiohttp.web.StreamResponse:
  253. """
  254. Sends a new request to the subprocess and returns the response.
  255. """
  256. if resp_type == ResponseType.HTTP:
  257. return await self.proxy_http(request)
  258. if resp_type == ResponseType.STREAM:
  259. return await self.proxy_stream(request)
  260. if resp_type == ResponseType.WEBSOCKET:
  261. return await self.proxy_websocket(request)
  262. raise ValueError(f"Unknown response type: {resp_type}")
  263. async def proxy_http(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
  264. """
  265. Proxy handler for non-streaming HTTP API
  266. It forwards the method, query string, headers, and body to the backend.
  267. """
  268. url = f"http://localhost{request.path_qs}"
  269. body = await request.read()
  270. async with self.http_client_session.request(
  271. request.method,
  272. url,
  273. data=body,
  274. headers=filter_hop_by_hop_headers(request.headers),
  275. allow_redirects=False,
  276. ) as backend_resp:
  277. resp_body = await backend_resp.read()
  278. return aiohttp.web.Response(
  279. status=backend_resp.status,
  280. headers=filter_hop_by_hop_headers(backend_resp.headers),
  281. body=resp_body,
  282. )
  283. async def proxy_stream(
  284. self, request: aiohttp.web.Request
  285. ) -> aiohttp.web.StreamResponse:
  286. """
  287. Proxy handler for streaming HTTP API.
  288. It forwards the method, query string, and body to the backend.
  289. """
  290. url = f"http://localhost{request.path_qs}"
  291. body = await request.read()
  292. async with self.http_client_session.request(
  293. request.method,
  294. url,
  295. data=body,
  296. headers=filter_hop_by_hop_headers(request.headers),
  297. ) as backend_resp:
  298. proxy_resp = aiohttp.web.StreamResponse(
  299. status=backend_resp.status,
  300. headers=filter_hop_by_hop_headers(backend_resp.headers),
  301. )
  302. await proxy_resp.prepare(request)
  303. async for chunk, _ in backend_resp.content.iter_chunks():
  304. await proxy_resp.write(chunk)
  305. await proxy_resp.write_eof()
  306. return proxy_resp
  307. async def proxy_websocket(
  308. self, request: aiohttp.web.Request
  309. ) -> aiohttp.web.StreamResponse:
  310. """
  311. Proxy handler for WebSocket API
  312. It establishes a WebSocket connection with the client and simultaneously connects
  313. to the backend server's WebSocket endpoint. Messages are forwarded in single
  314. direction from the backend to the client.
  315. If the backend responds with normal HTTP response, then try to treat it as a normal
  316. HTTP request and calls proxy_http instead.
  317. TODO: Support bidirectional communication if needed. We only support one direction
  318. because it's sufficient for the current use case.
  319. """
  320. url = f"http://localhost{request.path_qs}"
  321. try:
  322. async with self.http_client_session.ws_connect(
  323. url, headers=filter_hop_by_hop_headers(request.headers)
  324. ) as ws_to_backend:
  325. ws_from_client = aiohttp.web.WebSocketResponse()
  326. await ws_from_client.prepare(request)
  327. async for msg in ws_to_backend:
  328. if msg.type == aiohttp.WSMsgType.TEXT:
  329. await ws_from_client.send_str(msg.data)
  330. elif msg.type == aiohttp.WSMsgType.BINARY:
  331. await ws_from_client.send_bytes(msg.data)
  332. else:
  333. logger.error(f"Unknown msg type: {msg.type}")
  334. await ws_from_client.close()
  335. return ws_from_client
  336. except aiohttp.WSServerHandshakeError as e:
  337. logger.warning(f"WebSocket handshake error: {repr(e)}")
  338. # Try to treat it as a normal HTTP request
  339. return await self.proxy_http(request)
  340. except Exception as e:
  341. logger.error(f"WebSocket proxy error: {repr(e)}")
  342. raise aiohttp.web.HTTPInternalServerError(reason="WebSocket proxy error")