transcription.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # Copyright 2026 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Handler for the /v1/audio/transcriptions endpoint.
  16. """
  17. import io
  18. from typing import TYPE_CHECKING
  19. from ...utils import logging
  20. from ...utils.import_utils import is_serve_available
  21. if is_serve_available():
  22. from fastapi import HTTPException, Request
  23. from fastapi.responses import JSONResponse, StreamingResponse
  24. from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase
  25. from .model_manager import ModelManager
  26. from .utils import DirectStreamer, GenerateManager, GenerationState, _StreamError
  27. if TYPE_CHECKING:
  28. from transformers import PreTrainedModel, ProcessorMixin
  29. logger = logging.get_logger(__name__)
  30. class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False):
  31. stream: bool
  32. UNUSED_TRANSCRIPTION_FIELDS = {
  33. "chunking_strategy",
  34. "include",
  35. "language",
  36. "prompt",
  37. "response_format",
  38. "temperature",
  39. "timestamp_granularities",
  40. }
  41. class TranscriptionHandler:
  42. """Handler for ``POST /v1/audio/transcriptions``.
  43. Accepts a multipart/form-data request with an audio file and model name,
  44. runs speech-to-text, and returns an OpenAI-compatible Transcription response.
  45. Standalone (does not extend :class:`BaseHandler`) because audio requests use
  46. multipart form data, not JSON bodies, and don't need generation config or
  47. validation. Shares the :class:`GenerationState` for thread safety.
  48. """
  49. def __init__(self, model_manager: ModelManager, generation_state: GenerationState):
  50. """
  51. Args:
  52. model_manager (`ModelManager`): Handles model loading, caching, and lifecycle.
  53. generation_state (`GenerationState`): Shared generation state for thread safety.
  54. """
  55. self.model_manager = model_manager
  56. self.generation_state = generation_state
  57. def _validate_request(self, form_keys: set[str]) -> None:
  58. """Validate transcription request fields."""
  59. unexpected = form_keys - TransformersTranscriptionCreateParams.__mutable_keys__
  60. if unexpected:
  61. raise HTTPException(status_code=422, detail=f"Unexpected fields in the request: {unexpected}")
  62. unused = form_keys & UNUSED_TRANSCRIPTION_FIELDS
  63. if unused:
  64. logger.warning_once(f"Ignoring unsupported fields in the request: {unused}")
  65. async def handle_request(self, request: Request) -> JSONResponse | StreamingResponse:
  66. """Parse multipart form, run transcription, return result.
  67. Args:
  68. request (`Request`): FastAPI request containing multipart form data with
  69. ``file`` (audio bytes), ``model`` (model ID), and optional ``stream`` flag.
  70. Returns:
  71. `JSONResponse | StreamingResponse`: Transcription result or SSE stream.
  72. """
  73. from transformers.utils.import_utils import is_librosa_available, is_multipart_available
  74. if not is_librosa_available():
  75. raise ImportError("Missing librosa dependency for audio transcription. Install with `pip install librosa`")
  76. if not is_multipart_available():
  77. raise ImportError(
  78. "Missing python-multipart dependency for file uploads. Install with `pip install python-multipart`"
  79. )
  80. async with request.form() as form:
  81. self._validate_request(set(form.keys()))
  82. file_bytes = await form["file"].read()
  83. model = form["model"]
  84. stream = str(form.get("stream", "false")).lower() == "true"
  85. model_id_and_revision = self.model_manager.process_model_name(model)
  86. audio_model, audio_processor = self.model_manager.load_model_and_processor(model_id_and_revision)
  87. gen_manager = self.generation_state.get_manager(model_id_and_revision)
  88. audio_inputs = self._prepare_audio_inputs(file_bytes, audio_processor, audio_model)
  89. if stream:
  90. return self._streaming(gen_manager, audio_model, audio_processor, audio_inputs)
  91. return await self._non_streaming(gen_manager, audio_model, audio_processor, audio_inputs)
  92. @staticmethod
  93. def _prepare_audio_inputs(
  94. file_bytes: bytes, audio_processor: "ProcessorMixin", audio_model: "PreTrainedModel"
  95. ) -> dict:
  96. """Load audio bytes and convert to model inputs."""
  97. import librosa
  98. sampling_rate = audio_processor.feature_extractor.sampling_rate
  99. audio_array, _ = librosa.load(io.BytesIO(file_bytes), sr=sampling_rate, mono=True)
  100. audio_inputs = audio_processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").to(
  101. audio_model.device
  102. )
  103. audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype)
  104. return audio_inputs
  105. async def _non_streaming(
  106. self,
  107. gen_manager: GenerateManager,
  108. audio_model: "PreTrainedModel",
  109. audio_processor: "ProcessorMixin",
  110. audio_inputs: dict,
  111. ) -> JSONResponse:
  112. # Audio models have different inputs (input_features) and decode (batch_decode)
  113. # than text models, so we use async_submit() directly instead of
  114. # generate_non_streaming()
  115. from openai.types.audio import Transcription
  116. generated_ids = await gen_manager.async_submit(audio_model.generate, **audio_inputs)
  117. text = audio_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  118. return JSONResponse(Transcription(text=text).model_dump(exclude_none=True))
  119. def _streaming(
  120. self,
  121. gen_manager: GenerateManager,
  122. audio_model: "PreTrainedModel",
  123. audio_processor: "ProcessorMixin",
  124. audio_inputs: dict,
  125. ) -> StreamingResponse:
  126. # Same as _non_streaming — uses submit() directly because audio inputs
  127. # differ from text.
  128. import asyncio
  129. tokenizer = audio_processor.tokenizer if hasattr(audio_processor, "tokenizer") else audio_processor
  130. loop = asyncio.get_running_loop()
  131. queue: asyncio.Queue = asyncio.Queue()
  132. streamer = DirectStreamer(tokenizer._tokenizer, loop, queue, skip_special_tokens=True)
  133. gen_kwargs = {**audio_inputs, "streamer": streamer}
  134. def _run():
  135. try:
  136. audio_model.generate(**gen_kwargs)
  137. except Exception as e:
  138. loop.call_soon_threadsafe(queue.put_nowait, _StreamError(str(e)))
  139. gen_manager.submit(_run)
  140. async def sse_gen():
  141. while True:
  142. text = await queue.get()
  143. if text is None:
  144. break
  145. if isinstance(text, _StreamError):
  146. yield f'data: {{"error": "{text.msg}"}}\n\n'
  147. return
  148. yield f"data: {text}\n\n"
  149. return StreamingResponse(sse_gen(), media_type="text/event-stream")