import asyncio import logging import os import pickle import time from collections import defaultdict from typing import ( Any, Dict, Iterable, List, Optional, Set, Tuple, Union, ) import ray from ray._common.network_utils import build_address from ray._common.utils import run_background_task from ray._raylet import GcsClient from ray.actor import ActorHandle from ray.serve._private.application_state import ApplicationStateManager, StatusOverview from ray.serve._private.autoscaling_state import AutoscalingStateManager from ray.serve._private.common import ( DeploymentID, DeploymentSnapshot, HandleMetricReport, NodeId, ReplicaMetricReport, RequestProtocol, RequestRoutingInfo, RunningReplicaInfo, TargetCapacityDirection, ) from ray.serve._private.config import DeploymentConfig from ray.serve._private.constants import ( CONTROL_LOOP_INTERVAL_S, RAY_SERVE_CONTROLLER_CALLBACK_IMPORT_PATH, RAY_SERVE_ENABLE_DIRECT_INGRESS, RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS, RECOVERING_LONG_POLL_BROADCAST_TIMEOUT_S, SERVE_CONTROLLER_NAME, SERVE_DEFAULT_APP_NAME, SERVE_LOGGER_NAME, SERVE_NAMESPACE, ) from ray.serve._private.controller_health_metrics_tracker import ( ControllerHealthMetricsTracker, ) from ray.serve._private.default_impl import create_cluster_node_info_cache from ray.serve._private.deployment_info import DeploymentInfo from ray.serve._private.deployment_state import ( DeploymentReplica, DeploymentStateManager, ) from ray.serve._private.endpoint_state import EndpointState from ray.serve._private.exceptions import ExternalScalerDisabledError from ray.serve._private.grpc_util import set_proxy_default_grpc_options from ray.serve._private.http_util import ( configure_http_options_with_defaults, ) from ray.serve._private.logging_utils import ( configure_autoscaling_snapshot_logger, configure_component_logger, configure_component_memory_profiler, get_component_logger_file_path, ) from ray.serve._private.long_poll import LongPollHost, LongPollNamespace from ray.serve._private.node_port_manager import NodePortManager from ray.serve._private.proxy_state import ProxyStateManager from ray.serve._private.storage.kv_store import RayInternalKVStore from ray.serve._private.usage import ServeUsageTag from ray.serve._private.utils import ( call_function_from_import_path, get_all_live_placement_group_names, get_head_node_id, is_grpc_enabled, ) from ray.serve.config import DeploymentMode, HTTPOptions, ProxyLocation, gRPCOptions from ray.serve.generated.serve_pb2 import ( ActorNameList, ApplicationArgs, DeploymentArgs, DeploymentRoute, EndpointInfo as EndpointInfoProto, EndpointSet, ) from ray.serve.schema import ( APIType, ApplicationDetails, DeploymentDetails, HTTPOptionsSchema, LoggingConfig, ProxyDetails, ReplicaDetails, ReplicaRank, ServeActorDetails, ServeApplicationSchema, ServeDeploySchema, ServeInstanceDetails, Target, TargetGroup, gRPCOptionsSchema, ) from ray.util import metrics logger = logging.getLogger(SERVE_LOGGER_NAME) # Used for testing purposes only. If this is set, the controller will crash # after writing each checkpoint with the specified probability. _CRASH_AFTER_CHECKPOINT_PROBABILITY = 0 CONFIG_CHECKPOINT_KEY = "serve-app-config-checkpoint" LOGGING_CONFIG_CHECKPOINT_KEY = "serve-logging-config-checkpoint" class ServeController: """Responsible for managing the state of the serving system. The controller implements fault tolerance by persisting its state in a new checkpoint each time a state change is made. If the actor crashes, the latest checkpoint is loaded and the state is recovered. Checkpoints are written/read using a provided KV-store interface. All hard state in the system is maintained by this actor and persisted via these checkpoints. Soft state required by other components is fetched by those actors from this actor on startup and updates are pushed out from this actor. All other actors started by the controller are named, detached actors so they will not fate share with the controller if it crashes. The following guarantees are provided for state-changing calls to the controller: - If the call succeeds, the change was made and will be reflected in the system even if the controller or other actors die unexpectedly. - If the call fails, the change may have been made but isn't guaranteed to have been. The client should retry in this case. Note that this requires all implementations here to be idempotent. """ async def __init__( self, *, http_options: HTTPOptions, global_logging_config: LoggingConfig, grpc_options: Optional[gRPCOptions] = None, ): self._controller_node_id = ray.get_runtime_context().get_node_id() assert ( self._controller_node_id == get_head_node_id() ), "Controller must be on the head node." self.ray_worker_namespace = ray.get_runtime_context().namespace self.gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address) kv_store_namespace = f"ray-serve-{self.ray_worker_namespace}" self.kv_store = RayInternalKVStore(kv_store_namespace, self.gcs_client) self.long_poll_host = LongPollHost() self.done_recovering_event = asyncio.Event() # Autoscaling snapshot logger self._autoscaling_logger: Optional[logging.Logger] = None # Try to read config from checkpoint # logging config from checkpoint take precedence over the one passed in # the constructor. self.global_logging_config = None log_config_checkpoint = self.kv_store.get(LOGGING_CONFIG_CHECKPOINT_KEY) if log_config_checkpoint is not None: global_logging_config = pickle.loads(log_config_checkpoint) self.reconfigure_global_logging_config(global_logging_config) configure_component_memory_profiler( component_name="controller", component_id=str(os.getpid()) ) if RAY_SERVE_CONTROLLER_CALLBACK_IMPORT_PATH: logger.info( "Calling user-provided callback from import path " f"{RAY_SERVE_CONTROLLER_CALLBACK_IMPORT_PATH}." ) call_function_from_import_path(RAY_SERVE_CONTROLLER_CALLBACK_IMPORT_PATH) # Used to read/write checkpoints. self.cluster_node_info_cache = create_cluster_node_info_cache(self.gcs_client) self.cluster_node_info_cache.update() self._direct_ingress_enabled = RAY_SERVE_ENABLE_DIRECT_INGRESS if self._direct_ingress_enabled: logger.info( "Direct ingress is enabled in ServeController, enabling proxy " "on head node only." ) http_options.location = DeploymentMode.HeadOnly # Configure proxy default HTTP and gRPC options. self.proxy_state_manager = ProxyStateManager( http_options=configure_http_options_with_defaults(http_options), head_node_id=self._controller_node_id, cluster_node_info_cache=self.cluster_node_info_cache, logging_config=self.global_logging_config, grpc_options=set_proxy_default_grpc_options(grpc_options), ) # We modify the HTTP and gRPC options above, so delete them to avoid del http_options, grpc_options self.endpoint_state = EndpointState(self.kv_store, self.long_poll_host) # Fetch all running actors in current cluster as source of current # replica state for controller failure recovery all_current_actors = ray.util.list_named_actors(all_namespaces=True) all_serve_actor_names = [ actor["name"] for actor in all_current_actors if actor["namespace"] == SERVE_NAMESPACE ] self.autoscaling_state_manager = AutoscalingStateManager() self.deployment_state_manager = DeploymentStateManager( self.kv_store, self.long_poll_host, all_serve_actor_names, get_all_live_placement_group_names(), self.cluster_node_info_cache, self.autoscaling_state_manager, ) # Manage all applications' state self.application_state_manager = ApplicationStateManager( self.deployment_state_manager, self.autoscaling_state_manager, self.endpoint_state, self.kv_store, self.global_logging_config, ) # Controller actor details self._actor_details = ServeActorDetails( node_id=ray.get_runtime_context().get_node_id(), node_ip=ray.util.get_node_ip_address(), node_instance_id=ray.util.get_node_instance_id(), actor_id=ray.get_runtime_context().get_actor_id(), actor_name=SERVE_CONTROLLER_NAME, worker_id=ray.get_runtime_context().get_worker_id(), log_file_path=get_component_logger_file_path(), ) self._shutting_down = False self._shutdown_event = asyncio.Event() self._shutdown_start_time = None # Actors registered for cleanup on serve.shutdown(), keyed by actor ID self._registered_cleanup_actors: Dict[str, ActorHandle] = {} # Initialize health metrics tracker self._health_metrics_tracker = ControllerHealthMetricsTracker( controller_start_time=time.time() ) self._create_control_loop_metrics() run_background_task(self.run_control_loop()) # The target capacity percentage for all deployments across the cluster. self._target_capacity: Optional[float] = None self._target_capacity_direction: Optional[TargetCapacityDirection] = None self._recover_state_from_checkpoint() # Nodes where proxy actors should run. self._proxy_nodes = set() self._update_proxy_nodes() # Caches for autoscaling observability self._last_autoscaling_snapshots: Dict[DeploymentID, DeploymentSnapshot] = {} self._autoscaling_enabled_deployments_cache: List[ Tuple[str, str, DeploymentDetails, Any] ] = [] self._refresh_autoscaling_deployments_cache() self._last_broadcasted_target_groups: List[TargetGroup] = [] def reconfigure_global_logging_config(self, global_logging_config: LoggingConfig): if ( self.global_logging_config and self.global_logging_config == global_logging_config ): return self.kv_store.put( LOGGING_CONFIG_CHECKPOINT_KEY, pickle.dumps(global_logging_config) ) self.global_logging_config = global_logging_config self.long_poll_host.notify_changed( {LongPollNamespace.GLOBAL_LOGGING_CONFIG: global_logging_config} ) configure_component_logger( component_name="controller", component_id=str(os.getpid()), logging_config=global_logging_config, ) self._autoscaling_logger = configure_autoscaling_snapshot_logger( component_id=str(os.getpid()), logging_config=global_logging_config, ) logger.info( f"Controller starting (version='{ray.__version__}').", extra={"log_to_stderr": False}, ) logger.debug( "Configure the serve controller logger " f"with logging config: {self.global_logging_config}" ) def check_alive(self) -> None: """No-op to check if this controller is alive.""" return def get_pid(self) -> int: return os.getpid() def record_autoscaling_metrics_from_replica( self, replica_metric_report: ReplicaMetricReport ): latency = time.time() - replica_metric_report.timestamp latency_ms = latency * 1000 # Record the metrics delay for observability self.replica_metrics_delay_gauge.set( latency_ms, tags={ "deployment": replica_metric_report.replica_id.deployment_id.name, "application": replica_metric_report.replica_id.deployment_id.app_name, "replica": replica_metric_report.replica_id.unique_id, }, ) # Track in health metrics self._health_metrics_tracker.record_replica_metrics_delay(latency_ms) if latency_ms > RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS: logger.warning( f"Received autoscaling metrics from replica {replica_metric_report.replica_id} with timestamp {replica_metric_report.timestamp} " f"which is {latency_ms}ms ago. " f"This is greater than the warning threshold RPC latency of {RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS}ms. " "This may indicate a performance issue with the controller try increasing the RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS environment variable." ) self.autoscaling_state_manager.record_request_metrics_for_replica( replica_metric_report ) def record_autoscaling_metrics_from_handle( self, handle_metric_report: HandleMetricReport ): latency = time.time() - handle_metric_report.timestamp latency_ms = latency * 1000 # Record the metrics delay for observability self.handle_metrics_delay_gauge.set( latency_ms, tags={ "deployment": handle_metric_report.deployment_id.name, "application": handle_metric_report.deployment_id.app_name, "handle": handle_metric_report.handle_id, }, ) # Track in health metrics self._health_metrics_tracker.record_handle_metrics_delay(latency_ms) if latency_ms > RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS: logger.warning( f"Received autoscaling metrics from handle {handle_metric_report.handle_id} for deployment {handle_metric_report.deployment_id} with timestamp {handle_metric_report.timestamp} " f"which is {latency_ms}ms ago. " f"This is greater than the warning threshold RPC latency of {RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS}ms. " "This may indicate a performance issue with the controller try increasing the RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS environment variable." ) self.autoscaling_state_manager.record_request_metrics_for_handle( handle_metric_report ) def _get_total_num_requests_for_deployment_for_testing( self, deployment_id: DeploymentID ): return self.autoscaling_state_manager.get_total_num_requests_for_deployment( deployment_id ) def _get_metrics_for_deployment_for_testing(self, deployment_id: DeploymentID): return self.autoscaling_state_manager.get_metrics_for_deployment(deployment_id) def _dump_replica_states_for_testing(self, deployment_id: DeploymentID): return self.deployment_state_manager._deployment_states[deployment_id]._replicas def _stop_one_running_replica_for_testing(self, deployment_id): self.deployment_state_manager._deployment_states[ deployment_id ]._stop_one_running_replica_for_testing() async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]): """Proxy long pull client's listen request. Args: keys_to_snapshot_ids (Dict[str, int]): Snapshot IDs are used to determine whether or not the host should immediately return the data or wait for the value to be changed. """ if not self.done_recovering_event.is_set(): await self.done_recovering_event.wait() return await self.long_poll_host.listen_for_change(keys_to_snapshot_ids) async def listen_for_change_java(self, keys_to_snapshot_ids_bytes: bytes): """Proxy long pull client's listen request. Args: keys_to_snapshot_ids_bytes (Dict[str, int]): the protobuf bytes of keys_to_snapshot_ids (Dict[str, int]). """ if not self.done_recovering_event.is_set(): await self.done_recovering_event.wait() return await self.long_poll_host.listen_for_change_java( keys_to_snapshot_ids_bytes ) def get_all_endpoints(self) -> Dict[DeploymentID, Dict[str, Any]]: """Returns a dictionary of deployment name to config.""" return self.endpoint_state.get_endpoints() def get_all_endpoints_java(self) -> bytes: """Returns a dictionary of deployment name to config.""" endpoints = self.get_all_endpoints() # NOTE(zcin): Java only supports 1.x deployments, so only return # a dictionary of deployment name -> endpoint info data = { endpoint_tag.name: EndpointInfoProto(route=endpoint_dict["route"]) for endpoint_tag, endpoint_dict in endpoints.items() } return EndpointSet(endpoints=data).SerializeToString() def get_proxies(self) -> Dict[NodeId, ActorHandle]: """Returns a dictionary of node ID to proxy actor handles.""" if self.proxy_state_manager is None: return {} return self.proxy_state_manager.get_proxy_handles() def get_proxy_names(self) -> bytes: """Returns the proxy actor name list serialized by protobuf.""" if self.proxy_state_manager is None: return None actor_name_list = ActorNameList( names=self.proxy_state_manager.get_proxy_names().values() ) return actor_name_list.SerializeToString() def _update_proxy_nodes(self): """Update the nodes set where proxy actors should run. Controller decides where proxy actors should run (head node and nodes with deployment replicas). """ new_proxy_nodes = self.deployment_state_manager.get_active_node_ids() new_proxy_nodes = new_proxy_nodes - set( self.cluster_node_info_cache.get_draining_nodes() ) new_proxy_nodes.add(self._controller_node_id) self._proxy_nodes = new_proxy_nodes def _refresh_autoscaling_deployments_cache(self) -> None: result = [] active_dep_ids = set() for app_name in self.application_state_manager.list_app_names(): deployment_details = self.application_state_manager.list_deployment_details( app_name ) for dep_name, details in deployment_details.items(): active_dep_ids.add(DeploymentID(name=dep_name, app_name=app_name)) autoscaling_config = details.deployment_config.autoscaling_config if autoscaling_config: result.append((app_name, dep_name, details, autoscaling_config)) self._autoscaling_enabled_deployments_cache = result self._last_autoscaling_snapshots = { k: v for k, v in self._last_autoscaling_snapshots.items() if k in active_dep_ids } def _emit_deployment_autoscaling_snapshots(self) -> None: """Emit structured autoscaling snapshot logs in a single batch per loop.""" if self._autoscaling_logger is None: return snapshots_to_log: List[Dict[str, Any]] = [] for ( app_name, dep_name, details, autoscaling_config, ) in self._autoscaling_enabled_deployments_cache: dep_id = DeploymentID(name=dep_name, app_name=app_name) deployment_snapshot = ( self.autoscaling_state_manager.get_deployment_snapshot(dep_id) ) if deployment_snapshot is None: continue last = self._last_autoscaling_snapshots.get(dep_id) if last is not None and last.is_scaling_equivalent(deployment_snapshot): continue snapshots_to_log.append(deployment_snapshot.dict(exclude_none=True)) self._last_autoscaling_snapshots[dep_id] = deployment_snapshot if snapshots_to_log: # Single write per control-loop iteration self._autoscaling_logger.info({"snapshots": snapshots_to_log}) async def run_control_loop(self) -> None: # NOTE(edoakes): we catch all exceptions here and simply log them, # because an unhandled exception would cause the main control loop to # halt, which should *never* happen. recovering_timeout = RECOVERING_LONG_POLL_BROADCAST_TIMEOUT_S num_loops = 0 start_time = time.time() while True: loop_start_time = time.time() try: await self.run_control_loop_step( start_time, recovering_timeout, num_loops ) except Exception as e: # we never expect this to happen, but adding this to be safe logger.exception(f"There was an exception in the control loop: {e}") await asyncio.sleep(1) loop_duration = time.time() - loop_start_time if loop_duration > 10: logger.warning( f"The last control loop was slow (took {loop_duration}s). " "This is likely caused by running a large number of " "replicas in a single Ray cluster. Consider using " "multiple Ray clusters.", extra={"log_to_stderr": False}, ) self.control_loop_duration_gauge_s.set(loop_duration) # Track in health metrics self._health_metrics_tracker.record_loop_duration(loop_duration) num_loops += 1 self.num_control_loops_gauge.set(num_loops) self._health_metrics_tracker.num_control_loops = num_loops sleep_start_time = time.time() await asyncio.sleep(CONTROL_LOOP_INTERVAL_S) sleep_duration = time.time() - sleep_start_time self.sleep_duration_gauge_s.set(sleep_duration) self._health_metrics_tracker.last_sleep_duration_s = sleep_duration async def run_control_loop_step( self, start_time: float, recovering_timeout: float, num_loops: int ): try: self.cluster_node_info_cache.update() except Exception: logger.exception("Exception updating cluster node info cache.") if self._shutting_down: try: self.shutdown() except Exception: logger.exception("Exception during shutdown.") if ( not self.done_recovering_event.is_set() and time.time() - start_time > recovering_timeout ): logger.warning( f"Replicas still recovering after {recovering_timeout}s, " "setting done recovering event to broadcast long poll updates." ) self.done_recovering_event.set() # initialize any_recovering to None to indicate that we don't know if # we've recovered anything yet any_recovering: Optional[bool] = None try: dsm_update_start_time = time.time() any_recovering = self.deployment_state_manager.update() dsm_duration = time.time() - dsm_update_start_time self.dsm_update_duration_gauge_s.set(dsm_duration) self._health_metrics_tracker.record_dsm_update_duration(dsm_duration) if not self.done_recovering_event.is_set() and not any_recovering: self.done_recovering_event.set() if num_loops > 0: # Only log if we actually needed to recover anything. logger.info( "Finished recovering deployments after " f"{(time.time() - start_time):.2f}s.", extra={"log_to_stderr": False}, ) except Exception: logger.exception("Exception updating deployment state.") try: asm_update_start_time = time.time() any_target_state_changed = self.application_state_manager.update() if any_recovering or any_target_state_changed: self._refresh_autoscaling_deployments_cache() asm_duration = time.time() - asm_update_start_time self.asm_update_duration_gauge_s.set(asm_duration) self._health_metrics_tracker.record_asm_update_duration(asm_duration) except Exception: logger.exception("Exception updating application state.") try: # Emit one autoscaling snapshot per deployment per loop using existing state. self._emit_deployment_autoscaling_snapshots() except Exception: logger.exception("Exception emitting deployment autoscaling snapshots.") # Update the proxy nodes set before updating the proxy states, # so they are more consistent. node_update_start_time = time.time() self._update_proxy_nodes() node_update_duration = time.time() - node_update_start_time self.node_update_duration_gauge_s.set(node_update_duration) self._health_metrics_tracker.record_node_update_duration(node_update_duration) # Don't update proxy_state until after the done recovering event is set, # otherwise we may start a new proxy but not broadcast it any # info about available deployments & their replicas. if self.proxy_state_manager and self.done_recovering_event.is_set(): try: proxy_update_start_time = time.time() self.proxy_state_manager.update(proxy_nodes=self._proxy_nodes) proxy_update_duration = time.time() - proxy_update_start_time self.proxy_update_duration_gauge_s.set(proxy_update_duration) self._health_metrics_tracker.record_proxy_update_duration( proxy_update_duration ) except Exception: logger.exception("Exception updating proxy state.") # When the controller is done recovering, drop invalid handle metrics # that may be stale for autoscaling if any_recovering is False: self.autoscaling_state_manager.drop_stale_handle_metrics( self.deployment_state_manager.get_alive_replica_actor_ids() | self.proxy_state_manager.get_alive_proxy_actor_ids() ) # Direct ingress port management if self._direct_ingress_enabled: # Update port values for ingress replicas. # Non-ingress replicas are not expected to have ports allocated. ingress_replicas_info_list: List[ Tuple[str, str, int, int] ] = self.deployment_state_manager.get_ingress_replicas_info() NodePortManager.update_ports(ingress_replicas_info_list) # Clean up stale ports # get all alive replica ids and their node ids. NodePortManager.prune(self._get_node_id_to_alive_replica_ids()) def _create_control_loop_metrics(self): self.node_update_duration_gauge_s = metrics.Gauge( "serve_controller_node_update_duration_s", description="The control loop time spent on collecting proxy node info.", ) self.proxy_update_duration_gauge_s = metrics.Gauge( "serve_controller_proxy_state_update_duration_s", description="The control loop time spent on updating proxy state.", ) self.dsm_update_duration_gauge_s = metrics.Gauge( "serve_controller_deployment_state_update_duration_s", description="The control loop time spent on updating deployment state.", ) self.asm_update_duration_gauge_s = metrics.Gauge( "serve_controller_application_state_update_duration_s", description="The control loop time spent on updating application state.", ) self.sleep_duration_gauge_s = metrics.Gauge( "serve_controller_sleep_duration_s", description="The duration of the last control loop's sleep.", ) self.control_loop_duration_gauge_s = metrics.Gauge( "serve_controller_control_loop_duration_s", description="The duration of the last control loop.", ) self.num_control_loops_gauge = metrics.Gauge( "serve_controller_num_control_loops", description=( "The number of control loops performed by the controller. " "Increases monotonically over the controller's lifetime." ), tag_keys=("actor_id",), ) self.num_control_loops_gauge.set_default_tags( {"actor_id": ray.get_runtime_context().get_actor_id()} ) # Autoscaling metrics delay gauges self.replica_metrics_delay_gauge = metrics.Gauge( "serve_autoscaling_replica_metrics_delay_ms", description=( "Time taken for the replica metrics to be reported to the controller. " "High values may indicate a busy controller." ), tag_keys=("deployment", "application", "replica"), ) self.handle_metrics_delay_gauge = metrics.Gauge( "serve_autoscaling_handle_metrics_delay_ms", description=( "Time taken for the handle metrics to be reported to the controller. " "High values may indicate a busy controller." ), tag_keys=("deployment", "application", "handle"), ) def _recover_state_from_checkpoint(self): ( deployment_time, serve_config, target_capacity_direction, ) = self._read_config_checkpoint() self._target_capacity_direction = target_capacity_direction if serve_config is not None: logger.info( "Recovered config from checkpoint.", extra={"log_to_stderr": False} ) self.apply_config(serve_config, deployment_time=deployment_time) def _read_config_checkpoint( self, ) -> Tuple[float, Optional[ServeDeploySchema], Optional[TargetCapacityDirection]]: """Reads the current Serve config checkpoint. The Serve config checkpoint stores active application configs and other metadata. Returns: If the GCS contains a checkpoint, tuple of: 1. A deployment timestamp. 2. A Serve config. This Serve config is reconstructed from the active application states. It may not exactly match the submitted config (e.g. the top-level http options may be different). 3. The target_capacity direction calculated after the Serve was submitted. If the GCS doesn't contain a checkpoint, returns (0, None, None). """ checkpoint = self.kv_store.get(CONFIG_CHECKPOINT_KEY) if checkpoint is not None: ( deployment_time, target_capacity, target_capacity_direction, config_checkpoints_dict, ) = pickle.loads(checkpoint) return ( deployment_time, ServeDeploySchema( applications=list(config_checkpoints_dict.values()), target_capacity=target_capacity, ), target_capacity_direction, ) else: return (0.0, None, None) def _all_running_replicas(self) -> Dict[DeploymentID, List[RunningReplicaInfo]]: """Used for testing. Returned dictionary maps deployment names to replica infos. """ return self.deployment_state_manager.get_running_replica_infos() def get_actor_details(self) -> ServeActorDetails: """Returns the actor details for this controller. Currently used for test only. """ return self._actor_details def get_health_metrics(self) -> Dict[str, Any]: """Returns comprehensive health metrics for the controller. This method provides detailed performance metrics to help diagnose controller health issues, especially as cluster size increases. Returns: Dictionary containing health metrics including: - Control loop performance (iteration speed, durations) - Event loop health (task count, scheduling delay) - Component update latencies - Autoscaling metrics latency (handle/replica) - Memory usage """ try: return self._health_metrics_tracker.collect_metrics().dict() except Exception: logger.exception("Exception collecting controller health metrics.") raise def get_proxy_details(self, node_id: str) -> Optional[ProxyDetails]: """Returns the proxy details for the proxy on the given node. Currently used for test only. Will return None if the proxy doesn't exist on the given node. """ if self.proxy_state_manager is None: return None return self.proxy_state_manager.get_proxy_details().get(node_id) def get_deployment_timestamps(self, app_name: str) -> float: """Returns the deployment timestamp for the given app. Currently used for test only. """ for ( _app_name, app_status_info, ) in self.application_state_manager.list_app_statuses().items(): if app_name == _app_name: return app_status_info.deployment_timestamp def get_deployment_details( self, app_name: str, deployment_name: str ) -> DeploymentDetails: """Returns the deployment details for the app and deployment. Currently used for test only. """ return self.application_state_manager.list_deployment_details(app_name)[ deployment_name ] def get_http_config(self) -> HTTPOptions: """Return the HTTP proxy configuration.""" if self.proxy_state_manager is None: return HTTPOptions() return self.proxy_state_manager.get_config() def get_grpc_config(self) -> gRPCOptions: """Return the gRPC proxy configuration.""" if self.proxy_state_manager is None: return gRPCOptions() return self.proxy_state_manager.get_grpc_config() def get_root_url(self): """Return the root url for the serve instance.""" if self.proxy_state_manager is None: return None http_config = self.get_http_config() if http_config.root_url == "": # HTTP is disabled if http_config.host is None: return "" return ( f"http://{build_address(http_config.host, http_config.port)}" f"{http_config.root_path}" ) return http_config.root_url def config_checkpoint_deleted(self) -> bool: """Returns whether the config checkpoint has been deleted. Get the config checkpoint from the kv store. If it is None, then it has been deleted. """ return self.kv_store.get(CONFIG_CHECKPOINT_KEY) is None def _register_shutdown_cleanup_actor(self, actor_handle: ActorHandle) -> None: """Register an actor to be killed on serve.shutdown(). This allows deployments to register auxiliary actors (like caches, coordinators, etc.) that should be cleaned up when Serve shuts down. The actors must use lifetime="detached" to survive replica restarts, but will be explicitly killed during serve.shutdown(). Note: Registered actors are NOT persisted across controller restarts. For full persistence, use controller-managed deployment-scoped actors (see https://github.com/ray-project/ray/issues/60359). If the same actor is registered multiple times (e.g., from multiple router instances sharing a tree actor via get_if_exists=True), it will only be stored once. Args: actor_handle: The actor handle to register for cleanup. """ actor_id = actor_handle._actor_id.hex() self._registered_cleanup_actors[actor_id] = actor_handle def _kill_registered_cleanup_actors(self) -> None: """Kill all actors registered for shutdown cleanup.""" for actor in self._registered_cleanup_actors.values(): try: ray.kill(actor, no_restart=True) except Exception: pass # Actor may already be dead def shutdown(self): """Shuts down the serve instance completely. This method will only be triggered when `self._shutting_down` is true. It deletes the kv store for config checkpoints, sets application state to deleting, delete all deployments, and shuts down all proxies. Once all these resources are released, it then kills the controller actor. """ if not self._shutting_down: return if self._shutdown_start_time is None: self._shutdown_start_time = time.time() logger.info("Controller shutdown started.", extra={"log_to_stderr": False}) self.kv_store.delete(CONFIG_CHECKPOINT_KEY) self.kv_store.delete(LOGGING_CONFIG_CHECKPOINT_KEY) self.application_state_manager.shutdown() self.deployment_state_manager.shutdown() self.endpoint_state.shutdown() if self.proxy_state_manager: self.proxy_state_manager.shutdown() config_checkpoint_deleted = self.config_checkpoint_deleted() application_is_shutdown = self.application_state_manager.is_ready_for_shutdown() deployment_is_shutdown = self.deployment_state_manager.is_ready_for_shutdown() endpoint_is_shutdown = self.endpoint_state.is_ready_for_shutdown() proxy_state_is_shutdown = ( self.proxy_state_manager is None or self.proxy_state_manager.is_ready_for_shutdown() ) if ( config_checkpoint_deleted and application_is_shutdown and deployment_is_shutdown and endpoint_is_shutdown and proxy_state_is_shutdown ): self._kill_registered_cleanup_actors() logger.warning( "All resources have shut down, controller exiting.", extra={"log_to_stderr": False}, ) _controller_actor = ray.get_runtime_context().current_actor ray.kill(_controller_actor, no_restart=True) elif time.time() - self._shutdown_start_time > 10: if not config_checkpoint_deleted: logger.warning( f"{CONFIG_CHECKPOINT_KEY} not yet deleted", extra={"log_to_stderr": False}, ) if not application_is_shutdown: logger.warning( "application not yet shutdown", extra={"log_to_stderr": False}, ) if not deployment_is_shutdown: logger.warning( "deployment not yet shutdown", extra={"log_to_stderr": False}, ) if not endpoint_is_shutdown: logger.warning( "endpoint not yet shutdown", extra={"log_to_stderr": False}, ) if not proxy_state_is_shutdown: logger.warning( "proxy_state not yet shutdown", extra={"log_to_stderr": False}, ) def deploy_applications( self, name_to_deployment_args_list: Dict[str, List[bytes]], name_to_application_args: Dict[str, bytes], ) -> None: """ Takes in a list of dictionaries that contain deployment arguments. If same app name deployed, old application will be overwritten. Args: name: Application name. deployment_args_list: List of serialized deployment information, where each item in the list is bytes representing the serialized protobuf `DeploymentArgs` object. `DeploymentArgs` contains all the information for the single deployment. name_to_application_args: Dictionary mapping application names to serialized application arguments, where each item is bytes representing the serialized protobuf `ApplicationArgs` object. `ApplicationArgs` contains the information for the application. """ name_to_deployment_args = {} for name, deployment_args_list in name_to_deployment_args_list.items(): deployment_args_deserialized = [] for deployment_args_bytes in deployment_args_list: args = DeploymentArgs.FromString(deployment_args_bytes) deployment_args_deserialized.append( { "deployment_name": args.deployment_name, "deployment_config_proto_bytes": args.deployment_config, "replica_config_proto_bytes": args.replica_config, "deployer_job_id": args.deployer_job_id, "ingress": args.ingress, "route_prefix": ( args.route_prefix if args.HasField("route_prefix") else None ), } ) name_to_deployment_args[name] = deployment_args_deserialized name_to_application_args_deserialized = {} for name, application_args_bytes in name_to_application_args.items(): name_to_application_args_deserialized[name] = ApplicationArgs.FromString( application_args_bytes ) self.application_state_manager.deploy_apps( name_to_deployment_args, name_to_application_args_deserialized ) self.application_state_manager.save_checkpoint() def deploy_application( self, name: str, deployment_args_list: List[bytes], application_args: bytes, ) -> None: """ Deploy a single application (as deploy_applications(), but it only takes a single name and deployment args). This primarily exists as a shim to avoid changing Java code in https://github.com/ray-project/ray/pull/49168, and could be removed if the Java code was refactored to use the new bulk deploy_applications API. """ self.deploy_applications( {name: deployment_args_list}, {name: application_args}, ) def apply_config( self, config: ServeDeploySchema, deployment_time: float = 0.0, ) -> None: """Apply the config described in `ServeDeploySchema`. This will upgrade the applications to the goal state specified in the config. If `deployment_time` is not provided, `time.time()` is used. """ ServeUsageTag.API_VERSION.record("v2") if not deployment_time: deployment_time = time.time() new_config_checkpoint = {} _, curr_config, _ = self._read_config_checkpoint() self._target_capacity_direction = calculate_target_capacity_direction( curr_config=curr_config, new_config=config, curr_target_capacity_direction=self._target_capacity_direction, ) log_target_capacity_change( self._target_capacity, config.target_capacity, self._target_capacity_direction, ) self._target_capacity = config.target_capacity for app_config in config.applications: # If the application logging config is not set, use the global logging # config. if app_config.logging_config is None and config.logging_config: app_config.logging_config = config.logging_config app_config_dict = app_config.dict(exclude_unset=True) new_config_checkpoint[app_config.name] = app_config_dict self.kv_store.put( CONFIG_CHECKPOINT_KEY, pickle.dumps( ( deployment_time, self._target_capacity, self._target_capacity_direction, new_config_checkpoint, ) ), ) # Declaratively apply the new set of applications. # This will delete any applications no longer in the config that were # previously deployed via the REST API. self.application_state_manager.apply_app_configs( config.applications, deployment_time=deployment_time, target_capacity=self._target_capacity, target_capacity_direction=self._target_capacity_direction, ) self.application_state_manager.save_checkpoint() def get_deployment_info(self, name: str, app_name: str = "") -> bytes: """Get the current information about a deployment. Args: name: the name of the deployment. Returns: DeploymentRoute's protobuf serialized bytes Raises: KeyError: If the deployment doesn't exist. """ id = DeploymentID(name=name, app_name=app_name) deployment_info = self.deployment_state_manager.get_deployment(id) if deployment_info is None: app_msg = f" in application '{app_name}'" if app_name else "" raise KeyError(f"Deployment '{name}' does not exist{app_msg}.") route = self.endpoint_state.get_endpoint_route(id) deployment_route = DeploymentRoute( deployment_info=deployment_info.to_proto(), route=route ) return deployment_route.SerializeToString() def list_deployments_internal( self, ) -> Dict[DeploymentID, Tuple[DeploymentInfo, str]]: """Gets the current information about all deployments. Returns: Dict(deployment_id, (DeploymentInfo, route)) """ return { id: (info, self.endpoint_state.get_endpoint_route(id)) for id, info in self.deployment_state_manager.get_deployment_infos().items() } def get_deployment_config( self, deployment_id: DeploymentID ) -> Optional[DeploymentConfig]: """Get the deployment config for the given deployment id. Args: deployment_id: The deployment id to get the config for. Returns: A deployment config object if the deployment id exist, None otherwise. """ deployment_info = self.deployment_state_manager.get_deployment_infos().get( deployment_id ) return deployment_info.deployment_config if deployment_info else None def list_deployment_ids(self) -> List[DeploymentID]: """Gets the current list of all deployments' identifiers.""" return self.deployment_state_manager._deployment_states.keys() def update_deployment_replicas( self, deployment_id: DeploymentID, target_num_replicas: int ) -> None: """Update the target number of replicas for a deployment. Args: deployment_id: The deployment to update. target_num_replicas: The new target number of replicas. Raises: ExternalScalerDisabledError: If external_scaler_enabled is set to False for the application. """ # Check if external scaler is enabled for this application app_name = deployment_id.app_name if not self.application_state_manager.does_app_exist(app_name): raise ValueError(f"Application '{app_name}' not found") if not self.application_state_manager.get_external_scaler_enabled(app_name): raise ExternalScalerDisabledError( f"Cannot update replicas for deployment '{deployment_id.name}' in " f"application '{app_name}'. The external scaling API can only be used " f"when 'external_scaler_enabled' is set to true in the application " f"configuration. Current value: external_scaler_enabled=false. " f"To use this API, redeploy your application with " f"'external_scaler_enabled: true' in the config." ) self.deployment_state_manager.set_target_num_replicas( deployment_id, target_num_replicas ) def get_serve_instance_details(self, source: Optional[APIType] = None) -> Dict: """Gets details on all applications on the cluster and system-level info. The information includes application and deployment statuses, config options, error messages, etc. Args: source: If provided, returns application statuses for applications matching this API type. Defaults to None, which means all applications are returned. Returns: Dict that follows the format of the schema ServeInstanceDetails. """ http_config = self.get_http_config() grpc_config = self.get_grpc_config() applications = {} app_statuses = self.application_state_manager.list_app_statuses(source=source) # If there are no app statuses, there's no point getting the app configs. # Moreover, there might be no app statuses because the GCS is down, # in which case getting the app configs would fail anyway, # since they're stored in the checkpoint in the GCS. app_configs = self.get_app_configs() if app_statuses else {} for ( app_name, app_status_info, ) in app_statuses.items(): applications[app_name] = ApplicationDetails( name=app_name, route_prefix=self.application_state_manager.get_route_prefix(app_name), docs_path=self.get_docs_path(app_name), status=app_status_info.status, message=app_status_info.message, last_deployed_time_s=app_status_info.deployment_timestamp, # This can be none if the app was deployed through # serve.run, the app is in deleting state, # or a checkpoint hasn't been set yet deployed_app_config=app_configs.get(app_name), source=self.application_state_manager.get_app_source(app_name), deployments=self.application_state_manager.list_deployment_details( app_name ), external_scaler_enabled=self.application_state_manager.get_external_scaler_enabled( app_name ), deployment_topology=self.application_state_manager.get_deployment_topology( app_name ), ) # NOTE(zcin): We use exclude_unset here because we explicitly and intentionally # fill in all info that should be shown to users. http_options = HTTPOptionsSchema.parse_obj(http_config.dict(exclude_unset=True)) grpc_options = gRPCOptionsSchema.parse_obj(grpc_config.dict(exclude_unset=True)) return ServeInstanceDetails( target_capacity=self._target_capacity, controller_info=self._actor_details, proxy_location=ProxyLocation._from_deployment_mode(http_config.location), http_options=http_options, grpc_options=grpc_options, proxies=( self.proxy_state_manager.get_proxy_details() if self.proxy_state_manager else None ), applications=applications, target_groups=self.get_target_groups(), )._get_user_facing_json_serializable_dict(exclude_unset=True) def _get_proxy_target_groups(self) -> List[TargetGroup]: """Get target groups for proxy-based routing.""" target_groups: List[TargetGroup] = [] if self.proxy_state_manager.get_proxy_details(): # setting prefix route to "/" because in ray serve, proxy # accepts requests from the client and routes them to the # correct application. This is true for both HTTP and gRPC proxies. target_groups.append( TargetGroup( protocol=RequestProtocol.HTTP, route_prefix="/", targets=self.proxy_state_manager.get_targets(RequestProtocol.HTTP), ) ) if is_grpc_enabled(self.get_grpc_config()): target_groups.append( TargetGroup( protocol=RequestProtocol.GRPC, route_prefix="/", targets=self.proxy_state_manager.get_targets( RequestProtocol.GRPC ), ) ) return target_groups def get_target_groups( self, app_name: Optional[str] = None, from_proxy_manager: bool = False, ) -> List[TargetGroup]: """Get target groups for direct ingress deployments. This returns target groups that point directly to replica ports rather than proxy ports when direct ingress is enabled. Following situations are possible: 1. Direct ingress is not enabled. In this case, we just return the target groups from the proxy implementation. 2. Direct ingress is enabled and there are no applications. In this case, we return target groups for proxy. Serve controller is running but there are no applications to route traffic to. 3. Direct ingress is enabled and there are applications. All applications have atleast one running replica. In this case, we return target groups for all applications with targets pointing to the running replicas. 4. Direct ingress is enabled and there are applications. Some applications have no running replicas. In this case, for applications that have no running replicas, we return target groups for proxy and for applications that have running replicas, we return target groups for direct ingress. If there are multiple applications with no running replicas, we return one target group per application with unique route prefix. """ proxy_target_groups = self._get_proxy_target_groups() if not self._direct_ingress_enabled: return proxy_target_groups # Get all applications and their metadata if app_name is None: apps = [ _app_name for _app_name, _ in self.application_state_manager.list_app_statuses().items() ] else: apps = [app_name] # TODO(landscapepainter): A better way to handle this is to write an API that can tell # if the ingress deployment is healthy regardless of the application status. apps = [ app for app in apps if self.application_state_manager.get_route_prefix(app) is not None ] if not apps: return proxy_target_groups # Create target groups for each application target_groups = [] for app_name in apps: route_prefix = self.application_state_manager.get_route_prefix(app_name) app_target_groups = self._get_target_groups_for_app(app_name, route_prefix) if app_target_groups: target_groups.extend(app_target_groups) else: target_groups.extend( self._get_target_groups_for_app_with_no_running_replicas( route_prefix, app_name ) ) return target_groups def _get_running_replica_details_for_ingress_deployment( self, app_name: str ) -> List[ReplicaDetails]: """Get running replica details for a specific application.""" ingress_deployment_name = ( self.application_state_manager.get_ingress_deployment_name(app_name) ) deployment_id = DeploymentID(app_name=app_name, name=ingress_deployment_name) details = self.deployment_state_manager.get_deployment_details(deployment_id) if not details: return [] replica_details = details.replicas running_replica_ids = { replica_info.replica_id.unique_id for replica_info in self.deployment_state_manager.get_running_replica_infos().get( deployment_id, [] ) } return [ replica_detail for replica_detail in replica_details if replica_detail.replica_id in running_replica_ids ] def _get_target_groups_for_app( self, app_name: str, route_prefix: str ) -> List[TargetGroup]: """ Create HTTP and gRPC target groups for a specific application. This function can return empty list if there are no running replicas. Or replicas have not fully initialized yet, where their ports are not allocated yet. """ # Get running replicas for the ingress deployment replica_details = self._get_running_replica_details_for_ingress_deployment( app_name ) if not replica_details: return [] target_groups = [] # Create targets for each protocol http_targets = self._get_targets_for_protocol( replica_details, RequestProtocol.HTTP ) if http_targets: target_groups.append( TargetGroup( protocol=RequestProtocol.HTTP, route_prefix=route_prefix, targets=http_targets, app_name=app_name, ) ) # Add gRPC targets if enabled if is_grpc_enabled(self.get_grpc_config()): grpc_targets = self._get_targets_for_protocol( replica_details, RequestProtocol.GRPC ) if grpc_targets: target_groups.append( TargetGroup( protocol=RequestProtocol.GRPC, route_prefix=route_prefix, targets=grpc_targets, app_name=app_name, ) ) return target_groups def _get_target_groups_for_app_with_no_running_replicas( self, route_prefix: str, app_name: str ) -> List[TargetGroup]: """ For applications that have no running replicas, we return target groups for proxy. This will allow applications to be discoverable via the proxy in situations where their replicas have scaled down to 0. """ target_groups = [] http_targets = self.proxy_state_manager.get_targets(RequestProtocol.HTTP) grpc_targets = self.proxy_state_manager.get_targets(RequestProtocol.GRPC) if http_targets: target_groups.append( TargetGroup( protocol=RequestProtocol.HTTP, route_prefix=route_prefix, targets=http_targets, app_name=app_name, ) ) if grpc_targets: target_groups.append( TargetGroup( protocol=RequestProtocol.GRPC, route_prefix=route_prefix, targets=grpc_targets, app_name=app_name, ) ) return target_groups def _get_targets_for_protocol( self, replica_details: List[ReplicaDetails], protocol: RequestProtocol ) -> List[Target]: """Create targets for a specific protocol from a list of replicas.""" return [ Target( ip=replica_detail.node_ip, port=self._get_port(replica_detail, protocol), instance_id=replica_detail.node_instance_id, name=replica_detail.actor_name, ) for replica_detail in replica_details if self._is_port_allocated(replica_detail, protocol) ] def _get_node_id_to_alive_replica_ids(self) -> Dict[str, Set[str]]: node_id_to_alive_replica_ids = defaultdict(set) # TODO(abrar): Expose the right APIs in the DeploymentStateManager # to get the alive replicas for a deployment. for ds in self.deployment_state_manager._deployment_states.values(): # here we get all the replicas irrespective of their state # unlike in the get_running_replica_infos_for_ingress_deployment # where we only get the replicas that are running, because we dont # wish to agressively cleanup ports for replicas that are not running # and are in the process of being updated or are in the process of # being started. replicas: List[DeploymentReplica] = ds._replicas.get() for replica in replicas: node_id: Optional[str] = replica.actor_node_id if node_id is None: continue replica_unique_id = replica.replica_id.unique_id node_id_to_alive_replica_ids[node_id].add(replica_unique_id) return node_id_to_alive_replica_ids def allocate_replica_port( self, node_id: str, replica_id: str, protocol: RequestProtocol ) -> int: """Allocate an HTTP port for a replica in direct ingress mode.""" node_manager = NodePortManager.get_node_manager(node_id) return node_manager.allocate_port(replica_id, protocol) def release_replica_port( self, node_id: str, replica_id: str, port: int, protocol: RequestProtocol, block_port: bool = False, ): """Release an HTTP port for a replica in direct ingress mode.""" node_manager = NodePortManager.get_node_manager(node_id) node_manager.release_port(replica_id, port, protocol, block_port) def _get_port( self, replica_detail: ReplicaDetails, protocol: RequestProtocol ) -> int: """Get the port for a replica.""" node_manager = NodePortManager.get_node_manager(replica_detail.node_id) return node_manager.get_port(replica_detail.replica_id, protocol) def _is_port_allocated( self, replica_detail: ReplicaDetails, protocol: RequestProtocol ) -> bool: """Check if the port for a replica is allocated.""" node_manager = NodePortManager.get_node_manager(replica_detail.node_id) return node_manager.is_port_allocated(replica_detail.replica_id, protocol) def get_serve_status(self, name: str = SERVE_DEFAULT_APP_NAME) -> bytes: """Return application status Args: name: application name. If application name doesn't exist, app_status is NOT_STARTED. """ app_status = self.application_state_manager.get_app_status_info(name) deployment_statuses = self.application_state_manager.get_deployments_statuses( name ) status_info = StatusOverview( name=name, app_status=app_status, deployment_statuses=deployment_statuses, ) return status_info.to_proto().SerializeToString() def get_serve_statuses(self, names: List[str]) -> List[bytes]: statuses = [] for name in names: statuses.append(self.get_serve_status(name)) return statuses def list_serve_statuses(self) -> List[bytes]: statuses = [] for name in self.application_state_manager.list_app_statuses(): statuses.append(self.get_serve_status(name)) return statuses def get_app_configs(self) -> Dict[str, ServeApplicationSchema]: checkpoint = self.kv_store.get(CONFIG_CHECKPOINT_KEY) if checkpoint is None: return {} _, _, _, config_checkpoints_dict = pickle.loads(checkpoint) return { app: ServeApplicationSchema.parse_obj(config) for app, config in config_checkpoints_dict.items() } def get_external_scaler_enabled(self, app_name: str) -> bool: """Get the external_scaler_enabled flag value for an application. This is a helper method specifically for Java tests to verify the flag is correctly set, since Java cannot deserialize Python Pydantic objects. Args: app_name: Name of the application. Returns: True if external_scaler_enabled is set for the application, False otherwise. """ return self.application_state_manager.get_external_scaler_enabled(app_name) def get_all_deployment_statuses(self) -> List[bytes]: """Gets deployment status bytes for all live deployments.""" statuses = self.deployment_state_manager.get_deployment_statuses() return [status.to_proto().SerializeToString() for status in statuses] def get_deployment_status( self, name: str, app_name: str = "" ) -> Union[None, bytes]: """Get deployment status by deployment name. Args: name: Deployment name. app_name: Application name. Default is "" because 1.x deployments go through this API. """ id = DeploymentID(name=name, app_name=app_name) status = self.deployment_state_manager.get_deployment_statuses([id]) if not status: return None return status[0].to_proto().SerializeToString() def get_docs_path(self, name: str): """Docs path for application. Currently, this is the OpenAPI docs path for FastAPI-integrated applications.""" return self.application_state_manager.get_docs_path(name) def get_ingress_deployment_name(self, app_name: str) -> Optional[str]: """Name of the ingress deployment in an application. Returns: Ingress deployment name (str): if the application exists. None: if the application does not exist. """ return self.application_state_manager.get_ingress_deployment_name(app_name) def delete_apps(self, names: Iterable[str]): """Delete applications based on names During deletion, the application status is DELETING """ for name in names: self.application_state_manager.delete_app(name) self.application_state_manager.save_checkpoint() def record_request_routing_info(self, info: RequestRoutingInfo): """Record replica routing information for a replica. Args: info: RequestRoutingInfo including deployment name, replica tag, multiplex model ids, and routing stats. """ self.deployment_state_manager.record_request_routing_info(info) def _get_replica_ranks_mapping( self, deployment_id: DeploymentID ) -> Dict[str, ReplicaRank]: """Get the current rank mapping for all replicas in a deployment. Args: deployment_id: The deployment ID to get ranks for. Returns: Dictionary mapping replica_id to ReplicaRank object (with rank, node_rank, local_rank). """ return self.deployment_state_manager._get_replica_ranks_mapping(deployment_id) async def graceful_shutdown(self, wait: bool = True): """Set the shutting down flag on controller to signal shutdown in run_control_loop(). This is used to signal to the controller that it should proceed with shutdown process, so it can shut down gracefully. It also waits until the shutdown event is triggered if wait is true. Raises: RayActorError: if wait is True, the caller waits until the controller is killed, which raises a RayActorError. """ self._shutting_down = True if not wait: return # This event never gets set. The caller waits indefinitely on this event # until the controller is killed, which raises a RayActorError. await self._shutdown_event.wait() def _get_logging_config(self) -> Tuple: """Get the logging configuration (for testing purposes).""" log_file_path = None for handler in logger.handlers: if isinstance(handler, logging.handlers.MemoryHandler): log_file_path = handler.target.baseFilename return self.global_logging_config, log_file_path def _get_target_capacity_direction(self) -> Optional[TargetCapacityDirection]: """Gets the controller's scale direction (for testing purposes).""" return self._target_capacity_direction def calculate_target_capacity_direction( curr_config: Optional[ServeDeploySchema], new_config: ServeDeploySchema, curr_target_capacity_direction: Optional[float], ) -> Optional[TargetCapacityDirection]: """Compares two Serve configs to calculate the next scaling direction.""" curr_target_capacity = None next_target_capacity_direction = None if curr_config is not None and applications_match(curr_config, new_config): curr_target_capacity = curr_config.target_capacity next_target_capacity = new_config.target_capacity if curr_target_capacity == next_target_capacity: next_target_capacity_direction = curr_target_capacity_direction elif curr_target_capacity is None and next_target_capacity is not None: # target_capacity is scaling down from None to a number. next_target_capacity_direction = TargetCapacityDirection.DOWN elif next_target_capacity is None: next_target_capacity_direction = None elif curr_target_capacity < next_target_capacity: next_target_capacity_direction = TargetCapacityDirection.UP else: next_target_capacity_direction = TargetCapacityDirection.DOWN elif new_config.target_capacity is not None: # A config with different apps has been applied, and it contains a # target_capacity. Serve must start scaling this config up. next_target_capacity_direction = TargetCapacityDirection.UP else: next_target_capacity_direction = None return next_target_capacity_direction def applications_match(config1: ServeDeploySchema, config2: ServeDeploySchema) -> bool: """Checks whether the applications in config1 and config2 match. Two applications match if they have the same name. """ config1_app_names = {app.name for app in config1.applications} config2_app_names = {app.name for app in config2.applications} return config1_app_names == config2_app_names def log_target_capacity_change( curr_target_capacity: Optional[float], next_target_capacity: Optional[float], next_target_capacity_direction: Optional[TargetCapacityDirection], ): """Logs changes in the target_capacity.""" if curr_target_capacity != next_target_capacity: if isinstance(next_target_capacity_direction, TargetCapacityDirection): logger.info( "Target capacity scaling " f"{next_target_capacity_direction.value.lower()} " f"from {curr_target_capacity} to {next_target_capacity}." ) else: logger.info("Target capacity entering 100% at steady state.")