multiplex.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. import asyncio
  2. import inspect
  3. import logging
  4. import time
  5. from collections import OrderedDict
  6. from typing import Any, Callable, List, Set
  7. from ray.serve import metrics
  8. from ray.serve._private.common import ReplicaID, RequestRoutingInfo
  9. from ray.serve._private.constants import (
  10. MODEL_LOAD_LATENCY_BUCKETS_MS,
  11. PUSH_MULTIPLEXED_MODEL_IDS_INTERVAL_S,
  12. SERVE_LOGGER_NAME,
  13. )
  14. from ray.serve._private.metrics_utils import MetricsPusher
  15. from ray.serve._private.usage import ServeUsageTag
  16. from ray.serve.context import _get_global_client, _get_internal_replica_context
  17. logger = logging.getLogger(SERVE_LOGGER_NAME)
  18. class _ModelMultiplexWrapper:
  19. """A wrapper class that wraps the model load function and
  20. provides the LRU caching functionality.
  21. The model multiplexer is a wrapper class that wraps the model load function
  22. and provides the LRU caching functionality, and the model load function should
  23. be a coroutine function that takes the model ID as the first argument and
  24. returns the user-constructed model object.
  25. The model multiplexer will also ensure that the number of models on the current
  26. replica does not exceed the specified limit.
  27. The model will be unloaded in the LRU order, the model multiplexer will call the
  28. model's __del__ attribute if it exists to clean up the model resources eagerly.
  29. """
  30. _PUSH_MULTIPLEXED_MODEL_IDS_TASK_NAME = "push_multiplexed_model_ids"
  31. def __init__(
  32. self,
  33. model_load_func: Callable[[str], Any],
  34. self_arg: Any,
  35. max_num_models_per_replica: int,
  36. ):
  37. """Initialize the model multiplexer.
  38. Args:
  39. model_load_func: the model load async function.
  40. self_arg: self argument when model_load_func is class method.
  41. max_num_models_per_replica: the maximum number of models to be loaded on the
  42. current replica. If it is -1, there is no limit for the number of models
  43. per replica.
  44. """
  45. ServeUsageTag.MULTIPLEXED_API_USED.record("1")
  46. self.models = OrderedDict()
  47. self._func: Callable = model_load_func
  48. self.self_arg: Any = self_arg
  49. self.max_num_models_per_replica: int = max_num_models_per_replica
  50. # log MODEL_LOAD_LATENCY_BUCKET_MS
  51. logger.debug(f"MODEL_LOAD_LATENCY_BUCKET_MS: {MODEL_LOAD_LATENCY_BUCKETS_MS}")
  52. self.model_load_latency_ms = metrics.Histogram(
  53. "serve_multiplexed_model_load_latency_ms",
  54. description="The time it takes to load a model.",
  55. boundaries=MODEL_LOAD_LATENCY_BUCKETS_MS,
  56. )
  57. self.model_unload_latency_ms = metrics.Histogram(
  58. "serve_multiplexed_model_unload_latency_ms",
  59. description="The time it takes to unload a model.",
  60. boundaries=MODEL_LOAD_LATENCY_BUCKETS_MS,
  61. )
  62. self.num_models_gauge = metrics.Gauge(
  63. "serve_num_multiplexed_models",
  64. description="The number of models loaded on the current replica.",
  65. )
  66. self.registered_model_gauge = metrics.Gauge(
  67. "serve_registered_multiplexed_model_id",
  68. description="The model id registered on the current replica.",
  69. tag_keys=("model_id",),
  70. )
  71. self.get_model_requests_counter = metrics.Counter(
  72. "serve_multiplexed_get_model_requests_counter",
  73. description="The counter for get model requests on the current replica.",
  74. )
  75. self.models_unload_counter = metrics.Counter(
  76. "serve_multiplexed_models_unload_counter",
  77. description="The counter for unloaded models on the current replica.",
  78. )
  79. self.models_load_counter = metrics.Counter(
  80. "serve_multiplexed_models_load_counter",
  81. description="The counter for loaded models on the current replica.",
  82. )
  83. context = _get_internal_replica_context()
  84. if context is None:
  85. raise RuntimeError(
  86. "`@serve.multiplex` can only be used within a deployment "
  87. "(failed to retrieve Serve replica context)."
  88. )
  89. self._app_name: str = context.app_name
  90. self._deployment_name: str = context.deployment
  91. self._replica_id: ReplicaID = context.replica_id
  92. # Whether to push the multiplexed replica info to the controller.
  93. self._push_multiplexed_replica_info: bool = False
  94. # Model cache lock to ensure that only one model is loading/unloading at a time.
  95. self._model_cache_lock = asyncio.Lock()
  96. # The set of model IDs that are being loaded. This is used to early push
  97. # model ids info to the controller. The tasks will be added when there is cache
  98. # miss, and will be removed when the model is loaded successfully or
  99. # failed to load.
  100. self._model_load_tasks: Set[str] = set()
  101. self.metrics_pusher = MetricsPusher()
  102. self.metrics_pusher.register_or_update_task(
  103. self._PUSH_MULTIPLEXED_MODEL_IDS_TASK_NAME,
  104. self._push_model_ids_info,
  105. PUSH_MULTIPLEXED_MODEL_IDS_INTERVAL_S,
  106. )
  107. self.metrics_pusher.start()
  108. def _get_loading_and_loaded_model_ids(self) -> List[str]:
  109. """Get the model IDs of the loaded models & loading models in the replica.
  110. This is to push the model id information early to the controller, so that
  111. requests can be routed to the replica.
  112. """
  113. models_list = set(self.models.keys())
  114. models_list.update(self._model_load_tasks)
  115. return list(models_list)
  116. def _push_model_ids_info(self):
  117. """Push the multiplexed replica info to the controller."""
  118. try:
  119. self.num_models_gauge.set(len(self.models))
  120. for model_id in self.models:
  121. self.registered_model_gauge.set(1, tags={"model_id": model_id})
  122. if self._push_multiplexed_replica_info:
  123. _get_global_client().record_request_routing_info(
  124. RequestRoutingInfo(
  125. replica_id=self._replica_id,
  126. multiplexed_model_ids=self._get_loading_and_loaded_model_ids(),
  127. )
  128. )
  129. self._push_multiplexed_replica_info = False
  130. except Exception as e:
  131. logger.warning(
  132. "Failed to push the multiplexed replica info "
  133. f"to the controller. Error: {e}"
  134. )
  135. async def shutdown(self):
  136. """Unload all the models when the model multiplexer is deleted."""
  137. while len(self.models) > 0:
  138. try:
  139. await self.unload_model_lru()
  140. except Exception as e:
  141. logger.exception(
  142. f"Failed to unload model. Error: {e}",
  143. )
  144. async def load_model(self, model_id: str) -> Any:
  145. """Load the model if it is not loaded yet, and return
  146. the user-constructed model object.
  147. Args:
  148. model_id: the model ID.
  149. Returns:
  150. The user-constructed model object.
  151. """
  152. if type(model_id) is not str:
  153. raise TypeError("The model ID must be a string.")
  154. if not model_id:
  155. raise ValueError("The model ID cannot be empty.")
  156. self.get_model_requests_counter.inc()
  157. if model_id in self.models:
  158. # Move the model to the end of the OrderedDict to ensure LRU caching.
  159. model = self.models.pop(model_id)
  160. self.models[model_id] = model
  161. return self.models[model_id]
  162. else:
  163. # Set the flag to push the multiplexed replica info to the controller
  164. # before loading the model. This is to make sure we can push the model
  165. # id info to the controller/router early, so that requests can be routed to
  166. # the replica.
  167. self._push_multiplexed_replica_info = True
  168. self._model_load_tasks.add(model_id)
  169. async with self._model_cache_lock:
  170. # Check if the model has been loaded by another request.
  171. if model_id in self.models:
  172. return self.models[model_id]
  173. try:
  174. # If the number of models per replica is specified, check
  175. # if the number of models on the current replica has
  176. # reached the limit.
  177. if (
  178. self.max_num_models_per_replica > 0
  179. and len(self.models) >= self.max_num_models_per_replica
  180. ):
  181. # Unload the least recently used model.
  182. await self.unload_model_lru()
  183. self._push_multiplexed_replica_info = True
  184. # Load the model.
  185. logger.info(f"Loading model '{model_id}'.")
  186. self.models_load_counter.inc()
  187. load_start_time = time.time()
  188. if self.self_arg is None:
  189. self.models[model_id] = await self._func(model_id)
  190. else:
  191. self.models[model_id] = await self._func(
  192. self.self_arg, model_id
  193. )
  194. load_latency_ms = (time.time() - load_start_time) * 1000.0
  195. logger.info(
  196. f"Successfully loaded model '{model_id}' in "
  197. f"{load_latency_ms:.1f}ms."
  198. )
  199. self._model_load_tasks.discard(model_id)
  200. self.model_load_latency_ms.observe(load_latency_ms)
  201. return self.models[model_id]
  202. except Exception as e:
  203. logger.error(
  204. f"Failed to load model '{model_id}'. Error: {e}",
  205. )
  206. self._model_load_tasks.discard(model_id)
  207. raise e
  208. async def unload_model_lru(self) -> None:
  209. """Unload the least recently used model."""
  210. self.models_unload_counter.inc()
  211. unload_start_time = time.time()
  212. model_id, model = self.models.popitem(last=False)
  213. logger.info(f"Unloading model '{model_id}'.")
  214. # If the model has __del__ attribute, call it.
  215. # This is to clean up the model resources eagerly.
  216. if hasattr(model, "__del__"):
  217. if not inspect.iscoroutinefunction(model.__del__):
  218. await asyncio.get_running_loop().run_in_executor(None, model.__del__)
  219. else:
  220. await model.__del__()
  221. model.__del__ = lambda _: None
  222. unload_latency_ms = (time.time() - unload_start_time) * 1000.0
  223. self.model_unload_latency_ms.observe(unload_latency_ms)
  224. logger.info(
  225. f"Successfully unloaded model '{model_id}' in {unload_latency_ms:.1f}ms."
  226. )
  227. self.registered_model_gauge.set(0, tags={"model_id": model_id})