| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656 |
- # Copyright 2021 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections import defaultdict
- from typing import TYPE_CHECKING, Any, Union
- import httpx
- import numpy as np
- from ..generation import GenerationConfig
- from ..tokenization_python import PreTrainedTokenizer
- from ..utils import is_torch_available, is_torchaudio_available, is_torchcodec_available, logging
- from .audio_utils import ffmpeg_read
- from .base import ChunkPipeline
- if TYPE_CHECKING:
- from pyctcdecode import BeamSearchDecoderCTC
- from ..feature_extraction_sequence_utils import SequenceFeatureExtractor
- from ..modeling_utils import PreTrainedModel
- logger = logging.get_logger(__name__)
- if is_torch_available():
- import torch
- from ..models.auto.modeling_auto import MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
- def rescale_stride(stride, ratio):
- """
- Rescales the stride values from audio space to tokens/logits space.
- (160_000, 16_000, 16_000) -> (2000, 200, 200) for instance.
- """
- # Shape is [B, SEQ] for tokens
- # [B, SEQ, V] for logits
- new_strides = []
- for input_n, left, right in stride:
- token_n = int(round(input_n * ratio))
- left = int(round(left / input_n * token_n))
- right = int(round(right / input_n * token_n))
- new_stride = (token_n, left, right)
- new_strides.append(new_stride)
- return new_strides
- def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, dtype=None):
- inputs_len = inputs.shape[0]
- step = chunk_len - stride_left - stride_right
- for chunk_start_idx in range(0, inputs_len, step):
- chunk_end_idx = chunk_start_idx + chunk_len
- chunk = inputs[chunk_start_idx:chunk_end_idx]
- processed = feature_extractor(
- chunk,
- sampling_rate=feature_extractor.sampling_rate,
- return_tensors="pt",
- return_attention_mask=True,
- )
- if dtype is not None:
- processed = processed.to(dtype=dtype)
- _stride_left = 0 if chunk_start_idx == 0 else stride_left
- is_last = chunk_end_idx >= inputs_len
- _stride_right = 0 if is_last else stride_right
- chunk_len = chunk.shape[0]
- stride = (chunk_len, _stride_left, _stride_right)
- if chunk.shape[0] > _stride_left:
- yield {"is_last": is_last, "stride": stride, **processed}
- if is_last:
- break
- def _find_longest_common_sequence(sequences, tokenizer):
- # TODO Use a faster algorithm this can probably be done in O(n)
- # using suffix array.
- # It might be tedious to do because of fault tolerance.
- # We actually have a really good property which is that the total sequence
- # MUST be those subsequences in order.
- # Also the algorithm should be more tolerant to errors.
- sequence = [tok_id for tok_id in sequences[0][0].tolist() if tok_id not in tokenizer.all_special_ids]
- for new_seq in sequences[1:]:
- new_sequence = [tok_id for tok_id in new_seq[0].tolist() if tok_id not in tokenizer.all_special_ids]
- index = 0
- max_ = 0.0
- for i in range(1, len(new_sequence) + 1):
- # epsilon to favor long perfect matches
- eps = i / 10000.0
- matches = np.sum(np.array(sequence[-i:]) == np.array(new_sequence[:i]))
- matching = matches / i + eps
- if matches > 1 and matching > max_:
- index = i
- max_ = matching
- sequence.extend(new_sequence[index:])
- return np.array(sequence)
- class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
- """
- Pipeline that aims at extracting spoken text contained within some audio.
- The input can be either a raw waveform or a audio file. In case of the audio file, ffmpeg should be installed for
- to support multiple audio formats
- Unless the model you're using explicitly sets these generation parameters in its configuration files
- (`generation_config.json`), the following default values will be used:
- - max_new_tokens: 256
- - num_beams: 5
- Example:
- ```python
- >>> from transformers import pipeline
- >>> transcriber = pipeline(model="openai/whisper-base")
- >>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac")
- {'text': ' He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered flour-fatten sauce.'}
- ```
- Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
- Arguments:
- model ([`PreTrainedModel`]):
- The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
- [`PreTrainedModel`].
- feature_extractor ([`SequenceFeatureExtractor`], *optional*):
- The feature extractor that will be used by the pipeline to encode waveform for the model.
- tokenizer ([`PreTrainedTokenizer`], *optional*):
- The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
- [`PreTrainedTokenizer`].
- decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*):
- [PyCTCDecode's
- BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180)
- can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information.
- device (Union[`int`, `torch.device`], *optional*):
- Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the
- model on the associated CUDA device id.
- """
- _pipeline_calls_generate = True
- _load_processor = False
- _load_image_processor = False
- _load_feature_extractor = True
- _load_tokenizer = True
- # Make sure the docstring is updated when the default generation config is changed
- _default_generation_config = GenerationConfig(
- max_new_tokens=256,
- num_beams=5, # follows openai's whisper implementation
- )
- def __init__(
- self,
- model: "PreTrainedModel",
- feature_extractor: Union["SequenceFeatureExtractor", str] | None = None,
- tokenizer: PreTrainedTokenizer | None = None,
- decoder: Union["BeamSearchDecoderCTC", str] | None = None,
- device: Union[int, "torch.device"] | None = None,
- **kwargs,
- ):
- # set the model type so we can check we have the right pre- and post-processing parameters
- if model.config.model_type == "whisper":
- self.type = "seq2seq_whisper"
- elif model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
- self.type = "seq2seq"
- elif decoder is not None:
- self.decoder = decoder
- self.type = "ctc_with_lm"
- else:
- self.type = "ctc"
- super().__init__(model, tokenizer, feature_extractor, device=device, **kwargs)
- def __call__(self, inputs: np.ndarray | bytes | str | dict, **kwargs: Any) -> list[dict[str, Any]]:
- """
- Transcribe the audio sequence(s) given as inputs to text. See the [`AutomaticSpeechRecognitionPipeline`]
- documentation for more information.
- Args:
- inputs (`np.ndarray` or `bytes` or `str` or `dict`):
- The inputs is either :
- - `str` that is either the filename of a local audio file, or a public URL address to download the
- audio file. The file will be read at the correct sampling rate to get the waveform using
- *ffmpeg*. This requires *ffmpeg* to be installed on the system.
- - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
- same way.
- - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
- Raw audio at the correct sampling rate (no further check will be done)
- - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
- pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw":
- np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
- treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
- inference to provide more context to the model). Only use `stride` with CTC models.
- return_timestamps (*optional*, `str` or `bool`):
- Only available for pure CTC models (Wav2Vec2, HuBERT, etc) and the Whisper model. Not available for
- other sequence-to-sequence models.
- For CTC models, timestamps can take one of two formats:
- - `"char"`: the pipeline will return timestamps along the text for every character in the text. For
- instance, if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, {"text": "i", "timestamp": (0.7,
- 0.9)}]`, then it means the model predicts that the letter "h" was spoken after `0.5` and before
- `0.6` seconds.
- - `"word"`: the pipeline will return timestamps along the text for every word in the text. For
- instance, if you get `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp":
- (1.0, 1.5)}]`, then it means the model predicts that the word "hi" was spoken after `0.5` and
- before `0.9` seconds.
- For the Whisper model, timestamps can take one of two formats:
- - `"word"`: same as above for word-level CTC timestamps. Word-level timestamps are predicted
- through the *dynamic-time warping (DTW)* algorithm, an approximation to word-level timestamps
- by inspecting the cross-attention weights.
- - `True`: the pipeline will return timestamps along the text for *segments* of words in the text.
- For instance, if you get `[{"text": " Hi there!", "timestamp": (0.5, 1.5)}]`, then it means the
- model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds.
- Note that a segment of text refers to a sequence of one or more words, rather than individual
- words as with word-level timestamps.
- generate_kwargs (`dict`, *optional*):
- The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
- complete overview of generate, check the [following
- guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
- Return:
- `Dict`: A dictionary with the following keys:
- - **text** (`str`): The recognized text.
- - **chunks** (*optional(, `list[Dict]`)
- When using `return_timestamps`, the `chunks` will become a list containing all the various text
- chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text":
- "there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
- `"".join(chunk["text"] for chunk in output["chunks"])`.
- """
- return super().__call__(inputs, **kwargs)
- def _sanitize_parameters(
- self,
- chunk_length_s=None,
- stride_length_s=None,
- ignore_warning=None,
- decoder_kwargs=None,
- return_timestamps=None,
- return_language=None,
- **generate_kwargs,
- ):
- preprocess_params = {}
- forward_params = {}
- postprocess_params = {}
- # Preprocess params
- if chunk_length_s is not None:
- if self.type in ["seq2seq", "seq2seq_whisper"] and not ignore_warning:
- type_warning = (
- "Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
- " be entirely accurate and will have caveats. More information:"
- " https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
- " ignore_warning=True)."
- )
- if self.type == "seq2seq_whisper":
- type_warning += (
- " To use Whisper for long-form transcription, use rather the model's `generate` method directly "
- "as the model relies on it's own chunking mechanism (cf. Whisper original paper, section 3.8. "
- "Long-form Transcription)."
- )
- logger.warning(type_warning)
- preprocess_params["chunk_length_s"] = chunk_length_s
- if stride_length_s is not None:
- preprocess_params["stride_length_s"] = stride_length_s
- # Forward params
- # BC: accept a dictionary of generation kwargs (as opposed to **generate_kwargs)
- if "generate_kwargs" in generate_kwargs:
- forward_params.update(generate_kwargs.pop("generate_kwargs"))
- # Default use for kwargs: they are generation-time kwargs
- forward_params.update(generate_kwargs)
- if getattr(self, "assistant_model", None) is not None:
- forward_params["assistant_model"] = self.assistant_model
- if getattr(self, "assistant_tokenizer", None) is not None:
- forward_params["tokenizer"] = self.tokenizer
- forward_params["assistant_tokenizer"] = self.assistant_tokenizer
- # Postprocess params
- if decoder_kwargs is not None:
- postprocess_params["decoder_kwargs"] = decoder_kwargs
- if return_language is not None:
- if self.type != "seq2seq_whisper":
- raise ValueError("Only Whisper can return language for now.")
- postprocess_params["return_language"] = return_language
- # Parameter used in more than one place
- # in some models like whisper, the generation config has a `return_timestamps` key
- if hasattr(self, "generation_config") and hasattr(self.generation_config, "return_timestamps"):
- return_timestamps = return_timestamps or self.generation_config.return_timestamps
- if return_timestamps is not None:
- # Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass
- if self.type == "seq2seq" and return_timestamps:
- raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
- if self.type == "ctc_with_lm" and return_timestamps != "word":
- raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`")
- if self.type == "ctc" and return_timestamps not in ["char", "word"]:
- raise ValueError(
- "CTC can either predict character level timestamps, or word level timestamps. "
- "Set `return_timestamps='char'` or `return_timestamps='word'` as required."
- )
- if self.type == "seq2seq_whisper" and return_timestamps == "char":
- raise ValueError(
- "Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
- "Use `return_timestamps='word'` or `return_timestamps=True` respectively."
- )
- forward_params["return_timestamps"] = return_timestamps
- postprocess_params["return_timestamps"] = return_timestamps
- return preprocess_params, forward_params, postprocess_params
- @property
- def _align_to(self):
- """Sample stride per output."""
- # XXX: Carefully, this variable will not exist in `seq2seq` setting.
- # Currently chunking is not possible at this level for `seq2seq` so
- # it's ok.
- align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
- if self.model.config.model_type == "lasr_ctc":
- # TODO: find a standard for that but not easy because input length -> mel length depends on the feature extractor
- # specific way of doing it
- # means the model take mel features as input, we align according to the hop length
- align_to *= self.feature_extractor.hop_length
- return align_to
- def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
- if isinstance(inputs, str):
- if inputs.startswith("http://") or inputs.startswith("https://"):
- # We need to actually check for a real protocol, otherwise it's impossible to use a local file
- # like http_huggingface_co.png
- inputs = httpx.get(inputs, follow_redirects=True).content
- else:
- with open(inputs, "rb") as f:
- inputs = f.read()
- if isinstance(inputs, bytes):
- inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
- stride = None
- extra = {}
- if is_torch_available():
- import torch
- if isinstance(inputs, torch.Tensor):
- inputs = inputs.cpu().numpy()
- if is_torchcodec_available():
- import torchcodec
- if isinstance(inputs, torchcodec.decoders.AudioDecoder):
- _audio_samples = inputs.get_all_samples()
- # torchcodec always returns (num_channels, num_samples)
- # while before (datasets < 4.0) we had (2, num_samples) if stereo, (num_samples,) if mono
- _array = _audio_samples.data
- _array = _array[0] if _array.ndim == 2 and _array.shape[0] == 1 else _array
- inputs = {"array": _array, "sampling_rate": _audio_samples.sample_rate}
- if isinstance(inputs, dict):
- stride = inputs.pop("stride", None)
- # Accepting `"array"` which is the key defined in `datasets` for
- # better integration
- if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
- raise ValueError(
- "When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a "
- '"raw" key containing the numpy array or torch tensor representing the audio and a "sampling_rate" key, '
- "containing the sampling_rate associated with that array"
- )
- _inputs = inputs.pop("raw", None)
- if _inputs is None:
- # Remove path which will not be used from `datasets`.
- inputs.pop("path", None)
- _inputs = inputs.pop("array", None)
- in_sampling_rate = inputs.pop("sampling_rate")
- extra = inputs
- inputs = _inputs
- if in_sampling_rate != self.feature_extractor.sampling_rate:
- if is_torchaudio_available():
- from torchaudio import functional as F
- else:
- raise ImportError(
- "torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. "
- "The torchaudio package can be installed through: `pip install torchaudio`."
- )
- inputs = F.resample(
- torch.from_numpy(inputs) if isinstance(inputs, np.ndarray) else inputs,
- in_sampling_rate,
- self.feature_extractor.sampling_rate,
- ).numpy()
- ratio = self.feature_extractor.sampling_rate / in_sampling_rate
- else:
- ratio = 1
- if stride is not None:
- if stride[0] + stride[1] > inputs.shape[0]:
- raise ValueError("Stride is too large for input")
- # Stride needs to get the chunk length here, it's going to get
- # swallowed by the `feature_extractor` later, and then batching
- # can add extra data in the inputs, so we need to keep track
- # of the original length in the stride so we can cut properly.
- stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
- if not isinstance(inputs, (np.ndarray, torch.Tensor)):
- raise TypeError(f"We expect a numpy ndarray or torch tensor as input, got `{type(inputs)}`")
- if inputs.ndim != 1:
- logger.warning(
- f"We expect a single channel audio input for AutomaticSpeechRecognitionPipeline, got {inputs.ndim}. Taking the mean of the channels for mono conversion."
- )
- inputs = inputs.mean(axis=0)
- if chunk_length_s:
- if stride_length_s is None:
- stride_length_s = chunk_length_s / 6
- if isinstance(stride_length_s, (int, float)):
- stride_length_s = [stride_length_s, stride_length_s]
- align_to = self._align_to
- chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
- stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
- stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
- if chunk_len < stride_left + stride_right:
- raise ValueError("Chunk length must be superior to stride length")
- for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.dtype):
- yield {**item, **extra}
- else:
- if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
- processed = self.feature_extractor(
- inputs,
- sampling_rate=self.feature_extractor.sampling_rate,
- truncation=False,
- padding="longest",
- return_tensors="pt",
- return_attention_mask=True,
- )
- else:
- if self.type == "seq2seq_whisper" and stride is None:
- processed = self.feature_extractor(
- inputs,
- sampling_rate=self.feature_extractor.sampling_rate,
- return_tensors="pt",
- return_attention_mask=True,
- )
- else:
- processed = self.feature_extractor(
- inputs,
- sampling_rate=self.feature_extractor.sampling_rate,
- return_tensors="pt",
- return_attention_mask=True,
- )
- if self.dtype is not None:
- processed = processed.to(dtype=self.dtype)
- if stride is not None:
- if self.type == "seq2seq":
- raise ValueError("Stride is only usable with CTC models, try removing it !")
- processed["stride"] = stride
- yield {"is_last": True, **processed, **extra}
- def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
- attention_mask = model_inputs.pop("attention_mask", None)
- stride = model_inputs.pop("stride", None)
- num_frames = model_inputs.pop("num_frames", None)
- is_last = model_inputs.pop("is_last")
- if stride is not None and num_frames is not None:
- raise ValueError("num_frames must be used only when stride is None")
- if self.type in {"seq2seq", "seq2seq_whisper"}:
- # Consume values so we can let extra information flow freely through
- # the pipeline (important for `partial` in microphone)
- if "input_features" in model_inputs:
- inputs = model_inputs.pop("input_features")
- elif "input_values" in model_inputs:
- inputs = model_inputs.pop("input_values")
- else:
- raise ValueError(
- "Seq2Seq speech recognition model requires either a "
- f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
- )
- # custom processing for Whisper timestamps and word-level timestamps
- return_timestamps = return_timestamps or getattr(self.generation_config, "return_timestamps", False)
- if return_timestamps and self.type == "seq2seq_whisper":
- generate_kwargs["return_timestamps"] = bool(return_timestamps)
- if return_timestamps == "word":
- generate_kwargs["return_token_timestamps"] = True
- generate_kwargs["return_segments"] = True
- # User-defined `generation_config` passed to the pipeline call take precedence
- if "generation_config" not in generate_kwargs:
- generate_kwargs["generation_config"] = self.generation_config
- main_input_name = self.model.main_input_name if hasattr(self.model, "main_input_name") else "inputs"
- generate_kwargs = {
- main_input_name: inputs,
- "attention_mask": attention_mask,
- **generate_kwargs,
- }
- tokens = self.model.generate(**generate_kwargs)
- # whisper longform generation stores timestamps in "segments"
- if return_timestamps == "word" and self.type == "seq2seq_whisper":
- if "segments" not in tokens:
- out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
- else:
- token_timestamps = [
- torch.cat([segment["token_timestamps"] for segment in segment_list])
- for segment_list in tokens["segments"]
- ]
- out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps}
- else:
- out = {"tokens": tokens}
- if self.type == "seq2seq_whisper":
- if stride is not None:
- out["stride"] = stride
- else:
- inputs = {
- self.model.main_input_name: model_inputs.pop(self.model.main_input_name),
- "attention_mask": attention_mask,
- }
- outputs = self.model(**inputs)
- logits = outputs.logits
- if self.type == "ctc_with_lm":
- out = {"logits": logits}
- else:
- out = {"tokens": logits.argmax(dim=-1)}
- if stride is not None:
- # Send stride to `postprocess`.
- # it needs to be handled there where
- # the pieces are to be concatenated.
- ratio = 1 / self._align_to
- if isinstance(stride, tuple):
- out["stride"] = rescale_stride([stride], ratio)[0]
- else:
- out["stride"] = rescale_stride(stride, ratio)
- # Leftover
- extra = model_inputs
- return {"is_last": is_last, **out, **extra}
- def postprocess(
- self, model_outputs, decoder_kwargs: dict | None = None, return_timestamps=None, return_language=None
- ):
- # Optional return types
- optional = {}
- final_items = []
- key = "logits" if self.type == "ctc_with_lm" else "tokens"
- stride = None
- for outputs in model_outputs:
- if outputs[key].dtype in (torch.bfloat16, torch.float16):
- items = outputs[key].to(torch.float32).numpy()
- else:
- items = outputs[key].numpy()
- stride = outputs.get("stride", None)
- if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
- total_n, left, right = stride
- # Total_n might be < logits.shape[1]
- # because of padding, that's why
- # we need to reconstruct this information
- # This won't work with left padding (which doesn't exist right now)
- right_n = total_n - right
- items = items[:, left:right_n]
- final_items.append(items)
- if stride and self.type == "seq2seq":
- items = _find_longest_common_sequence(final_items, self.tokenizer)
- elif self.type == "seq2seq_whisper":
- time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
- # Send the chunking back to seconds, it's easier to handle in whisper
- sampling_rate = self.feature_extractor.sampling_rate
- for output in model_outputs:
- if "stride" in output:
- chunk_len, stride_left, stride_right = output["stride"]
- # Go back in seconds
- chunk_len /= sampling_rate
- stride_left /= sampling_rate
- stride_right /= sampling_rate
- output["stride"] = chunk_len, stride_left, stride_right
- text, optional = self.tokenizer._decode_asr(
- model_outputs,
- return_timestamps=return_timestamps,
- return_language=return_language,
- time_precision=time_precision,
- )
- else:
- items = np.concatenate(final_items, axis=1)
- items = items.squeeze(0)
- if self.type == "ctc_with_lm":
- if decoder_kwargs is None:
- decoder_kwargs = {}
- beams = self.decoder.decode_beams(items, **decoder_kwargs)
- text = beams[0][0]
- if return_timestamps:
- # Simply cast from pyctcdecode format to wav2vec2 format to leverage
- # pre-existing code later
- chunk_offset = beams[0][2]
- offsets = []
- for word, (start_offset, end_offset) in chunk_offset:
- offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
- elif self.type != "seq2seq_whisper":
- skip_special_tokens = self.type != "ctc"
- text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
- if return_timestamps:
- offsets = self.tokenizer.decode(
- items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
- )["char_offsets"]
- if return_timestamps == "word":
- offsets = self.tokenizer._get_word_offsets(offsets, self.tokenizer.replace_word_delimiter_char)
- if return_timestamps and self.type not in {"seq2seq", "seq2seq_whisper"}:
- chunks = []
- align_to = self._align_to
- for item in offsets:
- start = item["start_offset"] * align_to
- start /= self.feature_extractor.sampling_rate
- stop = item["end_offset"] * align_to
- stop /= self.feature_extractor.sampling_rate
- chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
- optional["chunks"] = chunks
- extra = defaultdict(list)
- for output in model_outputs:
- output.pop("tokens", None)
- output.pop("logits", None)
- output.pop("is_last", None)
- output.pop("stride", None)
- output.pop("token_timestamps", None)
- for k, v in output.items():
- extra[k].append(v)
- return {"text": text, **optional, **extra}
|