state_manager.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. import dataclasses
  2. import inspect
  3. import json
  4. import logging
  5. from functools import wraps
  6. from typing import List, Optional, Tuple
  7. import aiohttp
  8. import grpc
  9. from grpc.aio._call import UnaryStreamCall
  10. import ray.dashboard.consts as dashboard_consts
  11. import ray.dashboard.modules.log.log_consts as log_consts
  12. from ray._common.network_utils import build_address
  13. from ray._common.utils import hex_to_binary
  14. from ray._private import ray_constants
  15. from ray._private.authentication.http_token_authentication import (
  16. get_auth_headers_if_auth_enabled,
  17. )
  18. from ray._raylet import ActorID, GcsClient, JobID, NodeID, TaskID
  19. from ray.core.generated import gcs_service_pb2_grpc
  20. from ray.core.generated.gcs_pb2 import ActorTableData, GcsNodeInfo
  21. from ray.core.generated.gcs_service_pb2 import (
  22. FilterPredicate,
  23. GetAllActorInfoReply,
  24. GetAllActorInfoRequest,
  25. GetAllNodeInfoReply,
  26. GetAllNodeInfoRequest,
  27. GetAllPlacementGroupReply,
  28. GetAllPlacementGroupRequest,
  29. GetAllWorkerInfoReply,
  30. GetAllWorkerInfoRequest,
  31. GetTaskEventsReply,
  32. GetTaskEventsRequest,
  33. )
  34. from ray.core.generated.node_manager_pb2 import (
  35. GetObjectsInfoReply,
  36. GetObjectsInfoRequest,
  37. )
  38. from ray.core.generated.node_manager_pb2_grpc import NodeManagerServiceStub
  39. from ray.core.generated.reporter_pb2 import (
  40. ListLogsReply,
  41. ListLogsRequest,
  42. StreamLogRequest,
  43. )
  44. from ray.core.generated.reporter_pb2_grpc import LogServiceStub
  45. from ray.core.generated.runtime_env_agent_pb2 import (
  46. GetRuntimeEnvsInfoReply,
  47. GetRuntimeEnvsInfoRequest,
  48. )
  49. from ray.dashboard.modules.job.common import JobInfoStorageClient
  50. from ray.dashboard.modules.job.pydantic_models import JobDetails, JobType
  51. from ray.dashboard.modules.job.utils import get_driver_jobs
  52. from ray.util.state.common import (
  53. RAY_MAX_LIMIT_FROM_DATA_SOURCE,
  54. PredicateType,
  55. SupportedFilterType,
  56. )
  57. from ray.util.state.exception import DataSourceUnavailable
  58. logger = logging.getLogger(__name__)
  59. _STATE_MANAGER_GRPC_OPTIONS = [
  60. *ray_constants.GLOBAL_GRPC_OPTIONS,
  61. ("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
  62. ("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
  63. ]
  64. def handle_grpc_network_errors(func):
  65. """Decorator to add a network handling logic.
  66. It is a helper method for `StateDataSourceClient`.
  67. The method can only be used for async methods.
  68. """
  69. assert inspect.iscoroutinefunction(func)
  70. @wraps(func)
  71. async def api_with_network_error_handler(*args, **kwargs):
  72. """Apply the network error handling logic to each APIs,
  73. such as retry or exception policies.
  74. Returns:
  75. If RPC succeeds, it returns what the original function returns.
  76. If RPC fails, it raises exceptions.
  77. Raises:
  78. DataSourceUnavailable: if the source is unavailable because it is down
  79. or there's a slow network issue causing timeout.
  80. Exception: Otherwise, the raw network exceptions (e.g., gRPC) will be
  81. raised.
  82. """
  83. try:
  84. return await func(*args, **kwargs)
  85. except grpc.aio.AioRpcError as e:
  86. if (
  87. e.code() == grpc.StatusCode.DEADLINE_EXCEEDED
  88. or e.code() == grpc.StatusCode.UNAVAILABLE
  89. ):
  90. raise DataSourceUnavailable(
  91. "Failed to query the data source. "
  92. "It is either there's a network issue, or the source is down."
  93. ) from e
  94. else:
  95. logger.exception(e)
  96. raise e
  97. return api_with_network_error_handler
  98. class StateDataSourceClient:
  99. """The client to query states from various data sources such as Raylet, GCS, Agents.
  100. Note that it doesn't directly query core workers. They are proxied through raylets.
  101. The module is not in charge of service discovery. The caller is responsible for
  102. finding services and register stubs through `register*` APIs.
  103. Non `register*` APIs
  104. - Return the protobuf directly if it succeeds to query the source.
  105. - Raises an exception if there's any network issue.
  106. - throw a ValueError if it cannot find the source.
  107. """
  108. def __init__(self, gcs_channel: grpc.aio.Channel, gcs_client: GcsClient):
  109. self.register_gcs_client(gcs_channel)
  110. self._job_client = JobInfoStorageClient(gcs_client)
  111. self._gcs_client = gcs_client
  112. self._client_session = aiohttp.ClientSession()
  113. def register_gcs_client(self, gcs_channel: grpc.aio.Channel):
  114. self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
  115. gcs_channel
  116. )
  117. self._gcs_pg_info_stub = gcs_service_pb2_grpc.PlacementGroupInfoGcsServiceStub(
  118. gcs_channel
  119. )
  120. self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
  121. gcs_channel
  122. )
  123. self._gcs_worker_info_stub = gcs_service_pb2_grpc.WorkerInfoGcsServiceStub(
  124. gcs_channel
  125. )
  126. self._gcs_task_info_stub = gcs_service_pb2_grpc.TaskInfoGcsServiceStub(
  127. gcs_channel
  128. )
  129. def get_raylet_stub(self, ip: str, port: int):
  130. from ray._private.grpc_utils import init_grpc_channel
  131. options = _STATE_MANAGER_GRPC_OPTIONS
  132. channel = init_grpc_channel(build_address(ip, port), options, asynchronous=True)
  133. return NodeManagerServiceStub(channel)
  134. async def get_log_service_stub(self, node_id: NodeID) -> LogServiceStub:
  135. """Returns None if the agent on the node is not registered in Internal KV."""
  136. from ray._private.grpc_utils import init_grpc_channel
  137. agent_addr = await self._gcs_client.async_internal_kv_get(
  138. f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id.hex()}".encode(),
  139. namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
  140. timeout=dashboard_consts.GCS_RPC_TIMEOUT_SECONDS,
  141. )
  142. if not agent_addr:
  143. return None
  144. ip, http_port, grpc_port = json.loads(agent_addr)
  145. options = ray_constants.GLOBAL_GRPC_OPTIONS
  146. channel = init_grpc_channel(
  147. build_address(ip, grpc_port), options=options, asynchronous=True
  148. )
  149. return LogServiceStub(channel)
  150. async def ip_to_node_id(self, ip: Optional[str]) -> Optional[str]:
  151. """Return the node id in hex that corresponds to the given ip.
  152. Args:
  153. ip: The ip address.
  154. Returns:
  155. None if the corresponding id doesn't exist.
  156. Node id otherwise. If None node_ip is given,
  157. it will also return None.
  158. """
  159. if not ip:
  160. return None
  161. # Uses the dashboard agent keys to find ip -> id mapping.
  162. agent_addr = await self._gcs_client.async_internal_kv_get(
  163. f"{dashboard_consts.DASHBOARD_AGENT_ADDR_IP_PREFIX}{ip}".encode(),
  164. namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
  165. timeout=dashboard_consts.GCS_RPC_TIMEOUT_SECONDS,
  166. )
  167. if not agent_addr:
  168. return None
  169. node_id, http_port, grpc_port = json.loads(agent_addr)
  170. return node_id
  171. @handle_grpc_network_errors
  172. async def get_all_actor_info(
  173. self,
  174. timeout: int = None,
  175. limit: int = RAY_MAX_LIMIT_FROM_DATA_SOURCE,
  176. filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None,
  177. ) -> Optional[GetAllActorInfoReply]:
  178. if filters is None:
  179. filters = []
  180. req_filters = GetAllActorInfoRequest.Filters()
  181. for filter in filters:
  182. key, predicate, value = filter
  183. if predicate != "=":
  184. # We only support EQUAL predicate for source side filtering.
  185. continue
  186. if key == "actor_id":
  187. req_filters.actor_id = ActorID(hex_to_binary(value)).binary()
  188. elif key == "state":
  189. # Convert to uppercase.
  190. value = value.upper()
  191. if value not in ActorTableData.ActorState.keys():
  192. raise ValueError(f"Invalid actor state for filtering: {value}")
  193. req_filters.state = ActorTableData.ActorState.Value(value)
  194. elif key == "job_id":
  195. req_filters.job_id = JobID(hex_to_binary(value)).binary()
  196. request = GetAllActorInfoRequest(limit=limit, filters=req_filters)
  197. reply = await self._gcs_actor_info_stub.GetAllActorInfo(
  198. request, timeout=timeout
  199. )
  200. return reply
  201. @handle_grpc_network_errors
  202. async def get_all_task_info(
  203. self,
  204. timeout: int = None,
  205. limit: int = RAY_MAX_LIMIT_FROM_DATA_SOURCE,
  206. filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None,
  207. exclude_driver: bool = False,
  208. ) -> Optional[GetTaskEventsReply]:
  209. if filters is None:
  210. filters = []
  211. req_filters = GetTaskEventsRequest.Filters()
  212. for filter in filters:
  213. key, predicate, value = filter
  214. filter_predicate = None
  215. if predicate == "=":
  216. filter_predicate = FilterPredicate.EQUAL
  217. elif predicate == "!=":
  218. filter_predicate = FilterPredicate.NOT_EQUAL
  219. else:
  220. # We only support EQUAL and NOT_EQUAL predicate for source side
  221. # filtering. If invalid predicates were specified, it should already be
  222. # raised when the filters arguments are parsed
  223. assert False, "Invalid predicate: " + predicate
  224. if key == "actor_id":
  225. actor_filter = GetTaskEventsRequest.Filters.ActorIdFilter()
  226. actor_filter.actor_id = ActorID(hex_to_binary(value)).binary()
  227. actor_filter.predicate = filter_predicate
  228. req_filters.actor_filters.append(actor_filter)
  229. elif key == "job_id":
  230. job_filter = GetTaskEventsRequest.Filters.JobIdFilter()
  231. job_filter.job_id = JobID(hex_to_binary(value)).binary()
  232. job_filter.predicate = filter_predicate
  233. req_filters.job_filters.append(job_filter)
  234. elif key == "task_id":
  235. task_filter = GetTaskEventsRequest.Filters.TaskIdFilter()
  236. task_filter.task_id = TaskID(hex_to_binary(value)).binary()
  237. task_filter.predicate = filter_predicate
  238. req_filters.task_filters.append(task_filter)
  239. elif key == "name":
  240. task_name_filter = GetTaskEventsRequest.Filters.TaskNameFilter()
  241. task_name_filter.task_name = value
  242. task_name_filter.predicate = filter_predicate
  243. req_filters.task_name_filters.append(task_name_filter)
  244. elif key == "state":
  245. state_filter = GetTaskEventsRequest.Filters.StateFilter()
  246. state_filter.state = value
  247. state_filter.predicate = filter_predicate
  248. req_filters.state_filters.append(state_filter)
  249. else:
  250. continue
  251. req_filters.exclude_driver = exclude_driver
  252. request = GetTaskEventsRequest(limit=limit, filters=req_filters)
  253. reply = await self._gcs_task_info_stub.GetTaskEvents(request, timeout=timeout)
  254. return reply
  255. @handle_grpc_network_errors
  256. async def get_all_placement_group_info(
  257. self, timeout: int = None, limit: int = RAY_MAX_LIMIT_FROM_DATA_SOURCE
  258. ) -> Optional[GetAllPlacementGroupReply]:
  259. request = GetAllPlacementGroupRequest(limit=limit)
  260. reply = await self._gcs_pg_info_stub.GetAllPlacementGroup(
  261. request, timeout=timeout
  262. )
  263. return reply
  264. @handle_grpc_network_errors
  265. async def get_all_node_info(
  266. self,
  267. timeout: int = None,
  268. limit: int = RAY_MAX_LIMIT_FROM_DATA_SOURCE,
  269. filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None,
  270. ) -> Optional[GetAllNodeInfoReply]:
  271. # TODO(ryw): move this to GcsClient.async_get_all_node_info, i.e.
  272. # InnerGcsClient.async_get_all_node_info
  273. if filters is None:
  274. filters = []
  275. node_selectors = []
  276. state_filter = None
  277. for filter in filters:
  278. key, predicate, value = filter
  279. if predicate != "=":
  280. # We only support EQUAL predicate for source side filtering.
  281. continue
  282. if key == "node_id":
  283. node_selector = GetAllNodeInfoRequest.NodeSelector()
  284. node_selector.node_id = NodeID(hex_to_binary(value)).binary()
  285. node_selectors.append(node_selector)
  286. elif key == "state":
  287. value = value.upper()
  288. if value not in GcsNodeInfo.GcsNodeState.keys():
  289. raise ValueError(f"Invalid node state for filtering: {value}")
  290. state_filter = GcsNodeInfo.GcsNodeState.Value(value)
  291. elif key == "node_name":
  292. node_selector = GetAllNodeInfoRequest.NodeSelector()
  293. node_selector.node_name = value
  294. node_selectors.append(node_selector)
  295. else:
  296. continue
  297. request = GetAllNodeInfoRequest(
  298. limit=limit, node_selectors=node_selectors, state_filter=state_filter
  299. )
  300. reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=timeout)
  301. return reply
  302. @handle_grpc_network_errors
  303. async def get_all_worker_info(
  304. self,
  305. timeout: int = None,
  306. limit: int = RAY_MAX_LIMIT_FROM_DATA_SOURCE,
  307. filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None,
  308. ) -> Optional[GetAllWorkerInfoReply]:
  309. if filters is None:
  310. filters = []
  311. req_filters = GetAllWorkerInfoRequest.Filters()
  312. for filter in filters:
  313. key, predicate, value = filter
  314. # Special treatments for the Ray Debugger.
  315. if (
  316. key == "num_paused_threads"
  317. and predicate in ("!=", ">")
  318. and value == "0"
  319. ):
  320. req_filters.exist_paused_threads = True
  321. continue
  322. if key == "is_alive" and predicate == "=" and value == "True":
  323. req_filters.is_alive = True
  324. continue
  325. else:
  326. continue
  327. request = GetAllWorkerInfoRequest(limit=limit, filters=req_filters)
  328. reply = await self._gcs_worker_info_stub.GetAllWorkerInfo(
  329. request, timeout=timeout
  330. )
  331. return reply
  332. # TODO(rickyx):
  333. # This is currently mirroring dashboard/modules/job/job_head.py::list_jobs
  334. # We should eventually unify the logic.
  335. async def get_job_info(self, timeout: int = None) -> List[JobDetails]:
  336. # Cannot use @handle_grpc_network_errors because async def is not supported yet.
  337. driver_jobs, submission_job_drivers = await get_driver_jobs(
  338. self._gcs_client, timeout=timeout
  339. )
  340. submission_jobs = await self._job_client.get_all_jobs(timeout=timeout)
  341. submission_jobs = [
  342. JobDetails(
  343. **dataclasses.asdict(job),
  344. submission_id=submission_id,
  345. job_id=submission_job_drivers.get(submission_id).id
  346. if submission_id in submission_job_drivers
  347. else None,
  348. driver_info=submission_job_drivers.get(submission_id),
  349. type=JobType.SUBMISSION,
  350. )
  351. for submission_id, job in submission_jobs.items()
  352. ]
  353. return list(driver_jobs.values()) + submission_jobs
  354. @handle_grpc_network_errors
  355. async def get_object_info(
  356. self,
  357. node_manager_ip: str,
  358. node_manager_port: int,
  359. timeout: int = None,
  360. limit: int = RAY_MAX_LIMIT_FROM_DATA_SOURCE,
  361. ) -> Optional[GetObjectsInfoReply]:
  362. stub = self.get_raylet_stub(node_manager_ip, node_manager_port)
  363. reply = await stub.GetObjectsInfo(
  364. GetObjectsInfoRequest(limit=limit),
  365. timeout=timeout,
  366. )
  367. return reply
  368. async def get_runtime_envs_info(
  369. self,
  370. node_ip: str,
  371. runtime_env_agent_port: int,
  372. timeout: int = None,
  373. limit: int = RAY_MAX_LIMIT_FROM_DATA_SOURCE,
  374. ) -> Optional[GetRuntimeEnvsInfoReply]:
  375. if not node_ip or not runtime_env_agent_port:
  376. raise ValueError(
  377. f"Expected non empty node ip and runtime env agent port, got {node_ip} and {runtime_env_agent_port}."
  378. )
  379. timeout = aiohttp.ClientTimeout(total=timeout)
  380. url = f"http://{build_address(node_ip, runtime_env_agent_port)}/get_runtime_envs_info"
  381. request = GetRuntimeEnvsInfoRequest(limit=limit)
  382. data = request.SerializeToString()
  383. headers = get_auth_headers_if_auth_enabled({})
  384. async with self._client_session.post(
  385. url, data=data, timeout=timeout, headers=headers
  386. ) as resp:
  387. if resp.status >= 200 and resp.status < 300:
  388. response_data = await resp.read()
  389. reply = GetRuntimeEnvsInfoReply()
  390. reply.ParseFromString(response_data)
  391. return reply
  392. else:
  393. raise DataSourceUnavailable(
  394. "Failed to query the runtime env agent for get_runtime_envs_info. "
  395. "Either there's a network issue, or the source is down. "
  396. f"Response is {resp.status}, reason {resp.reason}"
  397. )
  398. @handle_grpc_network_errors
  399. async def list_logs(
  400. self, node_id: str, glob_filter: str, timeout: int = None
  401. ) -> ListLogsReply:
  402. stub = await self.get_log_service_stub(NodeID.from_hex(node_id))
  403. if not stub:
  404. raise ValueError(f"Agent for node id: {node_id} doesn't exist.")
  405. return await stub.ListLogs(
  406. ListLogsRequest(glob_filter=glob_filter), timeout=timeout
  407. )
  408. @handle_grpc_network_errors
  409. async def stream_log(
  410. self,
  411. node_id: str,
  412. log_file_name: str,
  413. keep_alive: bool,
  414. lines: int,
  415. interval: Optional[float],
  416. timeout: int,
  417. start_offset: Optional[int] = None,
  418. end_offset: Optional[int] = None,
  419. ) -> UnaryStreamCall:
  420. stub = await self.get_log_service_stub(NodeID.from_hex(node_id))
  421. if not stub:
  422. raise ValueError(f"Agent for node id: {node_id} doesn't exist.")
  423. stream = stub.StreamLog(
  424. StreamLogRequest(
  425. keep_alive=keep_alive,
  426. log_file_name=log_file_name,
  427. lines=lines,
  428. interval=interval,
  429. start_offset=start_offset,
  430. end_offset=end_offset,
  431. ),
  432. timeout=timeout,
  433. )
  434. metadata = await stream.initial_metadata()
  435. if metadata.get(log_consts.LOG_GRPC_ERROR) is not None:
  436. raise ValueError(metadata.get(log_consts.LOG_GRPC_ERROR))
  437. return stream