import asyncio import logging from concurrent.futures import ThreadPoolExecutor from itertools import islice from typing import List, Optional import ray.dashboard.memory_utils as memory_utils from ray import NodeID from ray._common.utils import get_or_create_event_loop from ray._private.profiling import chrome_tracing_dump from ray._private.ray_constants import env_integer from ray.dashboard.state_api_utils import do_filter from ray.dashboard.utils import compose_state_message from ray.runtime_env import RuntimeEnv from ray.util.state.common import ( RAY_MAX_LIMIT_FROM_API_SERVER, ActorState, ActorSummaries, JobState, ListApiOptions, ListApiResponse, NodeState, ObjectState, ObjectSummaries, PlacementGroupState, RuntimeEnvState, StateSummary, SummaryApiOptions, SummaryApiResponse, TaskState, TaskSummaries, WorkerState, protobuf_message_to_dict, protobuf_to_task_state_dict, ) from ray.util.state.state_manager import DataSourceUnavailable, StateDataSourceClient logger = logging.getLogger(__name__) GCS_QUERY_FAILURE_WARNING = ( "Failed to query data from GCS. It is due to " "(1) GCS is unexpectedly failed. " "(2) GCS is overloaded. " "(3) There's an unexpected network issue. " "Please check the gcs_server.out log to find the root cause." ) NODE_QUERY_FAILURE_WARNING = ( "Failed to query data from {type}. " "Queried {total} {type} " "and {network_failures} {type} failed to reply. It is due to " "(1) {type} is unexpectedly failed. " "(2) {type} is overloaded. " "(3) There's an unexpected network issue. Please check the " "{log_command} to find the root cause." ) # TODO(sang): Move the class to state/state_manager.py. # TODO(sang): Remove *State and replaces with Pydantic or protobuf. # (depending on API interface standardization). class StateAPIManager: """A class to query states from data source, caches, and post-processes the entries. """ def __init__( self, state_data_source_client: StateDataSourceClient, thread_pool_executor: ThreadPoolExecutor, ): self._client = state_data_source_client self._thread_pool_executor = thread_pool_executor @property def data_source_client(self): return self._client async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse: """List all actor information from the cluster. Returns: {actor_id -> actor_data_in_dict} actor_data_in_dict's schema is in ActorState """ try: reply = await self._client.get_all_actor_info( timeout=option.timeout, filters=option.filters ) except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) def transform(reply) -> ListApiResponse: result = [] for message in reply.actor_table_data: # Note: this is different from actor_table_data_to_dict in actor_head.py # because we set preserving_proto_field_name=True so fields are # snake_case, while actor_table_data_to_dict in actor_head.py is # camelCase. # TODO(ryw): modify actor_table_data_to_dict to use snake_case, and # consolidate the code. data = protobuf_message_to_dict( message=message, fields_to_decode=[ "actor_id", "owner_id", "job_id", "node_id", "placement_group_id", ], ) result.append(data) num_after_truncation = len(result) + reply.num_filtered result = do_filter(result, option.filters, ActorState, option.detail) num_filtered = len(result) # Sort to make the output deterministic. result.sort(key=lambda entry: entry["actor_id"]) result = list(islice(result, option.limit)) return ListApiResponse( result=result, total=reply.total, num_after_truncation=num_after_truncation, num_filtered=num_filtered, ) return await get_or_create_event_loop().run_in_executor( self._thread_pool_executor, transform, reply ) async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse: """List all placement group information from the cluster. Returns: {pg_id -> pg_data_in_dict} pg_data_in_dict's schema is in PlacementGroupState """ try: reply = await self._client.get_all_placement_group_info( timeout=option.timeout ) except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) def transform(reply) -> ListApiResponse: result = [] for message in reply.placement_group_table_data: data = protobuf_message_to_dict( message=message, fields_to_decode=[ "placement_group_id", "creator_job_id", "node_id", ], ) result.append(data) num_after_truncation = len(result) result = do_filter( result, option.filters, PlacementGroupState, option.detail ) num_filtered = len(result) # Sort to make the output deterministic. result.sort(key=lambda entry: entry["placement_group_id"]) return ListApiResponse( result=list(islice(result, option.limit)), total=reply.total, num_after_truncation=num_after_truncation, num_filtered=num_filtered, ) return await get_or_create_event_loop().run_in_executor( self._thread_pool_executor, transform, reply ) async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse: """List all node information from the cluster. Returns: {node_id -> node_data_in_dict} node_data_in_dict's schema is in NodeState """ try: reply = await self._client.get_all_node_info( timeout=option.timeout, filters=option.filters ) except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) def transform(reply) -> ListApiResponse: result = [] for message in reply.node_info_list: data = protobuf_message_to_dict( message=message, fields_to_decode=["node_id"] ) data["node_ip"] = data["node_manager_address"] data["start_time_ms"] = int(data["start_time_ms"]) data["end_time_ms"] = int(data["end_time_ms"]) death_info = data.get("death_info", {}) data["state_message"] = compose_state_message( death_info.get("reason", None), death_info.get("reason_message", None), ) result.append(data) num_after_truncation = len(result) + reply.num_filtered result = do_filter(result, option.filters, NodeState, option.detail) num_filtered = len(result) # Sort to make the output deterministic. result.sort(key=lambda entry: entry["node_id"]) result = list(islice(result, option.limit)) return ListApiResponse( result=result, total=reply.total, num_after_truncation=num_after_truncation, num_filtered=num_filtered, ) return await get_or_create_event_loop().run_in_executor( self._thread_pool_executor, transform, reply ) async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse: """List all worker information from the cluster. Returns: {worker_id -> worker_data_in_dict} worker_data_in_dict's schema is in WorkerState """ try: reply = await self._client.get_all_worker_info( timeout=option.timeout, filters=option.filters, ) except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) def transform(reply) -> ListApiResponse: result = [] for message in reply.worker_table_data: data = protobuf_message_to_dict( message=message, fields_to_decode=["worker_id", "node_id"] ) data["worker_id"] = data["worker_address"]["worker_id"] data["node_id"] = data["worker_address"]["node_id"] data["ip"] = data["worker_address"]["ip_address"] data["start_time_ms"] = int(data["start_time_ms"]) data["end_time_ms"] = int(data["end_time_ms"]) data["worker_launch_time_ms"] = int(data["worker_launch_time_ms"]) data["worker_launched_time_ms"] = int(data["worker_launched_time_ms"]) result.append(data) num_after_truncation = len(result) + reply.num_filtered result = do_filter(result, option.filters, WorkerState, option.detail) num_filtered = len(result) # Sort to make the output deterministic. result.sort(key=lambda entry: entry["worker_id"]) result = list(islice(result, option.limit)) return ListApiResponse( result=result, total=reply.total, num_after_truncation=num_after_truncation, num_filtered=num_filtered, ) return await get_or_create_event_loop().run_in_executor( self._thread_pool_executor, transform, reply ) async def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse: try: reply = await self._client.get_job_info(timeout=option.timeout) except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) def transform(reply) -> ListApiResponse: result = [job.dict() for job in reply] total = len(result) result = do_filter(result, option.filters, JobState, option.detail) num_filtered = len(result) result.sort(key=lambda entry: entry["job_id"] or "") result = list(islice(result, option.limit)) return ListApiResponse( result=result, total=total, num_after_truncation=total, num_filtered=num_filtered, ) return await get_or_create_event_loop().run_in_executor( self._thread_pool_executor, transform, reply ) async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse: """List all task information from the cluster. Returns: {task_id -> task_data_in_dict} task_data_in_dict's schema is in TaskState """ try: reply = await self._client.get_all_task_info( timeout=option.timeout, filters=option.filters, exclude_driver=option.exclude_driver, ) except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) def transform(reply) -> ListApiResponse: """ Transforms from proto to dict, applies filters, sorts, and truncates. This function is executed in a separate thread. """ result = [ protobuf_to_task_state_dict(message) for message in reply.events_by_task ] # Num pre-truncation is the number of tasks returned from # source + num filtered on source num_after_truncation = len(result) num_total = len(result) + reply.num_status_task_events_dropped # Only certain filters are done on GCS, so here the filter function is still # needed to apply all the filters result = do_filter(result, option.filters, TaskState, option.detail) num_filtered = len(result) result.sort(key=lambda entry: entry["task_id"]) result = list(islice(result, option.limit)) # TODO(rickyx): we could do better with the warning logic. It's messy now. return ListApiResponse( result=result, total=num_total, num_after_truncation=num_after_truncation, num_filtered=num_filtered, ) # In the error case if reply.status.code != 0: return ListApiResponse( result=[], total=0, num_after_truncation=0, num_filtered=0, warnings=[reply.status.message], ) return await get_or_create_event_loop().run_in_executor( self._thread_pool_executor, transform, reply ) async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse: """List all object information from the cluster. Returns: {object_id -> object_data_in_dict} object_data_in_dict's schema is in ObjectState """ all_node_info_reply = await self._client.get_all_node_info( timeout=option.timeout, limit=None, filters=[("state", "=", "ALIVE")], ) tasks = [ self._client.get_object_info( node_info.node_manager_address, node_info.node_manager_port, timeout=option.timeout, ) for node_info in all_node_info_reply.node_info_list ] replies = await asyncio.gather( *tasks, return_exceptions=True, ) def transform(replies) -> ListApiResponse: unresponsive_nodes = 0 worker_stats = [] total_objects = 0 for reply in replies: if isinstance(reply, DataSourceUnavailable): unresponsive_nodes += 1 continue elif isinstance(reply, Exception): raise reply total_objects += reply.total for core_worker_stat in reply.core_workers_stats: # NOTE: Set preserving_proto_field_name=False here because # `construct_memory_table` requires a dictionary that has # modified protobuf name # (e.g., workerId instead of worker_id) as a key. worker_stats.append( protobuf_message_to_dict( message=core_worker_stat, fields_to_decode=["object_id"], preserving_proto_field_name=False, ) ) partial_failure_warning = None if len(tasks) > 0 and unresponsive_nodes > 0: warning_msg = NODE_QUERY_FAILURE_WARNING.format( type="raylet", total=len(tasks), network_failures=unresponsive_nodes, log_command="raylet.out", ) if unresponsive_nodes == len(tasks): raise DataSourceUnavailable(warning_msg) partial_failure_warning = ( f"The returned data may contain incomplete result. {warning_msg}" ) result = [] memory_table = memory_utils.construct_memory_table(worker_stats) for entry in memory_table.table: data = entry.as_dict() # `construct_memory_table` returns object_ref field which is indeed # object_id. We do transformation here. # TODO(sang): Refactor `construct_memory_table`. data["object_id"] = data["object_ref"] del data["object_ref"] data["ip"] = data["node_ip_address"] del data["node_ip_address"] data["type"] = data["type"].upper() data["task_status"] = ( "NIL" if data["task_status"] == "-" else data["task_status"] ) result.append(data) # Add callsite warnings if it is not configured. callsite_warning = [] callsite_enabled = env_integer("RAY_record_ref_creation_sites", 0) if not callsite_enabled: callsite_warning.append( "Callsite is not being recorded. " "To record callsite information for each ObjectRef created, set " "env variable RAY_record_ref_creation_sites=1 during `ray start` " "and `ray.init`." ) num_after_truncation = len(result) result = do_filter(result, option.filters, ObjectState, option.detail) num_filtered = len(result) # Sort to make the output deterministic. result.sort(key=lambda entry: entry["object_id"]) result = list(islice(result, option.limit)) return ListApiResponse( result=result, partial_failure_warning=partial_failure_warning, total=total_objects, num_after_truncation=num_after_truncation, num_filtered=num_filtered, warnings=callsite_warning, ) return await get_or_create_event_loop().run_in_executor( self._thread_pool_executor, transform, replies ) async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse: """List all runtime env information from the cluster. Returns: A list of runtime env information in the cluster. The schema of returned "dict" is equivalent to the `RuntimeEnvState` protobuf message. We don't have id -> data mapping like other API because runtime env doesn't have unique ids. """ live_node_info_reply = await self._client.get_all_node_info( timeout=option.timeout, limit=None, filters=[("state", "=", "ALIVE")], ) node_infos = [ node_info for node_info in live_node_info_reply.node_info_list if node_info.runtime_env_agent_port is not None ] tasks = [ self._client.get_runtime_envs_info( node_info.node_manager_address, node_info.runtime_env_agent_port, timeout=option.timeout, ) for node_info in node_infos ] replies = await asyncio.gather( *tasks, return_exceptions=True, ) def transform(replies) -> ListApiResponse: result = [] unresponsive_nodes = 0 total_runtime_envs = 0 for node_info, reply in zip(node_infos, replies): if isinstance(reply, DataSourceUnavailable): unresponsive_nodes += 1 continue elif isinstance(reply, Exception): raise reply total_runtime_envs += reply.total states = reply.runtime_env_states for state in states: data = protobuf_message_to_dict(message=state, fields_to_decode=[]) # Need to deserialize this field. data["runtime_env"] = RuntimeEnv.deserialize( data["runtime_env"] ).to_dict() data["node_id"] = NodeID(node_info.node_id).hex() result.append(data) partial_failure_warning = None if len(tasks) > 0 and unresponsive_nodes > 0: warning_msg = NODE_QUERY_FAILURE_WARNING.format( type="agent", total=len(tasks), network_failures=unresponsive_nodes, log_command="dashboard_agent.log", ) if unresponsive_nodes == len(tasks): raise DataSourceUnavailable(warning_msg) partial_failure_warning = ( f"The returned data may contain incomplete result. {warning_msg}" ) num_after_truncation = len(result) result = do_filter(result, option.filters, RuntimeEnvState, option.detail) num_filtered = len(result) # Sort to make the output deterministic. def sort_func(entry): # If creation time is not there yet (runtime env is failed # to be created or not created yet, they are the highest priority. # Otherwise, "bigger" creation time is coming first. if "creation_time_ms" not in entry: return float("inf") elif entry["creation_time_ms"] is None: return float("inf") else: return float(entry["creation_time_ms"]) result.sort(key=sort_func, reverse=True) result = list(islice(result, option.limit)) return ListApiResponse( result=result, partial_failure_warning=partial_failure_warning, total=total_runtime_envs, num_after_truncation=num_after_truncation, num_filtered=num_filtered, ) return await get_or_create_event_loop().run_in_executor( self._thread_pool_executor, transform, replies ) async def summarize_tasks(self, option: SummaryApiOptions) -> SummaryApiResponse: summary_by = option.summary_by or "func_name" if summary_by not in ["func_name", "lineage"]: raise ValueError('summary_by must be one of "func_name" or "lineage".') # For summary, try getting as many entries as possible to minimze data loss. result = await self.list_tasks( option=ListApiOptions( timeout=option.timeout, limit=RAY_MAX_LIMIT_FROM_API_SERVER, filters=option.filters, detail=summary_by == "lineage", ) ) if summary_by == "func_name": summary_results = TaskSummaries.to_summary_by_func_name(tasks=result.result) else: # We will need the actors info for actor tasks. actors = await self.list_actors( option=ListApiOptions( timeout=option.timeout, limit=RAY_MAX_LIMIT_FROM_API_SERVER, detail=True, ) ) summary_results = TaskSummaries.to_summary_by_lineage( tasks=result.result, actors=actors.result ) summary = StateSummary(node_id_to_summary={"cluster": summary_results}) warnings = result.warnings if ( summary_results.total_actor_scheduled + summary_results.total_actor_tasks + summary_results.total_tasks < result.num_filtered ): warnings = warnings or [] warnings.append( "There is missing data in this aggregation. " "Possibly due to task data being evicted to preserve memory." ) return SummaryApiResponse( total=result.total, result=summary, partial_failure_warning=result.partial_failure_warning, warnings=warnings, num_after_truncation=result.num_after_truncation, num_filtered=result.num_filtered, ) async def summarize_actors(self, option: SummaryApiOptions) -> SummaryApiResponse: # For summary, try getting as many entries as possible to minimze data loss. result = await self.list_actors( option=ListApiOptions( timeout=option.timeout, limit=RAY_MAX_LIMIT_FROM_API_SERVER, filters=option.filters, ) ) summary = StateSummary( node_id_to_summary={ "cluster": ActorSummaries.to_summary(actors=result.result) } ) return SummaryApiResponse( total=result.total, result=summary, partial_failure_warning=result.partial_failure_warning, warnings=result.warnings, num_after_truncation=result.num_after_truncation, num_filtered=result.num_filtered, ) async def summarize_objects(self, option: SummaryApiOptions) -> SummaryApiResponse: # For summary, try getting as many entries as possible to minimize data loss. result = await self.list_objects( option=ListApiOptions( timeout=option.timeout, limit=RAY_MAX_LIMIT_FROM_API_SERVER, filters=option.filters, ) ) summary = StateSummary( node_id_to_summary={ "cluster": ObjectSummaries.to_summary(objects=result.result) } ) return SummaryApiResponse( total=result.total, result=summary, partial_failure_warning=result.partial_failure_warning, warnings=result.warnings, num_after_truncation=result.num_after_truncation, num_filtered=result.num_filtered, ) async def generate_task_timeline(self, job_id: Optional[str]) -> List[dict]: filters = [("job_id", "=", job_id)] if job_id else None result = await self.list_tasks( option=ListApiOptions(detail=True, filters=filters, limit=10000) ) return chrome_tracing_dump(result.result)