chat_completion.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. # Copyright 2026 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Handler for the /v1/chat/completions endpoint.
  16. Supports streaming (SSE via DirectStreamer) and non-streaming (JSON) responses.
  17. """
  18. import asyncio
  19. import time
  20. from collections.abc import AsyncGenerator
  21. from typing import TYPE_CHECKING
  22. from ...utils import logging
  23. from ...utils.import_utils import is_serve_available
  24. if is_serve_available():
  25. from fastapi.responses import JSONResponse, StreamingResponse
  26. from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall
  27. from openai.types.chat.chat_completion import Choice
  28. from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall
  29. from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk
  30. from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming
  31. from openai.types.completion_usage import CompletionUsage
  32. from .utils import (
  33. BaseGenerateManager,
  34. BaseHandler,
  35. ToolCallParser,
  36. _StreamError,
  37. detect_tool_format,
  38. )
  39. if TYPE_CHECKING:
  40. from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin
  41. class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False):
  42. generation_config: str
  43. seed: int
  44. # Fields accepted by the OpenAI schema but not yet supported.
  45. # Receiving these raises an error to avoid silent misbehaviour.
  46. # NOTE: "stop" is NOT in this set — we map it to stop_strings.
  47. UNUSED_CHAT_COMPLETION_FIELDS = {
  48. "audio",
  49. "function_call",
  50. "functions",
  51. "logprobs",
  52. "max_completion_tokens",
  53. "metadata",
  54. "modalities",
  55. "n",
  56. "parallel_tool_calls",
  57. "prediction",
  58. "presence_penalty",
  59. "reasoning_effort",
  60. "response_format",
  61. "service_tier",
  62. "store",
  63. "stream_options",
  64. "tool_choice",
  65. "top_logprobs",
  66. "user",
  67. "web_search_options",
  68. }
  69. logger = logging.get_logger(__name__)
  70. class ChatCompletionHandler(BaseHandler):
  71. """Handler for the `/v1/chat/completions` endpoint.
  72. Supports both streaming (SSE) and non-streaming (JSON) responses.
  73. """
  74. _valid_params_class = TransformersCompletionCreateParamsStreaming
  75. _unused_fields = UNUSED_CHAT_COMPLETION_FIELDS
  76. async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse:
  77. """Validate the request, load the model, and dispatch to streaming or non-streaming.
  78. Args:
  79. body (`dict`): The raw JSON request body (OpenAI chat completion format).
  80. request_id (`str`): Unique request identifier (from header or auto-generated).
  81. Returns:
  82. `StreamingResponse | JSONResponse`: SSE stream or JSON depending on ``body["stream"]``.
  83. """
  84. self._validate_request(body)
  85. model_id, model, processor = self._resolve_model(body)
  86. modality = self.model_manager.get_model_modality(model, processor=processor)
  87. use_cb = self.generation_state.use_continuous_batching(model, modality)
  88. logger.warning(f"[Request received] Model: {model_id}, CB: {use_cb}")
  89. gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb)
  90. processor_inputs = self.get_processor_inputs_from_messages(body["messages"], modality)
  91. inputs = processor.apply_chat_template(
  92. processor_inputs,
  93. add_generation_prompt=True,
  94. tools=body.get("tools"),
  95. return_tensors=None if use_cb else "pt",
  96. return_dict=True,
  97. tokenize=True,
  98. )
  99. if not use_cb:
  100. inputs = inputs.to(model.device)
  101. gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb)
  102. # TODO: remove when CB supports per-request generation config
  103. if use_cb:
  104. gen_manager.init_cb(model, gen_config)
  105. # Detect tool support for the loaded model
  106. # TODO: after tool_call start token, use constrained generation to:
  107. # 1. force generation to pick from the available tool names
  108. # 2. force generation to produce valid JSON matching the tool's parameter schema
  109. tool_format = detect_tool_format(model) if body.get("tools") else None
  110. streaming = body.get("stream")
  111. if streaming:
  112. return self._streaming(
  113. request_id,
  114. model,
  115. processor,
  116. model_id,
  117. inputs,
  118. gen_config,
  119. gen_manager=gen_manager,
  120. tool_format=tool_format,
  121. )
  122. else:
  123. return await self._non_streaming(
  124. request_id,
  125. model,
  126. processor,
  127. model_id,
  128. inputs,
  129. gen_config,
  130. gen_manager=gen_manager,
  131. tool_format=tool_format,
  132. )
  133. # ----- streaming -----
  134. def _streaming(
  135. self,
  136. request_id: str,
  137. model: "PreTrainedModel",
  138. processor: "ProcessorMixin | PreTrainedTokenizerFast",
  139. model_id: str,
  140. inputs: dict,
  141. gen_config: "GenerationConfig",
  142. gen_manager: BaseGenerateManager,
  143. tool_format: dict | None = None,
  144. ) -> StreamingResponse:
  145. """Stream tokens as SSE via DirectStreamer."""
  146. queue, streamer = gen_manager.generate_streaming(model, processor, inputs, gen_config, request_id=request_id)
  147. input_ids = inputs["input_ids"]
  148. # CB returns plain lists, regular path returns tensors
  149. input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1]
  150. parser = ToolCallParser(tool_format) if tool_format else None
  151. async def sse_gen() -> AsyncGenerator[str, None]:
  152. has_tool_calls = False
  153. try:
  154. yield self._build_chunk_sse(request_id, role="assistant", model=model_id)
  155. done = False
  156. while not done:
  157. text = await queue.get()
  158. batch = [text]
  159. try:
  160. while True:
  161. batch.append(queue.get_nowait())
  162. except asyncio.QueueEmpty:
  163. pass
  164. sse_parts: list[str] = []
  165. for text in batch:
  166. if text is None:
  167. done = True
  168. break
  169. if isinstance(text, _StreamError):
  170. sse_parts.append(f'data: {{"error": "{text.msg}"}}\n\n')
  171. yield "".join(sse_parts)
  172. return
  173. # Tool call parsing: None = normal text, CONSUMED = buffering, else = tool call dict
  174. chunk_kwargs = {"content": text}
  175. if parser is not None and (result := parser.feed(text)) is not None:
  176. if result is ToolCallParser.CONSUMED:
  177. continue
  178. has_tool_calls = True
  179. chunk_kwargs = {
  180. "tool_calls": [
  181. ChoiceDeltaToolCall(
  182. index=0,
  183. type="function",
  184. id=f"{request_id}_tool_call",
  185. function={"name": result["name"], "arguments": result["arguments"]},
  186. )
  187. ]
  188. }
  189. sse_parts.append(self._build_chunk_sse(request_id, model=model_id, **chunk_kwargs))
  190. if sse_parts:
  191. yield "".join(sse_parts)
  192. hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens
  193. if has_tool_calls:
  194. finish_reason = "tool_calls"
  195. elif hit_max:
  196. finish_reason = "length"
  197. else:
  198. finish_reason = "stop"
  199. usage = CompletionUsage(
  200. prompt_tokens=input_len,
  201. completion_tokens=streamer.total_tokens,
  202. total_tokens=input_len + streamer.total_tokens,
  203. )
  204. yield self._build_chunk_sse(
  205. request_id,
  206. finish_reason=finish_reason,
  207. model=model_id,
  208. usage=usage,
  209. )
  210. except (GeneratorExit, asyncio.CancelledError):
  211. # Client disconnected — abort generation to free GPU.
  212. # Re-raise is mandatory: Python raises RuntimeError if GeneratorExit is swallowed.
  213. streamer.cancel()
  214. raise
  215. return StreamingResponse(sse_gen(), media_type="text/event-stream")
  216. # ----- non-streaming -----
  217. async def _non_streaming(
  218. self,
  219. request_id: str,
  220. model: "PreTrainedModel",
  221. processor: "ProcessorMixin | PreTrainedTokenizerFast",
  222. model_id: str,
  223. inputs: dict,
  224. gen_config: "GenerationConfig",
  225. gen_manager: BaseGenerateManager,
  226. tool_format: dict | None = None,
  227. ) -> JSONResponse:
  228. """Run generation and return a JSONResponse."""
  229. content, input_len, generated_ids = await gen_manager.generate_non_streaming(
  230. model, processor, inputs, gen_config, request_id=request_id
  231. )
  232. hit_max = gen_config.max_new_tokens is not None and len(generated_ids) >= gen_config.max_new_tokens
  233. completion_tokens = len(generated_ids)
  234. usage = CompletionUsage(
  235. prompt_tokens=input_len,
  236. completion_tokens=completion_tokens,
  237. total_tokens=input_len + completion_tokens,
  238. )
  239. # Parse tool calls from the generated text
  240. tool_calls = None
  241. if tool_format is not None:
  242. parsed = ToolCallParser.parse(content, tool_format)
  243. if parsed is not None:
  244. tool_calls = [
  245. ChatCompletionMessageToolCall(
  246. id=f"{request_id}_tool_call",
  247. type="function",
  248. function={"name": tc["name"], "arguments": tc["arguments"]},
  249. )
  250. for tc in parsed
  251. ]
  252. if tool_calls is not None:
  253. finish_reason = "tool_calls"
  254. elif hit_max:
  255. finish_reason = "length"
  256. else:
  257. finish_reason = "stop"
  258. return JSONResponse(
  259. self._build_completion(
  260. request_id,
  261. content,
  262. model_id,
  263. finish_reason=finish_reason,
  264. usage=usage,
  265. tool_calls=tool_calls,
  266. ),
  267. media_type="application/json",
  268. )
  269. # ----- helpers -----
  270. def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False):
  271. """Apply Chat Completions params (``max_tokens``, ``frequency_penalty``, ``logit_bias``,
  272. ``stop``) on top of the base generation config."""
  273. generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb)
  274. if body.get("max_tokens") is not None:
  275. generation_config.max_new_tokens = int(body["max_tokens"])
  276. if body.get("frequency_penalty") is not None:
  277. generation_config.repetition_penalty = 1.0 + float(body["frequency_penalty"])
  278. if body.get("logit_bias") is not None:
  279. generation_config.sequence_bias = {(int(k),): v for k, v in body["logit_bias"].items()}
  280. if body.get("stop") is not None:
  281. generation_config.stop_strings = body["stop"]
  282. return generation_config
  283. # ----- response builders -----
  284. def _build_completion(
  285. self,
  286. request_id: str,
  287. content: str,
  288. model_id: str,
  289. finish_reason: str,
  290. usage: CompletionUsage | None = None,
  291. tool_calls: list[dict] | None = None,
  292. ) -> dict:
  293. """Build a non-streaming ChatCompletion response dict.
  294. Args:
  295. request_id (`str`): Unique request identifier.
  296. content (`str`): The generated text.
  297. model_id (`str`): Model ID to include in the response.
  298. finish_reason (`str`): Why generation stopped (``"stop"``, ``"length"``, ``"tool_calls"``).
  299. usage (`CompletionUsage`, *optional*): Token usage statistics.
  300. tool_calls (`list[dict]`, *optional*): Parsed tool calls, if any.
  301. Returns:
  302. `dict`: Serialized ``ChatCompletion`` ready for JSON response.
  303. """
  304. message = ChatCompletionMessage(content=content, role="assistant", tool_calls=tool_calls)
  305. result = ChatCompletion(
  306. id=request_id,
  307. created=int(time.time()),
  308. object="chat.completion",
  309. model=model_id,
  310. choices=[
  311. Choice(
  312. index=0,
  313. message=message,
  314. finish_reason=finish_reason,
  315. )
  316. ],
  317. usage=usage,
  318. )
  319. return result.model_dump(exclude_none=True)
  320. def _build_chunk_sse(
  321. self,
  322. request_id: str = "",
  323. content: str | None = None,
  324. model: str | None = None,
  325. role: str | None = None,
  326. finish_reason: str | None = None,
  327. tool_calls: list | None = None,
  328. usage: CompletionUsage | None = None,
  329. ) -> str:
  330. """Build a streaming ``ChatCompletionChunk`` and format it as an SSE ``data:`` line.
  331. Args:
  332. request_id (`str`): Unique request identifier.
  333. content (`str`, *optional*): Text content delta.
  334. model (`str`, *optional*): Model ID.
  335. role (`str`, *optional*): Role (only sent in the first chunk).
  336. finish_reason (`str`, *optional*): Set on the final chunk.
  337. tool_calls (`list`, *optional*): Tool call deltas.
  338. usage (`CompletionUsage`, *optional*): Token usage (sent with the final chunk).
  339. Returns:
  340. `str`: A formatted SSE event string.
  341. """
  342. chunk = ChatCompletionChunk(
  343. id=request_id,
  344. created=int(time.time()),
  345. model=model,
  346. choices=[
  347. ChoiceChunk(
  348. delta=ChoiceDelta(content=content, role=role, tool_calls=tool_calls),
  349. index=0,
  350. finish_reason=finish_reason,
  351. )
  352. ],
  353. usage=usage,
  354. system_fingerprint="",
  355. object="chat.completion.chunk",
  356. )
  357. return self.chunk_to_sse(chunk)