microbenchmark.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # Runs several scenarios with varying max batch size, max concurrent queries,
  2. # number of replicas, and with intermediate serve handles (to simulate ensemble
  3. # models) either on or off.
  4. import asyncio
  5. import logging
  6. from pprint import pprint
  7. from typing import Dict, Union
  8. import aiohttp
  9. from starlette.requests import Request
  10. import ray
  11. from ray import serve
  12. from ray.serve._private.benchmarks.common import run_throughput_benchmark
  13. from ray.serve.handle import DeploymentHandle
  14. NUM_CLIENTS = 8
  15. CALLS_PER_BATCH = 100
  16. async def fetch(session, data):
  17. async with session.get("http://localhost:8000/", data=data) as response:
  18. response = await response.text()
  19. assert response == "ok", response
  20. @ray.remote
  21. class Client:
  22. def ready(self):
  23. return "ok"
  24. async def do_queries(self, num, data):
  25. async with aiohttp.ClientSession() as session:
  26. for _ in range(num):
  27. await fetch(session, data)
  28. def build_app(
  29. intermediate_handles: bool,
  30. num_replicas: int,
  31. max_batch_size: int,
  32. max_ongoing_requests: int,
  33. ):
  34. @serve.deployment(max_ongoing_requests=1000)
  35. class Upstream:
  36. def __init__(self, handle: DeploymentHandle):
  37. self._handle = handle
  38. # Turn off access log.
  39. logging.getLogger("ray.serve").setLevel(logging.WARNING)
  40. async def __call__(self, req: Request):
  41. return await self._handle.remote(await req.body())
  42. @serve.deployment(
  43. num_replicas=num_replicas,
  44. max_ongoing_requests=max_ongoing_requests,
  45. )
  46. class Downstream:
  47. def __init__(self):
  48. # Turn off access log.
  49. logging.getLogger("ray.serve").setLevel(logging.WARNING)
  50. @serve.batch(max_batch_size=max_batch_size)
  51. async def batch(self, reqs):
  52. return [b"ok"] * len(reqs)
  53. async def __call__(self, req: Union[bytes, Request]):
  54. if max_batch_size > 1:
  55. return await self.batch(req)
  56. else:
  57. return b"ok"
  58. if intermediate_handles:
  59. return Upstream.bind(Downstream.bind())
  60. else:
  61. return Downstream.bind()
  62. async def trial(
  63. intermediate_handles: bool,
  64. num_replicas: int,
  65. max_batch_size: int,
  66. max_ongoing_requests: int,
  67. data_size: str,
  68. ) -> Dict[str, float]:
  69. results = {}
  70. trial_key_base = (
  71. f"replica:{num_replicas}/batch_size:{max_batch_size}/"
  72. f"concurrent_queries:{max_ongoing_requests}/"
  73. f"data_size:{data_size}/intermediate_handle:{intermediate_handles}"
  74. )
  75. print(
  76. f"intermediate_handles={intermediate_handles},"
  77. f"num_replicas={num_replicas},"
  78. f"max_batch_size={max_batch_size},"
  79. f"max_ongoing_requests={max_ongoing_requests},"
  80. f"data_size={data_size}"
  81. )
  82. app = build_app(
  83. intermediate_handles, num_replicas, max_batch_size, max_ongoing_requests
  84. )
  85. serve.run(app)
  86. if data_size == "small":
  87. data = None
  88. elif data_size == "large":
  89. data = b"a" * 1024 * 1024
  90. else:
  91. raise ValueError("data_size should be 'small' or 'large'.")
  92. async with aiohttp.ClientSession() as session:
  93. async def single_client():
  94. for _ in range(CALLS_PER_BATCH):
  95. await fetch(session, data)
  96. single_client_avg_tps, single_client_std_tps = await run_throughput_benchmark(
  97. single_client,
  98. multiplier=CALLS_PER_BATCH,
  99. )
  100. print(
  101. "\t{} {} +- {} requests/s".format(
  102. "single client {} data".format(data_size),
  103. single_client_avg_tps,
  104. single_client_std_tps,
  105. )
  106. )
  107. key = f"num_client:1/{trial_key_base}"
  108. results[key] = single_client_avg_tps
  109. clients = [Client.remote() for _ in range(NUM_CLIENTS)]
  110. ray.get([client.ready.remote() for client in clients])
  111. async def many_clients():
  112. ray.get([a.do_queries.remote(CALLS_PER_BATCH, data) for a in clients])
  113. multi_client_avg_tps, _ = await run_throughput_benchmark(
  114. many_clients,
  115. multiplier=CALLS_PER_BATCH * len(clients),
  116. )
  117. results[f"num_client:{len(clients)}/{trial_key_base}"] = multi_client_avg_tps
  118. return results
  119. async def main():
  120. results = {}
  121. for intermediate_handles in [False, True]:
  122. for num_replicas in [1, 8]:
  123. for max_batch_size, max_ongoing_requests in [
  124. (1, 1),
  125. (1, 10000),
  126. (10000, 10000),
  127. ]:
  128. # TODO(edoakes): large data causes broken pipe errors.
  129. for data_size in ["small"]:
  130. results.update(
  131. await trial(
  132. intermediate_handles,
  133. num_replicas,
  134. max_batch_size,
  135. max_ongoing_requests,
  136. data_size,
  137. )
  138. )
  139. print("Results from all conditions:")
  140. pprint(results)
  141. return results
  142. if __name__ == "__main__":
  143. ray.init()
  144. serve.start()
  145. loop = asyncio.new_event_loop()
  146. asyncio.set_event_loop(loop)
  147. loop.run_until_complete(main())