import asyncio import datetime import os import random import threading import time from contextlib import asynccontextmanager from copy import copy, deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, Union import grpc import httpx import requests from starlette.requests import Request import ray from ray import serve from ray._common.network_utils import build_address from ray._common.test_utils import wait_for_condition from ray._common.utils import TimerBase from ray._private.test_utils import ( PrometheusTimeseries, fetch_prometheus_metric_timeseries, ) from ray.actor import ActorHandle from ray.serve._private.client import ServeControllerClient from ray.serve._private.common import ( CreatePlacementGroupRequest, DeploymentID, DeploymentStatus, RequestProtocol, ) from ray.serve._private.constants import ( SERVE_DEFAULT_APP_NAME, SERVE_NAMESPACE, ) from ray.serve._private.deployment_state import ALL_REPLICA_STATES, ReplicaState from ray.serve._private.proxy import DRAINING_MESSAGE from ray.serve._private.usage import ServeUsageTag from ray.serve.context import _get_global_client from ray.serve.generated import serve_pb2, serve_pb2_grpc from ray.serve.schema import ApplicationStatus, TargetGroup from ray.util.state import list_actors TELEMETRY_ROUTE_PREFIX = "/telemetry" STORAGE_ACTOR_NAME = "storage" class MockTimer(TimerBase): def __init__(self, start_time: Optional[float] = None): self._lock = threading.Lock() self.reset(start_time=start_time) def reset(self, start_time: Optional[float] = None): if start_time is None: start_time = time.time() self._curr = start_time def time(self) -> float: return self._curr def advance(self, by: float): with self._lock: self._curr += by def realistic_sleep(self, amt: float): with self._lock: self._curr += amt + 0.001 class MockAsyncTimer: def __init__(self, start_time: Optional[float] = 0): self.reset(start_time=start_time) self._num_sleepers = 0 def reset(self, start_time: 0): self._curr = start_time def time(self) -> float: return self._curr async def sleep(self, amt: float): self._num_sleepers += 1 end = self._curr + amt # Give up the event loop while self._curr < end: await asyncio.sleep(0) self._num_sleepers -= 1 def advance(self, amt: float): self._curr += amt def num_sleepers(self): return self._num_sleepers class MockKVStore: def __init__(self): self.store = dict() def put(self, key: str, val: Any) -> bool: if not isinstance(key, str): raise TypeError("key must be a string, got: {}.".format(type(key))) self.store[key] = val return True def get(self, key: str) -> Any: if not isinstance(key, str): raise TypeError("key must be a string, got: {}.".format(type(key))) return self.store.get(key, None) def delete(self, key: str) -> bool: if not isinstance(key, str): raise TypeError("key must be a string, got: {}.".format(type(key))) if key in self.store: del self.store[key] return True return False class MockClusterNodeInfoCache: def __init__(self): self.alive_node_ids = set() self.total_resources_per_node = dict() self.available_resources_per_node = dict() self.draining_nodes = dict() self.node_labels = dict() def get_alive_node_ids(self): return self.alive_node_ids def get_draining_nodes(self): return self.draining_nodes def get_active_node_ids(self): return self.alive_node_ids - set(self.draining_nodes) def get_node_az(self, node_id): return None def get_available_resources_per_node(self): return self.available_resources_per_node def get_total_resources_per_node(self): return self.total_resources_per_node def add_node(self, node_id: str, resources: Dict = None, labels: Dict = None): self.alive_node_ids.add(node_id) self.total_resources_per_node[node_id] = deepcopy(resources) or {} self.available_resources_per_node[node_id] = deepcopy(resources) or {} self.node_labels[node_id] = labels or {} def set_available_resources_per_node(self, node_id: str, resources: Dict): self.available_resources_per_node[node_id] = deepcopy(resources) def get_node_labels(self, node_id: str): return self.node_labels.get(node_id, {}) class FakeRemoteFunction: def remote(self): pass class MockActorHandle: def __init__(self, **kwargs): self._options = kwargs self._actor_id = "fake_id" self.initialize_and_get_metadata_called = False self.is_allocated_called = False @property def initialize_and_get_metadata(self): self.initialize_and_get_metadata_called = True # return a mock object so that we can call `remote()` on it. return FakeRemoteFunction() @property def is_allocated(self): self.is_allocated_called = True return FakeRemoteFunction() class MockActorClass: def __init__(self): self._init_args = () self._options = dict() def options(self, **kwargs): res = copy(self) for k, v in kwargs.items(): res._options[k] = v return res def remote(self, *args) -> MockActorHandle: return MockActorHandle(init_args=args, **self._options) class MockPlacementGroup: def __init__(self, request: CreatePlacementGroupRequest): self._bundles = request.bundles self._strategy = request.strategy self._soft_target_node_id = request.target_node_id self._name = request.name self._lifetime = "detached" class MockDeploymentHandle: def __init__(self, deployment_name: str, app_name: str = SERVE_DEFAULT_APP_NAME): self._deployment_name = deployment_name self._app_name = app_name self._protocol = RequestProtocol.UNDEFINED self._running_replicas_populated = False self._initialized = False def is_initialized(self): return self._initialized def _init(self): if self._initialized: raise RuntimeError("already initialized") self._initialized = True def options(self, *args, **kwargs): return self def __eq__(self, dep: Tuple[str]): other_deployment_name, other_app_name = dep return ( self._deployment_name == other_deployment_name and self._app_name == other_app_name ) def _set_request_protocol(self, protocol: RequestProtocol): self._protocol = protocol def _get_or_create_router(self): pass def running_replicas_populated(self) -> bool: return self._running_replicas_populated def set_running_replicas_populated(self, val: bool): self._running_replicas_populated = val @serve.deployment class GetPID: def __call__(self): return os.getpid() get_pid_entrypoint = GetPID.bind() def check_ray_stopped(): try: requests.get("http://localhost:8265/api/ray/version") return False except Exception: return True def check_ray_started(): return requests.get("http://localhost:8265/api/ray/version").status_code == 200 def check_deployment_status( name: str, expected_status: DeploymentStatus, app_name=SERVE_DEFAULT_APP_NAME ) -> bool: app_status = serve.status().applications[app_name] assert app_status.deployments[name].status == expected_status return True def get_num_alive_replicas( deployment_name: str, app_name: str = SERVE_DEFAULT_APP_NAME ) -> int: """Get the replicas currently running for the given deployment.""" dep_id = DeploymentID(name=deployment_name, app_name=app_name) actors = list_actors( filters=[ ("class_name", "=", dep_id.to_replica_actor_class_name()), ("state", "=", "ALIVE"), ] ) return len(actors) def check_num_replicas_gte( name: str, target: int, app_name: str = SERVE_DEFAULT_APP_NAME ) -> int: """Check if num replicas is >= target.""" assert get_num_alive_replicas(name, app_name) >= target return True def check_num_replicas_eq( name: str, target: int, app_name: str = SERVE_DEFAULT_APP_NAME, use_controller: bool = False, ) -> bool: """Check if num replicas is == target.""" if use_controller: dep = serve.status().applications[app_name].deployments[name] num_running_replicas = dep.replica_states.get(ReplicaState.RUNNING, 0) assert num_running_replicas == target else: assert get_num_alive_replicas(name, app_name) == target return True def check_num_replicas_lte( name: str, target: int, app_name: str = SERVE_DEFAULT_APP_NAME ) -> int: """Check if num replicas is <= target.""" assert get_num_alive_replicas(name, app_name) <= target return True def check_apps_running(apps: List): status = serve.status() for app_name in apps: assert status.applications[app_name].status == ApplicationStatus.RUNNING return True def check_replica_counts( controller: ActorHandle, deployment_id: DeploymentID, total: Optional[int] = None, by_state: Optional[List[Tuple[ReplicaState, int, Callable]]] = None, ): """Uses _dump_replica_states_for_testing to check replica counts. Args: controller: A handle to the Serve controller. deployment_id: The deployment to check replica counts for. total: The total number of expected replicas for the deployment. by_state: A list of tuples of the form (replica state, number of replicas, filter function). Used for more fine grained checks. """ replicas = ray.get( controller._dump_replica_states_for_testing.remote(deployment_id) ) if total is not None: replica_counts = { state: len(replicas.get([state])) for state in ALL_REPLICA_STATES if replicas.get([state]) } assert replicas.count() == total, replica_counts if by_state is not None: for state, count, check in by_state: assert isinstance(state, ReplicaState) assert isinstance(count, int) and count >= 0 if check: filtered = {r for r in replicas.get(states=[state]) if check(r)} curr_count = len(filtered) else: curr_count = replicas.count(states=[state]) msg = f"Expected {count} for state {state} but got {curr_count}." assert curr_count == count, msg return True @ray.remote(name=STORAGE_ACTOR_NAME, namespace=SERVE_NAMESPACE, num_cpus=0) class TelemetryStorage: def __init__(self): self.reports_received = 0 self.current_report = dict() def store_report(self, report: Dict) -> None: self.reports_received += 1 self.current_report = report def get_report(self) -> Dict: return self.current_report def get_reports_received(self) -> int: return self.reports_received @serve.deployment(ray_actor_options={"num_cpus": 0}) class TelemetryReceiver: def __init__(self): self.storage = ray.get_actor(name=STORAGE_ACTOR_NAME, namespace=SERVE_NAMESPACE) async def __call__(self, request: Request) -> bool: report = await request.json() ray.get(self.storage.store_report.remote(report)) return True receiver_app = TelemetryReceiver.bind() def start_telemetry_app(): """Start a telemetry Serve app. Ray should be initialized before calling this method. NOTE: If you're running the TelemetryReceiver Serve app to check telemetry, remember that the receiver itself is counted in the telemetry. E.g. if you deploy a Serve app other than the receiver, the number of apps in the cluster is 2- not 1– since the receiver is also running. Returns a handle to a TelemetryStorage actor. You can use this actor to access the latest telemetry reports. """ storage = TelemetryStorage.remote() serve.run(receiver_app, name="telemetry", route_prefix=TELEMETRY_ROUTE_PREFIX) return storage def check_telemetry( tag: ServeUsageTag, expected: Any, storage_actor_name: str = STORAGE_ACTOR_NAME ): storage_handle = ray.get_actor(storage_actor_name, namespace=SERVE_NAMESPACE) report = ray.get(storage_handle.get_report.remote()) print(report["extra_usage_tags"]) assert tag.get_value_from_report(report) == expected return True def ping_grpc_list_applications(channel, app_names, test_draining=False): import pytest stub = serve_pb2_grpc.RayServeAPIServiceStub(channel) request = serve_pb2.ListApplicationsRequest() if test_draining: with pytest.raises(grpc.RpcError) as exception_info: _, _ = stub.ListApplications.with_call(request=request) rpc_error = exception_info.value assert rpc_error.code() == grpc.StatusCode.UNAVAILABLE assert rpc_error.details() == DRAINING_MESSAGE else: response, call = stub.ListApplications.with_call(request=request) assert call.code() == grpc.StatusCode.OK assert response.application_names == app_names return True def ping_grpc_healthz(channel, test_draining=False): import pytest stub = serve_pb2_grpc.RayServeAPIServiceStub(channel) request = serve_pb2.HealthzRequest() if test_draining: with pytest.raises(grpc.RpcError) as exception_info: _, _ = stub.Healthz.with_call(request=request) rpc_error = exception_info.value assert rpc_error.code() == grpc.StatusCode.UNAVAILABLE assert rpc_error.details() == DRAINING_MESSAGE else: response, call = stub.Healthz.with_call(request=request) assert call.code() == grpc.StatusCode.OK assert response.message == "success" def ping_grpc_call_method(channel, app_name, test_not_found=False): import pytest stub = serve_pb2_grpc.UserDefinedServiceStub(channel) request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar") metadata = (("application", app_name),) if test_not_found: with pytest.raises(grpc.RpcError) as exception_info: _, _ = stub.__call__.with_call(request=request, metadata=metadata) rpc_error = exception_info.value assert rpc_error.code() == grpc.StatusCode.NOT_FOUND, rpc_error.code() assert f"Application '{app_name}' not found." in rpc_error.details() else: response, call = stub.__call__.with_call(request=request, metadata=metadata) assert call.code() == grpc.StatusCode.OK, call.code() assert response.greeting == "Hello foo from bar", response.greeting def ping_grpc_another_method(channel, app_name): stub = serve_pb2_grpc.UserDefinedServiceStub(channel) request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar") metadata = (("application", app_name),) response = stub.Method1(request=request, metadata=metadata) assert response.greeting == "Hello foo from method1" def ping_grpc_model_multiplexing(channel, app_name): stub = serve_pb2_grpc.UserDefinedServiceStub(channel) request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar") multiplexed_model_id = "999" metadata = ( ("application", app_name), ("multiplexed_model_id", multiplexed_model_id), ) response = stub.Method2(request=request, metadata=metadata) assert ( response.greeting == f"Method2 called model, loading model: {multiplexed_model_id}" ) def ping_grpc_streaming(channel, app_name): stub = serve_pb2_grpc.UserDefinedServiceStub(channel) request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar") metadata = (("application", app_name),) responses = stub.Streaming(request=request, metadata=metadata) for idx, response in enumerate(responses): assert response.greeting == f"{idx}: Hello foo from bar" def ping_fruit_stand(channel, app_name): stub = serve_pb2_grpc.FruitServiceStub(channel) request = serve_pb2.FruitAmounts(orange=4, apple=8) metadata = (("application", app_name),) response = stub.FruitStand(request=request, metadata=metadata) assert response.costs == 32 @asynccontextmanager async def send_signal_on_cancellation(signal_actor: ActorHandle): cancelled = False try: yield await asyncio.sleep(100) except asyncio.CancelledError: cancelled = True # Clear the context var to avoid Ray recursively cancelling this method call. ray._raylet.async_task_id.set(None) await signal_actor.send.remote() if not cancelled: raise RuntimeError( "CancelledError wasn't raised during `send_signal_on_cancellation` block" ) class FakeGrpcContext: def __init__(self): self._auth_context = {"key": "value"} self._invocation_metadata = [("key", "value")] self._peer = "peer" self._peer_identities = b"peer_identities" self._peer_identity_key = "peer_identity_key" self._code = None self._details = None self._trailing_metadata = [] self._invocation_metadata = [] def auth_context(self): return self._auth_context def code(self): return self._code def details(self): return self._details def peer(self): return self._peer def peer_identities(self): return self._peer_identities def peer_identity_key(self): return self._peer_identity_key def trailing_metadata(self): return self._trailing_metadata def set_code(self, code): self._code = code def set_details(self, details): self._details = details def set_trailing_metadata(self, trailing_metadata): self._trailing_metadata = trailing_metadata def invocation_metadata(self): return self._invocation_metadata class FakeGauge: def __init__(self, name: str = None, tag_keys: Tuple[str] = None): self.name = name self.values = dict() self.tags = tag_keys or () self.default_tags = dict() def set_default_tags(self, tags: Dict[str, str]): for key, tag in tags.items(): assert key in self.tags self.default_tags[key] = tag def set(self, value: Union[int, float], tags: Dict[str, str] = None): merged_tags = self.default_tags.copy() merged_tags.update(tags or {}) assert set(merged_tags.keys()) == set(self.tags) d = self.values for tag in self.tags[:-1]: tag_value = merged_tags[tag] if tag_value not in d: d[tag_value] = dict() d = d[tag_value] d[merged_tags[self.tags[-1]]] = value def get_value(self, tags: Dict[str, str]): value = self.values for tag in self.tags: tag_value = tags[tag] value = value.get(tag_value) if value is None: return return value class FakeCounter: def __init__(self, name: str = None, tag_keys: Tuple[str] = None): self.name = name self.counts = dict() self.tags = tag_keys or () self.default_tags = dict() def set_default_tags(self, tags: Dict[str, str]): for key, tag in tags.items(): assert key in self.tags self.default_tags[key] = tag def inc(self, value: Union[int, float] = 1.0, tags: Dict[str, str] = None): merged_tags = self.default_tags.copy() merged_tags.update(tags or {}) assert set(merged_tags.keys()) == set(self.tags) d = self.counts for tag in self.tags[:-1]: tag_value = merged_tags[tag] if tag_value not in d: d[tag_value] = dict() d = d[tag_value] key = merged_tags[self.tags[-1]] d[key] = d.get(key, 0) + value def get_count(self, tags: Dict[str, str]) -> int: value = self.counts for tag in self.tags: tag_value = tags[tag] value = value.get(tag_value) if value is None: return return value def get_tags(self): return self.tags @ray.remote def get_node_id(): return ray.get_runtime_context().get_node_id() def check_num_alive_nodes(target: int): alive_nodes = [node for node in ray.nodes() if node["Alive"]] assert len(alive_nodes) == target return True def get_deployment_details( deployment_name: str, app_name: str = SERVE_DEFAULT_APP_NAME, _client: ServeControllerClient = None, ): client = _client or _get_global_client() details = client.get_serve_details() return details["applications"][app_name]["deployments"][deployment_name] @ray.remote class Counter: def __init__(self, target: int): self.count = 0 self.target = target self.ready_event = asyncio.Event() def inc(self): self.count += 1 if self.count == self.target: self.ready_event.set() async def wait(self): await self.ready_event.wait() def tlog(s: str, level: str = "INFO"): """Convenient logging method for testing.""" now = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] print(f"[{level}] {now} {s}") def check_target_groups_ready( client: ServeControllerClient, app_name: str, protocol: Union[str, RequestProtocol] = RequestProtocol.HTTP, ): """Wait for target groups to be ready for the given app and protocol. Target groups are ready when there are at least one target for the given protocol. And it's possible that target groups are not ready immediately. An example is when the controller is recovering from a crash. """ target_groups = ray.get(client._controller.get_target_groups.remote(app_name)) target_groups = [ target_group for target_group in target_groups if target_group.protocol == protocol ] all_targets = [ target for target_group in target_groups for target in target_group.targets ] return len(all_targets) > 0 def get_application_urls( protocol: Union[str, RequestProtocol] = RequestProtocol.HTTP, app_name: str = SERVE_DEFAULT_APP_NAME, use_localhost: bool = True, is_websocket: bool = False, exclude_route_prefix: bool = False, from_proxy_manager: bool = False, ) -> List[str]: """Get the URL of the application. Args: protocol: The protocol to use for the application. app_name: The name of the application. use_localhost: Whether to use localhost instead of the IP address. Set to True if Serve deployments are not exposed publicly or for low latency benchmarking. is_websocket: Whether the url should be served as a websocket. exclude_route_prefix: The route prefix to exclude from the application. from_proxy_manager: Whether the caller is a proxy manager. Returns: The URLs of the application. """ client = _get_global_client(_health_check_controller=True) serve_details = client.get_serve_details() assert ( app_name in serve_details["applications"] ), f"App {app_name} not found in serve details. Use this method only when the app is known to be running." route_prefix = serve_details["applications"][app_name]["route_prefix"] # route_prefix is set to None when route_prefix value is specifically set to None # in the config used to deploy the app. if exclude_route_prefix or route_prefix is None: route_prefix = "" if isinstance(protocol, str): protocol = RequestProtocol(protocol) target_groups: List[TargetGroup] = ray.get( client._controller.get_target_groups.remote(app_name, from_proxy_manager) ) target_groups = [ target_group for target_group in target_groups if target_group.protocol == protocol ] if len(target_groups) == 0: raise ValueError( f"No target group found for app {app_name} with protocol {protocol} and route prefix {route_prefix}" ) urls = [] for target_group in target_groups: for target in target_group.targets: ip = "localhost" if use_localhost else target.ip if protocol == RequestProtocol.HTTP: scheme = "ws" if is_websocket else "http" url = f"{scheme}://{build_address(ip, target.port)}{route_prefix}" elif protocol == RequestProtocol.GRPC: if is_websocket: raise ValueError( "is_websocket=True is not supported with gRPC protocol." ) url = build_address(ip, target.port) else: raise ValueError(f"Unsupported protocol: {protocol}") url = url.rstrip("/") urls.append(url) return urls def get_application_url( protocol: Union[str, RequestProtocol] = RequestProtocol.HTTP, app_name: str = SERVE_DEFAULT_APP_NAME, use_localhost: bool = True, is_websocket: bool = False, exclude_route_prefix: bool = False, from_proxy_manager: bool = False, ) -> str: """Get the URL of the application. Args: protocol: The protocol to use for the application. app_name: The name of the application. use_localhost: Whether to use localhost instead of the IP address. Set to True if Serve deployments are not exposed publicly or for low latency benchmarking. is_websocket: Whether the url should be served as a websocket. exclude_route_prefix: The route prefix to exclude from the application. from_proxy_manager: Whether the caller is a proxy manager. Returns: The URL of the application. If there are multiple URLs, a random one is returned. """ return random.choice( get_application_urls( protocol, app_name, use_localhost, is_websocket, exclude_route_prefix, from_proxy_manager, ) ) def check_running(app_name: str = SERVE_DEFAULT_APP_NAME): assert serve.status().applications[app_name].status == ApplicationStatus.RUNNING return True def request_with_retries(timeout=30, app_name=SERVE_DEFAULT_APP_NAME): result_holder = {"resp": None} def _attempt() -> bool: try: url = get_application_url("HTTP", app_name=app_name) result_holder["resp"] = httpx.get(url, timeout=timeout) return True except (httpx.RequestError, IndexError): return False try: wait_for_condition(_attempt, timeout=timeout) return result_holder["resp"] except RuntimeError as e: # Preserve previous API by raising TimeoutError on expiry raise TimeoutError from e # Metrics test utilities TEST_METRICS_EXPORT_PORT = 9999 def get_metric_float( metric: str, expected_tags: Optional[Dict[str, str]], timeseries: Optional[PrometheusTimeseries] = None, ) -> float: """Gets the float value of metric. If tags is specified, searched for metric with matching tags. Returns -1 if the metric isn't available. """ if timeseries is None: timeseries = PrometheusTimeseries() samples = fetch_prometheus_metric_timeseries( [f"localhost:{TEST_METRICS_EXPORT_PORT}"], timeseries ).get(metric, []) for sample in samples: if expected_tags.items() <= sample.labels.items(): return sample.value return -1 def check_metric_float_eq( metric: str, expected: float, expected_tags: Optional[Dict[str, str]], timeseries: Optional[PrometheusTimeseries] = None, ) -> bool: """Check if a metric's float value equals the expected value.""" metric_value = get_metric_float(metric, expected_tags, timeseries) assert float(metric_value) == expected return True def get_metric_dictionaries( name: str, timeout: float = 20, timeseries: Optional[PrometheusTimeseries] = None ) -> List[Dict]: if timeseries is None: timeseries = PrometheusTimeseries() def metric_available() -> bool: assert name in fetch_prometheus_metric_timeseries( [f"localhost:{TEST_METRICS_EXPORT_PORT}"], timeseries ) return True wait_for_condition(metric_available, retry_interval_ms=1000, timeout=timeout) metric_dicts = [] for sample in timeseries.metric_samples.values(): if sample.name == name: metric_dicts.append(sample.labels) return metric_dicts