| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564 |
- # Copyright 2026 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Handler for the /v1/responses endpoint (OpenAI Responses API).
- Supports streaming (SSE) and non-streaming (JSON) responses.
- """
- import asyncio
- import time
- from collections.abc import AsyncGenerator
- from typing import TYPE_CHECKING
- from ...utils import logging
- from ...utils.import_utils import is_serve_available
- if is_serve_available():
- from fastapi import HTTPException
- from fastapi.responses import JSONResponse, StreamingResponse
- from openai.types.responses import (
- Response,
- ResponseCompletedEvent,
- ResponseContentPartAddedEvent,
- ResponseContentPartDoneEvent,
- ResponseCreatedEvent,
- ResponseError,
- ResponseErrorEvent,
- ResponseFailedEvent,
- ResponseFunctionCallArgumentsDoneEvent,
- ResponseFunctionToolCall,
- ResponseInProgressEvent,
- ResponseOutputItemAddedEvent,
- ResponseOutputItemDoneEvent,
- ResponseOutputMessage,
- ResponseOutputText,
- ResponseTextDeltaEvent,
- ResponseTextDoneEvent,
- )
- from openai.types.responses.response_create_params import ResponseCreateParamsStreaming
- from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage
- from .utils import (
- BaseGenerateManager,
- BaseHandler,
- ToolCallParser,
- _StreamError,
- detect_tool_format,
- )
- if TYPE_CHECKING:
- from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin
- logger = logging.get_logger(__name__)
- class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False):
- generation_config: str
- seed: int
- UNUSED_RESPONSE_FIELDS = {
- "background",
- "include",
- "max_tool_calls",
- "previous_response_id",
- "prompt",
- "reasoning",
- "service_tier",
- "store",
- "text",
- "tool_choice",
- "top_logprobs",
- "truncation",
- "user",
- }
- class ResponseHandler(BaseHandler):
- """Handler for the ``/v1/responses`` endpoint."""
- _valid_params_class = TransformersResponseCreateParamsStreaming
- _unused_fields = UNUSED_RESPONSE_FIELDS
- async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse:
- """Validate, load model, dispatch to streaming or non-streaming.
- Args:
- body (`dict`): The raw JSON request body (OpenAI Responses API format).
- request_id (`str`): Unique request identifier (from header or auto-generated).
- Returns:
- `StreamingResponse | JSONResponse`: SSE stream or JSON depending on ``body["stream"]``.
- """
- self._validate_request(body)
- model_id, model, processor = self._resolve_model(body)
- modality = self.model_manager.get_model_modality(model, processor=processor)
- use_cb = self.generation_state.use_continuous_batching(model, modality)
- logger.warning(f"[Request received] Model: {model_id}, CB: {use_cb}")
- gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb)
- # Two-step input conversion (chat completions skips step 1 since messages are already standard):
- # 1. Normalize Responses API input (string/list/dict + instructions) → standard messages list
- # 2. Transform message content for the HF processor (VLM image handling, text joining, etc.)
- messages = self._input_to_messages(body)
- processor_inputs = self.get_processor_inputs_from_messages(messages, modality)
- inputs = processor.apply_chat_template(
- processor_inputs,
- add_generation_prompt=True,
- tools=body.get("tools"),
- return_tensors=None if use_cb else "pt",
- return_dict=True,
- tokenize=True,
- )
- if not use_cb:
- inputs = inputs.to(model.device)
- gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb)
- # TODO: remove when CB supports per-request generation config
- if use_cb:
- gen_manager.init_cb(model, gen_config)
- tool_format = detect_tool_format(model) if body.get("tools") else None
- streaming = body.get("stream", True)
- if streaming:
- return self._streaming(
- request_id,
- model,
- processor,
- model_id,
- body,
- inputs,
- gen_config,
- gen_manager=gen_manager,
- tool_format=tool_format,
- )
- else:
- return await self._non_streaming(
- request_id,
- model,
- processor,
- model_id,
- body,
- inputs,
- gen_config,
- gen_manager=gen_manager,
- tool_format=tool_format,
- )
- # ----- input conversion -----
- @staticmethod
- def _input_to_messages(body: dict) -> list[dict]:
- """Convert the Responses API ``input`` field to a list of chat messages.
- Handles string, list, and dict inputs. If ``instructions`` is provided, it is
- prepended as a system message (or replaces an existing one).
- Args:
- body (`dict`): The raw request body containing ``input`` and optionally ``instructions``.
- Returns:
- `list[dict]`: Standard OpenAI-format chat messages.
- """
- inp = body["input"]
- instructions = body.get("instructions")
- if isinstance(inp, str):
- messages = [{"role": "system", "content": instructions}] if instructions else []
- messages.append({"role": "user", "content": inp})
- elif isinstance(inp, list):
- if instructions:
- if inp[0]["role"] != "system":
- messages = [{"role": "system", "content": instructions}, *inp]
- else:
- messages = list(inp)
- messages[0]["content"] = instructions
- else:
- messages = inp
- elif isinstance(inp, dict):
- messages = [{"role": "system", "content": instructions}] if instructions else []
- messages.append(inp)
- else:
- raise HTTPException(status_code=422, detail="'input' must be a string, list, or dict")
- return messages
- # ----- streaming -----
- def _streaming(
- self,
- request_id: str,
- model: "PreTrainedModel",
- processor: "ProcessorMixin | PreTrainedTokenizerFast",
- model_id: str,
- body: dict,
- inputs: dict,
- gen_config: "GenerationConfig",
- gen_manager: BaseGenerateManager,
- tool_format: dict | None = None,
- ) -> StreamingResponse:
- """Generate a streaming Responses API reply (SSE) using DirectStreamer."""
- queue, streamer = gen_manager.generate_streaming(model, processor, inputs, gen_config, request_id=request_id)
- input_ids = inputs["input_ids"]
- # CB returns plain lists, regular path returns tensors
- input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1]
- parser = ToolCallParser(tool_format) if tool_format else None
- seq = 0
- output_index = 0
- created_at = time.time()
- resp_id = f"resp_{request_id}"
- msg_id = f"msg_{request_id}"
- response_base = {
- "id": resp_id,
- "created_at": created_at,
- "model": model_id,
- "object": "response",
- # Required by pydantic but not used — echo request config back
- "tools": [],
- "parallel_tool_calls": body.get("parallel_tool_calls", False),
- "tool_choice": "auto",
- }
- async def event_stream() -> AsyncGenerator[str, None]:
- nonlocal seq, output_index
- try:
- # 1. Created + In progress
- yield self.chunk_to_sse(
- ResponseCreatedEvent(
- type="response.created",
- sequence_number=seq,
- response=Response(**response_base, status="queued", output=[]),
- )
- )
- seq += 1
- yield self.chunk_to_sse(
- ResponseInProgressEvent(
- type="response.in_progress",
- sequence_number=seq,
- response=Response(**response_base, status="in_progress", output=[]),
- )
- )
- seq += 1
- # 2. Output item added (message)
- yield self.chunk_to_sse(
- ResponseOutputItemAddedEvent(
- type="response.output_item.added",
- sequence_number=seq,
- output_index=output_index,
- item=ResponseOutputMessage(
- id=msg_id,
- type="message",
- status="in_progress",
- role="assistant",
- content=[],
- ),
- )
- )
- seq += 1
- # 3. Content part added
- yield self.chunk_to_sse(
- ResponseContentPartAddedEvent(
- type="response.content_part.added",
- item_id=msg_id,
- sequence_number=seq,
- output_index=output_index,
- content_index=0,
- part=ResponseOutputText(type="output_text", text="", annotations=[]),
- )
- )
- seq += 1
- # 4. Stream tokens — drain queue to batch HTTP writes
- full_text = ""
- tool_calls = []
- done = False
- while not done:
- text = await queue.get()
- # Drain all available tokens for one batched HTTP write
- batch = [text]
- try:
- while True:
- batch.append(queue.get_nowait())
- except asyncio.QueueEmpty:
- pass
- sse_parts: list[str] = []
- for text in batch:
- if text is None:
- done = True
- break
- if isinstance(text, _StreamError):
- logger.error(f"Exception in response generation: {text.msg}")
- sse_parts.append(
- self.chunk_to_sse(
- ResponseErrorEvent(type="error", sequence_number=seq, message=text.msg)
- )
- )
- seq += 1
- sse_parts.append(
- self.chunk_to_sse(
- ResponseFailedEvent(
- type="response.failed",
- sequence_number=seq,
- response=Response(
- **response_base,
- status="failed",
- output=[],
- error=ResponseError(code="server_error", message=text.msg),
- ),
- )
- )
- )
- yield "".join(sse_parts)
- return
- # Tool call parsing
- if parser is not None and (result := parser.feed(text)) is not None:
- if result is not ToolCallParser.CONSUMED:
- tc_id = f"{request_id}_tool_call"
- name = result["name"]
- arguments = result["arguments"]
- tc_item = ResponseFunctionToolCall(
- id=tc_id,
- call_id=tc_id,
- type="function_call",
- name=name,
- arguments=arguments,
- status="completed",
- )
- tool_calls.append(tc_item)
- output_index += 1
- sse_parts.append(
- self.chunk_to_sse(
- ResponseOutputItemAddedEvent(
- type="response.output_item.added",
- sequence_number=seq,
- output_index=output_index,
- item=tc_item,
- )
- )
- )
- seq += 1
- sse_parts.append(
- self.chunk_to_sse(
- ResponseFunctionCallArgumentsDoneEvent(
- type="response.function_call_arguments.done",
- sequence_number=seq,
- item_id=tc_id,
- output_index=output_index,
- arguments=arguments,
- name=name,
- )
- )
- )
- seq += 1
- sse_parts.append(
- self.chunk_to_sse(
- ResponseOutputItemDoneEvent(
- type="response.output_item.done",
- sequence_number=seq,
- output_index=output_index,
- item=tc_item,
- )
- )
- )
- seq += 1
- continue
- full_text += text
- sse_parts.append(
- self.chunk_to_sse(
- ResponseTextDeltaEvent(
- type="response.output_text.delta",
- item_id=msg_id,
- sequence_number=seq,
- output_index=0,
- content_index=0,
- delta=text,
- logprobs=[],
- )
- )
- )
- seq += 1
- if sse_parts:
- yield "".join(sse_parts)
- # 5. Close text output
- output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[])
- yield self.chunk_to_sse(
- ResponseTextDoneEvent(
- type="response.output_text.done",
- item_id=msg_id,
- sequence_number=seq,
- output_index=0,
- content_index=0,
- text=full_text,
- logprobs=[],
- )
- )
- seq += 1
- yield self.chunk_to_sse(
- ResponseContentPartDoneEvent(
- type="response.content_part.done",
- item_id=msg_id,
- sequence_number=seq,
- output_index=0,
- content_index=0,
- part=output_text_part,
- )
- )
- seq += 1
- msg_item = ResponseOutputMessage(
- id=msg_id,
- type="message",
- status="completed",
- role="assistant",
- content=[output_text_part],
- annotations=[],
- )
- yield self.chunk_to_sse(
- ResponseOutputItemDoneEvent(
- type="response.output_item.done",
- sequence_number=seq,
- output_index=0,
- item=msg_item,
- )
- )
- seq += 1
- # 6. Completed
- all_output = [msg_item] + list(tool_calls)
- usage = compute_usage(input_len, streamer.total_tokens)
- yield self.chunk_to_sse(
- ResponseCompletedEvent(
- type="response.completed",
- sequence_number=seq,
- response=Response(**response_base, status="completed", output=all_output, usage=usage),
- )
- )
- seq += 1
- except (GeneratorExit, asyncio.CancelledError):
- # Client disconnected — abort generation to free GPU.
- # Re-raise is mandatory: Python raises RuntimeError if GeneratorExit is swallowed.
- streamer.cancel()
- raise
- return StreamingResponse(event_stream(), media_type="text/event-stream")
- # ----- non-streaming -----
- async def _non_streaming(
- self,
- request_id: str,
- model: "PreTrainedModel",
- processor: "ProcessorMixin | PreTrainedTokenizerFast",
- model_id: str,
- body: dict,
- inputs: dict,
- gen_config: "GenerationConfig",
- gen_manager: BaseGenerateManager,
- tool_format: dict | None = None,
- ) -> JSONResponse:
- """Generate a non-streaming Responses API reply (single JSON)."""
- full_text, input_len, generated_ids = await gen_manager.generate_non_streaming(
- model, processor, inputs, gen_config, request_id=request_id
- )
- output_items = [
- ResponseOutputMessage(
- id=f"msg_{request_id}",
- type="message",
- status="completed",
- role="assistant",
- content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])],
- annotations=[],
- )
- ]
- # Parse tool calls from the generated text
- if tool_format is not None:
- parsed_calls = ToolCallParser.parse(full_text, tool_format)
- if parsed_calls is not None:
- for i, tc in enumerate(parsed_calls):
- tc_id = f"{request_id}_tool_call"
- output_items.append(
- ResponseFunctionToolCall(
- id=tc_id,
- call_id=tc_id,
- type="function_call",
- name=tc["name"],
- arguments=tc["arguments"],
- status="completed",
- )
- )
- usage = compute_usage(input_len, len(generated_ids))
- response = Response(
- id=f"resp_{request_id}",
- created_at=time.time(),
- status="completed",
- model=model_id,
- output=output_items,
- object="response",
- usage=usage,
- # Required by pydantic but not used — echo request config back
- tools=[],
- parallel_tool_calls=body.get("parallel_tool_calls", False),
- tool_choice="auto",
- )
- return JSONResponse(response.model_dump(exclude_none=True))
- # ----- helpers -----
- def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False):
- """Apply Responses API params (``max_output_tokens``) on top of the base generation config."""
- generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb)
- if body.get("max_output_tokens") is not None:
- generation_config.max_new_tokens = int(body["max_output_tokens"])
- return generation_config
- def compute_usage(input_tokens: int, output_tokens: int) -> ResponseUsage:
- """Build a ``ResponseUsage`` object for a Responses API reply.
- Args:
- input_tokens (`int`): Number of prompt tokens.
- output_tokens (`int`): Number of generated tokens.
- Returns:
- `ResponseUsage`: Usage statistics with zero-filled detail fields.
- """
- return ResponseUsage(
- input_tokens=input_tokens,
- output_tokens=output_tokens,
- total_tokens=input_tokens + output_tokens,
- input_tokens_details=InputTokensDetails(cached_tokens=0),
- output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
- )
|