| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671 |
- import asyncio
- import logging
- from concurrent.futures import ThreadPoolExecutor
- from itertools import islice
- from typing import List, Optional
- import ray.dashboard.memory_utils as memory_utils
- from ray import NodeID
- from ray._common.utils import get_or_create_event_loop
- from ray._private.profiling import chrome_tracing_dump
- from ray._private.ray_constants import env_integer
- from ray.dashboard.state_api_utils import do_filter
- from ray.dashboard.utils import compose_state_message
- from ray.runtime_env import RuntimeEnv
- from ray.util.state.common import (
- RAY_MAX_LIMIT_FROM_API_SERVER,
- ActorState,
- ActorSummaries,
- JobState,
- ListApiOptions,
- ListApiResponse,
- NodeState,
- ObjectState,
- ObjectSummaries,
- PlacementGroupState,
- RuntimeEnvState,
- StateSummary,
- SummaryApiOptions,
- SummaryApiResponse,
- TaskState,
- TaskSummaries,
- WorkerState,
- protobuf_message_to_dict,
- protobuf_to_task_state_dict,
- )
- from ray.util.state.state_manager import DataSourceUnavailable, StateDataSourceClient
- logger = logging.getLogger(__name__)
- GCS_QUERY_FAILURE_WARNING = (
- "Failed to query data from GCS. It is due to "
- "(1) GCS is unexpectedly failed. "
- "(2) GCS is overloaded. "
- "(3) There's an unexpected network issue. "
- "Please check the gcs_server.out log to find the root cause."
- )
- NODE_QUERY_FAILURE_WARNING = (
- "Failed to query data from {type}. "
- "Queried {total} {type} "
- "and {network_failures} {type} failed to reply. It is due to "
- "(1) {type} is unexpectedly failed. "
- "(2) {type} is overloaded. "
- "(3) There's an unexpected network issue. Please check the "
- "{log_command} to find the root cause."
- )
- # TODO(sang): Move the class to state/state_manager.py.
- # TODO(sang): Remove *State and replaces with Pydantic or protobuf.
- # (depending on API interface standardization).
- class StateAPIManager:
- """A class to query states from data source, caches, and post-processes
- the entries.
- """
- def __init__(
- self,
- state_data_source_client: StateDataSourceClient,
- thread_pool_executor: ThreadPoolExecutor,
- ):
- self._client = state_data_source_client
- self._thread_pool_executor = thread_pool_executor
- @property
- def data_source_client(self):
- return self._client
- async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse:
- """List all actor information from the cluster.
- Returns:
- {actor_id -> actor_data_in_dict}
- actor_data_in_dict's schema is in ActorState
- """
- try:
- reply = await self._client.get_all_actor_info(
- timeout=option.timeout, filters=option.filters
- )
- except DataSourceUnavailable:
- raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
- def transform(reply) -> ListApiResponse:
- result = []
- for message in reply.actor_table_data:
- # Note: this is different from actor_table_data_to_dict in actor_head.py
- # because we set preserving_proto_field_name=True so fields are
- # snake_case, while actor_table_data_to_dict in actor_head.py is
- # camelCase.
- # TODO(ryw): modify actor_table_data_to_dict to use snake_case, and
- # consolidate the code.
- data = protobuf_message_to_dict(
- message=message,
- fields_to_decode=[
- "actor_id",
- "owner_id",
- "job_id",
- "node_id",
- "placement_group_id",
- ],
- )
- result.append(data)
- num_after_truncation = len(result) + reply.num_filtered
- result = do_filter(result, option.filters, ActorState, option.detail)
- num_filtered = len(result)
- # Sort to make the output deterministic.
- result.sort(key=lambda entry: entry["actor_id"])
- result = list(islice(result, option.limit))
- return ListApiResponse(
- result=result,
- total=reply.total,
- num_after_truncation=num_after_truncation,
- num_filtered=num_filtered,
- )
- return await get_or_create_event_loop().run_in_executor(
- self._thread_pool_executor, transform, reply
- )
- async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse:
- """List all placement group information from the cluster.
- Returns:
- {pg_id -> pg_data_in_dict}
- pg_data_in_dict's schema is in PlacementGroupState
- """
- try:
- reply = await self._client.get_all_placement_group_info(
- timeout=option.timeout
- )
- except DataSourceUnavailable:
- raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
- def transform(reply) -> ListApiResponse:
- result = []
- for message in reply.placement_group_table_data:
- data = protobuf_message_to_dict(
- message=message,
- fields_to_decode=[
- "placement_group_id",
- "creator_job_id",
- "node_id",
- ],
- )
- result.append(data)
- num_after_truncation = len(result)
- result = do_filter(
- result, option.filters, PlacementGroupState, option.detail
- )
- num_filtered = len(result)
- # Sort to make the output deterministic.
- result.sort(key=lambda entry: entry["placement_group_id"])
- return ListApiResponse(
- result=list(islice(result, option.limit)),
- total=reply.total,
- num_after_truncation=num_after_truncation,
- num_filtered=num_filtered,
- )
- return await get_or_create_event_loop().run_in_executor(
- self._thread_pool_executor, transform, reply
- )
- async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse:
- """List all node information from the cluster.
- Returns:
- {node_id -> node_data_in_dict}
- node_data_in_dict's schema is in NodeState
- """
- try:
- reply = await self._client.get_all_node_info(
- timeout=option.timeout, filters=option.filters
- )
- except DataSourceUnavailable:
- raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
- def transform(reply) -> ListApiResponse:
- result = []
- for message in reply.node_info_list:
- data = protobuf_message_to_dict(
- message=message, fields_to_decode=["node_id"]
- )
- data["node_ip"] = data["node_manager_address"]
- data["start_time_ms"] = int(data["start_time_ms"])
- data["end_time_ms"] = int(data["end_time_ms"])
- death_info = data.get("death_info", {})
- data["state_message"] = compose_state_message(
- death_info.get("reason", None),
- death_info.get("reason_message", None),
- )
- result.append(data)
- num_after_truncation = len(result) + reply.num_filtered
- result = do_filter(result, option.filters, NodeState, option.detail)
- num_filtered = len(result)
- # Sort to make the output deterministic.
- result.sort(key=lambda entry: entry["node_id"])
- result = list(islice(result, option.limit))
- return ListApiResponse(
- result=result,
- total=reply.total,
- num_after_truncation=num_after_truncation,
- num_filtered=num_filtered,
- )
- return await get_or_create_event_loop().run_in_executor(
- self._thread_pool_executor, transform, reply
- )
- async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse:
- """List all worker information from the cluster.
- Returns:
- {worker_id -> worker_data_in_dict}
- worker_data_in_dict's schema is in WorkerState
- """
- try:
- reply = await self._client.get_all_worker_info(
- timeout=option.timeout,
- filters=option.filters,
- )
- except DataSourceUnavailable:
- raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
- def transform(reply) -> ListApiResponse:
- result = []
- for message in reply.worker_table_data:
- data = protobuf_message_to_dict(
- message=message, fields_to_decode=["worker_id", "node_id"]
- )
- data["worker_id"] = data["worker_address"]["worker_id"]
- data["node_id"] = data["worker_address"]["node_id"]
- data["ip"] = data["worker_address"]["ip_address"]
- data["start_time_ms"] = int(data["start_time_ms"])
- data["end_time_ms"] = int(data["end_time_ms"])
- data["worker_launch_time_ms"] = int(data["worker_launch_time_ms"])
- data["worker_launched_time_ms"] = int(data["worker_launched_time_ms"])
- result.append(data)
- num_after_truncation = len(result) + reply.num_filtered
- result = do_filter(result, option.filters, WorkerState, option.detail)
- num_filtered = len(result)
- # Sort to make the output deterministic.
- result.sort(key=lambda entry: entry["worker_id"])
- result = list(islice(result, option.limit))
- return ListApiResponse(
- result=result,
- total=reply.total,
- num_after_truncation=num_after_truncation,
- num_filtered=num_filtered,
- )
- return await get_or_create_event_loop().run_in_executor(
- self._thread_pool_executor, transform, reply
- )
- async def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse:
- try:
- reply = await self._client.get_job_info(timeout=option.timeout)
- except DataSourceUnavailable:
- raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
- def transform(reply) -> ListApiResponse:
- result = [job.dict() for job in reply]
- total = len(result)
- result = do_filter(result, option.filters, JobState, option.detail)
- num_filtered = len(result)
- result.sort(key=lambda entry: entry["job_id"] or "")
- result = list(islice(result, option.limit))
- return ListApiResponse(
- result=result,
- total=total,
- num_after_truncation=total,
- num_filtered=num_filtered,
- )
- return await get_or_create_event_loop().run_in_executor(
- self._thread_pool_executor, transform, reply
- )
- async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse:
- """List all task information from the cluster.
- Returns:
- {task_id -> task_data_in_dict}
- task_data_in_dict's schema is in TaskState
- """
- try:
- reply = await self._client.get_all_task_info(
- timeout=option.timeout,
- filters=option.filters,
- exclude_driver=option.exclude_driver,
- )
- except DataSourceUnavailable:
- raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
- def transform(reply) -> ListApiResponse:
- """
- Transforms from proto to dict, applies filters, sorts, and truncates.
- This function is executed in a separate thread.
- """
- result = [
- protobuf_to_task_state_dict(message) for message in reply.events_by_task
- ]
- # Num pre-truncation is the number of tasks returned from
- # source + num filtered on source
- num_after_truncation = len(result)
- num_total = len(result) + reply.num_status_task_events_dropped
- # Only certain filters are done on GCS, so here the filter function is still
- # needed to apply all the filters
- result = do_filter(result, option.filters, TaskState, option.detail)
- num_filtered = len(result)
- result.sort(key=lambda entry: entry["task_id"])
- result = list(islice(result, option.limit))
- # TODO(rickyx): we could do better with the warning logic. It's messy now.
- return ListApiResponse(
- result=result,
- total=num_total,
- num_after_truncation=num_after_truncation,
- num_filtered=num_filtered,
- )
- # In the error case
- if reply.status.code != 0:
- return ListApiResponse(
- result=[],
- total=0,
- num_after_truncation=0,
- num_filtered=0,
- warnings=[reply.status.message],
- )
- return await get_or_create_event_loop().run_in_executor(
- self._thread_pool_executor, transform, reply
- )
- async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse:
- """List all object information from the cluster.
- Returns:
- {object_id -> object_data_in_dict}
- object_data_in_dict's schema is in ObjectState
- """
- all_node_info_reply = await self._client.get_all_node_info(
- timeout=option.timeout,
- limit=None,
- filters=[("state", "=", "ALIVE")],
- )
- tasks = [
- self._client.get_object_info(
- node_info.node_manager_address,
- node_info.node_manager_port,
- timeout=option.timeout,
- )
- for node_info in all_node_info_reply.node_info_list
- ]
- replies = await asyncio.gather(
- *tasks,
- return_exceptions=True,
- )
- def transform(replies) -> ListApiResponse:
- unresponsive_nodes = 0
- worker_stats = []
- total_objects = 0
- for reply in replies:
- if isinstance(reply, DataSourceUnavailable):
- unresponsive_nodes += 1
- continue
- elif isinstance(reply, Exception):
- raise reply
- total_objects += reply.total
- for core_worker_stat in reply.core_workers_stats:
- # NOTE: Set preserving_proto_field_name=False here because
- # `construct_memory_table` requires a dictionary that has
- # modified protobuf name
- # (e.g., workerId instead of worker_id) as a key.
- worker_stats.append(
- protobuf_message_to_dict(
- message=core_worker_stat,
- fields_to_decode=["object_id"],
- preserving_proto_field_name=False,
- )
- )
- partial_failure_warning = None
- if len(tasks) > 0 and unresponsive_nodes > 0:
- warning_msg = NODE_QUERY_FAILURE_WARNING.format(
- type="raylet",
- total=len(tasks),
- network_failures=unresponsive_nodes,
- log_command="raylet.out",
- )
- if unresponsive_nodes == len(tasks):
- raise DataSourceUnavailable(warning_msg)
- partial_failure_warning = (
- f"The returned data may contain incomplete result. {warning_msg}"
- )
- result = []
- memory_table = memory_utils.construct_memory_table(worker_stats)
- for entry in memory_table.table:
- data = entry.as_dict()
- # `construct_memory_table` returns object_ref field which is indeed
- # object_id. We do transformation here.
- # TODO(sang): Refactor `construct_memory_table`.
- data["object_id"] = data["object_ref"]
- del data["object_ref"]
- data["ip"] = data["node_ip_address"]
- del data["node_ip_address"]
- data["type"] = data["type"].upper()
- data["task_status"] = (
- "NIL" if data["task_status"] == "-" else data["task_status"]
- )
- result.append(data)
- # Add callsite warnings if it is not configured.
- callsite_warning = []
- callsite_enabled = env_integer("RAY_record_ref_creation_sites", 0)
- if not callsite_enabled:
- callsite_warning.append(
- "Callsite is not being recorded. "
- "To record callsite information for each ObjectRef created, set "
- "env variable RAY_record_ref_creation_sites=1 during `ray start` "
- "and `ray.init`."
- )
- num_after_truncation = len(result)
- result = do_filter(result, option.filters, ObjectState, option.detail)
- num_filtered = len(result)
- # Sort to make the output deterministic.
- result.sort(key=lambda entry: entry["object_id"])
- result = list(islice(result, option.limit))
- return ListApiResponse(
- result=result,
- partial_failure_warning=partial_failure_warning,
- total=total_objects,
- num_after_truncation=num_after_truncation,
- num_filtered=num_filtered,
- warnings=callsite_warning,
- )
- return await get_or_create_event_loop().run_in_executor(
- self._thread_pool_executor, transform, replies
- )
- async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse:
- """List all runtime env information from the cluster.
- Returns:
- A list of runtime env information in the cluster.
- The schema of returned "dict" is equivalent to the
- `RuntimeEnvState` protobuf message.
- We don't have id -> data mapping like other API because runtime env
- doesn't have unique ids.
- """
- live_node_info_reply = await self._client.get_all_node_info(
- timeout=option.timeout,
- limit=None,
- filters=[("state", "=", "ALIVE")],
- )
- node_infos = [
- node_info
- for node_info in live_node_info_reply.node_info_list
- if node_info.runtime_env_agent_port is not None
- ]
- tasks = [
- self._client.get_runtime_envs_info(
- node_info.node_manager_address,
- node_info.runtime_env_agent_port,
- timeout=option.timeout,
- )
- for node_info in node_infos
- ]
- replies = await asyncio.gather(
- *tasks,
- return_exceptions=True,
- )
- def transform(replies) -> ListApiResponse:
- result = []
- unresponsive_nodes = 0
- total_runtime_envs = 0
- for node_info, reply in zip(node_infos, replies):
- if isinstance(reply, DataSourceUnavailable):
- unresponsive_nodes += 1
- continue
- elif isinstance(reply, Exception):
- raise reply
- total_runtime_envs += reply.total
- states = reply.runtime_env_states
- for state in states:
- data = protobuf_message_to_dict(message=state, fields_to_decode=[])
- # Need to deserialize this field.
- data["runtime_env"] = RuntimeEnv.deserialize(
- data["runtime_env"]
- ).to_dict()
- data["node_id"] = NodeID(node_info.node_id).hex()
- result.append(data)
- partial_failure_warning = None
- if len(tasks) > 0 and unresponsive_nodes > 0:
- warning_msg = NODE_QUERY_FAILURE_WARNING.format(
- type="agent",
- total=len(tasks),
- network_failures=unresponsive_nodes,
- log_command="dashboard_agent.log",
- )
- if unresponsive_nodes == len(tasks):
- raise DataSourceUnavailable(warning_msg)
- partial_failure_warning = (
- f"The returned data may contain incomplete result. {warning_msg}"
- )
- num_after_truncation = len(result)
- result = do_filter(result, option.filters, RuntimeEnvState, option.detail)
- num_filtered = len(result)
- # Sort to make the output deterministic.
- def sort_func(entry):
- # If creation time is not there yet (runtime env is failed
- # to be created or not created yet, they are the highest priority.
- # Otherwise, "bigger" creation time is coming first.
- if "creation_time_ms" not in entry:
- return float("inf")
- elif entry["creation_time_ms"] is None:
- return float("inf")
- else:
- return float(entry["creation_time_ms"])
- result.sort(key=sort_func, reverse=True)
- result = list(islice(result, option.limit))
- return ListApiResponse(
- result=result,
- partial_failure_warning=partial_failure_warning,
- total=total_runtime_envs,
- num_after_truncation=num_after_truncation,
- num_filtered=num_filtered,
- )
- return await get_or_create_event_loop().run_in_executor(
- self._thread_pool_executor, transform, replies
- )
- async def summarize_tasks(self, option: SummaryApiOptions) -> SummaryApiResponse:
- summary_by = option.summary_by or "func_name"
- if summary_by not in ["func_name", "lineage"]:
- raise ValueError('summary_by must be one of "func_name" or "lineage".')
- # For summary, try getting as many entries as possible to minimze data loss.
- result = await self.list_tasks(
- option=ListApiOptions(
- timeout=option.timeout,
- limit=RAY_MAX_LIMIT_FROM_API_SERVER,
- filters=option.filters,
- detail=summary_by == "lineage",
- )
- )
- if summary_by == "func_name":
- summary_results = TaskSummaries.to_summary_by_func_name(tasks=result.result)
- else:
- # We will need the actors info for actor tasks.
- actors = await self.list_actors(
- option=ListApiOptions(
- timeout=option.timeout,
- limit=RAY_MAX_LIMIT_FROM_API_SERVER,
- detail=True,
- )
- )
- summary_results = TaskSummaries.to_summary_by_lineage(
- tasks=result.result, actors=actors.result
- )
- summary = StateSummary(node_id_to_summary={"cluster": summary_results})
- warnings = result.warnings
- if (
- summary_results.total_actor_scheduled
- + summary_results.total_actor_tasks
- + summary_results.total_tasks
- < result.num_filtered
- ):
- warnings = warnings or []
- warnings.append(
- "There is missing data in this aggregation. "
- "Possibly due to task data being evicted to preserve memory."
- )
- return SummaryApiResponse(
- total=result.total,
- result=summary,
- partial_failure_warning=result.partial_failure_warning,
- warnings=warnings,
- num_after_truncation=result.num_after_truncation,
- num_filtered=result.num_filtered,
- )
- async def summarize_actors(self, option: SummaryApiOptions) -> SummaryApiResponse:
- # For summary, try getting as many entries as possible to minimze data loss.
- result = await self.list_actors(
- option=ListApiOptions(
- timeout=option.timeout,
- limit=RAY_MAX_LIMIT_FROM_API_SERVER,
- filters=option.filters,
- )
- )
- summary = StateSummary(
- node_id_to_summary={
- "cluster": ActorSummaries.to_summary(actors=result.result)
- }
- )
- return SummaryApiResponse(
- total=result.total,
- result=summary,
- partial_failure_warning=result.partial_failure_warning,
- warnings=result.warnings,
- num_after_truncation=result.num_after_truncation,
- num_filtered=result.num_filtered,
- )
- async def summarize_objects(self, option: SummaryApiOptions) -> SummaryApiResponse:
- # For summary, try getting as many entries as possible to minimize data loss.
- result = await self.list_objects(
- option=ListApiOptions(
- timeout=option.timeout,
- limit=RAY_MAX_LIMIT_FROM_API_SERVER,
- filters=option.filters,
- )
- )
- summary = StateSummary(
- node_id_to_summary={
- "cluster": ObjectSummaries.to_summary(objects=result.result)
- }
- )
- return SummaryApiResponse(
- total=result.total,
- result=summary,
- partial_failure_warning=result.partial_failure_warning,
- warnings=result.warnings,
- num_after_truncation=result.num_after_truncation,
- num_filtered=result.num_filtered,
- )
- async def generate_task_timeline(self, job_id: Optional[str]) -> List[dict]:
- filters = [("job_id", "=", job_id)] if job_id else None
- result = await self.list_tasks(
- option=ListApiOptions(detail=True, filters=filters, limit=10000)
- )
- return chrome_tracing_dump(result.result)
|