response.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. # Copyright 2026 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Handler for the /v1/responses endpoint (OpenAI Responses API).
  16. Supports streaming (SSE) and non-streaming (JSON) responses.
  17. """
  18. import asyncio
  19. import time
  20. from collections.abc import AsyncGenerator
  21. from typing import TYPE_CHECKING
  22. from ...utils import logging
  23. from ...utils.import_utils import is_serve_available
  24. if is_serve_available():
  25. from fastapi import HTTPException
  26. from fastapi.responses import JSONResponse, StreamingResponse
  27. from openai.types.responses import (
  28. Response,
  29. ResponseCompletedEvent,
  30. ResponseContentPartAddedEvent,
  31. ResponseContentPartDoneEvent,
  32. ResponseCreatedEvent,
  33. ResponseError,
  34. ResponseErrorEvent,
  35. ResponseFailedEvent,
  36. ResponseFunctionCallArgumentsDoneEvent,
  37. ResponseFunctionToolCall,
  38. ResponseInProgressEvent,
  39. ResponseOutputItemAddedEvent,
  40. ResponseOutputItemDoneEvent,
  41. ResponseOutputMessage,
  42. ResponseOutputText,
  43. ResponseTextDeltaEvent,
  44. ResponseTextDoneEvent,
  45. )
  46. from openai.types.responses.response_create_params import ResponseCreateParamsStreaming
  47. from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage
  48. from .utils import (
  49. BaseGenerateManager,
  50. BaseHandler,
  51. ToolCallParser,
  52. _StreamError,
  53. detect_tool_format,
  54. )
  55. if TYPE_CHECKING:
  56. from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin
  57. logger = logging.get_logger(__name__)
  58. class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False):
  59. generation_config: str
  60. seed: int
  61. UNUSED_RESPONSE_FIELDS = {
  62. "background",
  63. "include",
  64. "max_tool_calls",
  65. "previous_response_id",
  66. "prompt",
  67. "reasoning",
  68. "service_tier",
  69. "store",
  70. "text",
  71. "tool_choice",
  72. "top_logprobs",
  73. "truncation",
  74. "user",
  75. }
  76. class ResponseHandler(BaseHandler):
  77. """Handler for the ``/v1/responses`` endpoint."""
  78. _valid_params_class = TransformersResponseCreateParamsStreaming
  79. _unused_fields = UNUSED_RESPONSE_FIELDS
  80. async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse:
  81. """Validate, load model, dispatch to streaming or non-streaming.
  82. Args:
  83. body (`dict`): The raw JSON request body (OpenAI Responses API format).
  84. request_id (`str`): Unique request identifier (from header or auto-generated).
  85. Returns:
  86. `StreamingResponse | JSONResponse`: SSE stream or JSON depending on ``body["stream"]``.
  87. """
  88. self._validate_request(body)
  89. model_id, model, processor = self._resolve_model(body)
  90. modality = self.model_manager.get_model_modality(model, processor=processor)
  91. use_cb = self.generation_state.use_continuous_batching(model, modality)
  92. logger.warning(f"[Request received] Model: {model_id}, CB: {use_cb}")
  93. gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb)
  94. # Two-step input conversion (chat completions skips step 1 since messages are already standard):
  95. # 1. Normalize Responses API input (string/list/dict + instructions) → standard messages list
  96. # 2. Transform message content for the HF processor (VLM image handling, text joining, etc.)
  97. messages = self._input_to_messages(body)
  98. processor_inputs = self.get_processor_inputs_from_messages(messages, modality)
  99. inputs = processor.apply_chat_template(
  100. processor_inputs,
  101. add_generation_prompt=True,
  102. tools=body.get("tools"),
  103. return_tensors=None if use_cb else "pt",
  104. return_dict=True,
  105. tokenize=True,
  106. )
  107. if not use_cb:
  108. inputs = inputs.to(model.device)
  109. gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb)
  110. # TODO: remove when CB supports per-request generation config
  111. if use_cb:
  112. gen_manager.init_cb(model, gen_config)
  113. tool_format = detect_tool_format(model) if body.get("tools") else None
  114. streaming = body.get("stream", True)
  115. if streaming:
  116. return self._streaming(
  117. request_id,
  118. model,
  119. processor,
  120. model_id,
  121. body,
  122. inputs,
  123. gen_config,
  124. gen_manager=gen_manager,
  125. tool_format=tool_format,
  126. )
  127. else:
  128. return await self._non_streaming(
  129. request_id,
  130. model,
  131. processor,
  132. model_id,
  133. body,
  134. inputs,
  135. gen_config,
  136. gen_manager=gen_manager,
  137. tool_format=tool_format,
  138. )
  139. # ----- input conversion -----
  140. @staticmethod
  141. def _input_to_messages(body: dict) -> list[dict]:
  142. """Convert the Responses API ``input`` field to a list of chat messages.
  143. Handles string, list, and dict inputs. If ``instructions`` is provided, it is
  144. prepended as a system message (or replaces an existing one).
  145. Args:
  146. body (`dict`): The raw request body containing ``input`` and optionally ``instructions``.
  147. Returns:
  148. `list[dict]`: Standard OpenAI-format chat messages.
  149. """
  150. inp = body["input"]
  151. instructions = body.get("instructions")
  152. if isinstance(inp, str):
  153. messages = [{"role": "system", "content": instructions}] if instructions else []
  154. messages.append({"role": "user", "content": inp})
  155. elif isinstance(inp, list):
  156. if instructions:
  157. if inp[0]["role"] != "system":
  158. messages = [{"role": "system", "content": instructions}, *inp]
  159. else:
  160. messages = list(inp)
  161. messages[0]["content"] = instructions
  162. else:
  163. messages = inp
  164. elif isinstance(inp, dict):
  165. messages = [{"role": "system", "content": instructions}] if instructions else []
  166. messages.append(inp)
  167. else:
  168. raise HTTPException(status_code=422, detail="'input' must be a string, list, or dict")
  169. return messages
  170. # ----- streaming -----
  171. def _streaming(
  172. self,
  173. request_id: str,
  174. model: "PreTrainedModel",
  175. processor: "ProcessorMixin | PreTrainedTokenizerFast",
  176. model_id: str,
  177. body: dict,
  178. inputs: dict,
  179. gen_config: "GenerationConfig",
  180. gen_manager: BaseGenerateManager,
  181. tool_format: dict | None = None,
  182. ) -> StreamingResponse:
  183. """Generate a streaming Responses API reply (SSE) using DirectStreamer."""
  184. queue, streamer = gen_manager.generate_streaming(model, processor, inputs, gen_config, request_id=request_id)
  185. input_ids = inputs["input_ids"]
  186. # CB returns plain lists, regular path returns tensors
  187. input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1]
  188. parser = ToolCallParser(tool_format) if tool_format else None
  189. seq = 0
  190. output_index = 0
  191. created_at = time.time()
  192. resp_id = f"resp_{request_id}"
  193. msg_id = f"msg_{request_id}"
  194. response_base = {
  195. "id": resp_id,
  196. "created_at": created_at,
  197. "model": model_id,
  198. "object": "response",
  199. # Required by pydantic but not used — echo request config back
  200. "tools": [],
  201. "parallel_tool_calls": body.get("parallel_tool_calls", False),
  202. "tool_choice": "auto",
  203. }
  204. async def event_stream() -> AsyncGenerator[str, None]:
  205. nonlocal seq, output_index
  206. try:
  207. # 1. Created + In progress
  208. yield self.chunk_to_sse(
  209. ResponseCreatedEvent(
  210. type="response.created",
  211. sequence_number=seq,
  212. response=Response(**response_base, status="queued", output=[]),
  213. )
  214. )
  215. seq += 1
  216. yield self.chunk_to_sse(
  217. ResponseInProgressEvent(
  218. type="response.in_progress",
  219. sequence_number=seq,
  220. response=Response(**response_base, status="in_progress", output=[]),
  221. )
  222. )
  223. seq += 1
  224. # 2. Output item added (message)
  225. yield self.chunk_to_sse(
  226. ResponseOutputItemAddedEvent(
  227. type="response.output_item.added",
  228. sequence_number=seq,
  229. output_index=output_index,
  230. item=ResponseOutputMessage(
  231. id=msg_id,
  232. type="message",
  233. status="in_progress",
  234. role="assistant",
  235. content=[],
  236. ),
  237. )
  238. )
  239. seq += 1
  240. # 3. Content part added
  241. yield self.chunk_to_sse(
  242. ResponseContentPartAddedEvent(
  243. type="response.content_part.added",
  244. item_id=msg_id,
  245. sequence_number=seq,
  246. output_index=output_index,
  247. content_index=0,
  248. part=ResponseOutputText(type="output_text", text="", annotations=[]),
  249. )
  250. )
  251. seq += 1
  252. # 4. Stream tokens — drain queue to batch HTTP writes
  253. full_text = ""
  254. tool_calls = []
  255. done = False
  256. while not done:
  257. text = await queue.get()
  258. # Drain all available tokens for one batched HTTP write
  259. batch = [text]
  260. try:
  261. while True:
  262. batch.append(queue.get_nowait())
  263. except asyncio.QueueEmpty:
  264. pass
  265. sse_parts: list[str] = []
  266. for text in batch:
  267. if text is None:
  268. done = True
  269. break
  270. if isinstance(text, _StreamError):
  271. logger.error(f"Exception in response generation: {text.msg}")
  272. sse_parts.append(
  273. self.chunk_to_sse(
  274. ResponseErrorEvent(type="error", sequence_number=seq, message=text.msg)
  275. )
  276. )
  277. seq += 1
  278. sse_parts.append(
  279. self.chunk_to_sse(
  280. ResponseFailedEvent(
  281. type="response.failed",
  282. sequence_number=seq,
  283. response=Response(
  284. **response_base,
  285. status="failed",
  286. output=[],
  287. error=ResponseError(code="server_error", message=text.msg),
  288. ),
  289. )
  290. )
  291. )
  292. yield "".join(sse_parts)
  293. return
  294. # Tool call parsing
  295. if parser is not None and (result := parser.feed(text)) is not None:
  296. if result is not ToolCallParser.CONSUMED:
  297. tc_id = f"{request_id}_tool_call"
  298. name = result["name"]
  299. arguments = result["arguments"]
  300. tc_item = ResponseFunctionToolCall(
  301. id=tc_id,
  302. call_id=tc_id,
  303. type="function_call",
  304. name=name,
  305. arguments=arguments,
  306. status="completed",
  307. )
  308. tool_calls.append(tc_item)
  309. output_index += 1
  310. sse_parts.append(
  311. self.chunk_to_sse(
  312. ResponseOutputItemAddedEvent(
  313. type="response.output_item.added",
  314. sequence_number=seq,
  315. output_index=output_index,
  316. item=tc_item,
  317. )
  318. )
  319. )
  320. seq += 1
  321. sse_parts.append(
  322. self.chunk_to_sse(
  323. ResponseFunctionCallArgumentsDoneEvent(
  324. type="response.function_call_arguments.done",
  325. sequence_number=seq,
  326. item_id=tc_id,
  327. output_index=output_index,
  328. arguments=arguments,
  329. name=name,
  330. )
  331. )
  332. )
  333. seq += 1
  334. sse_parts.append(
  335. self.chunk_to_sse(
  336. ResponseOutputItemDoneEvent(
  337. type="response.output_item.done",
  338. sequence_number=seq,
  339. output_index=output_index,
  340. item=tc_item,
  341. )
  342. )
  343. )
  344. seq += 1
  345. continue
  346. full_text += text
  347. sse_parts.append(
  348. self.chunk_to_sse(
  349. ResponseTextDeltaEvent(
  350. type="response.output_text.delta",
  351. item_id=msg_id,
  352. sequence_number=seq,
  353. output_index=0,
  354. content_index=0,
  355. delta=text,
  356. logprobs=[],
  357. )
  358. )
  359. )
  360. seq += 1
  361. if sse_parts:
  362. yield "".join(sse_parts)
  363. # 5. Close text output
  364. output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[])
  365. yield self.chunk_to_sse(
  366. ResponseTextDoneEvent(
  367. type="response.output_text.done",
  368. item_id=msg_id,
  369. sequence_number=seq,
  370. output_index=0,
  371. content_index=0,
  372. text=full_text,
  373. logprobs=[],
  374. )
  375. )
  376. seq += 1
  377. yield self.chunk_to_sse(
  378. ResponseContentPartDoneEvent(
  379. type="response.content_part.done",
  380. item_id=msg_id,
  381. sequence_number=seq,
  382. output_index=0,
  383. content_index=0,
  384. part=output_text_part,
  385. )
  386. )
  387. seq += 1
  388. msg_item = ResponseOutputMessage(
  389. id=msg_id,
  390. type="message",
  391. status="completed",
  392. role="assistant",
  393. content=[output_text_part],
  394. annotations=[],
  395. )
  396. yield self.chunk_to_sse(
  397. ResponseOutputItemDoneEvent(
  398. type="response.output_item.done",
  399. sequence_number=seq,
  400. output_index=0,
  401. item=msg_item,
  402. )
  403. )
  404. seq += 1
  405. # 6. Completed
  406. all_output = [msg_item] + list(tool_calls)
  407. usage = compute_usage(input_len, streamer.total_tokens)
  408. yield self.chunk_to_sse(
  409. ResponseCompletedEvent(
  410. type="response.completed",
  411. sequence_number=seq,
  412. response=Response(**response_base, status="completed", output=all_output, usage=usage),
  413. )
  414. )
  415. seq += 1
  416. except (GeneratorExit, asyncio.CancelledError):
  417. # Client disconnected — abort generation to free GPU.
  418. # Re-raise is mandatory: Python raises RuntimeError if GeneratorExit is swallowed.
  419. streamer.cancel()
  420. raise
  421. return StreamingResponse(event_stream(), media_type="text/event-stream")
  422. # ----- non-streaming -----
  423. async def _non_streaming(
  424. self,
  425. request_id: str,
  426. model: "PreTrainedModel",
  427. processor: "ProcessorMixin | PreTrainedTokenizerFast",
  428. model_id: str,
  429. body: dict,
  430. inputs: dict,
  431. gen_config: "GenerationConfig",
  432. gen_manager: BaseGenerateManager,
  433. tool_format: dict | None = None,
  434. ) -> JSONResponse:
  435. """Generate a non-streaming Responses API reply (single JSON)."""
  436. full_text, input_len, generated_ids = await gen_manager.generate_non_streaming(
  437. model, processor, inputs, gen_config, request_id=request_id
  438. )
  439. output_items = [
  440. ResponseOutputMessage(
  441. id=f"msg_{request_id}",
  442. type="message",
  443. status="completed",
  444. role="assistant",
  445. content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])],
  446. annotations=[],
  447. )
  448. ]
  449. # Parse tool calls from the generated text
  450. if tool_format is not None:
  451. parsed_calls = ToolCallParser.parse(full_text, tool_format)
  452. if parsed_calls is not None:
  453. for i, tc in enumerate(parsed_calls):
  454. tc_id = f"{request_id}_tool_call"
  455. output_items.append(
  456. ResponseFunctionToolCall(
  457. id=tc_id,
  458. call_id=tc_id,
  459. type="function_call",
  460. name=tc["name"],
  461. arguments=tc["arguments"],
  462. status="completed",
  463. )
  464. )
  465. usage = compute_usage(input_len, len(generated_ids))
  466. response = Response(
  467. id=f"resp_{request_id}",
  468. created_at=time.time(),
  469. status="completed",
  470. model=model_id,
  471. output=output_items,
  472. object="response",
  473. usage=usage,
  474. # Required by pydantic but not used — echo request config back
  475. tools=[],
  476. parallel_tool_calls=body.get("parallel_tool_calls", False),
  477. tool_choice="auto",
  478. )
  479. return JSONResponse(response.model_dump(exclude_none=True))
  480. # ----- helpers -----
  481. def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False):
  482. """Apply Responses API params (``max_output_tokens``) on top of the base generation config."""
  483. generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb)
  484. if body.get("max_output_tokens") is not None:
  485. generation_config.max_new_tokens = int(body["max_output_tokens"])
  486. return generation_config
  487. def compute_usage(input_tokens: int, output_tokens: int) -> ResponseUsage:
  488. """Build a ``ResponseUsage`` object for a Responses API reply.
  489. Args:
  490. input_tokens (`int`): Number of prompt tokens.
  491. output_tokens (`int`): Number of generated tokens.
  492. Returns:
  493. `ResponseUsage`: Usage statistics with zero-filled detail fields.
  494. """
  495. return ResponseUsage(
  496. input_tokens=input_tokens,
  497. output_tokens=output_tokens,
  498. total_tokens=input_tokens + output_tokens,
  499. input_tokens_details=InputTokensDetails(cached_tokens=0),
  500. output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
  501. )