| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930 |
- import asyncio
- import datetime
- import os
- import random
- import threading
- import time
- from contextlib import asynccontextmanager
- from copy import copy, deepcopy
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
- import grpc
- import httpx
- import requests
- from starlette.requests import Request
- import ray
- from ray import serve
- from ray._common.network_utils import build_address
- from ray._common.test_utils import wait_for_condition
- from ray._common.utils import TimerBase
- from ray._private.test_utils import (
- PrometheusTimeseries,
- fetch_prometheus_metric_timeseries,
- )
- from ray.actor import ActorHandle
- from ray.serve._private.client import ServeControllerClient
- from ray.serve._private.common import (
- CreatePlacementGroupRequest,
- DeploymentID,
- DeploymentStatus,
- RequestProtocol,
- )
- from ray.serve._private.constants import (
- SERVE_DEFAULT_APP_NAME,
- SERVE_NAMESPACE,
- )
- from ray.serve._private.deployment_state import ALL_REPLICA_STATES, ReplicaState
- from ray.serve._private.proxy import DRAINING_MESSAGE
- from ray.serve._private.usage import ServeUsageTag
- from ray.serve.context import _get_global_client
- from ray.serve.generated import serve_pb2, serve_pb2_grpc
- from ray.serve.schema import ApplicationStatus, TargetGroup
- from ray.util.state import list_actors
- TELEMETRY_ROUTE_PREFIX = "/telemetry"
- STORAGE_ACTOR_NAME = "storage"
- class MockTimer(TimerBase):
- def __init__(self, start_time: Optional[float] = None):
- self._lock = threading.Lock()
- self.reset(start_time=start_time)
- def reset(self, start_time: Optional[float] = None):
- if start_time is None:
- start_time = time.time()
- self._curr = start_time
- def time(self) -> float:
- return self._curr
- def advance(self, by: float):
- with self._lock:
- self._curr += by
- def realistic_sleep(self, amt: float):
- with self._lock:
- self._curr += amt + 0.001
- class MockAsyncTimer:
- def __init__(self, start_time: Optional[float] = 0):
- self.reset(start_time=start_time)
- self._num_sleepers = 0
- def reset(self, start_time: 0):
- self._curr = start_time
- def time(self) -> float:
- return self._curr
- async def sleep(self, amt: float):
- self._num_sleepers += 1
- end = self._curr + amt
- # Give up the event loop
- while self._curr < end:
- await asyncio.sleep(0)
- self._num_sleepers -= 1
- def advance(self, amt: float):
- self._curr += amt
- def num_sleepers(self):
- return self._num_sleepers
- class MockKVStore:
- def __init__(self):
- self.store = dict()
- def put(self, key: str, val: Any) -> bool:
- if not isinstance(key, str):
- raise TypeError("key must be a string, got: {}.".format(type(key)))
- self.store[key] = val
- return True
- def get(self, key: str) -> Any:
- if not isinstance(key, str):
- raise TypeError("key must be a string, got: {}.".format(type(key)))
- return self.store.get(key, None)
- def delete(self, key: str) -> bool:
- if not isinstance(key, str):
- raise TypeError("key must be a string, got: {}.".format(type(key)))
- if key in self.store:
- del self.store[key]
- return True
- return False
- class MockClusterNodeInfoCache:
- def __init__(self):
- self.alive_node_ids = set()
- self.total_resources_per_node = dict()
- self.available_resources_per_node = dict()
- self.draining_nodes = dict()
- self.node_labels = dict()
- def get_alive_node_ids(self):
- return self.alive_node_ids
- def get_draining_nodes(self):
- return self.draining_nodes
- def get_active_node_ids(self):
- return self.alive_node_ids - set(self.draining_nodes)
- def get_node_az(self, node_id):
- return None
- def get_available_resources_per_node(self):
- return self.available_resources_per_node
- def get_total_resources_per_node(self):
- return self.total_resources_per_node
- def add_node(self, node_id: str, resources: Dict = None, labels: Dict = None):
- self.alive_node_ids.add(node_id)
- self.total_resources_per_node[node_id] = deepcopy(resources) or {}
- self.available_resources_per_node[node_id] = deepcopy(resources) or {}
- self.node_labels[node_id] = labels or {}
- def set_available_resources_per_node(self, node_id: str, resources: Dict):
- self.available_resources_per_node[node_id] = deepcopy(resources)
- def get_node_labels(self, node_id: str):
- return self.node_labels.get(node_id, {})
- class FakeRemoteFunction:
- def remote(self):
- pass
- class MockActorHandle:
- def __init__(self, **kwargs):
- self._options = kwargs
- self._actor_id = "fake_id"
- self.initialize_and_get_metadata_called = False
- self.is_allocated_called = False
- @property
- def initialize_and_get_metadata(self):
- self.initialize_and_get_metadata_called = True
- # return a mock object so that we can call `remote()` on it.
- return FakeRemoteFunction()
- @property
- def is_allocated(self):
- self.is_allocated_called = True
- return FakeRemoteFunction()
- class MockActorClass:
- def __init__(self):
- self._init_args = ()
- self._options = dict()
- def options(self, **kwargs):
- res = copy(self)
- for k, v in kwargs.items():
- res._options[k] = v
- return res
- def remote(self, *args) -> MockActorHandle:
- return MockActorHandle(init_args=args, **self._options)
- class MockPlacementGroup:
- def __init__(self, request: CreatePlacementGroupRequest):
- self._bundles = request.bundles
- self._strategy = request.strategy
- self._soft_target_node_id = request.target_node_id
- self._name = request.name
- self._lifetime = "detached"
- class MockDeploymentHandle:
- def __init__(self, deployment_name: str, app_name: str = SERVE_DEFAULT_APP_NAME):
- self._deployment_name = deployment_name
- self._app_name = app_name
- self._protocol = RequestProtocol.UNDEFINED
- self._running_replicas_populated = False
- self._initialized = False
- def is_initialized(self):
- return self._initialized
- def _init(self):
- if self._initialized:
- raise RuntimeError("already initialized")
- self._initialized = True
- def options(self, *args, **kwargs):
- return self
- def __eq__(self, dep: Tuple[str]):
- other_deployment_name, other_app_name = dep
- return (
- self._deployment_name == other_deployment_name
- and self._app_name == other_app_name
- )
- def _set_request_protocol(self, protocol: RequestProtocol):
- self._protocol = protocol
- def _get_or_create_router(self):
- pass
- def running_replicas_populated(self) -> bool:
- return self._running_replicas_populated
- def set_running_replicas_populated(self, val: bool):
- self._running_replicas_populated = val
- @serve.deployment
- class GetPID:
- def __call__(self):
- return os.getpid()
- get_pid_entrypoint = GetPID.bind()
- def check_ray_stopped():
- try:
- requests.get("http://localhost:8265/api/ray/version")
- return False
- except Exception:
- return True
- def check_ray_started():
- return requests.get("http://localhost:8265/api/ray/version").status_code == 200
- def check_deployment_status(
- name: str, expected_status: DeploymentStatus, app_name=SERVE_DEFAULT_APP_NAME
- ) -> bool:
- app_status = serve.status().applications[app_name]
- assert app_status.deployments[name].status == expected_status
- return True
- def get_num_alive_replicas(
- deployment_name: str, app_name: str = SERVE_DEFAULT_APP_NAME
- ) -> int:
- """Get the replicas currently running for the given deployment."""
- dep_id = DeploymentID(name=deployment_name, app_name=app_name)
- actors = list_actors(
- filters=[
- ("class_name", "=", dep_id.to_replica_actor_class_name()),
- ("state", "=", "ALIVE"),
- ]
- )
- return len(actors)
- def check_num_replicas_gte(
- name: str, target: int, app_name: str = SERVE_DEFAULT_APP_NAME
- ) -> int:
- """Check if num replicas is >= target."""
- assert get_num_alive_replicas(name, app_name) >= target
- return True
- def check_num_replicas_eq(
- name: str,
- target: int,
- app_name: str = SERVE_DEFAULT_APP_NAME,
- use_controller: bool = False,
- ) -> bool:
- """Check if num replicas is == target."""
- if use_controller:
- dep = serve.status().applications[app_name].deployments[name]
- num_running_replicas = dep.replica_states.get(ReplicaState.RUNNING, 0)
- assert num_running_replicas == target
- else:
- assert get_num_alive_replicas(name, app_name) == target
- return True
- def check_num_replicas_lte(
- name: str, target: int, app_name: str = SERVE_DEFAULT_APP_NAME
- ) -> int:
- """Check if num replicas is <= target."""
- assert get_num_alive_replicas(name, app_name) <= target
- return True
- def check_apps_running(apps: List):
- status = serve.status()
- for app_name in apps:
- assert status.applications[app_name].status == ApplicationStatus.RUNNING
- return True
- def check_replica_counts(
- controller: ActorHandle,
- deployment_id: DeploymentID,
- total: Optional[int] = None,
- by_state: Optional[List[Tuple[ReplicaState, int, Callable]]] = None,
- ):
- """Uses _dump_replica_states_for_testing to check replica counts.
- Args:
- controller: A handle to the Serve controller.
- deployment_id: The deployment to check replica counts for.
- total: The total number of expected replicas for the deployment.
- by_state: A list of tuples of the form
- (replica state, number of replicas, filter function).
- Used for more fine grained checks.
- """
- replicas = ray.get(
- controller._dump_replica_states_for_testing.remote(deployment_id)
- )
- if total is not None:
- replica_counts = {
- state: len(replicas.get([state]))
- for state in ALL_REPLICA_STATES
- if replicas.get([state])
- }
- assert replicas.count() == total, replica_counts
- if by_state is not None:
- for state, count, check in by_state:
- assert isinstance(state, ReplicaState)
- assert isinstance(count, int) and count >= 0
- if check:
- filtered = {r for r in replicas.get(states=[state]) if check(r)}
- curr_count = len(filtered)
- else:
- curr_count = replicas.count(states=[state])
- msg = f"Expected {count} for state {state} but got {curr_count}."
- assert curr_count == count, msg
- return True
- @ray.remote(name=STORAGE_ACTOR_NAME, namespace=SERVE_NAMESPACE, num_cpus=0)
- class TelemetryStorage:
- def __init__(self):
- self.reports_received = 0
- self.current_report = dict()
- def store_report(self, report: Dict) -> None:
- self.reports_received += 1
- self.current_report = report
- def get_report(self) -> Dict:
- return self.current_report
- def get_reports_received(self) -> int:
- return self.reports_received
- @serve.deployment(ray_actor_options={"num_cpus": 0})
- class TelemetryReceiver:
- def __init__(self):
- self.storage = ray.get_actor(name=STORAGE_ACTOR_NAME, namespace=SERVE_NAMESPACE)
- async def __call__(self, request: Request) -> bool:
- report = await request.json()
- ray.get(self.storage.store_report.remote(report))
- return True
- receiver_app = TelemetryReceiver.bind()
- def start_telemetry_app():
- """Start a telemetry Serve app.
- Ray should be initialized before calling this method.
- NOTE: If you're running the TelemetryReceiver Serve app to check telemetry,
- remember that the receiver itself is counted in the telemetry. E.g. if you
- deploy a Serve app other than the receiver, the number of apps in the
- cluster is 2- not 1– since the receiver is also running.
- Returns a handle to a TelemetryStorage actor. You can use this actor
- to access the latest telemetry reports.
- """
- storage = TelemetryStorage.remote()
- serve.run(receiver_app, name="telemetry", route_prefix=TELEMETRY_ROUTE_PREFIX)
- return storage
- def check_telemetry(
- tag: ServeUsageTag, expected: Any, storage_actor_name: str = STORAGE_ACTOR_NAME
- ):
- storage_handle = ray.get_actor(storage_actor_name, namespace=SERVE_NAMESPACE)
- report = ray.get(storage_handle.get_report.remote())
- print(report["extra_usage_tags"])
- assert tag.get_value_from_report(report) == expected
- return True
- def ping_grpc_list_applications(channel, app_names, test_draining=False):
- import pytest
- stub = serve_pb2_grpc.RayServeAPIServiceStub(channel)
- request = serve_pb2.ListApplicationsRequest()
- if test_draining:
- with pytest.raises(grpc.RpcError) as exception_info:
- _, _ = stub.ListApplications.with_call(request=request)
- rpc_error = exception_info.value
- assert rpc_error.code() == grpc.StatusCode.UNAVAILABLE
- assert rpc_error.details() == DRAINING_MESSAGE
- else:
- response, call = stub.ListApplications.with_call(request=request)
- assert call.code() == grpc.StatusCode.OK
- assert response.application_names == app_names
- return True
- def ping_grpc_healthz(channel, test_draining=False):
- import pytest
- stub = serve_pb2_grpc.RayServeAPIServiceStub(channel)
- request = serve_pb2.HealthzRequest()
- if test_draining:
- with pytest.raises(grpc.RpcError) as exception_info:
- _, _ = stub.Healthz.with_call(request=request)
- rpc_error = exception_info.value
- assert rpc_error.code() == grpc.StatusCode.UNAVAILABLE
- assert rpc_error.details() == DRAINING_MESSAGE
- else:
- response, call = stub.Healthz.with_call(request=request)
- assert call.code() == grpc.StatusCode.OK
- assert response.message == "success"
- def ping_grpc_call_method(channel, app_name, test_not_found=False):
- import pytest
- stub = serve_pb2_grpc.UserDefinedServiceStub(channel)
- request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar")
- metadata = (("application", app_name),)
- if test_not_found:
- with pytest.raises(grpc.RpcError) as exception_info:
- _, _ = stub.__call__.with_call(request=request, metadata=metadata)
- rpc_error = exception_info.value
- assert rpc_error.code() == grpc.StatusCode.NOT_FOUND, rpc_error.code()
- assert f"Application '{app_name}' not found." in rpc_error.details()
- else:
- response, call = stub.__call__.with_call(request=request, metadata=metadata)
- assert call.code() == grpc.StatusCode.OK, call.code()
- assert response.greeting == "Hello foo from bar", response.greeting
- def ping_grpc_another_method(channel, app_name):
- stub = serve_pb2_grpc.UserDefinedServiceStub(channel)
- request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar")
- metadata = (("application", app_name),)
- response = stub.Method1(request=request, metadata=metadata)
- assert response.greeting == "Hello foo from method1"
- def ping_grpc_model_multiplexing(channel, app_name):
- stub = serve_pb2_grpc.UserDefinedServiceStub(channel)
- request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar")
- multiplexed_model_id = "999"
- metadata = (
- ("application", app_name),
- ("multiplexed_model_id", multiplexed_model_id),
- )
- response = stub.Method2(request=request, metadata=metadata)
- assert (
- response.greeting
- == f"Method2 called model, loading model: {multiplexed_model_id}"
- )
- def ping_grpc_streaming(channel, app_name):
- stub = serve_pb2_grpc.UserDefinedServiceStub(channel)
- request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar")
- metadata = (("application", app_name),)
- responses = stub.Streaming(request=request, metadata=metadata)
- for idx, response in enumerate(responses):
- assert response.greeting == f"{idx}: Hello foo from bar"
- def ping_fruit_stand(channel, app_name):
- stub = serve_pb2_grpc.FruitServiceStub(channel)
- request = serve_pb2.FruitAmounts(orange=4, apple=8)
- metadata = (("application", app_name),)
- response = stub.FruitStand(request=request, metadata=metadata)
- assert response.costs == 32
- @asynccontextmanager
- async def send_signal_on_cancellation(signal_actor: ActorHandle):
- cancelled = False
- try:
- yield
- await asyncio.sleep(100)
- except asyncio.CancelledError:
- cancelled = True
- # Clear the context var to avoid Ray recursively cancelling this method call.
- ray._raylet.async_task_id.set(None)
- await signal_actor.send.remote()
- if not cancelled:
- raise RuntimeError(
- "CancelledError wasn't raised during `send_signal_on_cancellation` block"
- )
- class FakeGrpcContext:
- def __init__(self):
- self._auth_context = {"key": "value"}
- self._invocation_metadata = [("key", "value")]
- self._peer = "peer"
- self._peer_identities = b"peer_identities"
- self._peer_identity_key = "peer_identity_key"
- self._code = None
- self._details = None
- self._trailing_metadata = []
- self._invocation_metadata = []
- def auth_context(self):
- return self._auth_context
- def code(self):
- return self._code
- def details(self):
- return self._details
- def peer(self):
- return self._peer
- def peer_identities(self):
- return self._peer_identities
- def peer_identity_key(self):
- return self._peer_identity_key
- def trailing_metadata(self):
- return self._trailing_metadata
- def set_code(self, code):
- self._code = code
- def set_details(self, details):
- self._details = details
- def set_trailing_metadata(self, trailing_metadata):
- self._trailing_metadata = trailing_metadata
- def invocation_metadata(self):
- return self._invocation_metadata
- class FakeGauge:
- def __init__(self, name: str = None, tag_keys: Tuple[str] = None):
- self.name = name
- self.values = dict()
- self.tags = tag_keys or ()
- self.default_tags = dict()
- def set_default_tags(self, tags: Dict[str, str]):
- for key, tag in tags.items():
- assert key in self.tags
- self.default_tags[key] = tag
- def set(self, value: Union[int, float], tags: Dict[str, str] = None):
- merged_tags = self.default_tags.copy()
- merged_tags.update(tags or {})
- assert set(merged_tags.keys()) == set(self.tags)
- d = self.values
- for tag in self.tags[:-1]:
- tag_value = merged_tags[tag]
- if tag_value not in d:
- d[tag_value] = dict()
- d = d[tag_value]
- d[merged_tags[self.tags[-1]]] = value
- def get_value(self, tags: Dict[str, str]):
- value = self.values
- for tag in self.tags:
- tag_value = tags[tag]
- value = value.get(tag_value)
- if value is None:
- return
- return value
- class FakeCounter:
- def __init__(self, name: str = None, tag_keys: Tuple[str] = None):
- self.name = name
- self.counts = dict()
- self.tags = tag_keys or ()
- self.default_tags = dict()
- def set_default_tags(self, tags: Dict[str, str]):
- for key, tag in tags.items():
- assert key in self.tags
- self.default_tags[key] = tag
- def inc(self, value: Union[int, float] = 1.0, tags: Dict[str, str] = None):
- merged_tags = self.default_tags.copy()
- merged_tags.update(tags or {})
- assert set(merged_tags.keys()) == set(self.tags)
- d = self.counts
- for tag in self.tags[:-1]:
- tag_value = merged_tags[tag]
- if tag_value not in d:
- d[tag_value] = dict()
- d = d[tag_value]
- key = merged_tags[self.tags[-1]]
- d[key] = d.get(key, 0) + value
- def get_count(self, tags: Dict[str, str]) -> int:
- value = self.counts
- for tag in self.tags:
- tag_value = tags[tag]
- value = value.get(tag_value)
- if value is None:
- return
- return value
- def get_tags(self):
- return self.tags
- @ray.remote
- def get_node_id():
- return ray.get_runtime_context().get_node_id()
- def check_num_alive_nodes(target: int):
- alive_nodes = [node for node in ray.nodes() if node["Alive"]]
- assert len(alive_nodes) == target
- return True
- def get_deployment_details(
- deployment_name: str,
- app_name: str = SERVE_DEFAULT_APP_NAME,
- _client: ServeControllerClient = None,
- ):
- client = _client or _get_global_client()
- details = client.get_serve_details()
- return details["applications"][app_name]["deployments"][deployment_name]
- @ray.remote
- class Counter:
- def __init__(self, target: int):
- self.count = 0
- self.target = target
- self.ready_event = asyncio.Event()
- def inc(self):
- self.count += 1
- if self.count == self.target:
- self.ready_event.set()
- async def wait(self):
- await self.ready_event.wait()
- def tlog(s: str, level: str = "INFO"):
- """Convenient logging method for testing."""
- now = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3]
- print(f"[{level}] {now} {s}")
- def check_target_groups_ready(
- client: ServeControllerClient,
- app_name: str,
- protocol: Union[str, RequestProtocol] = RequestProtocol.HTTP,
- ):
- """Wait for target groups to be ready for the given app and protocol.
- Target groups are ready when there are at least one target for the given protocol. And it's
- possible that target groups are not ready immediately. An example is when the controller
- is recovering from a crash.
- """
- target_groups = ray.get(client._controller.get_target_groups.remote(app_name))
- target_groups = [
- target_group
- for target_group in target_groups
- if target_group.protocol == protocol
- ]
- all_targets = [
- target for target_group in target_groups for target in target_group.targets
- ]
- return len(all_targets) > 0
- def get_application_urls(
- protocol: Union[str, RequestProtocol] = RequestProtocol.HTTP,
- app_name: str = SERVE_DEFAULT_APP_NAME,
- use_localhost: bool = True,
- is_websocket: bool = False,
- exclude_route_prefix: bool = False,
- from_proxy_manager: bool = False,
- ) -> List[str]:
- """Get the URL of the application.
- Args:
- protocol: The protocol to use for the application.
- app_name: The name of the application.
- use_localhost: Whether to use localhost instead of the IP address.
- Set to True if Serve deployments are not exposed publicly or
- for low latency benchmarking.
- is_websocket: Whether the url should be served as a websocket.
- exclude_route_prefix: The route prefix to exclude from the application.
- from_proxy_manager: Whether the caller is a proxy manager.
- Returns:
- The URLs of the application.
- """
- client = _get_global_client(_health_check_controller=True)
- serve_details = client.get_serve_details()
- assert (
- app_name in serve_details["applications"]
- ), f"App {app_name} not found in serve details. Use this method only when the app is known to be running."
- route_prefix = serve_details["applications"][app_name]["route_prefix"]
- # route_prefix is set to None when route_prefix value is specifically set to None
- # in the config used to deploy the app.
- if exclude_route_prefix or route_prefix is None:
- route_prefix = ""
- if isinstance(protocol, str):
- protocol = RequestProtocol(protocol)
- target_groups: List[TargetGroup] = ray.get(
- client._controller.get_target_groups.remote(app_name, from_proxy_manager)
- )
- target_groups = [
- target_group
- for target_group in target_groups
- if target_group.protocol == protocol
- ]
- if len(target_groups) == 0:
- raise ValueError(
- f"No target group found for app {app_name} with protocol {protocol} and route prefix {route_prefix}"
- )
- urls = []
- for target_group in target_groups:
- for target in target_group.targets:
- ip = "localhost" if use_localhost else target.ip
- if protocol == RequestProtocol.HTTP:
- scheme = "ws" if is_websocket else "http"
- url = f"{scheme}://{build_address(ip, target.port)}{route_prefix}"
- elif protocol == RequestProtocol.GRPC:
- if is_websocket:
- raise ValueError(
- "is_websocket=True is not supported with gRPC protocol."
- )
- url = build_address(ip, target.port)
- else:
- raise ValueError(f"Unsupported protocol: {protocol}")
- url = url.rstrip("/")
- urls.append(url)
- return urls
- def get_application_url(
- protocol: Union[str, RequestProtocol] = RequestProtocol.HTTP,
- app_name: str = SERVE_DEFAULT_APP_NAME,
- use_localhost: bool = True,
- is_websocket: bool = False,
- exclude_route_prefix: bool = False,
- from_proxy_manager: bool = False,
- ) -> str:
- """Get the URL of the application.
- Args:
- protocol: The protocol to use for the application.
- app_name: The name of the application.
- use_localhost: Whether to use localhost instead of the IP address.
- Set to True if Serve deployments are not exposed publicly or
- for low latency benchmarking.
- is_websocket: Whether the url should be served as a websocket.
- exclude_route_prefix: The route prefix to exclude from the application.
- from_proxy_manager: Whether the caller is a proxy manager.
- Returns:
- The URL of the application. If there are multiple URLs, a random one is returned.
- """
- return random.choice(
- get_application_urls(
- protocol,
- app_name,
- use_localhost,
- is_websocket,
- exclude_route_prefix,
- from_proxy_manager,
- )
- )
- def check_running(app_name: str = SERVE_DEFAULT_APP_NAME):
- assert serve.status().applications[app_name].status == ApplicationStatus.RUNNING
- return True
- def request_with_retries(timeout=30, app_name=SERVE_DEFAULT_APP_NAME):
- result_holder = {"resp": None}
- def _attempt() -> bool:
- try:
- url = get_application_url("HTTP", app_name=app_name)
- result_holder["resp"] = httpx.get(url, timeout=timeout)
- return True
- except (httpx.RequestError, IndexError):
- return False
- try:
- wait_for_condition(_attempt, timeout=timeout)
- return result_holder["resp"]
- except RuntimeError as e:
- # Preserve previous API by raising TimeoutError on expiry
- raise TimeoutError from e
- # Metrics test utilities
- TEST_METRICS_EXPORT_PORT = 9999
- def get_metric_float(
- metric: str,
- expected_tags: Optional[Dict[str, str]],
- timeseries: Optional[PrometheusTimeseries] = None,
- ) -> float:
- """Gets the float value of metric.
- If tags is specified, searched for metric with matching tags.
- Returns -1 if the metric isn't available.
- """
- if timeseries is None:
- timeseries = PrometheusTimeseries()
- samples = fetch_prometheus_metric_timeseries(
- [f"localhost:{TEST_METRICS_EXPORT_PORT}"], timeseries
- ).get(metric, [])
- for sample in samples:
- if expected_tags.items() <= sample.labels.items():
- return sample.value
- return -1
- def check_metric_float_eq(
- metric: str,
- expected: float,
- expected_tags: Optional[Dict[str, str]],
- timeseries: Optional[PrometheusTimeseries] = None,
- ) -> bool:
- """Check if a metric's float value equals the expected value."""
- metric_value = get_metric_float(metric, expected_tags, timeseries)
- assert float(metric_value) == expected
- return True
- def get_metric_dictionaries(
- name: str, timeout: float = 20, timeseries: Optional[PrometheusTimeseries] = None
- ) -> List[Dict]:
- if timeseries is None:
- timeseries = PrometheusTimeseries()
- def metric_available() -> bool:
- assert name in fetch_prometheus_metric_timeseries(
- [f"localhost:{TEST_METRICS_EXPORT_PORT}"], timeseries
- )
- return True
- wait_for_condition(metric_available, retry_interval_ms=1000, timeout=timeout)
- metric_dicts = []
- for sample in timeseries.metric_samples.values():
- if sample.name == name:
- metric_dicts.append(sample.labels)
- return metric_dicts
|