| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410 |
- # Copyright 2026 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Handler for the /v1/chat/completions endpoint.
- Supports streaming (SSE via DirectStreamer) and non-streaming (JSON) responses.
- """
- import asyncio
- import time
- from collections.abc import AsyncGenerator
- from typing import TYPE_CHECKING
- from ...utils import logging
- from ...utils.import_utils import is_serve_available
- if is_serve_available():
- from fastapi.responses import JSONResponse, StreamingResponse
- from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall
- from openai.types.chat.chat_completion import Choice
- from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall
- from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk
- from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming
- from openai.types.completion_usage import CompletionUsage
- from .utils import (
- BaseGenerateManager,
- BaseHandler,
- ToolCallParser,
- _StreamError,
- detect_tool_format,
- )
- if TYPE_CHECKING:
- from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin
- class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False):
- generation_config: str
- seed: int
- # Fields accepted by the OpenAI schema but not yet supported.
- # Receiving these raises an error to avoid silent misbehaviour.
- # NOTE: "stop" is NOT in this set — we map it to stop_strings.
- UNUSED_CHAT_COMPLETION_FIELDS = {
- "audio",
- "function_call",
- "functions",
- "logprobs",
- "max_completion_tokens",
- "metadata",
- "modalities",
- "n",
- "parallel_tool_calls",
- "prediction",
- "presence_penalty",
- "reasoning_effort",
- "response_format",
- "service_tier",
- "store",
- "stream_options",
- "tool_choice",
- "top_logprobs",
- "user",
- "web_search_options",
- }
- logger = logging.get_logger(__name__)
- class ChatCompletionHandler(BaseHandler):
- """Handler for the `/v1/chat/completions` endpoint.
- Supports both streaming (SSE) and non-streaming (JSON) responses.
- """
- _valid_params_class = TransformersCompletionCreateParamsStreaming
- _unused_fields = UNUSED_CHAT_COMPLETION_FIELDS
- async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse:
- """Validate the request, load the model, and dispatch to streaming or non-streaming.
- Args:
- body (`dict`): The raw JSON request body (OpenAI chat completion format).
- request_id (`str`): Unique request identifier (from header or auto-generated).
- Returns:
- `StreamingResponse | JSONResponse`: SSE stream or JSON depending on ``body["stream"]``.
- """
- self._validate_request(body)
- model_id, model, processor = self._resolve_model(body)
- modality = self.model_manager.get_model_modality(model, processor=processor)
- use_cb = self.generation_state.use_continuous_batching(model, modality)
- logger.warning(f"[Request received] Model: {model_id}, CB: {use_cb}")
- gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb)
- processor_inputs = self.get_processor_inputs_from_messages(body["messages"], modality)
- inputs = processor.apply_chat_template(
- processor_inputs,
- add_generation_prompt=True,
- tools=body.get("tools"),
- return_tensors=None if use_cb else "pt",
- return_dict=True,
- tokenize=True,
- )
- if not use_cb:
- inputs = inputs.to(model.device)
- gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb)
- # TODO: remove when CB supports per-request generation config
- if use_cb:
- gen_manager.init_cb(model, gen_config)
- # Detect tool support for the loaded model
- # TODO: after tool_call start token, use constrained generation to:
- # 1. force generation to pick from the available tool names
- # 2. force generation to produce valid JSON matching the tool's parameter schema
- tool_format = detect_tool_format(model) if body.get("tools") else None
- streaming = body.get("stream")
- if streaming:
- return self._streaming(
- request_id,
- model,
- processor,
- model_id,
- inputs,
- gen_config,
- gen_manager=gen_manager,
- tool_format=tool_format,
- )
- else:
- return await self._non_streaming(
- request_id,
- model,
- processor,
- model_id,
- inputs,
- gen_config,
- gen_manager=gen_manager,
- tool_format=tool_format,
- )
- # ----- streaming -----
- def _streaming(
- self,
- request_id: str,
- model: "PreTrainedModel",
- processor: "ProcessorMixin | PreTrainedTokenizerFast",
- model_id: str,
- inputs: dict,
- gen_config: "GenerationConfig",
- gen_manager: BaseGenerateManager,
- tool_format: dict | None = None,
- ) -> StreamingResponse:
- """Stream tokens as SSE via DirectStreamer."""
- queue, streamer = gen_manager.generate_streaming(model, processor, inputs, gen_config, request_id=request_id)
- input_ids = inputs["input_ids"]
- # CB returns plain lists, regular path returns tensors
- input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1]
- parser = ToolCallParser(tool_format) if tool_format else None
- async def sse_gen() -> AsyncGenerator[str, None]:
- has_tool_calls = False
- try:
- yield self._build_chunk_sse(request_id, role="assistant", model=model_id)
- done = False
- while not done:
- text = await queue.get()
- batch = [text]
- try:
- while True:
- batch.append(queue.get_nowait())
- except asyncio.QueueEmpty:
- pass
- sse_parts: list[str] = []
- for text in batch:
- if text is None:
- done = True
- break
- if isinstance(text, _StreamError):
- sse_parts.append(f'data: {{"error": "{text.msg}"}}\n\n')
- yield "".join(sse_parts)
- return
- # Tool call parsing: None = normal text, CONSUMED = buffering, else = tool call dict
- chunk_kwargs = {"content": text}
- if parser is not None and (result := parser.feed(text)) is not None:
- if result is ToolCallParser.CONSUMED:
- continue
- has_tool_calls = True
- chunk_kwargs = {
- "tool_calls": [
- ChoiceDeltaToolCall(
- index=0,
- type="function",
- id=f"{request_id}_tool_call",
- function={"name": result["name"], "arguments": result["arguments"]},
- )
- ]
- }
- sse_parts.append(self._build_chunk_sse(request_id, model=model_id, **chunk_kwargs))
- if sse_parts:
- yield "".join(sse_parts)
- hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens
- if has_tool_calls:
- finish_reason = "tool_calls"
- elif hit_max:
- finish_reason = "length"
- else:
- finish_reason = "stop"
- usage = CompletionUsage(
- prompt_tokens=input_len,
- completion_tokens=streamer.total_tokens,
- total_tokens=input_len + streamer.total_tokens,
- )
- yield self._build_chunk_sse(
- request_id,
- finish_reason=finish_reason,
- model=model_id,
- usage=usage,
- )
- except (GeneratorExit, asyncio.CancelledError):
- # Client disconnected — abort generation to free GPU.
- # Re-raise is mandatory: Python raises RuntimeError if GeneratorExit is swallowed.
- streamer.cancel()
- raise
- return StreamingResponse(sse_gen(), media_type="text/event-stream")
- # ----- non-streaming -----
- async def _non_streaming(
- self,
- request_id: str,
- model: "PreTrainedModel",
- processor: "ProcessorMixin | PreTrainedTokenizerFast",
- model_id: str,
- inputs: dict,
- gen_config: "GenerationConfig",
- gen_manager: BaseGenerateManager,
- tool_format: dict | None = None,
- ) -> JSONResponse:
- """Run generation and return a JSONResponse."""
- content, input_len, generated_ids = await gen_manager.generate_non_streaming(
- model, processor, inputs, gen_config, request_id=request_id
- )
- hit_max = gen_config.max_new_tokens is not None and len(generated_ids) >= gen_config.max_new_tokens
- completion_tokens = len(generated_ids)
- usage = CompletionUsage(
- prompt_tokens=input_len,
- completion_tokens=completion_tokens,
- total_tokens=input_len + completion_tokens,
- )
- # Parse tool calls from the generated text
- tool_calls = None
- if tool_format is not None:
- parsed = ToolCallParser.parse(content, tool_format)
- if parsed is not None:
- tool_calls = [
- ChatCompletionMessageToolCall(
- id=f"{request_id}_tool_call",
- type="function",
- function={"name": tc["name"], "arguments": tc["arguments"]},
- )
- for tc in parsed
- ]
- if tool_calls is not None:
- finish_reason = "tool_calls"
- elif hit_max:
- finish_reason = "length"
- else:
- finish_reason = "stop"
- return JSONResponse(
- self._build_completion(
- request_id,
- content,
- model_id,
- finish_reason=finish_reason,
- usage=usage,
- tool_calls=tool_calls,
- ),
- media_type="application/json",
- )
- # ----- helpers -----
- def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False):
- """Apply Chat Completions params (``max_tokens``, ``frequency_penalty``, ``logit_bias``,
- ``stop``) on top of the base generation config."""
- generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb)
- if body.get("max_tokens") is not None:
- generation_config.max_new_tokens = int(body["max_tokens"])
- if body.get("frequency_penalty") is not None:
- generation_config.repetition_penalty = 1.0 + float(body["frequency_penalty"])
- if body.get("logit_bias") is not None:
- generation_config.sequence_bias = {(int(k),): v for k, v in body["logit_bias"].items()}
- if body.get("stop") is not None:
- generation_config.stop_strings = body["stop"]
- return generation_config
- # ----- response builders -----
- def _build_completion(
- self,
- request_id: str,
- content: str,
- model_id: str,
- finish_reason: str,
- usage: CompletionUsage | None = None,
- tool_calls: list[dict] | None = None,
- ) -> dict:
- """Build a non-streaming ChatCompletion response dict.
- Args:
- request_id (`str`): Unique request identifier.
- content (`str`): The generated text.
- model_id (`str`): Model ID to include in the response.
- finish_reason (`str`): Why generation stopped (``"stop"``, ``"length"``, ``"tool_calls"``).
- usage (`CompletionUsage`, *optional*): Token usage statistics.
- tool_calls (`list[dict]`, *optional*): Parsed tool calls, if any.
- Returns:
- `dict`: Serialized ``ChatCompletion`` ready for JSON response.
- """
- message = ChatCompletionMessage(content=content, role="assistant", tool_calls=tool_calls)
- result = ChatCompletion(
- id=request_id,
- created=int(time.time()),
- object="chat.completion",
- model=model_id,
- choices=[
- Choice(
- index=0,
- message=message,
- finish_reason=finish_reason,
- )
- ],
- usage=usage,
- )
- return result.model_dump(exclude_none=True)
- def _build_chunk_sse(
- self,
- request_id: str = "",
- content: str | None = None,
- model: str | None = None,
- role: str | None = None,
- finish_reason: str | None = None,
- tool_calls: list | None = None,
- usage: CompletionUsage | None = None,
- ) -> str:
- """Build a streaming ``ChatCompletionChunk`` and format it as an SSE ``data:`` line.
- Args:
- request_id (`str`): Unique request identifier.
- content (`str`, *optional*): Text content delta.
- model (`str`, *optional*): Model ID.
- role (`str`, *optional*): Role (only sent in the first chunk).
- finish_reason (`str`, *optional*): Set on the final chunk.
- tool_calls (`list`, *optional*): Tool call deltas.
- usage (`CompletionUsage`, *optional*): Token usage (sent with the final chunk).
- Returns:
- `str`: A formatted SSE event string.
- """
- chunk = ChatCompletionChunk(
- id=request_id,
- created=int(time.time()),
- model=model,
- choices=[
- ChoiceChunk(
- delta=ChoiceDelta(content=content, role=role, tool_calls=tool_calls),
- index=0,
- finish_reason=finish_reason,
- )
- ],
- usage=usage,
- system_fingerprint="",
- object="chat.completion.chunk",
- )
- return self.chunk_to_sse(chunk)
|