test_utils.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930
  1. import asyncio
  2. import datetime
  3. import os
  4. import random
  5. import threading
  6. import time
  7. from contextlib import asynccontextmanager
  8. from copy import copy, deepcopy
  9. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  10. import grpc
  11. import httpx
  12. import requests
  13. from starlette.requests import Request
  14. import ray
  15. from ray import serve
  16. from ray._common.network_utils import build_address
  17. from ray._common.test_utils import wait_for_condition
  18. from ray._common.utils import TimerBase
  19. from ray._private.test_utils import (
  20. PrometheusTimeseries,
  21. fetch_prometheus_metric_timeseries,
  22. )
  23. from ray.actor import ActorHandle
  24. from ray.serve._private.client import ServeControllerClient
  25. from ray.serve._private.common import (
  26. CreatePlacementGroupRequest,
  27. DeploymentID,
  28. DeploymentStatus,
  29. RequestProtocol,
  30. )
  31. from ray.serve._private.constants import (
  32. SERVE_DEFAULT_APP_NAME,
  33. SERVE_NAMESPACE,
  34. )
  35. from ray.serve._private.deployment_state import ALL_REPLICA_STATES, ReplicaState
  36. from ray.serve._private.proxy import DRAINING_MESSAGE
  37. from ray.serve._private.usage import ServeUsageTag
  38. from ray.serve.context import _get_global_client
  39. from ray.serve.generated import serve_pb2, serve_pb2_grpc
  40. from ray.serve.schema import ApplicationStatus, TargetGroup
  41. from ray.util.state import list_actors
  42. TELEMETRY_ROUTE_PREFIX = "/telemetry"
  43. STORAGE_ACTOR_NAME = "storage"
  44. class MockTimer(TimerBase):
  45. def __init__(self, start_time: Optional[float] = None):
  46. self._lock = threading.Lock()
  47. self.reset(start_time=start_time)
  48. def reset(self, start_time: Optional[float] = None):
  49. if start_time is None:
  50. start_time = time.time()
  51. self._curr = start_time
  52. def time(self) -> float:
  53. return self._curr
  54. def advance(self, by: float):
  55. with self._lock:
  56. self._curr += by
  57. def realistic_sleep(self, amt: float):
  58. with self._lock:
  59. self._curr += amt + 0.001
  60. class MockAsyncTimer:
  61. def __init__(self, start_time: Optional[float] = 0):
  62. self.reset(start_time=start_time)
  63. self._num_sleepers = 0
  64. def reset(self, start_time: 0):
  65. self._curr = start_time
  66. def time(self) -> float:
  67. return self._curr
  68. async def sleep(self, amt: float):
  69. self._num_sleepers += 1
  70. end = self._curr + amt
  71. # Give up the event loop
  72. while self._curr < end:
  73. await asyncio.sleep(0)
  74. self._num_sleepers -= 1
  75. def advance(self, amt: float):
  76. self._curr += amt
  77. def num_sleepers(self):
  78. return self._num_sleepers
  79. class MockKVStore:
  80. def __init__(self):
  81. self.store = dict()
  82. def put(self, key: str, val: Any) -> bool:
  83. if not isinstance(key, str):
  84. raise TypeError("key must be a string, got: {}.".format(type(key)))
  85. self.store[key] = val
  86. return True
  87. def get(self, key: str) -> Any:
  88. if not isinstance(key, str):
  89. raise TypeError("key must be a string, got: {}.".format(type(key)))
  90. return self.store.get(key, None)
  91. def delete(self, key: str) -> bool:
  92. if not isinstance(key, str):
  93. raise TypeError("key must be a string, got: {}.".format(type(key)))
  94. if key in self.store:
  95. del self.store[key]
  96. return True
  97. return False
  98. class MockClusterNodeInfoCache:
  99. def __init__(self):
  100. self.alive_node_ids = set()
  101. self.total_resources_per_node = dict()
  102. self.available_resources_per_node = dict()
  103. self.draining_nodes = dict()
  104. self.node_labels = dict()
  105. def get_alive_node_ids(self):
  106. return self.alive_node_ids
  107. def get_draining_nodes(self):
  108. return self.draining_nodes
  109. def get_active_node_ids(self):
  110. return self.alive_node_ids - set(self.draining_nodes)
  111. def get_node_az(self, node_id):
  112. return None
  113. def get_available_resources_per_node(self):
  114. return self.available_resources_per_node
  115. def get_total_resources_per_node(self):
  116. return self.total_resources_per_node
  117. def add_node(self, node_id: str, resources: Dict = None, labels: Dict = None):
  118. self.alive_node_ids.add(node_id)
  119. self.total_resources_per_node[node_id] = deepcopy(resources) or {}
  120. self.available_resources_per_node[node_id] = deepcopy(resources) or {}
  121. self.node_labels[node_id] = labels or {}
  122. def set_available_resources_per_node(self, node_id: str, resources: Dict):
  123. self.available_resources_per_node[node_id] = deepcopy(resources)
  124. def get_node_labels(self, node_id: str):
  125. return self.node_labels.get(node_id, {})
  126. class FakeRemoteFunction:
  127. def remote(self):
  128. pass
  129. class MockActorHandle:
  130. def __init__(self, **kwargs):
  131. self._options = kwargs
  132. self._actor_id = "fake_id"
  133. self.initialize_and_get_metadata_called = False
  134. self.is_allocated_called = False
  135. @property
  136. def initialize_and_get_metadata(self):
  137. self.initialize_and_get_metadata_called = True
  138. # return a mock object so that we can call `remote()` on it.
  139. return FakeRemoteFunction()
  140. @property
  141. def is_allocated(self):
  142. self.is_allocated_called = True
  143. return FakeRemoteFunction()
  144. class MockActorClass:
  145. def __init__(self):
  146. self._init_args = ()
  147. self._options = dict()
  148. def options(self, **kwargs):
  149. res = copy(self)
  150. for k, v in kwargs.items():
  151. res._options[k] = v
  152. return res
  153. def remote(self, *args) -> MockActorHandle:
  154. return MockActorHandle(init_args=args, **self._options)
  155. class MockPlacementGroup:
  156. def __init__(self, request: CreatePlacementGroupRequest):
  157. self._bundles = request.bundles
  158. self._strategy = request.strategy
  159. self._soft_target_node_id = request.target_node_id
  160. self._name = request.name
  161. self._lifetime = "detached"
  162. class MockDeploymentHandle:
  163. def __init__(self, deployment_name: str, app_name: str = SERVE_DEFAULT_APP_NAME):
  164. self._deployment_name = deployment_name
  165. self._app_name = app_name
  166. self._protocol = RequestProtocol.UNDEFINED
  167. self._running_replicas_populated = False
  168. self._initialized = False
  169. def is_initialized(self):
  170. return self._initialized
  171. def _init(self):
  172. if self._initialized:
  173. raise RuntimeError("already initialized")
  174. self._initialized = True
  175. def options(self, *args, **kwargs):
  176. return self
  177. def __eq__(self, dep: Tuple[str]):
  178. other_deployment_name, other_app_name = dep
  179. return (
  180. self._deployment_name == other_deployment_name
  181. and self._app_name == other_app_name
  182. )
  183. def _set_request_protocol(self, protocol: RequestProtocol):
  184. self._protocol = protocol
  185. def _get_or_create_router(self):
  186. pass
  187. def running_replicas_populated(self) -> bool:
  188. return self._running_replicas_populated
  189. def set_running_replicas_populated(self, val: bool):
  190. self._running_replicas_populated = val
  191. @serve.deployment
  192. class GetPID:
  193. def __call__(self):
  194. return os.getpid()
  195. get_pid_entrypoint = GetPID.bind()
  196. def check_ray_stopped():
  197. try:
  198. requests.get("http://localhost:8265/api/ray/version")
  199. return False
  200. except Exception:
  201. return True
  202. def check_ray_started():
  203. return requests.get("http://localhost:8265/api/ray/version").status_code == 200
  204. def check_deployment_status(
  205. name: str, expected_status: DeploymentStatus, app_name=SERVE_DEFAULT_APP_NAME
  206. ) -> bool:
  207. app_status = serve.status().applications[app_name]
  208. assert app_status.deployments[name].status == expected_status
  209. return True
  210. def get_num_alive_replicas(
  211. deployment_name: str, app_name: str = SERVE_DEFAULT_APP_NAME
  212. ) -> int:
  213. """Get the replicas currently running for the given deployment."""
  214. dep_id = DeploymentID(name=deployment_name, app_name=app_name)
  215. actors = list_actors(
  216. filters=[
  217. ("class_name", "=", dep_id.to_replica_actor_class_name()),
  218. ("state", "=", "ALIVE"),
  219. ]
  220. )
  221. return len(actors)
  222. def check_num_replicas_gte(
  223. name: str, target: int, app_name: str = SERVE_DEFAULT_APP_NAME
  224. ) -> int:
  225. """Check if num replicas is >= target."""
  226. assert get_num_alive_replicas(name, app_name) >= target
  227. return True
  228. def check_num_replicas_eq(
  229. name: str,
  230. target: int,
  231. app_name: str = SERVE_DEFAULT_APP_NAME,
  232. use_controller: bool = False,
  233. ) -> bool:
  234. """Check if num replicas is == target."""
  235. if use_controller:
  236. dep = serve.status().applications[app_name].deployments[name]
  237. num_running_replicas = dep.replica_states.get(ReplicaState.RUNNING, 0)
  238. assert num_running_replicas == target
  239. else:
  240. assert get_num_alive_replicas(name, app_name) == target
  241. return True
  242. def check_num_replicas_lte(
  243. name: str, target: int, app_name: str = SERVE_DEFAULT_APP_NAME
  244. ) -> int:
  245. """Check if num replicas is <= target."""
  246. assert get_num_alive_replicas(name, app_name) <= target
  247. return True
  248. def check_apps_running(apps: List):
  249. status = serve.status()
  250. for app_name in apps:
  251. assert status.applications[app_name].status == ApplicationStatus.RUNNING
  252. return True
  253. def check_replica_counts(
  254. controller: ActorHandle,
  255. deployment_id: DeploymentID,
  256. total: Optional[int] = None,
  257. by_state: Optional[List[Tuple[ReplicaState, int, Callable]]] = None,
  258. ):
  259. """Uses _dump_replica_states_for_testing to check replica counts.
  260. Args:
  261. controller: A handle to the Serve controller.
  262. deployment_id: The deployment to check replica counts for.
  263. total: The total number of expected replicas for the deployment.
  264. by_state: A list of tuples of the form
  265. (replica state, number of replicas, filter function).
  266. Used for more fine grained checks.
  267. """
  268. replicas = ray.get(
  269. controller._dump_replica_states_for_testing.remote(deployment_id)
  270. )
  271. if total is not None:
  272. replica_counts = {
  273. state: len(replicas.get([state]))
  274. for state in ALL_REPLICA_STATES
  275. if replicas.get([state])
  276. }
  277. assert replicas.count() == total, replica_counts
  278. if by_state is not None:
  279. for state, count, check in by_state:
  280. assert isinstance(state, ReplicaState)
  281. assert isinstance(count, int) and count >= 0
  282. if check:
  283. filtered = {r for r in replicas.get(states=[state]) if check(r)}
  284. curr_count = len(filtered)
  285. else:
  286. curr_count = replicas.count(states=[state])
  287. msg = f"Expected {count} for state {state} but got {curr_count}."
  288. assert curr_count == count, msg
  289. return True
  290. @ray.remote(name=STORAGE_ACTOR_NAME, namespace=SERVE_NAMESPACE, num_cpus=0)
  291. class TelemetryStorage:
  292. def __init__(self):
  293. self.reports_received = 0
  294. self.current_report = dict()
  295. def store_report(self, report: Dict) -> None:
  296. self.reports_received += 1
  297. self.current_report = report
  298. def get_report(self) -> Dict:
  299. return self.current_report
  300. def get_reports_received(self) -> int:
  301. return self.reports_received
  302. @serve.deployment(ray_actor_options={"num_cpus": 0})
  303. class TelemetryReceiver:
  304. def __init__(self):
  305. self.storage = ray.get_actor(name=STORAGE_ACTOR_NAME, namespace=SERVE_NAMESPACE)
  306. async def __call__(self, request: Request) -> bool:
  307. report = await request.json()
  308. ray.get(self.storage.store_report.remote(report))
  309. return True
  310. receiver_app = TelemetryReceiver.bind()
  311. def start_telemetry_app():
  312. """Start a telemetry Serve app.
  313. Ray should be initialized before calling this method.
  314. NOTE: If you're running the TelemetryReceiver Serve app to check telemetry,
  315. remember that the receiver itself is counted in the telemetry. E.g. if you
  316. deploy a Serve app other than the receiver, the number of apps in the
  317. cluster is 2- not 1– since the receiver is also running.
  318. Returns a handle to a TelemetryStorage actor. You can use this actor
  319. to access the latest telemetry reports.
  320. """
  321. storage = TelemetryStorage.remote()
  322. serve.run(receiver_app, name="telemetry", route_prefix=TELEMETRY_ROUTE_PREFIX)
  323. return storage
  324. def check_telemetry(
  325. tag: ServeUsageTag, expected: Any, storage_actor_name: str = STORAGE_ACTOR_NAME
  326. ):
  327. storage_handle = ray.get_actor(storage_actor_name, namespace=SERVE_NAMESPACE)
  328. report = ray.get(storage_handle.get_report.remote())
  329. print(report["extra_usage_tags"])
  330. assert tag.get_value_from_report(report) == expected
  331. return True
  332. def ping_grpc_list_applications(channel, app_names, test_draining=False):
  333. import pytest
  334. stub = serve_pb2_grpc.RayServeAPIServiceStub(channel)
  335. request = serve_pb2.ListApplicationsRequest()
  336. if test_draining:
  337. with pytest.raises(grpc.RpcError) as exception_info:
  338. _, _ = stub.ListApplications.with_call(request=request)
  339. rpc_error = exception_info.value
  340. assert rpc_error.code() == grpc.StatusCode.UNAVAILABLE
  341. assert rpc_error.details() == DRAINING_MESSAGE
  342. else:
  343. response, call = stub.ListApplications.with_call(request=request)
  344. assert call.code() == grpc.StatusCode.OK
  345. assert response.application_names == app_names
  346. return True
  347. def ping_grpc_healthz(channel, test_draining=False):
  348. import pytest
  349. stub = serve_pb2_grpc.RayServeAPIServiceStub(channel)
  350. request = serve_pb2.HealthzRequest()
  351. if test_draining:
  352. with pytest.raises(grpc.RpcError) as exception_info:
  353. _, _ = stub.Healthz.with_call(request=request)
  354. rpc_error = exception_info.value
  355. assert rpc_error.code() == grpc.StatusCode.UNAVAILABLE
  356. assert rpc_error.details() == DRAINING_MESSAGE
  357. else:
  358. response, call = stub.Healthz.with_call(request=request)
  359. assert call.code() == grpc.StatusCode.OK
  360. assert response.message == "success"
  361. def ping_grpc_call_method(channel, app_name, test_not_found=False):
  362. import pytest
  363. stub = serve_pb2_grpc.UserDefinedServiceStub(channel)
  364. request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar")
  365. metadata = (("application", app_name),)
  366. if test_not_found:
  367. with pytest.raises(grpc.RpcError) as exception_info:
  368. _, _ = stub.__call__.with_call(request=request, metadata=metadata)
  369. rpc_error = exception_info.value
  370. assert rpc_error.code() == grpc.StatusCode.NOT_FOUND, rpc_error.code()
  371. assert f"Application '{app_name}' not found." in rpc_error.details()
  372. else:
  373. response, call = stub.__call__.with_call(request=request, metadata=metadata)
  374. assert call.code() == grpc.StatusCode.OK, call.code()
  375. assert response.greeting == "Hello foo from bar", response.greeting
  376. def ping_grpc_another_method(channel, app_name):
  377. stub = serve_pb2_grpc.UserDefinedServiceStub(channel)
  378. request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar")
  379. metadata = (("application", app_name),)
  380. response = stub.Method1(request=request, metadata=metadata)
  381. assert response.greeting == "Hello foo from method1"
  382. def ping_grpc_model_multiplexing(channel, app_name):
  383. stub = serve_pb2_grpc.UserDefinedServiceStub(channel)
  384. request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar")
  385. multiplexed_model_id = "999"
  386. metadata = (
  387. ("application", app_name),
  388. ("multiplexed_model_id", multiplexed_model_id),
  389. )
  390. response = stub.Method2(request=request, metadata=metadata)
  391. assert (
  392. response.greeting
  393. == f"Method2 called model, loading model: {multiplexed_model_id}"
  394. )
  395. def ping_grpc_streaming(channel, app_name):
  396. stub = serve_pb2_grpc.UserDefinedServiceStub(channel)
  397. request = serve_pb2.UserDefinedMessage(name="foo", num=30, foo="bar")
  398. metadata = (("application", app_name),)
  399. responses = stub.Streaming(request=request, metadata=metadata)
  400. for idx, response in enumerate(responses):
  401. assert response.greeting == f"{idx}: Hello foo from bar"
  402. def ping_fruit_stand(channel, app_name):
  403. stub = serve_pb2_grpc.FruitServiceStub(channel)
  404. request = serve_pb2.FruitAmounts(orange=4, apple=8)
  405. metadata = (("application", app_name),)
  406. response = stub.FruitStand(request=request, metadata=metadata)
  407. assert response.costs == 32
  408. @asynccontextmanager
  409. async def send_signal_on_cancellation(signal_actor: ActorHandle):
  410. cancelled = False
  411. try:
  412. yield
  413. await asyncio.sleep(100)
  414. except asyncio.CancelledError:
  415. cancelled = True
  416. # Clear the context var to avoid Ray recursively cancelling this method call.
  417. ray._raylet.async_task_id.set(None)
  418. await signal_actor.send.remote()
  419. if not cancelled:
  420. raise RuntimeError(
  421. "CancelledError wasn't raised during `send_signal_on_cancellation` block"
  422. )
  423. class FakeGrpcContext:
  424. def __init__(self):
  425. self._auth_context = {"key": "value"}
  426. self._invocation_metadata = [("key", "value")]
  427. self._peer = "peer"
  428. self._peer_identities = b"peer_identities"
  429. self._peer_identity_key = "peer_identity_key"
  430. self._code = None
  431. self._details = None
  432. self._trailing_metadata = []
  433. self._invocation_metadata = []
  434. def auth_context(self):
  435. return self._auth_context
  436. def code(self):
  437. return self._code
  438. def details(self):
  439. return self._details
  440. def peer(self):
  441. return self._peer
  442. def peer_identities(self):
  443. return self._peer_identities
  444. def peer_identity_key(self):
  445. return self._peer_identity_key
  446. def trailing_metadata(self):
  447. return self._trailing_metadata
  448. def set_code(self, code):
  449. self._code = code
  450. def set_details(self, details):
  451. self._details = details
  452. def set_trailing_metadata(self, trailing_metadata):
  453. self._trailing_metadata = trailing_metadata
  454. def invocation_metadata(self):
  455. return self._invocation_metadata
  456. class FakeGauge:
  457. def __init__(self, name: str = None, tag_keys: Tuple[str] = None):
  458. self.name = name
  459. self.values = dict()
  460. self.tags = tag_keys or ()
  461. self.default_tags = dict()
  462. def set_default_tags(self, tags: Dict[str, str]):
  463. for key, tag in tags.items():
  464. assert key in self.tags
  465. self.default_tags[key] = tag
  466. def set(self, value: Union[int, float], tags: Dict[str, str] = None):
  467. merged_tags = self.default_tags.copy()
  468. merged_tags.update(tags or {})
  469. assert set(merged_tags.keys()) == set(self.tags)
  470. d = self.values
  471. for tag in self.tags[:-1]:
  472. tag_value = merged_tags[tag]
  473. if tag_value not in d:
  474. d[tag_value] = dict()
  475. d = d[tag_value]
  476. d[merged_tags[self.tags[-1]]] = value
  477. def get_value(self, tags: Dict[str, str]):
  478. value = self.values
  479. for tag in self.tags:
  480. tag_value = tags[tag]
  481. value = value.get(tag_value)
  482. if value is None:
  483. return
  484. return value
  485. class FakeCounter:
  486. def __init__(self, name: str = None, tag_keys: Tuple[str] = None):
  487. self.name = name
  488. self.counts = dict()
  489. self.tags = tag_keys or ()
  490. self.default_tags = dict()
  491. def set_default_tags(self, tags: Dict[str, str]):
  492. for key, tag in tags.items():
  493. assert key in self.tags
  494. self.default_tags[key] = tag
  495. def inc(self, value: Union[int, float] = 1.0, tags: Dict[str, str] = None):
  496. merged_tags = self.default_tags.copy()
  497. merged_tags.update(tags or {})
  498. assert set(merged_tags.keys()) == set(self.tags)
  499. d = self.counts
  500. for tag in self.tags[:-1]:
  501. tag_value = merged_tags[tag]
  502. if tag_value not in d:
  503. d[tag_value] = dict()
  504. d = d[tag_value]
  505. key = merged_tags[self.tags[-1]]
  506. d[key] = d.get(key, 0) + value
  507. def get_count(self, tags: Dict[str, str]) -> int:
  508. value = self.counts
  509. for tag in self.tags:
  510. tag_value = tags[tag]
  511. value = value.get(tag_value)
  512. if value is None:
  513. return
  514. return value
  515. def get_tags(self):
  516. return self.tags
  517. @ray.remote
  518. def get_node_id():
  519. return ray.get_runtime_context().get_node_id()
  520. def check_num_alive_nodes(target: int):
  521. alive_nodes = [node for node in ray.nodes() if node["Alive"]]
  522. assert len(alive_nodes) == target
  523. return True
  524. def get_deployment_details(
  525. deployment_name: str,
  526. app_name: str = SERVE_DEFAULT_APP_NAME,
  527. _client: ServeControllerClient = None,
  528. ):
  529. client = _client or _get_global_client()
  530. details = client.get_serve_details()
  531. return details["applications"][app_name]["deployments"][deployment_name]
  532. @ray.remote
  533. class Counter:
  534. def __init__(self, target: int):
  535. self.count = 0
  536. self.target = target
  537. self.ready_event = asyncio.Event()
  538. def inc(self):
  539. self.count += 1
  540. if self.count == self.target:
  541. self.ready_event.set()
  542. async def wait(self):
  543. await self.ready_event.wait()
  544. def tlog(s: str, level: str = "INFO"):
  545. """Convenient logging method for testing."""
  546. now = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3]
  547. print(f"[{level}] {now} {s}")
  548. def check_target_groups_ready(
  549. client: ServeControllerClient,
  550. app_name: str,
  551. protocol: Union[str, RequestProtocol] = RequestProtocol.HTTP,
  552. ):
  553. """Wait for target groups to be ready for the given app and protocol.
  554. Target groups are ready when there are at least one target for the given protocol. And it's
  555. possible that target groups are not ready immediately. An example is when the controller
  556. is recovering from a crash.
  557. """
  558. target_groups = ray.get(client._controller.get_target_groups.remote(app_name))
  559. target_groups = [
  560. target_group
  561. for target_group in target_groups
  562. if target_group.protocol == protocol
  563. ]
  564. all_targets = [
  565. target for target_group in target_groups for target in target_group.targets
  566. ]
  567. return len(all_targets) > 0
  568. def get_application_urls(
  569. protocol: Union[str, RequestProtocol] = RequestProtocol.HTTP,
  570. app_name: str = SERVE_DEFAULT_APP_NAME,
  571. use_localhost: bool = True,
  572. is_websocket: bool = False,
  573. exclude_route_prefix: bool = False,
  574. from_proxy_manager: bool = False,
  575. ) -> List[str]:
  576. """Get the URL of the application.
  577. Args:
  578. protocol: The protocol to use for the application.
  579. app_name: The name of the application.
  580. use_localhost: Whether to use localhost instead of the IP address.
  581. Set to True if Serve deployments are not exposed publicly or
  582. for low latency benchmarking.
  583. is_websocket: Whether the url should be served as a websocket.
  584. exclude_route_prefix: The route prefix to exclude from the application.
  585. from_proxy_manager: Whether the caller is a proxy manager.
  586. Returns:
  587. The URLs of the application.
  588. """
  589. client = _get_global_client(_health_check_controller=True)
  590. serve_details = client.get_serve_details()
  591. assert (
  592. app_name in serve_details["applications"]
  593. ), f"App {app_name} not found in serve details. Use this method only when the app is known to be running."
  594. route_prefix = serve_details["applications"][app_name]["route_prefix"]
  595. # route_prefix is set to None when route_prefix value is specifically set to None
  596. # in the config used to deploy the app.
  597. if exclude_route_prefix or route_prefix is None:
  598. route_prefix = ""
  599. if isinstance(protocol, str):
  600. protocol = RequestProtocol(protocol)
  601. target_groups: List[TargetGroup] = ray.get(
  602. client._controller.get_target_groups.remote(app_name, from_proxy_manager)
  603. )
  604. target_groups = [
  605. target_group
  606. for target_group in target_groups
  607. if target_group.protocol == protocol
  608. ]
  609. if len(target_groups) == 0:
  610. raise ValueError(
  611. f"No target group found for app {app_name} with protocol {protocol} and route prefix {route_prefix}"
  612. )
  613. urls = []
  614. for target_group in target_groups:
  615. for target in target_group.targets:
  616. ip = "localhost" if use_localhost else target.ip
  617. if protocol == RequestProtocol.HTTP:
  618. scheme = "ws" if is_websocket else "http"
  619. url = f"{scheme}://{build_address(ip, target.port)}{route_prefix}"
  620. elif protocol == RequestProtocol.GRPC:
  621. if is_websocket:
  622. raise ValueError(
  623. "is_websocket=True is not supported with gRPC protocol."
  624. )
  625. url = build_address(ip, target.port)
  626. else:
  627. raise ValueError(f"Unsupported protocol: {protocol}")
  628. url = url.rstrip("/")
  629. urls.append(url)
  630. return urls
  631. def get_application_url(
  632. protocol: Union[str, RequestProtocol] = RequestProtocol.HTTP,
  633. app_name: str = SERVE_DEFAULT_APP_NAME,
  634. use_localhost: bool = True,
  635. is_websocket: bool = False,
  636. exclude_route_prefix: bool = False,
  637. from_proxy_manager: bool = False,
  638. ) -> str:
  639. """Get the URL of the application.
  640. Args:
  641. protocol: The protocol to use for the application.
  642. app_name: The name of the application.
  643. use_localhost: Whether to use localhost instead of the IP address.
  644. Set to True if Serve deployments are not exposed publicly or
  645. for low latency benchmarking.
  646. is_websocket: Whether the url should be served as a websocket.
  647. exclude_route_prefix: The route prefix to exclude from the application.
  648. from_proxy_manager: Whether the caller is a proxy manager.
  649. Returns:
  650. The URL of the application. If there are multiple URLs, a random one is returned.
  651. """
  652. return random.choice(
  653. get_application_urls(
  654. protocol,
  655. app_name,
  656. use_localhost,
  657. is_websocket,
  658. exclude_route_prefix,
  659. from_proxy_manager,
  660. )
  661. )
  662. def check_running(app_name: str = SERVE_DEFAULT_APP_NAME):
  663. assert serve.status().applications[app_name].status == ApplicationStatus.RUNNING
  664. return True
  665. def request_with_retries(timeout=30, app_name=SERVE_DEFAULT_APP_NAME):
  666. result_holder = {"resp": None}
  667. def _attempt() -> bool:
  668. try:
  669. url = get_application_url("HTTP", app_name=app_name)
  670. result_holder["resp"] = httpx.get(url, timeout=timeout)
  671. return True
  672. except (httpx.RequestError, IndexError):
  673. return False
  674. try:
  675. wait_for_condition(_attempt, timeout=timeout)
  676. return result_holder["resp"]
  677. except RuntimeError as e:
  678. # Preserve previous API by raising TimeoutError on expiry
  679. raise TimeoutError from e
  680. # Metrics test utilities
  681. TEST_METRICS_EXPORT_PORT = 9999
  682. def get_metric_float(
  683. metric: str,
  684. expected_tags: Optional[Dict[str, str]],
  685. timeseries: Optional[PrometheusTimeseries] = None,
  686. ) -> float:
  687. """Gets the float value of metric.
  688. If tags is specified, searched for metric with matching tags.
  689. Returns -1 if the metric isn't available.
  690. """
  691. if timeseries is None:
  692. timeseries = PrometheusTimeseries()
  693. samples = fetch_prometheus_metric_timeseries(
  694. [f"localhost:{TEST_METRICS_EXPORT_PORT}"], timeseries
  695. ).get(metric, [])
  696. for sample in samples:
  697. if expected_tags.items() <= sample.labels.items():
  698. return sample.value
  699. return -1
  700. def check_metric_float_eq(
  701. metric: str,
  702. expected: float,
  703. expected_tags: Optional[Dict[str, str]],
  704. timeseries: Optional[PrometheusTimeseries] = None,
  705. ) -> bool:
  706. """Check if a metric's float value equals the expected value."""
  707. metric_value = get_metric_float(metric, expected_tags, timeseries)
  708. assert float(metric_value) == expected
  709. return True
  710. def get_metric_dictionaries(
  711. name: str, timeout: float = 20, timeseries: Optional[PrometheusTimeseries] = None
  712. ) -> List[Dict]:
  713. if timeseries is None:
  714. timeseries = PrometheusTimeseries()
  715. def metric_available() -> bool:
  716. assert name in fetch_prometheus_metric_timeseries(
  717. [f"localhost:{TEST_METRICS_EXPORT_PORT}"], timeseries
  718. )
  719. return True
  720. wait_for_condition(metric_available, retry_interval_ms=1000, timeout=timeout)
  721. metric_dicts = []
  722. for sample in timeseries.metric_samples.values():
  723. if sample.name == name:
  724. metric_dicts.append(sample.labels)
  725. return metric_dicts