common.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import asyncio
  2. import inspect
  3. import logging
  4. import random
  5. import string
  6. import time
  7. from functools import partial
  8. from typing import Any, Callable, Coroutine, List, Optional, Tuple
  9. import aiohttp
  10. import aiohttp.client_exceptions
  11. import grpc
  12. import numpy as np
  13. import pandas as pd
  14. from starlette.responses import StreamingResponse
  15. from tqdm import tqdm
  16. from ray import serve
  17. from ray.serve.generated import serve_pb2, serve_pb2_grpc
  18. from ray.serve.handle import DeploymentHandle
  19. async def run_latency_benchmark(
  20. f: Callable, num_requests: int, *, num_warmup_requests: int = 100
  21. ) -> pd.Series:
  22. if inspect.iscoroutinefunction(f):
  23. to_call = f
  24. else:
  25. async def to_call():
  26. f()
  27. latencies = []
  28. for i in tqdm(range(num_requests + num_warmup_requests)):
  29. start = time.perf_counter()
  30. await to_call()
  31. end = time.perf_counter()
  32. # Don't include warm-up requests.
  33. if i >= num_warmup_requests:
  34. latencies.append(1000 * (end - start))
  35. return pd.Series(latencies)
  36. async def run_throughput_benchmark(
  37. fn: Callable[[], List[float]],
  38. multiplier: int = 1,
  39. num_trials: int = 10,
  40. trial_runtime: float = 1,
  41. ) -> Tuple[float, float, pd.Series]:
  42. """Benchmarks throughput of a function.
  43. Args:
  44. fn: The function to benchmark. If this returns anything, it must
  45. return a list of latencies.
  46. multiplier: The number of requests or tokens (or whatever unit
  47. is appropriate for this throughput benchmark) that is
  48. completed in one call to `fn`.
  49. num_trials: The number of trials to run.
  50. trial_runtime: How long each trial should run for. During the
  51. duration of one trial, `fn` will be repeatedly called.
  52. Returns (mean, stddev, latencies).
  53. """
  54. # Warmup
  55. start = time.time()
  56. while time.time() - start < 0.1:
  57. await fn()
  58. # Benchmark
  59. stats = []
  60. latencies = []
  61. for _ in tqdm(range(num_trials)):
  62. start = time.perf_counter()
  63. count = 0
  64. while time.perf_counter() - start < trial_runtime:
  65. res = await fn()
  66. if res:
  67. latencies.extend(res)
  68. count += 1
  69. end = time.perf_counter()
  70. stats.append(multiplier * count / (end - start))
  71. return round(np.mean(stats), 2), round(np.std(stats), 2), pd.Series(latencies)
  72. async def do_single_http_batch(
  73. *,
  74. batch_size: int = 100,
  75. url: str = "http://localhost:8000",
  76. stream: bool = False,
  77. ) -> List[float]:
  78. """Sends a batch of http requests and returns e2e latencies."""
  79. # By default, aiohttp limits the number of client connections to 100.
  80. # We need to use TCPConnector to configure the limit if batch size
  81. # is greater than 100.
  82. connector = aiohttp.TCPConnector(limit=batch_size)
  83. async with aiohttp.ClientSession(
  84. connector=connector, raise_for_status=True
  85. ) as session:
  86. async def do_query():
  87. start = time.perf_counter()
  88. try:
  89. async with session.get(url) as r:
  90. if stream:
  91. async for chunk, _ in r.content.iter_chunks():
  92. pass
  93. else:
  94. # Read the response to ensure it's consumed
  95. await r.read()
  96. except aiohttp.client_exceptions.ClientConnectionError:
  97. pass
  98. end = time.perf_counter()
  99. return 1000 * (end - start)
  100. return await asyncio.gather(*[do_query() for _ in range(batch_size)])
  101. async def do_single_grpc_batch(
  102. *, batch_size: int = 100, target: str = "localhost:9000"
  103. ):
  104. channel = grpc.aio.insecure_channel(target)
  105. stub = serve_pb2_grpc.RayServeBenchmarkServiceStub(channel)
  106. payload = serve_pb2.StringData(data="")
  107. async def do_query():
  108. start = time.perf_counter()
  109. await stub.grpc_call(payload)
  110. end = time.perf_counter()
  111. return 1000 * (end - start)
  112. return await asyncio.gather(*[do_query() for _ in range(batch_size)])
  113. async def collect_profile_events(coro: Coroutine):
  114. """Collects profiling events using Viztracer"""
  115. from viztracer import VizTracer
  116. tracer = VizTracer()
  117. tracer.start()
  118. await coro
  119. tracer.stop()
  120. tracer.save()
  121. def generate_payload(size: int = 100, chars=string.ascii_uppercase + string.digits):
  122. return "".join(random.choice(chars) for _ in range(size))
  123. class Blackhole:
  124. def sink(self, o):
  125. pass
  126. @serve.deployment
  127. class Noop:
  128. def __init__(self):
  129. logging.getLogger("ray.serve").setLevel(logging.WARNING)
  130. def __call__(self, *args, **kwargs):
  131. return b""
  132. @serve.deployment
  133. class ModelComp:
  134. def __init__(self, child):
  135. logging.getLogger("ray.serve").setLevel(logging.WARNING)
  136. self._child = child
  137. async def __call__(self, *args, **kwargs):
  138. return await self._child.remote()
  139. @serve.deployment
  140. class GrpcDeployment:
  141. def __init__(self):
  142. logging.getLogger("ray.serve").setLevel(logging.WARNING)
  143. async def grpc_call(self, user_message):
  144. return serve_pb2.ModelOutput(output=9)
  145. async def call_with_string(self, user_message):
  146. return serve_pb2.ModelOutput(output=9)
  147. @serve.deployment
  148. class GrpcModelComp:
  149. def __init__(self, child):
  150. logging.getLogger("ray.serve").setLevel(logging.WARNING)
  151. self._child = child
  152. async def grpc_call(self, user_message):
  153. await self._child.remote()
  154. return serve_pb2.ModelOutput(output=9)
  155. async def call_with_string(self, user_message):
  156. await self._child.remote()
  157. return serve_pb2.ModelOutput(output=9)
  158. @serve.deployment
  159. class Streamer:
  160. def __init__(self, tokens_per_request: int, inter_token_delay_ms: int = 10):
  161. logging.getLogger("ray.serve").setLevel(logging.WARNING)
  162. self._tokens_per_request = tokens_per_request
  163. self._inter_token_delay_s = inter_token_delay_ms / 1000
  164. async def stream(self):
  165. for _ in range(self._tokens_per_request):
  166. await asyncio.sleep(self._inter_token_delay_s)
  167. yield b"hi"
  168. async def __call__(self):
  169. return StreamingResponse(self.stream())
  170. @serve.deployment
  171. class IntermediateRouter:
  172. def __init__(self, handle: DeploymentHandle):
  173. logging.getLogger("ray.serve").setLevel(logging.WARNING)
  174. self._handle = handle.options(stream=True)
  175. async def stream(self):
  176. async for token in self._handle.stream.remote():
  177. yield token
  178. def __call__(self):
  179. return StreamingResponse(self.stream())
  180. @serve.deployment
  181. class Benchmarker:
  182. def __init__(
  183. self,
  184. handle: DeploymentHandle,
  185. stream: bool = False,
  186. ):
  187. logging.getLogger("ray.serve").setLevel(logging.WARNING)
  188. self._handle = handle.options(stream=stream)
  189. self._stream = stream
  190. async def do_single_request(self, payload: Any = None) -> float:
  191. """Completes a single unary request. Returns e2e latency in ms."""
  192. start = time.perf_counter()
  193. if payload is None:
  194. await self._handle.remote()
  195. else:
  196. await self._handle.remote(payload)
  197. end = time.perf_counter()
  198. return 1000 * (end - start)
  199. async def _do_single_stream(self) -> float:
  200. """Consumes a single streaming request. Returns e2e latency in ms."""
  201. start = time.perf_counter()
  202. async for r in self._handle.stream.remote():
  203. pass
  204. end = time.perf_counter()
  205. return 1000 * (end - start)
  206. async def _do_single_batch(self, batch_size: int) -> List[float]:
  207. if self._stream:
  208. return await asyncio.gather(
  209. *[self._do_single_stream() for _ in range(batch_size)]
  210. )
  211. else:
  212. return await asyncio.gather(
  213. *[self.do_single_request() for _ in range(batch_size)]
  214. )
  215. async def run_latency_benchmark(
  216. self, *, num_requests: int, payload: Any = None
  217. ) -> pd.Series:
  218. async def f():
  219. await self.do_single_request(payload)
  220. return await run_latency_benchmark(f, num_requests=num_requests)
  221. async def run_throughput_benchmark(
  222. self,
  223. *,
  224. batch_size: int,
  225. num_trials: int,
  226. trial_runtime: float,
  227. tokens_per_request: Optional[float] = None,
  228. ) -> Tuple[float, float]:
  229. if self._stream:
  230. assert tokens_per_request
  231. multiplier = tokens_per_request * batch_size
  232. else:
  233. multiplier = batch_size
  234. return await run_throughput_benchmark(
  235. fn=partial(
  236. self._do_single_batch,
  237. batch_size=batch_size,
  238. ),
  239. multiplier=multiplier,
  240. num_trials=num_trials,
  241. trial_runtime=trial_runtime,
  242. )