state_head.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. import asyncio
  2. import logging
  3. import re
  4. from concurrent.futures import ThreadPoolExecutor
  5. from dataclasses import asdict
  6. from datetime import datetime
  7. from typing import Optional
  8. import aiohttp.web
  9. from aiohttp.web import Response
  10. import ray
  11. from ray import ActorID
  12. from ray._common.usage.usage_lib import TagKey, record_extra_usage_tag
  13. from ray._private.ray_constants import env_integer
  14. from ray.core.generated.gcs_pb2 import ActorTableData
  15. from ray.dashboard.consts import (
  16. RAY_STATE_SERVER_MAX_HTTP_REQUEST,
  17. RAY_STATE_SERVER_MAX_HTTP_REQUEST_ALLOWED,
  18. RAY_STATE_SERVER_MAX_HTTP_REQUEST_ENV_NAME,
  19. )
  20. from ray.dashboard.modules.log.log_manager import LogsManager
  21. from ray.dashboard.state_aggregator import StateAPIManager
  22. from ray.dashboard.state_api_utils import (
  23. do_reply,
  24. handle_list_api,
  25. handle_summary_api,
  26. options_from_req,
  27. )
  28. from ray.dashboard.subprocesses.module import SubprocessModule
  29. from ray.dashboard.subprocesses.routes import SubprocessRouteTable as routes
  30. from ray.dashboard.subprocesses.utils import ResponseType
  31. from ray.dashboard.utils import HTTPStatusCode, RateLimitedModule
  32. from ray.util.state.common import (
  33. DEFAULT_DOWNLOAD_FILENAME,
  34. DEFAULT_LOG_LIMIT,
  35. DEFAULT_RPC_TIMEOUT,
  36. GetLogOptions,
  37. )
  38. from ray.util.state.exception import DataSourceUnavailable
  39. from ray.util.state.state_manager import StateDataSourceClient
  40. logger = logging.getLogger(__name__)
  41. # NOTE: Executor in this head is intentionally constrained to just 1 thread by
  42. # default to limit its concurrency, therefore reducing potential for
  43. # GIL contention
  44. RAY_DASHBOARD_STATE_HEAD_TPE_MAX_WORKERS = env_integer(
  45. "RAY_DASHBOARD_STATE_HEAD_TPE_MAX_WORKERS", 1
  46. )
  47. # For filtering ANSI escape codes; the byte string used in the regex is equivalent to r'\x1b\[[\d;]+m'.
  48. ANSI_ESC_PATTERN = re.compile(b"\x1b\\x5b[(\x30-\x39)\x3b]+\x6d")
  49. class StateHead(SubprocessModule, RateLimitedModule):
  50. """Module to obtain state information from the Ray cluster.
  51. It is responsible for state observability APIs such as
  52. ray.list_actors(), ray.get_actor(), ray.summary_actors().
  53. """
  54. def __init__(self, *args, **kwargs):
  55. """Initialize for handling RESTful requests from State API Client"""
  56. SubprocessModule.__init__(self, *args, **kwargs)
  57. # We don't allow users to configure too high a rate limit
  58. RateLimitedModule.__init__(
  59. self,
  60. min(
  61. RAY_STATE_SERVER_MAX_HTTP_REQUEST,
  62. RAY_STATE_SERVER_MAX_HTTP_REQUEST_ALLOWED,
  63. ),
  64. )
  65. self._state_api_data_source_client = None
  66. self._state_api = None
  67. self._log_api = None
  68. self._executor = ThreadPoolExecutor(
  69. max_workers=RAY_DASHBOARD_STATE_HEAD_TPE_MAX_WORKERS,
  70. thread_name_prefix="state_head_executor",
  71. )
  72. # To make sure that the internal KV is initialized by getting the lazy property
  73. assert self.gcs_client is not None
  74. assert ray.experimental.internal_kv._internal_kv_initialized()
  75. async def limit_handler_(self):
  76. return do_reply(
  77. status_code=HTTPStatusCode.TOO_MANY_REQUESTS,
  78. error_message=(
  79. "Max number of in-progress requests="
  80. f"{self.max_num_call_} reached. "
  81. "To set a higher limit, set environment variable: "
  82. f"export {RAY_STATE_SERVER_MAX_HTTP_REQUEST_ENV_NAME}='xxx'. "
  83. f"Max allowed = {RAY_STATE_SERVER_MAX_HTTP_REQUEST_ALLOWED}"
  84. ),
  85. result=None,
  86. )
  87. @routes.get("/api/v0/actors")
  88. @RateLimitedModule.enforce_max_concurrent_calls
  89. async def list_actors(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
  90. record_extra_usage_tag(TagKey.CORE_STATE_API_LIST_ACTORS, "1")
  91. return await handle_list_api(self._state_api.list_actors, req)
  92. @routes.get("/api/v0/jobs")
  93. @RateLimitedModule.enforce_max_concurrent_calls
  94. async def list_jobs(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
  95. record_extra_usage_tag(TagKey.CORE_STATE_API_LIST_JOBS, "1")
  96. try:
  97. result = await self._state_api.list_jobs(option=options_from_req(req))
  98. return do_reply(
  99. status_code=HTTPStatusCode.OK,
  100. error_message="",
  101. result=asdict(result),
  102. )
  103. except DataSourceUnavailable as e:
  104. return do_reply(
  105. status_code=HTTPStatusCode.INTERNAL_ERROR,
  106. error_message=str(e),
  107. result=None,
  108. )
  109. @routes.get("/api/v0/nodes")
  110. @RateLimitedModule.enforce_max_concurrent_calls
  111. async def list_nodes(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
  112. record_extra_usage_tag(TagKey.CORE_STATE_API_LIST_NODES, "1")
  113. return await handle_list_api(self._state_api.list_nodes, req)
  114. @routes.get("/api/v0/placement_groups")
  115. @RateLimitedModule.enforce_max_concurrent_calls
  116. async def list_placement_groups(
  117. self, req: aiohttp.web.Request
  118. ) -> aiohttp.web.Response:
  119. record_extra_usage_tag(TagKey.CORE_STATE_API_LIST_PLACEMENT_GROUPS, "1")
  120. return await handle_list_api(self._state_api.list_placement_groups, req)
  121. @routes.get("/api/v0/workers")
  122. @RateLimitedModule.enforce_max_concurrent_calls
  123. async def list_workers(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
  124. record_extra_usage_tag(TagKey.CORE_STATE_API_LIST_WORKERS, "1")
  125. return await handle_list_api(self._state_api.list_workers, req)
  126. @routes.get("/api/v0/tasks")
  127. @RateLimitedModule.enforce_max_concurrent_calls
  128. async def list_tasks(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
  129. record_extra_usage_tag(TagKey.CORE_STATE_API_LIST_TASKS, "1")
  130. return await handle_list_api(self._state_api.list_tasks, req)
  131. @routes.get("/api/v0/objects")
  132. @RateLimitedModule.enforce_max_concurrent_calls
  133. async def list_objects(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
  134. record_extra_usage_tag(TagKey.CORE_STATE_API_LIST_OBJECTS, "1")
  135. return await handle_list_api(self._state_api.list_objects, req)
  136. @routes.get("/api/v0/runtime_envs")
  137. @RateLimitedModule.enforce_max_concurrent_calls
  138. async def list_runtime_envs(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
  139. record_extra_usage_tag(TagKey.CORE_STATE_API_LIST_RUNTIME_ENVS, "1")
  140. return await handle_list_api(self._state_api.list_runtime_envs, req)
  141. @routes.get("/api/v0/logs")
  142. @RateLimitedModule.enforce_max_concurrent_calls
  143. async def list_logs(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
  144. """Return a list of log files on a given node id.
  145. Unlike other list APIs that display all existing resources in the cluster,
  146. this API always require to specify node id and node ip.
  147. """
  148. record_extra_usage_tag(TagKey.CORE_STATE_API_LIST_LOGS, "1")
  149. glob_filter = req.query.get("glob", "*")
  150. node_id = req.query.get("node_id", None)
  151. node_ip = req.query.get("node_ip", None)
  152. timeout = int(req.query.get("timeout", DEFAULT_RPC_TIMEOUT))
  153. if not node_id and not node_ip:
  154. return do_reply(
  155. status_code=HTTPStatusCode.BAD_REQUEST,
  156. error_message=(
  157. "Both node id and node ip are not provided. "
  158. "Please provide at least one of them."
  159. ),
  160. result=None,
  161. )
  162. if not node_id:
  163. node_id = await self._log_api.ip_to_node_id(node_ip)
  164. if not node_id:
  165. return do_reply(
  166. status_code=HTTPStatusCode.NOT_FOUND,
  167. error_message=(
  168. f"Cannot find matching node_id for a given node ip {node_ip}"
  169. ),
  170. result=None,
  171. )
  172. try:
  173. result = await self._log_api.list_logs(
  174. node_id, timeout, glob_filter=glob_filter
  175. )
  176. except DataSourceUnavailable as e:
  177. return do_reply(
  178. status_code=HTTPStatusCode.INTERNAL_ERROR,
  179. error_message=str(e),
  180. result=None,
  181. )
  182. return do_reply(
  183. status_code=HTTPStatusCode.OK,
  184. error_message="",
  185. result=result,
  186. )
  187. @routes.get("/api/v0/logs/{media_type}", resp_type=ResponseType.STREAM)
  188. @RateLimitedModule.enforce_max_concurrent_calls
  189. async def get_logs(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
  190. """
  191. Fetches logs from the given criteria.
  192. """
  193. record_extra_usage_tag(TagKey.CORE_STATE_API_GET_LOG, "1")
  194. options = GetLogOptions(
  195. timeout=int(req.query.get("timeout", DEFAULT_RPC_TIMEOUT)),
  196. node_id=req.query.get("node_id", None),
  197. node_ip=req.query.get("node_ip", None),
  198. media_type=req.match_info.get("media_type", "file"),
  199. # The filename to match on the server side.
  200. filename=req.query.get("filename", None),
  201. # The filename to download the log as on the client side.
  202. download_filename=req.query.get(
  203. "download_filename", DEFAULT_DOWNLOAD_FILENAME
  204. ),
  205. actor_id=req.query.get("actor_id", None),
  206. task_id=req.query.get("task_id", None),
  207. submission_id=req.query.get("submission_id", None),
  208. pid=req.query.get("pid", None),
  209. lines=req.query.get("lines", DEFAULT_LOG_LIMIT),
  210. interval=req.query.get("interval", None),
  211. suffix=req.query.get("suffix", "out"),
  212. attempt_number=req.query.get("attempt_number", 0),
  213. )
  214. filtering_ansi_code = req.query.get("filter_ansi_code", False)
  215. if isinstance(filtering_ansi_code, str):
  216. filtering_ansi_code = filtering_ansi_code.lower() == "true"
  217. logger.info(f"Streaming logs with options: {options}")
  218. logger.info(f"Filtering ANSI escape codes: {filtering_ansi_code}")
  219. async def get_actor_fn(actor_id: ActorID) -> Optional[ActorTableData]:
  220. actor_info_dict = await self.gcs_client.async_get_all_actor_info(
  221. actor_id=actor_id
  222. )
  223. if len(actor_info_dict) == 0:
  224. return None
  225. return actor_info_dict[actor_id]
  226. response = aiohttp.web.StreamResponse(
  227. headers={
  228. "Content-Disposition": (
  229. f'attachment; filename="{options.download_filename}"'
  230. )
  231. },
  232. )
  233. response.content_type = "text/plain"
  234. logs_gen = self._log_api.stream_logs(options, get_actor_fn)
  235. # Handle the first chunk separately and returns 500 if an error occurs.
  236. try:
  237. first_chunk = await logs_gen.__anext__()
  238. # Filter ANSI escape codes
  239. if filtering_ansi_code:
  240. first_chunk = ANSI_ESC_PATTERN.sub(b"", first_chunk)
  241. await response.prepare(req)
  242. await response.write(first_chunk)
  243. except StopAsyncIteration:
  244. pass
  245. except asyncio.CancelledError:
  246. # This happens when the client side closes the connection.
  247. # Force close the connection and do no-op.
  248. response.force_close()
  249. raise
  250. except Exception as e:
  251. logger.exception("Error while streaming logs")
  252. raise aiohttp.web.HTTPInternalServerError(text=str(e))
  253. try:
  254. async for logs in logs_gen:
  255. # Filter ANSI escape codes
  256. if filtering_ansi_code:
  257. logs = ANSI_ESC_PATTERN.sub(b"", logs)
  258. await response.write(logs)
  259. except Exception:
  260. logger.exception("Error while streaming logs")
  261. response.force_close()
  262. raise
  263. await response.write_eof()
  264. return response
  265. @routes.get("/api/v0/tasks/summarize")
  266. @RateLimitedModule.enforce_max_concurrent_calls
  267. async def summarize_tasks(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
  268. record_extra_usage_tag(TagKey.CORE_STATE_API_SUMMARIZE_TASKS, "1")
  269. return await handle_summary_api(self._state_api.summarize_tasks, req)
  270. @routes.get("/api/v0/actors/summarize")
  271. @RateLimitedModule.enforce_max_concurrent_calls
  272. async def summarize_actors(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
  273. record_extra_usage_tag(TagKey.CORE_STATE_API_SUMMARIZE_ACTORS, "1")
  274. return await handle_summary_api(self._state_api.summarize_actors, req)
  275. @routes.get("/api/v0/objects/summarize")
  276. @RateLimitedModule.enforce_max_concurrent_calls
  277. async def summarize_objects(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
  278. record_extra_usage_tag(TagKey.CORE_STATE_API_SUMMARIZE_OBJECTS, "1")
  279. return await handle_summary_api(self._state_api.summarize_objects, req)
  280. @routes.get("/api/v0/tasks/timeline")
  281. @RateLimitedModule.enforce_max_concurrent_calls
  282. async def tasks_timeline(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
  283. job_id = req.query.get("job_id")
  284. download = req.query.get("download")
  285. result = await self._state_api.generate_task_timeline(job_id)
  286. if download == "1":
  287. # Support download if specified.
  288. now_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
  289. content_disposition = (
  290. f'attachment; filename="timeline-{job_id}-{now_str}.json"'
  291. )
  292. headers = {"Content-Disposition": content_disposition}
  293. else:
  294. headers = None
  295. return Response(text=result, content_type="application/json", headers=headers)
  296. @routes.get("/api/v0/delay/{delay_s}")
  297. async def delayed_response(self, req: aiohttp.web.Request):
  298. """Testing only. Response after a specified delay."""
  299. delay = int(req.match_info.get("delay_s", 10))
  300. await asyncio.sleep(delay)
  301. return do_reply(
  302. status_code=HTTPStatusCode.OK,
  303. error_message="",
  304. result={},
  305. partial_failure_warning=None,
  306. )
  307. async def run(self):
  308. await SubprocessModule.run(self)
  309. gcs_channel = self.aiogrpc_gcs_channel
  310. self._state_api_data_source_client = StateDataSourceClient(
  311. gcs_channel, self.gcs_client
  312. )
  313. self._state_api = StateAPIManager(
  314. self._state_api_data_source_client,
  315. self._executor,
  316. )
  317. self._log_api = LogsManager(self._state_api_data_source_client)