| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- # Copyright 2026 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Handler for the /v1/audio/transcriptions endpoint.
- """
- import io
- from typing import TYPE_CHECKING
- from ...utils import logging
- from ...utils.import_utils import is_serve_available
- if is_serve_available():
- from fastapi import HTTPException, Request
- from fastapi.responses import JSONResponse, StreamingResponse
- from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase
- from .model_manager import ModelManager
- from .utils import DirectStreamer, GenerateManager, GenerationState, _StreamError
- if TYPE_CHECKING:
- from transformers import PreTrainedModel, ProcessorMixin
- logger = logging.get_logger(__name__)
- class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False):
- stream: bool
- UNUSED_TRANSCRIPTION_FIELDS = {
- "chunking_strategy",
- "include",
- "language",
- "prompt",
- "response_format",
- "temperature",
- "timestamp_granularities",
- }
- class TranscriptionHandler:
- """Handler for ``POST /v1/audio/transcriptions``.
- Accepts a multipart/form-data request with an audio file and model name,
- runs speech-to-text, and returns an OpenAI-compatible Transcription response.
- Standalone (does not extend :class:`BaseHandler`) because audio requests use
- multipart form data, not JSON bodies, and don't need generation config or
- validation. Shares the :class:`GenerationState` for thread safety.
- """
- def __init__(self, model_manager: ModelManager, generation_state: GenerationState):
- """
- Args:
- model_manager (`ModelManager`): Handles model loading, caching, and lifecycle.
- generation_state (`GenerationState`): Shared generation state for thread safety.
- """
- self.model_manager = model_manager
- self.generation_state = generation_state
- def _validate_request(self, form_keys: set[str]) -> None:
- """Validate transcription request fields."""
- unexpected = form_keys - TransformersTranscriptionCreateParams.__mutable_keys__
- if unexpected:
- raise HTTPException(status_code=422, detail=f"Unexpected fields in the request: {unexpected}")
- unused = form_keys & UNUSED_TRANSCRIPTION_FIELDS
- if unused:
- logger.warning_once(f"Ignoring unsupported fields in the request: {unused}")
- async def handle_request(self, request: Request) -> JSONResponse | StreamingResponse:
- """Parse multipart form, run transcription, return result.
- Args:
- request (`Request`): FastAPI request containing multipart form data with
- ``file`` (audio bytes), ``model`` (model ID), and optional ``stream`` flag.
- Returns:
- `JSONResponse | StreamingResponse`: Transcription result or SSE stream.
- """
- from transformers.utils.import_utils import is_librosa_available, is_multipart_available
- if not is_librosa_available():
- raise ImportError("Missing librosa dependency for audio transcription. Install with `pip install librosa`")
- if not is_multipart_available():
- raise ImportError(
- "Missing python-multipart dependency for file uploads. Install with `pip install python-multipart`"
- )
- async with request.form() as form:
- self._validate_request(set(form.keys()))
- file_bytes = await form["file"].read()
- model = form["model"]
- stream = str(form.get("stream", "false")).lower() == "true"
- model_id_and_revision = self.model_manager.process_model_name(model)
- audio_model, audio_processor = self.model_manager.load_model_and_processor(model_id_and_revision)
- gen_manager = self.generation_state.get_manager(model_id_and_revision)
- audio_inputs = self._prepare_audio_inputs(file_bytes, audio_processor, audio_model)
- if stream:
- return self._streaming(gen_manager, audio_model, audio_processor, audio_inputs)
- return await self._non_streaming(gen_manager, audio_model, audio_processor, audio_inputs)
- @staticmethod
- def _prepare_audio_inputs(
- file_bytes: bytes, audio_processor: "ProcessorMixin", audio_model: "PreTrainedModel"
- ) -> dict:
- """Load audio bytes and convert to model inputs."""
- import librosa
- sampling_rate = audio_processor.feature_extractor.sampling_rate
- audio_array, _ = librosa.load(io.BytesIO(file_bytes), sr=sampling_rate, mono=True)
- audio_inputs = audio_processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").to(
- audio_model.device
- )
- audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype)
- return audio_inputs
- async def _non_streaming(
- self,
- gen_manager: GenerateManager,
- audio_model: "PreTrainedModel",
- audio_processor: "ProcessorMixin",
- audio_inputs: dict,
- ) -> JSONResponse:
- # Audio models have different inputs (input_features) and decode (batch_decode)
- # than text models, so we use async_submit() directly instead of
- # generate_non_streaming()
- from openai.types.audio import Transcription
- generated_ids = await gen_manager.async_submit(audio_model.generate, **audio_inputs)
- text = audio_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
- return JSONResponse(Transcription(text=text).model_dump(exclude_none=True))
- def _streaming(
- self,
- gen_manager: GenerateManager,
- audio_model: "PreTrainedModel",
- audio_processor: "ProcessorMixin",
- audio_inputs: dict,
- ) -> StreamingResponse:
- # Same as _non_streaming — uses submit() directly because audio inputs
- # differ from text.
- import asyncio
- tokenizer = audio_processor.tokenizer if hasattr(audio_processor, "tokenizer") else audio_processor
- loop = asyncio.get_running_loop()
- queue: asyncio.Queue = asyncio.Queue()
- streamer = DirectStreamer(tokenizer._tokenizer, loop, queue, skip_special_tokens=True)
- gen_kwargs = {**audio_inputs, "streamer": streamer}
- def _run():
- try:
- audio_model.generate(**gen_kwargs)
- except Exception as e:
- loop.call_soon_threadsafe(queue.put_nowait, _StreamError(str(e)))
- gen_manager.submit(_run)
- async def sse_gen():
- while True:
- text = await queue.get()
- if text is None:
- break
- if isinstance(text, _StreamError):
- yield f'data: {{"error": "{text.msg}"}}\n\n'
- return
- yield f"data: {text}\n\n"
- return StreamingResponse(sse_gen(), media_type="text/event-stream")
|