serve.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. CLI entry point for `transformers serve`.
  16. """
  17. import asyncio
  18. import threading
  19. from typing import Annotated
  20. import typer
  21. from transformers.utils import logging
  22. from transformers.utils.import_utils import is_serve_available
  23. from .serving.utils import set_torch_seed
  24. logger = logging.get_logger(__name__)
  25. class Serve:
  26. def __init__(
  27. self,
  28. force_model: Annotated[str | None, typer.Argument(help="Model to preload and use for all requests.")] = None,
  29. # Model options
  30. continuous_batching: Annotated[
  31. bool,
  32. typer.Option(help="Enable continuous batching with paged attention. Configure with --cb-* flags."),
  33. ] = False,
  34. cb_block_size: Annotated[
  35. int | None, typer.Option(help="KV cache block size in tokens for continuous batching.")
  36. ] = None,
  37. cb_num_blocks: Annotated[
  38. int | None, typer.Option(help="Number of KV cache blocks for continuous batching.")
  39. ] = None,
  40. cb_max_batch_tokens: Annotated[
  41. int | None, typer.Option(help="Maximum tokens per batch for continuous batching.")
  42. ] = None,
  43. cb_max_memory_percent: Annotated[
  44. float | None, typer.Option(help="Max GPU memory fraction for KV cache (0.0-1.0).")
  45. ] = None,
  46. cb_use_cuda_graph: Annotated[
  47. bool | None, typer.Option(help="Enable CUDA graphs for continuous batching.")
  48. ] = None,
  49. attn_implementation: Annotated[
  50. str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).")
  51. ] = None,
  52. compile: Annotated[bool, typer.Option(help="Enable torch.compile for faster inference.")] = False,
  53. quantization: Annotated[
  54. str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.")
  55. ] = None,
  56. device: Annotated[str, typer.Option(help="Device for inference (e.g. 'auto', 'cuda:0', 'cpu').")] = "auto",
  57. dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto",
  58. trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False,
  59. model_timeout: Annotated[
  60. int, typer.Option(help="Seconds before idle model is unloaded. Ignored when force_model is set.")
  61. ] = 300,
  62. # Server options
  63. host: Annotated[str, typer.Option(help="Server listen address.")] = "localhost",
  64. port: Annotated[int, typer.Option(help="Server listen port.")] = 8000,
  65. enable_cors: Annotated[bool, typer.Option(help="Enable permissive CORS.")] = False,
  66. log_level: Annotated[str, typer.Option(help="Logging level (e.g. 'info', 'warning').")] = "warning",
  67. default_seed: Annotated[int | None, typer.Option(help="Default torch seed.")] = None,
  68. non_blocking: Annotated[
  69. bool, typer.Option(hidden=True, help="Run server in a background thread. Used by tests.")
  70. ] = False,
  71. ) -> None:
  72. if not is_serve_available():
  73. raise ImportError("Missing dependencies for serving. Install with `pip install transformers[serving]`")
  74. import uvicorn
  75. from .serving.chat_completion import ChatCompletionHandler
  76. from .serving.model_manager import ModelManager
  77. from .serving.response import ResponseHandler
  78. from .serving.server import build_server
  79. from .serving.transcription import TranscriptionHandler
  80. from .serving.utils import GenerationState
  81. # Seed
  82. if default_seed is not None:
  83. set_torch_seed(default_seed)
  84. # Logging
  85. transformers_logger = logging.get_logger("transformers")
  86. transformers_logger.setLevel(logging.log_levels[log_level.lower()])
  87. self._model_manager = ModelManager(
  88. device=device,
  89. dtype=dtype,
  90. trust_remote_code=trust_remote_code,
  91. attn_implementation=attn_implementation,
  92. quantization=quantization,
  93. model_timeout=model_timeout,
  94. force_model=force_model,
  95. )
  96. from transformers import ContinuousBatchingConfig
  97. cb_kwargs = {
  98. k: v
  99. for k, v in {
  100. "block_size": cb_block_size,
  101. "num_blocks": cb_num_blocks,
  102. "max_batch_tokens": cb_max_batch_tokens,
  103. "max_memory_percent": cb_max_memory_percent,
  104. "use_cuda_graph": cb_use_cuda_graph,
  105. }.items()
  106. if v is not None
  107. }
  108. cb_config = ContinuousBatchingConfig(**cb_kwargs) if cb_kwargs else None
  109. self._generation_state = GenerationState(
  110. continuous_batching=continuous_batching,
  111. compile=compile,
  112. cb_config=cb_config,
  113. )
  114. self._chat_handler = ChatCompletionHandler(
  115. model_manager=self._model_manager,
  116. generation_state=self._generation_state,
  117. )
  118. self._response_handler = ResponseHandler(
  119. model_manager=self._model_manager,
  120. generation_state=self._generation_state,
  121. )
  122. self._transcription_handler = TranscriptionHandler(self._model_manager, self._generation_state)
  123. app = build_server(
  124. self._model_manager,
  125. self._chat_handler,
  126. response_handler=self._response_handler,
  127. transcription_handler=self._transcription_handler,
  128. enable_cors=enable_cors,
  129. )
  130. config = uvicorn.Config(app, host=host, port=port, log_level="info")
  131. self.server = uvicorn.Server(config)
  132. if non_blocking:
  133. self.start_server()
  134. else:
  135. self.server.run()
  136. def start_server(self):
  137. def _run():
  138. loop = asyncio.new_event_loop()
  139. asyncio.set_event_loop(loop)
  140. loop.run_until_complete(self.server.serve())
  141. self._thread = threading.Thread(target=_run, name="uvicorn-thread", daemon=False)
  142. self._thread.start()
  143. def reset_loaded_models(self):
  144. """Clear all loaded models from memory."""
  145. self._model_manager.shutdown()
  146. def kill_server(self):
  147. self._generation_state.shutdown()
  148. self._model_manager.shutdown()
  149. if not self._thread or not self._thread.is_alive():
  150. return
  151. self.server.should_exit = True
  152. self._thread.join(timeout=2)
  153. Serve.__doc__ = """
  154. Run a FastAPI server to serve models on-demand with an OpenAI compatible API.
  155. Models will be loaded and unloaded automatically based on usage and a timeout.
  156. \b
  157. Endpoints:
  158. POST /v1/chat/completions — Chat completions (streaming + non-streaming).
  159. GET /v1/models — Lists available models.
  160. GET /health — Health check.
  161. Requires FastAPI and Uvicorn: pip install transformers[serving]
  162. """