| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956 |
- # Copyright 2026 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Shared types, constants, and utilities for the serving layer.
- """
- import asyncio
- import base64
- import copy
- import enum
- import json
- import re
- import tempfile
- import threading
- from abc import ABC, abstractmethod
- from collections.abc import Callable
- from concurrent.futures import Future
- from io import BytesIO
- from queue import Queue
- from typing import TYPE_CHECKING
- from transformers.utils import logging
- if TYPE_CHECKING:
- import pydantic
- import tokenizers
- import torch
- from transformers import (
- ContinuousBatchingConfig,
- GenerationConfig,
- PreTrainedModel,
- PreTrainedTokenizerFast,
- ProcessorMixin,
- )
- from transformers.generation.continuous_batching.continuous_api import ContinuousBatchingManager
- from transformers.generation.continuous_batching.requests import GenerationOutput
- from transformers.generation.continuous_batching.scheduler import Scheduler
- from .model_manager import ModelManager
- logger = logging.get_logger(__name__)
- X_REQUEST_ID = "x-request-id"
- class Modality(enum.Enum):
- LLM = "LLM"
- VLM = "VLM"
- STT = "STT"
- TTS = "TTS"
- class _StreamError:
- """Sentinel to signal an error from the generate thread."""
- def __init__(self, msg: str):
- self.msg = msg
- class _GenerationCancelled(Exception):
- """Raised inside ``DirectStreamer.put()`` to abort ``model.generate()``."""
- # Model-specific tokens that mark the start/end of a tool call block.
- # TODO: extract these from the chat template at runtime instead of hardcoding.
- # Qwen/Hermes use <tool_call>/<tool_call>, Mistral uses [TOOL_CALLS], etc.
- # The markers are defined in each model's Jinja chat template.
- _TOOL_CALL_TOKENS = {
- "qwen": {
- "start": "<tool_call>",
- "end": "</tool_call>",
- },
- }
- def detect_tool_format(model: "PreTrainedModel") -> dict | None:
- """Return the tool call token format for a model, if supported.
- Args:
- model (`PreTrainedModel`): The loaded model.
- Returns:
- `dict | None`: A dict ``{"start": str, "end": str}`` with the model's tool call
- delimiters, or ``None`` if the model family is not recognized.
- """
- architecture = model.config.architectures[0].lower()
- for family in _TOOL_CALL_TOKENS:
- if family in architecture:
- return _TOOL_CALL_TOKENS[family]
- return None
- class ToolCallParser:
- """Parses tool calls from model output.
- The model emits tool calls as structured text between start/end tokens
- (e.g. ``<tool_call>{"name": "fn", "arguments": {...}}</tool_call>``).
- **Streaming** (``feed``): buffers tokens between start/end markers, parses
- the complete block when the end marker is seen, returns a ``ChoiceDeltaToolCall``.
- **Non-streaming** (``parse``): extracts all tool call blocks from complete text.
- Usage::
- parser = ToolCallParser(tool_format={"start": ..., "end": ...})
- for text_chunk in streamer:
- result = parser.feed(text_chunk)
- if result is None:
- # Normal text — emit as content
- elif result is ToolCallParser.CONSUMED:
- # Buffering — skip
- else:
- # result is a ChoiceDeltaToolCall — emit it
- """
- def __init__(self, tool_format: dict):
- self._tokens = tool_format
- self._inside = False
- self._buffer = ""
- # Sentinel: token was consumed by the parser but produced no output.
- CONSUMED = object()
- def feed(self, text: str) -> object | dict | None:
- """Feed a text chunk (streaming).
- Returns:
- - ``None`` — normal text, not a tool token. Emit as content.
- - ``CONSUMED`` — token consumed internally (buffering/markers). Skip.
- - A ``ChoiceDeltaToolCall`` — emit as a tool call delta.
- """
- if text.strip() == self._tokens["start"]:
- self._inside = True
- self._buffer = ""
- return self.CONSUMED
- if text.strip() == self._tokens["end"]:
- self._inside = False
- block = self._buffer.strip()
- self._buffer = ""
- return self._parse_block(block) or self.CONSUMED
- if self._inside:
- self._buffer += text
- return self.CONSUMED
- return None
- @staticmethod
- def _extract_name_and_args(block: str) -> tuple[str, str] | None:
- """Extract (name, arguments_json) from a tool call block, or None if invalid."""
- if not block:
- return None
- parsed = json.loads(block)
- name = parsed.get("name")
- if name is None:
- return None
- arguments = parsed.get("arguments", {})
- return name, json.dumps(arguments)
- @staticmethod
- def parse(text: str, tool_format: dict) -> list[dict] | None:
- """Parse tool calls from complete text.
- Returns a list of ``{"name": str, "arguments": str}`` dicts, or ``None`` if none found.
- """
- start, end = tool_format["start"], tool_format["end"]
- tool_calls = []
- pos = 0
- while True:
- s = text.find(start, pos)
- if s < 0:
- break
- e = text.find(end, s + len(start))
- if e < 0:
- break
- result = ToolCallParser._extract_name_and_args(text[s + len(start) : e].strip())
- if result is not None:
- tool_calls.append({"name": result[0], "arguments": result[1]})
- pos = e + len(end)
- return tool_calls if tool_calls else None
- def _parse_block(self, block: str) -> dict | None:
- """Parse a buffered tool call block. Returns ``{"name": str, "arguments": str}`` or None."""
- result = self._extract_name_and_args(block)
- if result is None:
- return None
- return {"name": result[0], "arguments": result[1]}
- class DownloadAggregator:
- """Aggregates byte-progress across multiple concurrent download tqdm bars.
- huggingface_hub opens one tqdm bar per file shard. This class tracks them all and emits
- a single aggregate ``{"stage": "download", "progress": {...}}`` event whenever any updates.
- """
- def __init__(self, enqueue: Callable, model_id: str):
- self.enqueue = enqueue
- self.model = model_id
- self.bars: dict[int, tuple[int, int | None]] = {}
- self.last_emitted_current: int | None = None
- def register(self, bar_id: int, total: int | None) -> None:
- """Register a new download bar with its total byte count."""
- self.bars[bar_id] = (0, total)
- self._emit()
- def update(self, bar_id: int, current: int, total: int | None) -> None:
- """Update a bar's current byte count and emit aggregate progress."""
- self.bars[bar_id] = (current, total)
- self._emit()
- def close(self, bar_id: int) -> None:
- pass # keep the bar so totals remain correct
- def _emit(self) -> None:
- agg_current = sum(c for c, _ in self.bars.values())
- if agg_current == self.last_emitted_current:
- return
- self.last_emitted_current = agg_current
- totals = [t for _, t in self.bars.values() if t is not None]
- agg_total = sum(totals) if totals else None
- self.enqueue(
- {
- "status": "loading",
- "model": self.model,
- "stage": "download",
- "progress": {"current": agg_current, "total": agg_total},
- }
- )
- def make_progress_tqdm_class(callback: Callable, model_id: str) -> type:
- """Create a tqdm subclass that routes progress to a callback.
- Bars with ``unit="B"`` are download bars — aggregated via ``DownloadAggregator``.
- Other bars (e.g. "Loading weights") emit ``weights`` stage events.
- Args:
- callback (`callable`): Called with a dict payload
- ``{"status": "loading", "model": ..., "stage": ..., "progress": ...}``.
- model_id (`str`): The model ID (included in progress payloads).
- Returns:
- A tqdm subclass that forwards progress to *callback*.
- """
- from tqdm.auto import tqdm as base_tqdm
- download_aggregator = DownloadAggregator(callback, model_id)
- class ProgressTqdm(base_tqdm):
- def __init__(self, *args, **kwargs):
- self.sse_unit = kwargs.get("unit") or "it"
- kwargs["disable"] = True
- super().__init__(*args, **kwargs)
- self.n = 0
- self.last_emitted = -1
- if self.sse_unit == "B":
- self._bar_id = id(self)
- download_aggregator.register(self._bar_id, self.total)
- def update(self, n=1):
- if n is None:
- n = 1
- self.n += n
- if self.sse_unit == "B":
- download_aggregator.update(self._bar_id, self.n, self.total)
- elif self.n != self.last_emitted:
- self.last_emitted = self.n
- callback(
- {
- "status": "loading",
- "model": model_id,
- "stage": "weights",
- "progress": {"current": self.n, "total": self.total},
- }
- )
- def __iter__(self):
- for item in self.iterable:
- self.n += 1
- if self.sse_unit == "B":
- download_aggregator.update(self._bar_id, self.n, self.total)
- elif self.n != self.last_emitted:
- self.last_emitted = self.n
- callback(
- {
- "status": "loading",
- "model": model_id,
- "stage": "weights",
- "progress": {"current": self.n, "total": self.total},
- }
- )
- yield item
- def close(self):
- if self.sse_unit == "B":
- download_aggregator.close(self._bar_id)
- super().close()
- return ProgressTqdm
- class DirectStreamer:
- """Streamer for ``model.generate()`` (used by :class:`GenerateManager`).
- Implements the ``put``/``end`` protocol that ``model.generate()`` expects:
- generate calls ``put(token_tensor)`` after each decode step, and ``end()``
- when generation is complete. Tokens are decoded incrementally via the Rust
- ``DecodeStream`` (O(1) per token) and pushed as text to an asyncio.Queue.
- """
- def __init__(
- self,
- tokenizer: "tokenizers.Tokenizer",
- loop: asyncio.AbstractEventLoop,
- queue: asyncio.Queue,
- skip_special_tokens: bool = True,
- ):
- """
- Args:
- tokenizer: The Rust tokenizer (``tokenizer._tokenizer``).
- loop (`asyncio.AbstractEventLoop`): The event loop to push decoded text to.
- queue (`asyncio.Queue`): The queue that receives decoded text chunks.
- skip_special_tokens (`bool`, *optional*, defaults to `True`):
- Whether to strip special tokens during decoding.
- """
- from tokenizers.decoders import DecodeStream
- self._tokenizer = tokenizer
- self._loop = loop
- self._queue = queue
- self._decode_stream = DecodeStream([], skip_special_tokens)
- self._first = True
- self._cancelled = threading.Event()
- self.total_tokens = 0
- def put(self, value: "torch.Tensor") -> None:
- """Called by ``model.generate()`` after each decode step with new token(s)."""
- if self._cancelled.is_set():
- raise _GenerationCancelled()
- # The first put() contains the prompt tokens — skip since we only stream generated tokens.
- if self._first:
- self._first = False
- return
- for token_id in value.tolist():
- self.total_tokens += 1
- text = self._decode_stream.step(self._tokenizer, token_id)
- if text is not None:
- self._loop.call_soon_threadsafe(self._queue.put_nowait, text)
- def end(self) -> None:
- """Called by ``model.generate()`` when generation is complete."""
- self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
- def cancel(self) -> None:
- """Signal cancellation. The next ``put()`` call will raise and abort ``model.generate()``."""
- self._cancelled.set()
- class CBStreamer:
- """Streamer for continuous batching (used by :class:`CBGenerateManager`).
- Same ``put``/``end`` protocol as :class:`DirectStreamer`, but called manually
- by :class:`CBGenerateManager` instead of by ``model.generate()``:
- ``put(output)`` receives a CB ``GenerationOutput``, decodes new tokens, and
- pushes text to the asyncio.Queue. ``end()`` signals the stream is complete.
- """
- def __init__(
- self,
- cb_manager: "ContinuousBatchingManager",
- request_id: str,
- tokenizer: "tokenizers.Tokenizer",
- loop: asyncio.AbstractEventLoop,
- queue: asyncio.Queue,
- ):
- """
- Args:
- cb_manager (`ContinuousBatchingManager`): The CB manager instance.
- request_id (`str`): The request ID to track in the CB scheduler.
- tokenizer: The Rust tokenizer (``tokenizer._tokenizer``).
- loop (`asyncio.AbstractEventLoop`): The event loop to push decoded text to.
- queue (`asyncio.Queue`): The queue that receives decoded text chunks.
- """
- from tokenizers.decoders import DecodeStream
- self._cb = cb_manager
- self._request_id = request_id
- self._loop = loop
- self._queue = queue
- self._tokenizer = tokenizer
- self._decode_stream = DecodeStream([], True)
- self._prev_len = 0
- self.total_tokens = 0
- def put(self, output: "GenerationOutput") -> None:
- """Decode new tokens from a CB ``GenerationOutput`` and push text to the queue."""
- new_tokens = output.generated_tokens[self._prev_len :]
- self._prev_len = len(output.generated_tokens)
- for token_id in new_tokens:
- self.total_tokens += 1
- text = self._decode_stream.step(self._tokenizer, token_id)
- if text is not None:
- self._queue.put_nowait(text)
- def end(self) -> None:
- """Signal end of stream."""
- self._queue.put_nowait(None)
- def cancel(self) -> None:
- """Cancel the CB request."""
- self._cb.cancel_request(self._request_id)
- def set_torch_seed(seed: int) -> None:
- """Set the PyTorch random seed for reproducible generation."""
- import torch
- torch.manual_seed(seed)
- def reset_torch_cache() -> None:
- """Empty the CUDA cache if a GPU is available."""
- import torch
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- class InferenceThread:
- """Persistent thread for ``model.generate()`` calls.
- ``torch.compile`` with CUDA graphs stores state in thread-local storage.
- All inference must run on the same thread to avoid corrupted graph state.
- """
- def __init__(self):
- self._queue: Queue = Queue()
- self._thread = threading.Thread(target=self._run, daemon=True)
- self._thread.start()
- def _run(self) -> None:
- while True:
- fn, args, kwargs, future, loop = self._queue.get()
- try:
- result = fn(*args, **kwargs)
- if loop is not None:
- loop.call_soon_threadsafe(future.set_result, result)
- else:
- future.set_result(result)
- except Exception as e:
- if loop is not None:
- loop.call_soon_threadsafe(future.set_exception, e)
- else:
- future.set_exception(e)
- def submit(self, fn, *args, **kwargs) -> Future:
- """Submit a callable to the inference thread. Returns a blocking Future."""
- future: Future = Future()
- self._queue.put((fn, args, kwargs, future, None))
- return future
- def async_submit(self, fn, *args, **kwargs) -> asyncio.Future:
- """Submit a callable to the inference thread. Returns an awaitable asyncio.Future."""
- loop = asyncio.get_running_loop()
- future = loop.create_future()
- self._queue.put((fn, args, kwargs, future, loop))
- return future
- class BaseGenerateManager(ABC):
- """Base class for generation managers.
- Subclasses:
- - :class:`GenerateManager` — sequential ``model.generate()`` on a persistent thread.
- - :class:`CBGenerateManager` — continuous batching with paged attention.
- """
- @abstractmethod
- def generate_streaming(
- self,
- model: "PreTrainedModel",
- processor: "ProcessorMixin | PreTrainedTokenizerFast",
- inputs: dict,
- gen_config: "GenerationConfig",
- request_id: str,
- ) -> tuple[asyncio.Queue, "DirectStreamer | CBStreamer"]:
- """Start streaming generation.
- Args:
- model (`PreTrainedModel`): The loaded model.
- processor: The processor or tokenizer for decoding.
- inputs (`dict`): Tokenized inputs (tensors for sequential, lists for CB).
- gen_config (`GenerationConfig`): Generation parameters.
- request_id (`str`): Unique request identifier.
- Returns:
- `tuple[asyncio.Queue, DirectStreamer | CBStreamer]`: A ``(queue, streamer)`` pair
- where *queue* yields ``str | _StreamError | None`` and *streamer* exposes
- ``.total_tokens`` and ``.cancel()``.
- """
- @abstractmethod
- def generate_non_streaming(
- self,
- model: "PreTrainedModel",
- processor: "ProcessorMixin | PreTrainedTokenizerFast",
- inputs: dict,
- gen_config: "GenerationConfig",
- request_id: str,
- ) -> tuple[str, int, list[int]]:
- """Run generation to completion.
- Args:
- model (`PreTrainedModel`): The loaded model.
- processor: The processor or tokenizer for decoding.
- inputs (`dict`): Tokenized inputs (tensors for sequential, lists for CB).
- gen_config (`GenerationConfig`): Generation parameters.
- request_id (`str`): Unique request identifier.
- Returns:
- `tuple[str, int, list[int]]`: ``(text, input_len, generated_ids)``.
- """
- @abstractmethod
- def stop(self) -> None:
- """Stop the generation manager and free resources."""
- class GenerateManager(BaseGenerateManager):
- """Sequential generation via ``model.generate()`` on a persistent thread."""
- def __init__(self):
- self._thread = InferenceThread()
- def generate_streaming(
- self,
- model: "PreTrainedModel",
- processor: "ProcessorMixin | PreTrainedTokenizerFast",
- inputs: dict,
- gen_config: "GenerationConfig",
- request_id: str,
- ) -> tuple[asyncio.Queue, DirectStreamer]:
- """Start streaming generation via ``model.generate()`` on the inference thread."""
- loop = asyncio.get_running_loop()
- queue: asyncio.Queue = asyncio.Queue()
- streamer = DirectStreamer(processor._tokenizer, loop, queue, skip_special_tokens=True)
- gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config, "tokenizer": processor}
- def _run() -> None:
- try:
- model.generate(**gen_kwargs)
- except _GenerationCancelled:
- loop.call_soon_threadsafe(queue.put_nowait, None)
- except Exception as e:
- loop.call_soon_threadsafe(queue.put_nowait, _StreamError(str(e)))
- self.submit(_run)
- return queue, streamer
- async def generate_non_streaming(
- self,
- model: "PreTrainedModel",
- processor: "ProcessorMixin | PreTrainedTokenizerFast",
- inputs: dict,
- gen_config: "GenerationConfig",
- request_id: str,
- ) -> tuple[str, int, "torch.Tensor"]:
- """Run generation to completion via ``model.generate()`` on the inference thread."""
- sequences = await self.async_submit(
- model.generate, **inputs, generation_config=gen_config, tokenizer=processor
- )
- input_len = inputs["input_ids"].shape[-1]
- generated_ids = sequences[0, input_len:]
- text = processor.decode(generated_ids, skip_special_tokens=True)
- return text, input_len, generated_ids
- def submit(self, fn: Callable, *args, **kwargs) -> Future:
- """Submit a callable to the inference thread. Returns a blocking Future."""
- return self._thread.submit(fn, *args, **kwargs)
- def async_submit(self, fn: Callable, *args, **kwargs) -> asyncio.Future:
- """Submit a callable to the inference thread. Returns an awaitable asyncio.Future."""
- return self._thread.async_submit(fn, *args, **kwargs)
- def stop(self) -> None:
- pass # inference thread is a daemon
- class CBGenerateManager(BaseGenerateManager):
- """Continuous batching generation via paged attention.
- Translates between the handler's text-level asyncio.Queue and CB's
- token-level interface. Per-request: ``max_new_tokens``, ``eos_token_id``.
- The CB manager is initialized lazily on the first request via
- :meth:`ensure_initialized`, using that request's ``gen_config`` for shared
- sampling params (temperature, top_p, do_sample).
- .. todo:: Remove :meth:`init_cb` when CB supports per-request
- generation config. At that point, ``gen_config`` can be passed directly
- to ``add_request`` and the CB manager no longer needs a shared config.
- """
- def __init__(self, cb_config: "ContinuousBatchingConfig | None" = None):
- self._cb = None
- self._cb_config = cb_config
- def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> None:
- """Initialize the CB manager on first call with the request's generation config.
- .. todo:: Remove when CB supports per-request generation config.
- Args:
- model (`PreTrainedModel`): The loaded model (must support ``init_continuous_batching``).
- gen_config (`GenerationConfig`): Generation config used for shared sampling params.
- """
- if self._cb is not None:
- return
- self._cb = model.init_continuous_batching(
- generation_config=gen_config, continuous_batching_config=self._cb_config
- )
- self._cb.start()
- def generate_streaming(
- self,
- model: "PreTrainedModel",
- processor: "ProcessorMixin | PreTrainedTokenizerFast",
- inputs: dict,
- gen_config: "GenerationConfig",
- request_id: str,
- ) -> tuple[asyncio.Queue, CBStreamer]:
- """Start streaming CB generation. Registers a per-request output handler."""
- loop = asyncio.get_running_loop()
- text_queue: asyncio.Queue = asyncio.Queue()
- input_ids = inputs["input_ids"]
- request_id = self._cb.add_request(
- input_ids,
- request_id=request_id,
- streaming=True,
- max_new_tokens=gen_config.max_new_tokens,
- eos_token_id=gen_config.eos_token_id,
- )
- streamer = CBStreamer(self._cb, request_id, processor._tokenizer, loop, text_queue)
- # Register a direct callback: the dispatcher calls this on the event loop with each GenerationOutput.
- # This decodes tokens and pushes text straight to the SSE text_queue
- def _on_output(output):
- try:
- streamer.put(output)
- if output.is_finished():
- streamer.end()
- except Exception as e:
- text_queue.put_nowait(_StreamError(str(e)))
- self._cb.register_result_handler(request_id, _on_output)
- return text_queue, streamer
- async def generate_non_streaming(
- self,
- model: "PreTrainedModel",
- processor: "ProcessorMixin | PreTrainedTokenizerFast",
- inputs: dict,
- gen_config: "GenerationConfig",
- request_id: str,
- ) -> tuple[str, int, list[int]]:
- """Run non-streaming CB generation. Registers a handler that resolves an asyncio.Future on completion."""
- input_ids = inputs["input_ids"]
- input_len = len(input_ids)
- # Register future BEFORE add_request to avoid race with fast completion
- loop = asyncio.get_running_loop()
- future = loop.create_future()
- def _on_result(result):
- if not future.done():
- future.set_result(result)
- self._cb.register_result_handler(request_id, _on_result)
- self._cb.add_request(
- input_ids,
- request_id=request_id,
- max_new_tokens=gen_config.max_new_tokens,
- streaming=False,
- eos_token_id=gen_config.eos_token_id,
- )
- result = await future
- if result is None:
- raise RuntimeError(f"CB manager stopped before producing a result for {request_id}")
- generated_ids = result.generated_tokens
- text = processor.decode(generated_ids, skip_special_tokens=True)
- return text, input_len, generated_ids
- @property
- def scheduler(self) -> "Scheduler":
- """The CB scheduler (for testing/monitoring)."""
- return self._cb.batch_processor.scheduler
- def stop(self) -> None:
- if self._cb is not None:
- self._cb.stop(block=True, timeout=2)
- class GenerationState:
- """Shared generation state across all handlers.
- Manages per-model :class:`GenerateManager` instances (each with its own
- :class:`InferenceThread` so different models can run concurrently while
- ``torch.compile`` / CUDA graphs require same-model-same-thread) and a
- single :class:`CBGenerateManager` for continuous batching.
- Args:
- continuous_batching (`bool`, *optional*, defaults to `False`):
- Whether to use continuous batching with paged attention instead of
- sequential ``model.generate()`` calls.
- """
- def __init__(
- self,
- continuous_batching: bool = False,
- compile: bool = False,
- cb_config: "ContinuousBatchingConfig | None" = None,
- ):
- self._continuous_batching = continuous_batching
- self._compile = compile
- self._cb_config = cb_config
- self._generate_managers: dict[str, GenerateManager] = {}
- self._cb_manager: CBGenerateManager | None = None
- self._cb_model_id: str | None = None
- def use_continuous_batching(self, model: "PreTrainedModel", modality: Modality) -> bool:
- """Check if continuous batching can be used for this model and modality.
- Args:
- model (`PreTrainedModel`): The loaded model.
- modality (`Modality`): The detected model modality (LLM, VLM, etc.).
- Returns:
- `bool`: ``True`` if CB is enabled and the model supports it, ``False`` otherwise.
- """
- if not self._continuous_batching:
- return False
- can = hasattr(model, "init_continuous_batching") and modality == Modality.LLM
- if not can:
- logger.warning_once(
- f"{model.__class__.__name__} does not support continuous batching. "
- "Falling back to sequential generation."
- )
- return can
- def get_manager(self, model_id: str, use_cb: bool = False) -> BaseGenerateManager:
- """Return a per-model generation manager, lazily created on first request.
- Args:
- model_id (`str`): The model ID in ``'model_id@revision'`` format.
- use_cb (`bool`): Whether to return a CB manager or a sequential one.
- Returns:
- `BaseGenerateManager`: Either a `GenerateManager` or `CBGenerateManager`.
- """
- if use_cb:
- if self._cb_model_id != model_id:
- if self._cb_manager is not None:
- self._cb_manager.stop()
- self._cb_manager = None
- if self._cb_manager is None:
- self._cb_manager = CBGenerateManager(cb_config=self._cb_config)
- self._cb_model_id = model_id
- return self._cb_manager
- if model_id not in self._generate_managers:
- self._generate_managers[model_id] = GenerateManager()
- return self._generate_managers[model_id]
- def shutdown(self) -> None:
- """Stop any active generation managers."""
- if self._cb_manager is not None:
- self._cb_manager.stop()
- self._cb_manager = None
- class BaseHandler:
- """Shared logic for chat completion and responses handlers.
- Provides model resolution, generation config building, and SSE formatting.
- Generation is delegated to the shared :class:`GenerationState`.
- Args:
- model_manager (`ModelManager`):
- Handles model loading, caching, and lifecycle.
- generation_state (`GenerationState`):
- Shared state managing per-model generation managers.
- """
- _valid_params_class: type | None = None
- _unused_fields: set[str] = set()
- def __init__(
- self,
- model_manager: "ModelManager",
- generation_state: GenerationState,
- ):
- self.model_manager = model_manager
- self.generation_state = generation_state
- def _validate_request(self, body: dict) -> None:
- """Validate request fields against the handler's params class and unused fields."""
- from fastapi import HTTPException
- input_keys = set(body.keys())
- if self._valid_params_class is not None:
- unexpected = input_keys - self._valid_params_class.__mutable_keys__
- if unexpected:
- raise HTTPException(status_code=422, detail=f"Unexpected fields in the request: {unexpected}")
- unused = input_keys & self._unused_fields
- if unused:
- logger.warning_once(f"Ignoring unsupported fields in the request: {unused}")
- @staticmethod
- def chunk_to_sse(chunk: "str | pydantic.BaseModel") -> str:
- """Format a pydantic model or string as an SSE ``data:`` line."""
- if isinstance(chunk, str):
- return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n"
- return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n"
- def _resolve_model(self, body: dict) -> tuple[str, "PreTrainedModel", "ProcessorMixin | PreTrainedTokenizerFast"]:
- """Apply force_model, load model + processor.
- Returns ``(model_id, model, processor)``.
- """
- if self.model_manager.force_model is not None:
- body["model"] = self.model_manager.force_model
- model_id = self.model_manager.process_model_name(body["model"])
- model, processor = self.model_manager.load_model_and_processor(model_id)
- return model_id, model, processor
- def _build_generation_config(
- self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False
- ) -> "GenerationConfig":
- """Build a GenerationConfig from shared params (temperature, top_p, seed, generation_config JSON).
- Subclasses should call ``super()._build_generation_config(...)`` then apply
- endpoint-specific params (``max_tokens``, ``max_output_tokens``, etc.).
- Args:
- body (`dict`):
- The raw request body.
- model_generation_config (`GenerationConfig`):
- The model's default generation config (will be deep-copied).
- use_cb (`bool`, *optional*, defaults to `False`):
- Whether continuous batching is active. If ``True``, disables the model's
- internal KV cache (CB manages its own paged cache).
- Returns:
- `GenerationConfig`: A new config with request-specific overrides applied.
- """
- from transformers import GenerationConfig
- if body.get("generation_config") is not None:
- generation_config = GenerationConfig(**json.loads(body["generation_config"]))
- else:
- generation_config = copy.deepcopy(model_generation_config)
- if generation_config.max_new_tokens is None or generation_config.max_new_tokens < 1024:
- generation_config.max_new_tokens = 1024
- if body.get("temperature") is not None:
- generation_config.temperature = float(body["temperature"])
- if float(body["temperature"]) == 0.0:
- generation_config.do_sample = False
- if body.get("top_p") is not None:
- generation_config.top_p = float(body["top_p"])
- if body.get("seed") is not None:
- set_torch_seed(body["seed"])
- # --compile flag: use static cache + torch.compile for faster decode
- if self.generation_state._compile and generation_config.cache_implementation is None:
- generation_config.cache_implementation = "static"
- # CB manages its own paged KV cache
- if use_cb:
- generation_config.use_cache = False
- # TODO: add prefix caching for the non-CB path (reuse KV cache across multi-turn conversations)
- return generation_config
- @staticmethod
- def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) -> list[dict]:
- """Convert OpenAI-format messages to the format expected by HF processors.
- For LLMs, collapses list content blocks into plain text. For VLMs, converts
- ``image_url`` content parts (including base64) into ``{"type": "image", "url": ...}``
- entries that HF processors understand.
- Args:
- messages (`list[dict]`): OpenAI-format chat messages.
- modality (`Modality`): Whether the model is an LLM or VLM.
- Returns:
- `list[dict]`: Processor-compatible messages.
- """
- processor_inputs = []
- for message in messages:
- parsed = {"role": message["role"], "content": []}
- if modality == Modality.LLM:
- if isinstance(message["content"], str):
- parsed["content"] = message["content"]
- elif isinstance(message["content"], list):
- texts = [c["text"] for c in message["content"] if c["type"] == "text"]
- parsed["content"] = " ".join(texts)
- elif modality == Modality.VLM:
- if isinstance(message["content"], str):
- parsed["content"].append({"type": "text", "text": message["content"]})
- else:
- for content in message["content"]:
- if content["type"] == "text":
- parsed["content"].append(content)
- elif content["type"] == "image_url":
- from PIL import Image
- url = content["image_url"]["url"]
- if "base64" in url:
- image_data = re.sub("^data:image/.+;base64,", "", url)
- image = Image.open(BytesIO(base64.b64decode(image_data)))
- file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
- image.save(file.name)
- url = file.name
- parsed["content"].append({"type": "image", "url": url})
- processor_inputs.append(parsed)
- return processor_inputs
|