generation_whisper.py 107 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073
  1. # Copyright 2024 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import math
  16. import zlib
  17. from collections.abc import Callable, Iterator
  18. import numpy as np
  19. import torch
  20. import torch.nn.functional as F
  21. from torch import nn
  22. from transformers.cache_utils import EncoderDecoderCache
  23. from ...generation import GenerationConfig, GenerationMixin
  24. from ...generation.logits_process import (
  25. LogitsProcessorList,
  26. SuppressTokensAtBeginLogitsProcessor,
  27. SuppressTokensLogitsProcessor,
  28. WhisperNoSpeechDetection,
  29. WhisperTimeStampLogitsProcessor,
  30. )
  31. from ...generation.stopping_criteria import StoppingCriteriaList
  32. from ...modeling_outputs import BaseModelOutput
  33. from ...utils import logging
  34. from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
  35. logger = logging.get_logger(__name__)
  36. def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor:
  37. """
  38. Applies a median filter of width `filter_width` along the last dimension of the input.
  39. The `inputs` tensor is assumed to be 3- or 4-dimensional.
  40. """
  41. if filter_width <= 0 or filter_width % 2 != 1:
  42. raise ValueError("`filter_width` should be an odd number")
  43. pad_width = filter_width // 2
  44. if inputs.shape[-1] <= pad_width:
  45. return inputs
  46. # Pad the left and right edges.
  47. inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect")
  48. # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
  49. result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width]
  50. return result
  51. def _dynamic_time_warping(matrix: np.ndarray):
  52. """
  53. Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate
  54. token-level timestamps.
  55. """
  56. output_length, input_length = matrix.shape
  57. cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf
  58. trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32)
  59. cost[0, 0] = 0
  60. for j in range(1, input_length + 1):
  61. for i in range(1, output_length + 1):
  62. c0 = cost[i - 1, j - 1]
  63. c1 = cost[i - 1, j]
  64. c2 = cost[i, j - 1]
  65. if c0 < c1 and c0 < c2:
  66. c, t = c0, 0
  67. elif c1 < c0 and c1 < c2:
  68. c, t = c1, 1
  69. else:
  70. c, t = c2, 2
  71. cost[i, j] = matrix[i - 1, j - 1] + c
  72. trace[i, j] = t
  73. # backtrace
  74. i = trace.shape[0] - 1
  75. j = trace.shape[1] - 1
  76. trace[0, :] = 2
  77. trace[:, 0] = 1
  78. text_indices = []
  79. time_indices = []
  80. while i > 0 or j > 0:
  81. text_indices.append(i - 1)
  82. time_indices.append(j - 1)
  83. if trace[i, j] == 0:
  84. i -= 1
  85. j -= 1
  86. elif trace[i, j] == 1:
  87. i -= 1
  88. elif trace[i, j] == 2:
  89. j -= 1
  90. else:
  91. raise RuntimeError(
  92. f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report."
  93. )
  94. text_indices = np.array(text_indices)[::-1]
  95. time_indices = np.array(time_indices)[::-1]
  96. return text_indices, time_indices
  97. def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name):
  98. if logits_processor is not None:
  99. logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None)
  100. if logit_processor:
  101. return getattr(logit_processor, attribute_name, None)
  102. return None
  103. def _pad_to_max_length(
  104. current_segments,
  105. pad_token_id,
  106. device,
  107. padding_side="right",
  108. padding="longest",
  109. bos_token_tensor=None,
  110. cut_off_length=None,
  111. return_token_timestamps=False,
  112. force_unique_generate_call=False,
  113. skip_ending_double_timestamps=False,
  114. timestamp_begin=None,
  115. ):
  116. """
  117. skip_ending_double_timestamps: when the segment ended with two timestamp tokens, whether to ignore the last timestamp token
  118. see https://github.com/huggingface/transformers/pull/35750
  119. _pad_to_max_length is used in different contexts:
  120. 1. At the end of generation: we need to keep both ending timestamp tokens in the segment (see https://github.com/huggingface/transformers/pull/34537).
  121. 2. In the middle of generation, e.g. when condition_on_prev_tokens=True and we want to use the last generated tokens as decoder_input_ids:
  122. we must skip one of the double ending timestamp tokens (see https://github.com/huggingface/transformers/pull/35750).
  123. """
  124. max_total_length = 0
  125. sequences = []
  126. token_timestamps_list = []
  127. if padding_side not in ["right", "left"]:
  128. raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
  129. if padding not in ["longest", "max_length"]:
  130. raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
  131. elif padding == "max_length" and cut_off_length is None:
  132. raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
  133. if force_unique_generate_call:
  134. sequences_list = []
  135. timestamps_list = []
  136. for segments in current_segments:
  137. result = segments[0]["result"]
  138. sequences_list.append(result if isinstance(result, torch.Tensor) else result["sequences"])
  139. if return_token_timestamps:
  140. timestamps_list.append(result["token_timestamps"])
  141. sequences = torch.stack(sequences_list, dim=0)
  142. if return_token_timestamps:
  143. token_timestamps = torch.stack(timestamps_list, dim=0)
  144. return sequences, token_timestamps
  145. return sequences
  146. for current_segment_list in current_segments:
  147. if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
  148. sequences_list = []
  149. for d in current_segment_list:
  150. if skip_ending_double_timestamps and len(d["tokens"]) > 2 and d["tokens"][-2] >= timestamp_begin:
  151. # the segment finishes with two timestamp tokens
  152. # we need to ignore the last timestamp token
  153. # see https://github.com/huggingface/transformers/pull/34537
  154. sequences_list.append(d["tokens"][:-1])
  155. else:
  156. sequences_list.append(d["tokens"])
  157. sequence = torch.cat(sequences_list, dim=-1)
  158. if return_token_timestamps:
  159. token_timestamps = torch.cat(
  160. [d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
  161. dim=-1,
  162. )
  163. if cut_off_length is not None:
  164. sequence = sequence[-cut_off_length:]
  165. if return_token_timestamps:
  166. token_timestamps = token_timestamps[-cut_off_length:]
  167. if bos_token_tensor is not None:
  168. sequence = torch.cat([bos_token_tensor, sequence])
  169. if return_token_timestamps:
  170. token_timestamps = torch.cat(
  171. [torch.ones_like(bos_token_tensor, device=device) * 0.0, token_timestamps]
  172. )
  173. sequences.append(sequence)
  174. if return_token_timestamps:
  175. token_timestamps_list.append(token_timestamps)
  176. max_total_length = max(max_total_length, len(sequences[-1]))
  177. elif bos_token_tensor is not None:
  178. sequences.append(bos_token_tensor)
  179. if return_token_timestamps:
  180. token_timestamps_list.append(torch.ones_like(bos_token_tensor, device=device) * 0.0)
  181. else:
  182. sequences.append(torch.tensor([], device=device))
  183. if return_token_timestamps:
  184. token_timestamps_list.append(torch.tensor([], device=device))
  185. max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
  186. for i in range(len(current_segments)):
  187. pad_length = max_total_length - len(sequences[i])
  188. pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
  189. sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
  190. if return_token_timestamps:
  191. token_timestamps_list[i] = F.pad(
  192. token_timestamps_list[i],
  193. pad=pad,
  194. value=token_timestamps_list[i][-1] if len(token_timestamps_list[i]) > 0 else 0.0,
  195. )
  196. sequences = torch.stack(sequences, dim=0)
  197. if return_token_timestamps:
  198. token_timestamps = torch.stack(token_timestamps_list, dim=0)
  199. return sequences, token_timestamps
  200. else:
  201. return sequences
  202. class WhisperGenerationMixin(GenerationMixin):
  203. def _extract_token_timestamps(
  204. self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None, num_input_ids=None
  205. ):
  206. """
  207. Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
  208. map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
  209. cross-attentions will be cropped before applying DTW.
  210. Returns:
  211. tensor containing the timestamps in seconds for each predicted token
  212. """
  213. # Create a list with `decoder_layers` elements, each a tensor of shape
  214. # (batch size * num beams, attention_heads, output length, input length).
  215. cross_attentions = []
  216. for i in range(self.config.decoder_layers):
  217. cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
  218. # Select specific cross-attention layers and heads. This is a tensor
  219. # of shape (batch size * num beams, num selected heads, output length, input length).
  220. weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
  221. weights = weights.permute([1, 0, 2, 3])
  222. weight_length = None
  223. if "beam_indices" in generate_outputs:
  224. # If beam search was used, the sequence length of the outputs may not be the real sequence length:
  225. # beam search may end up returning a sequence that finished a few steps earlier while decoding.
  226. # In that case, the `cross_attentions` weights are too long and we have to make sure that they have
  227. # the right `output_length`
  228. # get the real sequence length of the longest sequence, crop the beam_indices to the real length
  229. weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
  230. beam_indices = generate_outputs.beam_indices[:, :weight_length]
  231. # The first forward pass (prefill) may have processed more than one token and, therefore, contain
  232. # cross-attention weights for several tokens.
  233. # Let's unroll the first `beam_indices` accordingly, so we can use it to gather the weights.
  234. if num_input_ids is not None and num_input_ids > 1:
  235. # `-1`: `beam_indices` can be used as-is to gather the weights when `num_input_ids` is 1
  236. weight_length += num_input_ids - 1
  237. beam_indices_first_step_unrolled = (
  238. torch.ones(beam_indices.shape[0], num_input_ids - 1, device=beam_indices.device, dtype=torch.long)
  239. * (beam_indices[:, 0:1])
  240. )
  241. unrolled_beam_indices = torch.cat([beam_indices_first_step_unrolled, beam_indices], dim=-1)
  242. else:
  243. unrolled_beam_indices = beam_indices
  244. # If beam index is still -1, it means that the associated token id is EOS
  245. # We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
  246. unrolled_beam_indices = unrolled_beam_indices.masked_fill(unrolled_beam_indices == -1, 0)
  247. # Select the cross attention from the right beam for each output sequence, up to the real sequence
  248. # length (`weight_length`)
  249. weights = torch.stack(
  250. [
  251. torch.index_select(weights[:, :, i, :], dim=0, index=unrolled_beam_indices[:, i])
  252. for i in range(unrolled_beam_indices.shape[1])
  253. ],
  254. dim=2,
  255. )
  256. # make sure timestamps are as long as weights
  257. input_length = weight_length or cross_attentions[0].shape[2]
  258. batch_size = generate_outputs.sequences.shape[0]
  259. timestamps = torch.zeros(
  260. (batch_size, input_length + 1), dtype=torch.float32, device=generate_outputs.sequences.device
  261. )
  262. if num_frames is not None:
  263. # two cases:
  264. # 1. num_frames is the same for each sample -> compute the DTW matrix for each sample in parallel
  265. # 2. num_frames is different, compute the DTW matrix for each sample sequentially
  266. # we're using np.unique because num_frames can be int/list/tuple
  267. if isinstance(num_frames, int):
  268. weights = weights[..., : num_frames // 2]
  269. elif isinstance(num_frames, (list, tuple, np.ndarray)) and len(np.unique(num_frames)) == 1:
  270. weights = weights[..., : num_frames[0] // 2]
  271. elif isinstance(num_frames, (torch.Tensor)) and len(torch.unique(num_frames)) == 1:
  272. weights = weights[..., : num_frames[0] // 2]
  273. else:
  274. # num_frames is of shape (batch_size,) whereas batch_size is truly batch_size*num_return_sequences
  275. repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
  276. num_frames = num_frames.cpu() if isinstance(num_frames, (torch.Tensor)) else num_frames
  277. num_frames = np.repeat(num_frames, repeat_time)
  278. # let's ignore decoder_input_ids that can negatively impact the DTW while we know they have timestamps 0.0s
  279. # (they are not taken into account for the DTW in OAI implementation)
  280. if num_input_ids is not None:
  281. weights = weights[:, :, num_input_ids:, :]
  282. # Since we ignore `decoder_input_ids` in the DTW and in the case where we generated only one token (for which we don't have cross attentions, see below comments),
  283. # the DTW sequence length is 0 and we should return only 0.0s for the token timestamps
  284. if weights.shape[2] == 0:
  285. return timestamps
  286. if num_frames is None or isinstance(num_frames, int):
  287. # Normalize and smoothen the weights.
  288. std = torch.std(weights, dim=-2, keepdim=True, unbiased=False)
  289. mean = torch.mean(weights, dim=-2, keepdim=True)
  290. weights = (weights - mean) / std
  291. weights = _median_filter(weights, self.config.median_filter_width)
  292. # Average the different cross-attention heads.
  293. weights = weights.mean(dim=1)
  294. # Perform dynamic time warping on each element of the batch.
  295. for batch_idx in range(batch_size):
  296. if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray, torch.Tensor)):
  297. matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]
  298. # Normalize and smoothen the weights.
  299. std = torch.std(matrix, dim=-2, keepdim=True, unbiased=False)
  300. mean = torch.mean(matrix, dim=-2, keepdim=True)
  301. matrix = (matrix - mean) / std
  302. matrix = _median_filter(matrix, self.config.median_filter_width)
  303. # Average the different cross-attention heads.
  304. matrix = matrix.mean(dim=0)
  305. else:
  306. matrix = weights[batch_idx]
  307. text_indices, time_indices = _dynamic_time_warping(-matrix.cpu().double().numpy())
  308. jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
  309. jump_times = time_indices[jumps] * time_precision
  310. # each predicted token has a corresponding timestamp, expect the eos token (or last predicted token) for which we don't retrieve cross attentions
  311. # (indeed contrary to OAI that re-run a full forward to retrieve cross attentions for each token and therefore also the last one predicted, we retrieve
  312. # cross attentions directly from the auto-regressive generation, so we don't have cross attentiosn for the token at the end of the sequence. Nevertheless,
  313. # that is not important since we expect this last token to be the eos token)
  314. # 1. for decoder_input_ids, we set the timestamps to 0.0
  315. # 2. for the eos token (or last predicted token), we simply duplicate the timestamp of the last non-eos token
  316. timestamps[batch_idx] = torch.cat(
  317. [torch.zeros(num_input_ids), torch.tensor(jump_times), torch.tensor([jump_times[-1]])]
  318. )
  319. return timestamps
  320. def generate(
  321. self,
  322. input_features: torch.Tensor | None = None,
  323. generation_config: GenerationConfig | None = None,
  324. logits_processor: LogitsProcessorList | None = None,
  325. stopping_criteria: StoppingCriteriaList | None = None,
  326. prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] | None = None,
  327. synced_gpus: bool = False,
  328. return_timestamps: bool | None = None,
  329. task: str | None = None,
  330. language: str | list[str] | None = None,
  331. is_multilingual: bool | None = None,
  332. prompt_ids: torch.Tensor | None = None,
  333. prompt_condition_type: str | None = None, # first-segment, all-segments
  334. condition_on_prev_tokens: bool | None = None,
  335. temperature: float | tuple[float, ...] | None = None,
  336. compression_ratio_threshold: float | None = None,
  337. logprob_threshold: float | None = None,
  338. no_speech_threshold: float | None = None,
  339. num_segment_frames: int | None = None,
  340. attention_mask: torch.Tensor | None = None,
  341. time_precision: float = 0.02,
  342. time_precision_features: float = 0.01,
  343. return_token_timestamps: bool | None = None,
  344. return_segments: bool = False,
  345. return_dict_in_generate: bool | None = None,
  346. force_unique_generate_call: bool | None = None,
  347. monitor_progress: Callable[[torch.Tensor], None] | None = None,
  348. **kwargs,
  349. ):
  350. """
  351. Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
  352. <Tip warning={true}>
  353. Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
  354. model's default generation configuration. You can override any `generation_config` by passing the corresponding
  355. parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
  356. For an overview of generation strategies and code examples, check out the [following
  357. guide](../generation_strategies).
  358. </Tip>
  359. Parameters:
  360. input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
  361. Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
  362. loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`,
  363. *e.g.* via the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  364. To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel
  365. features, padding and conversion into a tensor of type `torch.FloatTensor`.
  366. See [`~WhisperFeatureExtractor.__call__`] for details.
  367. generation_config ([`~generation.GenerationConfig`], *optional*):
  368. The generation configuration to be used as base parametrization for the generation call. `**kwargs`
  369. passed to generate matching the attributes of `generation_config` will override them. If
  370. `generation_config` is not provided, the default will be used, which had the following loading
  371. priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
  372. configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
  373. default values, whose documentation should be checked to parameterize generation.
  374. logits_processor (`LogitsProcessorList`, *optional*):
  375. Custom logits processors that complement the default logits processors built from arguments and
  376. generation config. If a logit processor is passed that is already created with the arguments or a
  377. generation config an error is thrown. This feature is intended for advanced users.
  378. stopping_criteria (`StoppingCriteriaList`, *optional*):
  379. Custom stopping criteria that complement the default stopping criteria built from arguments and a
  380. generation config. If a stopping criteria is passed that is already created with the arguments or a
  381. generation config an error is thrown. This feature is intended for advanced users.
  382. prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
  383. If provided, this function constraints the beam search to allowed tokens only at each step. If not
  384. provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
  385. `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
  386. on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
  387. for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
  388. Retrieval](https://huggingface.co/papers/2010.00904).
  389. synced_gpus (`bool`, *optional*, defaults to `False`):
  390. Whether to continue running the while loop until max_length (needed to avoid deadlocking with
  391. `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
  392. return_timestamps (`bool`, *optional*):
  393. Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
  394. For audios longer than 30 seconds, it is necessary to set `return_timestamps=True`.
  395. task (`str`, *optional*):
  396. Task to use for generation, either "translate" or "transcribe".
  397. language (`str` or list of `str`, *optional*):
  398. Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. For
  399. batched generation, a list of language tokens can be passed. You can find all the possible language
  400. tokens in the `model.generation_config.lang_to_id` dictionary.
  401. is_multilingual (`bool`, *optional*):
  402. Whether or not the model is multilingual.
  403. prompt_ids (`torch.Tensor`, *optional*):
  404. Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
  405. provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
  406. transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
  407. correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
  408. prompt_condition_type (`str`, *optional*):
  409. Only relevant for long-form transcription. Condition type of `prompt_ids`. 'first-segment' means only the first segment is conditioned on `prompt_ids`. 'all-segments' means each segment is conditioned on `prompt_ids`. Make sure to enable `condition_on_prev_tokens` for 'all-segments'.
  410. Defaults to 'first-segment'. For short-term transcription only 'first-segment' is possible.
  411. condition_on_prev_tokens (`bool`, *optional*):
  412. Only relevant for long-form transcription. Whether to condition each segment on the previous segment.
  413. As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
  414. performance.
  415. temperature (`float` or list of `float`, *optional*):
  416. The temperature to be used for generation. Passing a single `float` value and `do_sample=True` activates
  417. generation using sampling. For long-form transcription, temperature fallback can be activated by passing
  418. a list of float values such as (0.0, 0.2, 0.4, 0.6, 0.8, 1.0). As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
  419. performance.
  420. compression_ratio_threshold (`float`, *optional*):
  421. Only relevant for long-form transcription. If defined, the zlib compression rate of each segment will be computed. If the compression rate of
  422. a segment is higher than `compression_ratio_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is
  423. repeated using a higher temperature. The intuition behind this feature is that segments with very high compression rates
  424. suffer from a lot of repetition. The unwanted repetition can be reduced by injecting more randomness by increasing the temperature. If `compression_ratio_threshold` is defined
  425. make sure that `temperature` is a list of values. A common value for `compression_ratio_threshold` is 1.35.
  426. As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
  427. performance.
  428. logprob_threshold (`float`, *optional*):
  429. Only relevant for long-form transcription. If defined, the average log-probability of each segment will be computed. If the log-probability of
  430. a given segment is lower than `logprob_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is
  431. repeated using a higher temperature. The intuition behind this feature is that segments of low log-probability
  432. can be improved by injecting more randomness by increasing the temperature. If `logprob_threshold` is defined
  433. make sure that `temperature` is a list of values. A common value for `logprob_threshold` is -1.0.
  434. As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
  435. performance.
  436. no_speech_threshold (`float`, *optional*):
  437. Only relevant for long-form transcription. If defined, the "no-speech" token combined with the `logprob_threshold`
  438. is used to determine whether a segment contains only silence. In this case, the transcription for this segment
  439. is skipped.
  440. As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
  441. performance.
  442. num_segment_frames (`int`, *optional*):
  443. The number of frames a single segment is made of. If not defined, `num_segment_frames` defaults to the model's stride
  444. times the maximum input length.
  445. attention_mask (`torch.Tensor`, *optional*):
  446. `attention_mask` needs to be passed when doing long-form transcription using a batch size > 1.
  447. time_precision (`int`, *optional*, defaults to 0.02):
  448. The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
  449. for 20 ms.
  450. time_precision_features (`int`, *optional*, defaults to 0.01):
  451. The duration represented by a feature frame in seconds.
  452. return_token_timestamps (`bool`, *optional*):
  453. Whether to return token-level timestamps with the text. This can be used with or without the
  454. `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
  455. words.
  456. return_segments (`bool`, *optional*, defaults to `False`):
  457. Whether to additionally return a list of all segments. Note that this option can only be enabled
  458. when doing long-form transcription.
  459. return_dict_in_generate (`bool`, *optional*, defaults to `False`):
  460. Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens.
  461. Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when
  462. `return_segments` is set True. In this case the generation outputs of each segment is added to each
  463. segment.
  464. force_unique_generate_call (`bool`, *optional*):
  465. Whether to force a unique call to the underlying GenerationMixin's [`~generation.GenerationMixin.generate`] method. This is useful for assisted decoding and testing purposes to ensure
  466. that only one call to [`~generation.GenerationMixin.generate`] is made and therefore decoder input token ids and eos token ids are returned.
  467. monitor_progress (`Callable[[torch.Tensor], None]`, *optional*):
  468. If provided, this function can be called to report the progress of the audio transcription. The function
  469. takes a tensor argument `p` of shape `(n, 2)`, where `n` is the batch size. `p[i, 0]` contains the
  470. index of the audio frame that is currently being transcribed for batch item `i`. `p[i, 1]` contains
  471. the total number of frames for batch item `i`. No return value is expected.
  472. kwargs (`dict[str, Any]`, *optional*):
  473. Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
  474. forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
  475. specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
  476. Return:
  477. [`~utils.ModelOutput`] or `dict[str, Any]` or `torch.LongTensor`:
  478. One of the following:
  479. - [`~utils.ModelOutput`] when `return_dict_in_generate=True` and (`return_timestamps=False` or `force_unique_generate_call=True`), including the decoder input ids and end of sequence id.
  480. - `dict[str, Any]` when (`return_dict_in_generate=True` and `return_timestamps=True`) or `return_segments=True` or `return_token_timestamps=True`.
  481. - `torch.LongTensor` in all other cases, excluding the decoder input ids and end of sequence id.
  482. The possible [`~utils.ModelOutput`] types are:
  483. - [`~generation.GenerateEncoderDecoderOutput`]
  484. - [`~generation.GenerateBeamEncoderDecoderOutput`]
  485. `segments` is a list of lists (one list per batch element) of `segment`.
  486. A `segment` is a dictionary with keys `start`, `end`, `tokens`, `idxs`, and `result`.
  487. - `start`: the start timestamp of the segment.
  488. - `end`: the end timestamp of the segment.
  489. - `tokens`: the tokens of the segment, excluding the decoder input ids and end of sequence id.
  490. - `idxs`: the start (included) and end (excluded) indices of the `tokens` of the segment in the underlying call to GenerationMixin's [`~generation.GenerationMixin.generate`] (present in `result`).
  491. - `result`: the result of the underlying call to GenerationMixin's [`~generation.GenerationMixin.generate`].
  492. When `return_timestamps=True`, `return_dict_in_generate=True` applies to each call of the underlying GenerationMixin's [`~generation.GenerationMixin.generate`], with outputs stored in `result` of each `segment`.
  493. Example:
  494. - *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate. It is necessary to set `return_timestamps=True`.
  495. Indeed, long-form transcription uses a sequential algorithm based on timestamps predictions, with heuristics like compression ratio threshold, log probability threshold and temperature fallback. This algorithm is described in the [the Whisper original paper](https://cdn.openai.com/papers/whisper.pdf), section *3.8. Long-form Transcription*.
  496. ```python
  497. >>> import torch
  498. >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
  499. >>> from datasets import load_dataset, Audio
  500. >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
  501. >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
  502. >>> model.cuda() # doctest: +IGNORE_RESULT
  503. >>> # load audios > 30 seconds
  504. >>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
  505. >>> # resample to 16kHz
  506. >>> ds = ds.cast_column("audio", Audio(sampling_rate=16000))
  507. >>> # take first 8 audios and retrieve array
  508. >>> audio = ds[:8]["audio"]
  509. >>> audio = [x["array"] for x in audio]
  510. >>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio
  511. >>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
  512. >>> inputs = inputs.to("cuda", torch.float32)
  513. >>> # transcribe audio to ids
  514. >>> generated_ids = model.generate(**inputs, return_timestamps=True)
  515. >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
  516. >>> transcription[0]
  517. " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."
  518. ```
  519. The `monitor_progress` callback can be used to monitor the progress of the transcription:
  520. ```python
  521. >>> from tqdm import tqdm
  522. >>> # prepare inputs like above
  523. >>> # define a callback to monitor the progress of the transcription.
  524. >>> with tqdm(desc="Progress") as pbar:
  525. >>> def monitor_progress(p_batch):
  526. >>> i = torch.argmax(p_batch[:, 1])
  527. >>> p = p_batch[i].detach().cpu()
  528. >>> pbar.total = int(p[1])
  529. >>> pbar.n = int(p[0])
  530. >>> pbar.update()
  531. >>> # transcribe audio to ids
  532. >>> generated_ids = model.generate(**inputs, return_timestamps=True, monitor_progress=monitor_progress)
  533. >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
  534. >>> transcription[0]
  535. Progress: 95%|█████████████████████████████████████████████████████████████████████████████████████████████████▎ | 8497/8901 [00:04<00:00, 2052.79it/s]
  536. " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."
  537. ```
  538. - *Shortform transcription*: If passed mel input features are <= 30 seconds, there are two possibilities:
  539. - `return_timestamps=False`: the whole audio will be transcribed with a single call to GenerationMixin's [`~generation.GenerationMixin.generate`].
  540. - `return_timestamps=True`: the audio will be transcribed using the same logic as long-form transcription.
  541. ```python
  542. >>> import torch
  543. >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
  544. >>> from datasets import load_dataset
  545. >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
  546. >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
  547. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  548. >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
  549. >>> input_features = inputs.input_features
  550. >>> generated_ids = model.generate(inputs=input_features)
  551. >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  552. >>> transcription
  553. ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
  554. ```
  555. """
  556. # 1. prepare generation config
  557. generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
  558. # 2. set global generate variables
  559. input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
  560. num_segment_frames = input_stride * self.config.max_source_positions
  561. batch_size, total_input_frames = self._retrieve_total_input_frames(
  562. input_features=input_features, input_stride=input_stride, kwargs=kwargs
  563. )
  564. is_shortform = total_input_frames <= num_segment_frames
  565. # 3. Make sure generation config is correctly set
  566. # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
  567. return_dict_in_generate = self._set_return_outputs(
  568. return_dict_in_generate=return_dict_in_generate,
  569. return_token_timestamps=return_token_timestamps,
  570. logprob_threshold=logprob_threshold,
  571. generation_config=generation_config,
  572. )
  573. timestamp_begin = self._set_return_timestamps(
  574. return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config
  575. )
  576. self._set_language_and_task(
  577. language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
  578. )
  579. self._set_num_frames(
  580. return_token_timestamps=return_token_timestamps,
  581. generation_config=generation_config,
  582. attention_mask=attention_mask,
  583. kwargs=kwargs,
  584. )
  585. self._set_thresholds_and_condition(
  586. generation_config=generation_config,
  587. logprob_threshold=logprob_threshold,
  588. compression_ratio_threshold=compression_ratio_threshold,
  589. no_speech_threshold=no_speech_threshold,
  590. condition_on_prev_tokens=condition_on_prev_tokens,
  591. )
  592. self._set_prompt_condition_type(
  593. generation_config=generation_config,
  594. prompt_condition_type=prompt_condition_type,
  595. )
  596. # pass self.config for backward compatibility
  597. init_tokens = self._retrieve_init_tokens(
  598. input_features,
  599. batch_size=batch_size,
  600. generation_config=generation_config,
  601. config=self.config,
  602. num_segment_frames=num_segment_frames,
  603. kwargs=kwargs,
  604. )
  605. # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
  606. # where the input ids are handled explicitly by the generate method
  607. self._check_decoder_input_ids(kwargs=kwargs)
  608. # `output_attentions` is deprecated - we force eager attention if this feature is
  609. # indirectly requested, e.g. through return_token_timestamps
  610. if return_token_timestamps:
  611. self.model.config._attn_implementation = "eager"
  612. # 3. Retrieve logits processors
  613. device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
  614. begin_index = init_tokens.shape[1]
  615. num_beams = kwargs.get(
  616. "num_beams",
  617. generation_config.num_beams
  618. if hasattr(generation_config, "num_beams") and generation_config.num_beams is not None
  619. else 1,
  620. )
  621. if "assistant_model" in kwargs:
  622. # speculative decoding: the model should be able to return eos token
  623. generation_config.begin_suppress_tokens = None
  624. logits_processor = self._retrieve_logit_processors(
  625. generation_config=generation_config,
  626. logits_processor=logits_processor,
  627. begin_index=begin_index, # begin index is index of first generated decoder token
  628. num_beams=num_beams,
  629. device=device,
  630. )
  631. # 4 Set and retrieve global generation variables
  632. self._set_condition_on_prev_tokens(
  633. condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config
  634. )
  635. temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature
  636. temperature = temperatures[0]
  637. max_frames, seek = self._retrieve_max_frames_and_seek(
  638. batch_size=batch_size,
  639. attention_mask=attention_mask,
  640. total_input_frames=total_input_frames,
  641. is_shortform=is_shortform,
  642. )
  643. # 5 Prepare running variables, list for generation
  644. num_return_sequences = generation_config.num_return_sequences
  645. (
  646. batch_idx_map,
  647. cur_bsz,
  648. input_features,
  649. seek,
  650. max_frames,
  651. init_tokens,
  652. do_condition_on_prev_tokens,
  653. ) = self._expand_variables_for_generation(
  654. input_features=input_features,
  655. seek=seek,
  656. max_frames=max_frames,
  657. init_tokens=init_tokens,
  658. batch_size=batch_size,
  659. condition_on_prev_tokens=condition_on_prev_tokens,
  660. generation_config=generation_config,
  661. )
  662. current_segments = self._prepare_segments(
  663. prompt_ids=prompt_ids,
  664. batch_size=cur_bsz,
  665. generation_config=generation_config,
  666. )
  667. # 5bis speculative decoding: ensure the assistant model does only one call to generate and therefore returns decoder input token ids and eos token id
  668. # we set a flag in the generation config to force the model to make only one call to generate and return the decoder input token ids and eos token id
  669. if "assistant_model" in kwargs:
  670. assistant_model = kwargs["assistant_model"]
  671. assistant_model.generation_config.force_unique_generate_call = True
  672. if force_unique_generate_call is None:
  673. if hasattr(generation_config, "force_unique_generate_call"):
  674. force_unique_generate_call = generation_config.force_unique_generate_call
  675. elif hasattr(self.generation_config, "force_unique_generate_call"):
  676. force_unique_generate_call = self.generation_config.force_unique_generate_call
  677. else:
  678. force_unique_generate_call = False
  679. # 6 Transcribe audio until we reach the end of all input audios
  680. while (seek < max_frames).any():
  681. if monitor_progress is not None:
  682. monitor_progress(torch.stack((seek, max_frames), dim=1))
  683. # 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
  684. # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
  685. # to know which original audio is being decoded
  686. # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
  687. input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(
  688. input_features=input_features,
  689. seek=seek,
  690. max_frames=max_frames,
  691. cur_bsz=cur_bsz,
  692. batch_idx_map=batch_idx_map,
  693. )
  694. time_offset = (
  695. seek.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision / input_stride
  696. )
  697. seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
  698. # 6.2 cut out next 30s segment from input features
  699. segment_input = self._get_input_segment(
  700. input_features=input_features,
  701. seek=seek,
  702. seek_num_frames=seek_num_frames,
  703. num_segment_frames=num_segment_frames,
  704. cur_bsz=cur_bsz,
  705. batch_idx_map=batch_idx_map,
  706. )
  707. # 6.3 prepare decoder input ids
  708. suppress_tokens = _get_attr_from_logit_processors(
  709. logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
  710. )
  711. decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
  712. cur_bsz=cur_bsz,
  713. init_tokens=init_tokens,
  714. current_segments=current_segments,
  715. batch_idx_map=batch_idx_map,
  716. do_condition_on_prev_tokens=do_condition_on_prev_tokens,
  717. prompt_ids=prompt_ids,
  718. generation_config=generation_config,
  719. config=self.config,
  720. device=init_tokens.device,
  721. suppress_tokens=suppress_tokens,
  722. timestamp_begin=timestamp_begin,
  723. kwargs=kwargs,
  724. )
  725. # 6.4 set max new tokens or max length
  726. self._set_max_new_tokens_and_length(
  727. config=self.config,
  728. decoder_input_ids=decoder_input_ids,
  729. generation_config=generation_config,
  730. )
  731. # 6.5 Set current `begin_index` for all logit processors
  732. if logits_processor is not None:
  733. for proc in logits_processor:
  734. if hasattr(proc, "set_begin_index"):
  735. proc.set_begin_index(decoder_input_ids.shape[-1])
  736. # 6.6 Run generate with fallback
  737. (
  738. seek_sequences,
  739. seek_outputs,
  740. should_skip,
  741. do_condition_on_prev_tokens,
  742. model_output_type,
  743. ) = self.generate_with_fallback(
  744. segment_input=segment_input,
  745. decoder_input_ids=decoder_input_ids,
  746. cur_bsz=cur_bsz,
  747. seek=seek,
  748. batch_idx_map=batch_idx_map,
  749. temperatures=temperatures,
  750. generation_config=generation_config,
  751. logits_processor=logits_processor,
  752. stopping_criteria=stopping_criteria,
  753. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  754. synced_gpus=synced_gpus,
  755. return_token_timestamps=return_token_timestamps,
  756. do_condition_on_prev_tokens=do_condition_on_prev_tokens,
  757. is_shortform=is_shortform,
  758. batch_size=batch_size,
  759. attention_mask=attention_mask,
  760. kwargs=kwargs,
  761. )
  762. # 6.7 In every generated sequence, split by timestamp tokens and extract segments
  763. for i, seek_sequence in enumerate(seek_sequences):
  764. prev_i = batch_idx_map[i]
  765. if should_skip[i]:
  766. seek[prev_i] += seek_num_frames[prev_i]
  767. continue
  768. segments, segment_offset = self._retrieve_segment(
  769. seek_sequence=seek_sequence,
  770. seek_outputs=seek_outputs,
  771. time_offset=time_offset,
  772. timestamp_begin=timestamp_begin,
  773. seek_num_frames=seek_num_frames,
  774. time_precision=time_precision,
  775. time_precision_features=time_precision_features,
  776. input_stride=input_stride,
  777. prev_idx=prev_i,
  778. idx=i,
  779. return_token_timestamps=return_token_timestamps,
  780. decoder_input_ids=decoder_input_ids,
  781. )
  782. seek[prev_i] += segment_offset
  783. current_segments[prev_i] += segments
  784. if force_unique_generate_call:
  785. break
  786. # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
  787. # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
  788. final_segments = (
  789. [x[1:] for x in current_segments]
  790. if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
  791. else current_segments
  792. )
  793. # if return_dict_in_generate=True and we forced a unique call to generate or return_timestamps=False, meaning we are sure only one call to generate has been made,
  794. # -> we can return a ModelOutput
  795. # otherwise, return_dict_in_generate is applied in the 'result' of each segment in final_segments
  796. if (
  797. return_dict_in_generate
  798. and generation_config.return_dict_in_generate
  799. and (force_unique_generate_call or not return_timestamps)
  800. ):
  801. # only one call to generate_with_fallback, we can return a ModelOutput
  802. outputs = self._stack_split_outputs(seek_outputs, model_output_type, self.device, kwargs)
  803. if num_return_sequences > 1:
  804. if hasattr(outputs, "encoder_attentions") and outputs.encoder_attentions is not None:
  805. outputs.encoder_attentions = tuple(
  806. outputs.encoder_attentions[i][::num_return_sequences]
  807. for i in range(len(outputs.encoder_attentions))
  808. )
  809. if hasattr(outputs, "encoder_hidden_states") and outputs.encoder_hidden_states is not None:
  810. outputs.encoder_hidden_states = tuple(
  811. outputs.encoder_hidden_states[i][::num_return_sequences]
  812. for i in range(len(outputs.encoder_hidden_states))
  813. )
  814. return outputs
  815. padded_outputs = _pad_to_max_length(
  816. current_segments=final_segments,
  817. pad_token_id=generation_config.pad_token_id,
  818. device=self.device,
  819. padding_side="right",
  820. return_token_timestamps=return_token_timestamps,
  821. force_unique_generate_call=force_unique_generate_call,
  822. )
  823. if return_dict_in_generate and generation_config.return_dict_in_generate:
  824. logger.warning_once(
  825. "You have passed `return_dict_in_generate=True` and `return_timestamps=True`, this automatically sets `return_segments=True` to access the results of the underlying calls to GenerationMixin's generate in the returned `segments`."
  826. )
  827. return_segments = True
  828. elif not return_segments and not return_token_timestamps:
  829. return padded_outputs
  830. if return_token_timestamps:
  831. sequences, token_timestamps = padded_outputs
  832. outputs = {
  833. "sequences": sequences,
  834. "token_timestamps": token_timestamps,
  835. }
  836. else:
  837. sequences = padded_outputs
  838. outputs = {
  839. "sequences": sequences,
  840. }
  841. if return_segments:
  842. outputs["segments"] = final_segments
  843. return outputs
  844. def generate_with_fallback(
  845. self,
  846. segment_input,
  847. decoder_input_ids,
  848. cur_bsz,
  849. seek,
  850. batch_idx_map,
  851. temperatures,
  852. generation_config,
  853. logits_processor,
  854. stopping_criteria,
  855. prefix_allowed_tokens_fn,
  856. synced_gpus,
  857. return_token_timestamps,
  858. do_condition_on_prev_tokens,
  859. is_shortform,
  860. batch_size,
  861. attention_mask,
  862. kwargs,
  863. ):
  864. kwargs = copy.copy(kwargs)
  865. # 6.6 Batch generate current chunk
  866. seek_sequence_list = [None for _ in range(cur_bsz)]
  867. seek_outputs_list = [None for _ in range(cur_bsz)]
  868. needs_fallback = [False for _ in range(cur_bsz)]
  869. should_skip = [False for _ in range(cur_bsz)]
  870. fallback_index_map = list(range(cur_bsz))
  871. if generation_config.no_speech_threshold is not None:
  872. self._setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs)
  873. for fallback_idx, temperature in enumerate(temperatures):
  874. generation_config.do_sample = temperature is not None and temperature > 0.0
  875. generation_config.temperature = temperature if generation_config.do_sample else 1.0
  876. if generation_config.do_sample:
  877. generation_config.num_beams = 1
  878. generate_kwargs = copy.copy(kwargs)
  879. for key in ["do_sample", "temperature", "num_beams"]:
  880. if key in generate_kwargs:
  881. del generate_kwargs[key]
  882. cur_bsz = decoder_input_ids.shape[0]
  883. if generation_config.cache_implementation == "static" and cur_bsz < batch_size:
  884. segment_input = F.pad(segment_input, (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0)
  885. decoder_input_ids = F.pad(
  886. decoder_input_ids, (0, 0, 0, batch_size - cur_bsz), value=generation_config.pad_token_id
  887. )
  888. if generate_kwargs.get("decoder_attention_mask") is not None:
  889. generate_kwargs["decoder_attention_mask"] = F.pad(
  890. generate_kwargs["decoder_attention_mask"], (0, 0, 0, batch_size - cur_bsz), value=True
  891. )
  892. if generate_kwargs.get("encoder_outputs") is not None:
  893. generate_kwargs["encoder_outputs"] = F.pad(
  894. generate_kwargs["encoder_outputs"], (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0
  895. )
  896. seek_outputs = super().generate(
  897. segment_input,
  898. generation_config=generation_config,
  899. logits_processor=logits_processor,
  900. stopping_criteria=stopping_criteria,
  901. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  902. synced_gpus=synced_gpus,
  903. decoder_input_ids=decoder_input_ids,
  904. attention_mask=attention_mask,
  905. **generate_kwargs,
  906. )
  907. model_output_type = type(seek_outputs)
  908. # post-process sequence tokens and outputs to be in list form
  909. seek_sequences, seek_outputs = self._postprocess_outputs(
  910. seek_outputs=seek_outputs,
  911. decoder_input_ids=decoder_input_ids,
  912. return_token_timestamps=return_token_timestamps,
  913. generation_config=generation_config,
  914. is_shortform=is_shortform,
  915. seek=seek,
  916. batch_idx_map=batch_idx_map,
  917. )
  918. if cur_bsz < batch_size:
  919. seek_sequences = seek_sequences[:cur_bsz]
  920. seek_outputs = seek_outputs[:cur_bsz]
  921. # 6.7 Extract cut sequences from every sequence and check if fallback should be applied
  922. # Loop over each decoded audio individually as each decoding can be of a different length
  923. new_fallback_index_map = []
  924. new_segment_input = []
  925. new_decoder_input_ids = []
  926. new_decoder_attention_mask = []
  927. for i, seek_sequence in enumerate(seek_sequences):
  928. # remove all padding tokens, except for the eos token
  929. if seek_sequence[-1] == generation_config.pad_token_id:
  930. num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
  931. if generation_config.pad_token_id == generation_config.eos_token_id:
  932. # we do not remove the eos token id since it is needed for avg logprob calculation in _need_fallback
  933. num_paddings -= 1
  934. if num_paddings != 0:
  935. seek_sequence = seek_sequence[:-num_paddings]
  936. # check which sequences in batch need fallback & which should be skipped
  937. needs_fallback[i], should_skip[i] = self._need_fallback(
  938. seek_sequence,
  939. seek_outputs,
  940. i,
  941. logits_processor,
  942. generation_config,
  943. self.config.vocab_size,
  944. temperature,
  945. )
  946. # remove eos token
  947. if seek_sequence[-1] == generation_config.eos_token_id:
  948. seek_sequence = seek_sequence[:-1]
  949. seek_sequence_list[fallback_index_map[i]] = seek_sequence
  950. seek_outputs_list[fallback_index_map[i]] = seek_outputs[i]
  951. is_low_temperature = temperature is None or temperature < 0.5
  952. do_condition_on_prev_tokens[fallback_index_map[i]] = (
  953. generation_config.condition_on_prev_tokens and is_low_temperature
  954. )
  955. if needs_fallback[i]:
  956. new_fallback_index_map.append(fallback_index_map[i])
  957. new_segment_input.append(segment_input[i])
  958. new_decoder_input_ids.append(decoder_input_ids[i])
  959. if "decoder_attention_mask" in kwargs:
  960. new_decoder_attention_mask.append(kwargs["decoder_attention_mask"][i])
  961. fallback_index_map = new_fallback_index_map
  962. # if no sequence needs to be run with temperature fallback, we're finished
  963. if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1:
  964. seek_sequences = seek_sequence_list
  965. seek_outputs = seek_outputs_list
  966. break
  967. # if we're still in the loop, make sure that decoder_input_ids and segment inputs are tensors
  968. decoder_input_ids = torch.stack(new_decoder_input_ids)
  969. segment_input = torch.stack(new_segment_input)
  970. if "decoder_attention_mask" in kwargs:
  971. kwargs["decoder_attention_mask"] = torch.stack(new_decoder_attention_mask)
  972. return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, model_output_type
  973. @staticmethod
  974. def _prepare_segments(prompt_ids, batch_size, generation_config):
  975. if prompt_ids is not None and generation_config.prompt_condition_type == "first-segment":
  976. prev_sot_token_id = getattr(generation_config, "prev_sot_token_id", None)
  977. prompt_ids = prompt_ids[1:] if prompt_ids[0] == prev_sot_token_id else prompt_ids
  978. current_segments = [[{"tokens": prompt_ids}] for _ in range(batch_size)]
  979. else:
  980. current_segments = [[] for _ in range(batch_size)]
  981. return current_segments
  982. def _postprocess_outputs(
  983. self,
  984. seek_outputs,
  985. decoder_input_ids,
  986. return_token_timestamps,
  987. generation_config,
  988. is_shortform,
  989. seek,
  990. batch_idx_map,
  991. ):
  992. # remove all previously passed decoder input ids
  993. # should happen only if it is the first generated segment
  994. start_idx = decoder_input_ids.shape[-1]
  995. if isinstance(seek_outputs, torch.Tensor):
  996. return seek_outputs[:, start_idx:], seek_outputs
  997. if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
  998. num_frames = getattr(generation_config, "num_frames")
  999. if num_frames is not None:
  1000. num_frames = num_frames - seek
  1001. num_frames = num_frames[batch_idx_map]
  1002. seek_outputs["token_timestamps"] = self._extract_token_timestamps(
  1003. seek_outputs,
  1004. generation_config.alignment_heads,
  1005. num_frames=num_frames,
  1006. num_input_ids=decoder_input_ids.shape[-1],
  1007. )
  1008. def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None):
  1009. if beam_indices is not None and key == "scores":
  1010. return [v[beam_idx].cpu() for (v, beam_idx) in zip(values, beam_indices[batch_idx][: len(values)])]
  1011. if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
  1012. return [v[batch_idx].cpu() for v in values]
  1013. if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
  1014. return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
  1015. elif key == "past_key_values":
  1016. if not is_shortform:
  1017. # we don't save `past_key_values` as this is too costly for longform
  1018. return None
  1019. all_past_key_values = []
  1020. for layer_idx in range(self.config.decoder_layers):
  1021. layer_cache = (
  1022. values.self_attention_cache.layers[layer_idx].keys[batch_idx][None].cpu(),
  1023. values.self_attention_cache.layers[layer_idx].values[batch_idx][None].cpu(),
  1024. values.cross_attention_cache.layers[layer_idx].keys[batch_idx][None].cpu(),
  1025. values.cross_attention_cache.layers[layer_idx].values[batch_idx][None].cpu(),
  1026. )
  1027. all_past_key_values.append(layer_cache)
  1028. return EncoderDecoderCache(all_past_key_values)
  1029. return values[batch_idx].cpu()
  1030. sequence_tokens = seek_outputs["sequences"][:, start_idx:]
  1031. seek_outputs = [
  1032. {
  1033. k: split_by_batch_index(v, k, i, is_shortform, beam_indices=seek_outputs.get("beam_indices"))
  1034. for k, v in seek_outputs.items()
  1035. }
  1036. for i in range(sequence_tokens.shape[0])
  1037. ]
  1038. return sequence_tokens, seek_outputs
  1039. def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs):
  1040. # Stack back seek_outputs tensors after splitting them with the split_by_batch_index method
  1041. outputs = {}
  1042. for key in seek_outputs[0]:
  1043. if key in ["sequences", "beam_indices", "token_timestamps"]:
  1044. outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device)
  1045. elif key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
  1046. outputs[key] = tuple(
  1047. torch.stack([v[key][i] for v in seek_outputs]).to(device) for i in range(len(seek_outputs[0][key]))
  1048. )
  1049. elif key == "sequences_scores":
  1050. outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device)
  1051. elif key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
  1052. outputs[key] = tuple(
  1053. tuple(
  1054. torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device)
  1055. for j in range(len(seek_outputs[0][key][0]))
  1056. )
  1057. for i in range(len(seek_outputs[0][key]))
  1058. )
  1059. elif key == "past_key_values":
  1060. if seek_outputs[0][key] is not None:
  1061. all_past_key_values = []
  1062. for layer_idx in range(len(seek_outputs[0][key])):
  1063. self_attention_k, self_attention_v, cross_attention_k, cross_attention_v = (
  1064. torch.stack(
  1065. [
  1066. getattr(getattr(sub_output[key], sub_cache).layers[layer_idx], sub_key)
  1067. for sub_output in seek_outputs
  1068. ]
  1069. )
  1070. .squeeze(1)
  1071. .to(device)
  1072. for sub_cache in ["self_attention_cache", "cross_attention_cache"]
  1073. for sub_key in ["keys", "values"]
  1074. )
  1075. all_past_key_values.append(
  1076. (self_attention_k, self_attention_v, cross_attention_k, cross_attention_v)
  1077. )
  1078. outputs[key] = EncoderDecoderCache(tuple(all_past_key_values))
  1079. else:
  1080. outputs[key] = None
  1081. token_timestamps = outputs.get("token_timestamps")
  1082. if token_timestamps is not None:
  1083. model_output_type = dict
  1084. return model_output_type(**outputs)
  1085. def _need_fallback(
  1086. self,
  1087. seek_sequence,
  1088. seek_outputs,
  1089. index,
  1090. logits_processor,
  1091. generation_config,
  1092. vocab_size,
  1093. temperature,
  1094. ):
  1095. needs_fallback = False
  1096. should_skip = False
  1097. if generation_config.compression_ratio_threshold is not None:
  1098. compression_ratio = self._retrieve_compression_ratio(seek_sequence, vocab_size)
  1099. if compression_ratio > generation_config.compression_ratio_threshold:
  1100. needs_fallback = True
  1101. if generation_config.logprob_threshold is not None:
  1102. if hasattr(seek_outputs[0], "sequences_scores"):
  1103. logprobs = [s["sequences_scores"] for s in seek_outputs][index]
  1104. else:
  1105. scores = seek_outputs[index]["scores"]
  1106. logprobs = self._retrieve_avg_logprobs(
  1107. scores,
  1108. seek_sequence,
  1109. temperature,
  1110. )
  1111. if logprobs < generation_config.logprob_threshold:
  1112. needs_fallback = True
  1113. if generation_config.no_speech_threshold is not None:
  1114. no_speech_prob = _get_attr_from_logit_processors(
  1115. logits_processor, WhisperNoSpeechDetection, "no_speech_prob"
  1116. )
  1117. if (
  1118. logprobs < generation_config.logprob_threshold
  1119. and no_speech_prob[index] > generation_config.no_speech_threshold
  1120. ):
  1121. needs_fallback = False
  1122. should_skip = True
  1123. return needs_fallback, should_skip
  1124. def _expand_variables_for_generation(
  1125. self, input_features, seek, max_frames, init_tokens, batch_size, condition_on_prev_tokens, generation_config
  1126. ):
  1127. if generation_config.num_return_sequences is not None and generation_config.num_return_sequences > 1:
  1128. batch_idx_map = list(range(batch_size * generation_config.num_return_sequences))
  1129. cur_bsz = len(batch_idx_map)
  1130. do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(len(batch_idx_map))]
  1131. input_features = input_features.repeat_interleave(generation_config.num_return_sequences, dim=0)
  1132. seek = seek.repeat_interleave(generation_config.num_return_sequences, dim=0)
  1133. max_frames = max_frames.repeat_interleave(generation_config.num_return_sequences, dim=0)
  1134. init_tokens = init_tokens.repeat_interleave(generation_config.num_return_sequences, dim=0)
  1135. generation_config.num_return_sequences = 1
  1136. else:
  1137. cur_bsz = batch_size
  1138. batch_idx_map = list(range(cur_bsz))
  1139. do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(cur_bsz)]
  1140. return (
  1141. batch_idx_map,
  1142. cur_bsz,
  1143. input_features,
  1144. seek,
  1145. max_frames,
  1146. init_tokens,
  1147. do_condition_on_prev_tokens,
  1148. )
  1149. @staticmethod
  1150. def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs):
  1151. set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs")
  1152. extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)}
  1153. set_inputs({"inputs": segment_input, "input_ids": decoder_input_ids, **extra_kwargs})
  1154. @staticmethod
  1155. def _retrieve_total_input_frames(input_features, input_stride, kwargs):
  1156. if input_features is not None:
  1157. return input_features.shape[0], input_features.shape[-1]
  1158. if "encoder_outputs" in kwargs:
  1159. encoder_outputs_shape = (
  1160. kwargs["encoder_outputs"][0].shape
  1161. if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
  1162. else kwargs["encoder_outputs"].shape
  1163. )
  1164. return encoder_outputs_shape[0], encoder_outputs_shape[1] * input_stride
  1165. raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
  1166. @staticmethod
  1167. def _maybe_warn_unused_inputs(
  1168. condition_on_prev_tokens,
  1169. temperature,
  1170. compression_ratio_threshold,
  1171. logprob_threshold,
  1172. no_speech_threshold,
  1173. total_input_frames,
  1174. ):
  1175. warning_prefix = (
  1176. f"Audio input consists of only {total_input_frames}. "
  1177. "Short-form transcription is activated."
  1178. "{}, but will be ignored."
  1179. )
  1180. if condition_on_prev_tokens is not None:
  1181. logger.warning(warning_prefix.format(f"condition_on_prev_tokens is set to {condition_on_prev_tokens}"))
  1182. if compression_ratio_threshold is not None:
  1183. logger.warning(
  1184. warning_prefix.format(f"compression_ratio_threshold is set to {compression_ratio_threshold}")
  1185. )
  1186. if logprob_threshold is not None:
  1187. logger.warning(warning_prefix.format(f"logprob_threshold is set to {logprob_threshold}"))
  1188. if no_speech_threshold is not None:
  1189. logger.warning(warning_prefix.format(f"no_speech_threshold is set to {no_speech_threshold}"))
  1190. @staticmethod
  1191. def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config):
  1192. if return_dict_in_generate is None:
  1193. return_dict_in_generate = generation_config.return_dict_in_generate
  1194. else:
  1195. generation_config.return_dict_in_generate = return_dict_in_generate
  1196. generation_config.return_token_timestamps = return_token_timestamps
  1197. if return_token_timestamps:
  1198. generation_config.return_dict_in_generate = True
  1199. generation_config.output_attentions = True
  1200. generation_config.output_scores = True
  1201. if logprob_threshold is not None:
  1202. generation_config.return_dict_in_generate = True
  1203. generation_config.output_scores = True
  1204. return return_dict_in_generate
  1205. def _set_return_timestamps(self, return_timestamps, is_shortform, generation_config):
  1206. if return_timestamps is None and hasattr(generation_config, "return_timestamps"):
  1207. return_timestamps = generation_config.return_timestamps
  1208. if not is_shortform:
  1209. if return_timestamps is False:
  1210. raise ValueError(
  1211. "You have passed more than 3000 mel input features (> 30 seconds) which automatically "
  1212. "enables long-form generation which requires the model to predict timestamp tokens. Please "
  1213. "either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
  1214. )
  1215. logger.info("Setting `return_timestamps=True` for long-form generation.")
  1216. return_timestamps = True
  1217. if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"):
  1218. raise ValueError(
  1219. "You are trying to return timestamps, but the generation config is not properly set. "
  1220. "Make sure to initialize the generation config with the correct attributes that are needed such as "
  1221. "`no_timestamps_token_id`. For more details on how to generate the approtiate config, refer to "
  1222. "https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
  1223. )
  1224. generation_config.return_timestamps = return_timestamps
  1225. if hasattr(generation_config, "no_timestamps_token_id"):
  1226. timestamp_begin = generation_config.no_timestamps_token_id + 1
  1227. else:
  1228. # BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form
  1229. # with no timestamps. We set the timestamp begin token larger than the vocab size, such that the
  1230. # timestamp condition is never met in the decoding loop
  1231. timestamp_begin = self.config.vocab_size + 1
  1232. return timestamp_begin
  1233. @staticmethod
  1234. def _set_language_and_task(language, task, is_multilingual, generation_config):
  1235. if is_multilingual is not None:
  1236. if not hasattr(generation_config, "is_multilingual"):
  1237. raise ValueError(
  1238. "The generation config is outdated and is thus not compatible with the `is_multilingual` argument "
  1239. "to `generate`. Please update the generation config as per the instructions "
  1240. "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
  1241. )
  1242. generation_config.is_multilingual = is_multilingual
  1243. if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual:
  1244. if task is not None or language is not None:
  1245. raise ValueError(
  1246. "Cannot specify `task` or `language` for an English-only model. If the model is intended to be "
  1247. "multilingual, pass `is_multilingual=True` to generate, or update the generation config."
  1248. )
  1249. if language is not None:
  1250. if not hasattr(generation_config, "lang_to_id"):
  1251. raise ValueError(
  1252. "The generation config is outdated and is thus not compatible with the `language` argument "
  1253. "to `generate`. Please update the generation config as per the instructions "
  1254. "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
  1255. )
  1256. generation_config.language = language
  1257. if task is not None:
  1258. if not hasattr(generation_config, "task_to_id"):
  1259. raise ValueError(
  1260. "The generation config is outdated and is thus not compatible with the `task` argument "
  1261. "to `generate`. Please update the generation config as per the instructions "
  1262. "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
  1263. )
  1264. generation_config.task = task
  1265. def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
  1266. def replace_or_add(lst: list[int], num: int, itr: Iterator[int]):
  1267. """short function to replace num with a itr in lst"""
  1268. found = any(i in lst for i in itr)
  1269. if found:
  1270. lst = [num if i in itr else i for i in lst]
  1271. else:
  1272. lst.append(num)
  1273. return lst
  1274. def language_to_id(language: str) -> int:
  1275. language = language.lower()
  1276. if language in generation_config.lang_to_id:
  1277. language_token = language
  1278. elif language in TO_LANGUAGE_CODE:
  1279. language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
  1280. elif language in TO_LANGUAGE_CODE.values():
  1281. language_token = f"<|{language}|>"
  1282. else:
  1283. is_language_code = len(language) == 2
  1284. raise ValueError(
  1285. f"Unsupported language: {language}. Language should be one of:"
  1286. f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
  1287. )
  1288. if language_token not in generation_config.lang_to_id:
  1289. raise ValueError(
  1290. f"{language_token} is not supported by this specific model as it is not in the "
  1291. "`generation_config.lang_to_id`. (You should just add it to the generation config)"
  1292. )
  1293. return generation_config.lang_to_id[language_token]
  1294. task = getattr(generation_config, "task", None)
  1295. language = getattr(generation_config, "language", None)
  1296. init_tokens = [generation_config.decoder_start_token_id]
  1297. # TL;DR we silently ignore `forced_decoder_ids` (old flag) when `task` or `language` (new flags) are set.
  1298. # `forced_decoder_ids` is an old generation config attribute that is now deprecated in favor of `task` and
  1299. # `language` (see https://github.com/huggingface/transformers/pull/28687). Nevertheless, keep in mind that
  1300. # the original checkpoints all contain this attribute, and thus we should maintain backwards compatibility.
  1301. if task is None and language is None:
  1302. forced_decoder_ids = getattr(generation_config, "forced_decoder_ids", None)
  1303. # fallback: check the model config for forced_decoder_ids
  1304. if forced_decoder_ids is None and getattr(config, "forced_decoder_ids", None) is not None:
  1305. forced_decoder_ids = config.forced_decoder_ids
  1306. if forced_decoder_ids is not None:
  1307. logger.warning_once(
  1308. "Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of "
  1309. "the `task` and `language` flags/config options."
  1310. )
  1311. if forced_decoder_ids is not None and forced_decoder_ids[0][1] is None:
  1312. logger.warning_once(
  1313. "Transcription using a multilingual Whisper will default to language detection followed by "
  1314. "transcription instead of translation to English. This might be a breaking change for your "
  1315. "use case. If you want to instead always translate your audio to English, make sure to pass "
  1316. "`language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details."
  1317. )
  1318. if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
  1319. i = 1
  1320. while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
  1321. init_tokens += [forced_decoder_ids[0][1]]
  1322. forced_decoder_ids = forced_decoder_ids[1:]
  1323. i += 1
  1324. if len(forced_decoder_ids) > 0:
  1325. raise ValueError(
  1326. f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow "
  1327. f"the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all "
  1328. f"indices >= 1 and < {forced_decoder_ids[0][0]}.",
  1329. )
  1330. is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
  1331. # Make sure language is a list of strings of the correct length
  1332. if isinstance(language, (list, tuple)):
  1333. if any(l is None for l in language):
  1334. raise TypeError(
  1335. "Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with "
  1336. "length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list "
  1337. "containing `None`."
  1338. )
  1339. if len(language) != batch_size:
  1340. raise ValueError(
  1341. "When passing a list of languages, the length of the list must match the batch size. "
  1342. f"Expected length of {batch_size}, but got {len(language)} languages."
  1343. )
  1344. languages = language
  1345. elif language is None:
  1346. # Language will be detected for each item in batch
  1347. languages = [None] * batch_size
  1348. else:
  1349. languages = [language] # Use a length-1 list now, broadcast later
  1350. # Separate init_tokens for each language
  1351. init_tokens = [copy.copy(init_tokens) for _ in languages]
  1352. # Update init_tokens with languages
  1353. lang_ids = None
  1354. if language is not None:
  1355. lang_ids = [language_to_id(l) for l in languages]
  1356. elif hasattr(generation_config, "lang_to_id") and is_lang_id_undefined:
  1357. # language is not defined or intentionally set to `None` to trigger language detection
  1358. lang_ids = self.detect_language(
  1359. input_features=input_features,
  1360. encoder_outputs=kwargs.get("encoder_outputs", None),
  1361. generation_config=generation_config,
  1362. num_segment_frames=num_segment_frames,
  1363. ).tolist()
  1364. if lang_ids is not None:
  1365. # append or replace lang_ids to init_tokens
  1366. for i in range(len(init_tokens)):
  1367. if len(init_tokens[i]) > 1:
  1368. init_tokens[i][1] = lang_ids[i]
  1369. else:
  1370. init_tokens[i].append(lang_ids[i])
  1371. del languages
  1372. # Update init_tokens with task
  1373. for i in range(len(init_tokens)):
  1374. if task is not None:
  1375. if task in TASK_IDS:
  1376. init_tokens[i].append(generation_config.task_to_id[generation_config.task])
  1377. task_id = generation_config.task_to_id[generation_config.task]
  1378. # if task is defined it'll overwrite task ids that might have already been defined via the generation_config
  1379. replace_or_add(init_tokens[i], task_id, generation_config.task_to_id.values())
  1380. else:
  1381. raise ValueError(f"The `{task}` task is not supported. The task should be one of `{TASK_IDS}`")
  1382. elif language is not None and hasattr(generation_config, "task_to_id"):
  1383. # if language is defined, but no task id is in `init_tokens`, default to transcribe
  1384. if not any(ti in init_tokens[i] for ti in generation_config.task_to_id.values()):
  1385. init_tokens[i].append(generation_config.task_to_id["transcribe"])
  1386. if (
  1387. not generation_config.return_timestamps
  1388. and hasattr(generation_config, "no_timestamps_token_id")
  1389. and init_tokens[i][-1] != generation_config.no_timestamps_token_id
  1390. ):
  1391. init_tokens[i].append(generation_config.no_timestamps_token_id)
  1392. elif (
  1393. generation_config.return_timestamps and init_tokens[i][-1] == generation_config.no_timestamps_token_id
  1394. ):
  1395. logger.info(
  1396. "<|notimestamps|> prompt token is removed from generation_config since `return_timestamps` is set to `'True'`."
  1397. )
  1398. init_tokens[i] = init_tokens[i][:-1]
  1399. # let's make sure we don't pass `None` tokens as prompt tokens
  1400. init_tokens[i] = [t for t in init_tokens[i] if t is not None]
  1401. return torch.as_tensor(init_tokens, dtype=torch.long, device=self.device).expand(batch_size, -1)
  1402. def detect_language(
  1403. self,
  1404. input_features: torch.FloatTensor | None = None,
  1405. encoder_outputs: torch.FloatTensor | BaseModelOutput | None = None,
  1406. generation_config: GenerationConfig | None = None,
  1407. num_segment_frames: int = 3000,
  1408. ) -> torch.Tensor:
  1409. """
  1410. Detects language from log-mel input features or encoder_outputs
  1411. Parameters:
  1412. input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
  1413. Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
  1414. loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via
  1415. the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
  1416. [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
  1417. tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
  1418. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
  1419. Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
  1420. `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
  1421. hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
  1422. generation_config (`~generation.GenerationConfig`, *optional*):
  1423. The generation configuration to be used as base parametrization for the generation call. `**kwargs`
  1424. passed to generate matching the attributes of `generation_config` will override them. If
  1425. `generation_config` is not provided, the default will be used, which had the following loading
  1426. priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
  1427. configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
  1428. default values, whose documentation should be checked to parameterize generation.
  1429. num_segment_frames (`int`, *optional*, defaults to 3000):
  1430. The number of log-mel frames the model expects
  1431. Return:
  1432. A `torch.LongTensor` representing the detected language ids.
  1433. """
  1434. if input_features is None and encoder_outputs is None:
  1435. raise ValueError("You have to specify either `input_features` or `encoder_outputs`")
  1436. elif input_features is not None and encoder_outputs is not None:
  1437. raise ValueError("Make sure to specify only one of `input_features` or `encoder_outputs` - not both!")
  1438. elif input_features is not None:
  1439. inputs = {"input_features": input_features[:, :, :num_segment_frames]}
  1440. batch_size = input_features.shape[0]
  1441. elif encoder_outputs is not None:
  1442. inputs = {"encoder_outputs": encoder_outputs}
  1443. batch_size = (
  1444. encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0]
  1445. )
  1446. generation_config = generation_config or self.generation_config
  1447. decoder_input_ids = (
  1448. torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
  1449. * generation_config.decoder_start_token_id
  1450. )
  1451. with torch.no_grad():
  1452. logits = self(**inputs, decoder_input_ids=decoder_input_ids, use_cache=False).logits[:, -1]
  1453. non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)
  1454. non_lang_mask[list(generation_config.lang_to_id.values())] = False
  1455. logits[:, non_lang_mask] = -np.inf
  1456. lang_ids = logits.argmax(-1)
  1457. return lang_ids
  1458. @staticmethod
  1459. def _check_decoder_input_ids(kwargs):
  1460. decoder_input_ids = kwargs.get("decoder_input_ids", None)
  1461. assistant_model = kwargs.get("assistant_model", None)
  1462. if decoder_input_ids is not None and assistant_model is not None:
  1463. raise ValueError(
  1464. "Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
  1465. )
  1466. @staticmethod
  1467. def _set_num_frames(return_token_timestamps, generation_config, attention_mask, kwargs):
  1468. if return_token_timestamps:
  1469. if getattr(generation_config, "task", None) == "translate":
  1470. logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
  1471. if not hasattr(generation_config, "alignment_heads"):
  1472. raise ValueError(
  1473. "Model generation config has no `alignment_heads`, token-level timestamps not available. "
  1474. "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
  1475. )
  1476. if attention_mask is not None:
  1477. generation_config.num_frames = attention_mask.sum(-1).cpu()
  1478. else:
  1479. logger.warning_once(
  1480. "When setting `return_token_timestamps` to `True`, make sure to pass an `attention_mask` to get precise token-level timestamps. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
  1481. )
  1482. generation_config.num_frames = None
  1483. @staticmethod
  1484. def _set_thresholds_and_condition(
  1485. generation_config,
  1486. logprob_threshold,
  1487. compression_ratio_threshold,
  1488. no_speech_threshold,
  1489. condition_on_prev_tokens,
  1490. ):
  1491. generation_config.logprob_threshold = (
  1492. logprob_threshold
  1493. if logprob_threshold is not None
  1494. else getattr(generation_config, "logprob_threshold", None)
  1495. )
  1496. generation_config.compression_ratio_threshold = (
  1497. compression_ratio_threshold
  1498. if compression_ratio_threshold is not None
  1499. else getattr(generation_config, "compression_ratio_threshold", None)
  1500. )
  1501. generation_config.no_speech_threshold = (
  1502. no_speech_threshold
  1503. if no_speech_threshold is not None
  1504. else getattr(generation_config, "no_speech_threshold", None)
  1505. )
  1506. generation_config.condition_on_prev_tokens = (
  1507. condition_on_prev_tokens
  1508. if condition_on_prev_tokens is not None
  1509. else getattr(generation_config, "condition_on_prev_tokens", None)
  1510. )
  1511. @staticmethod
  1512. def _set_prompt_condition_type(generation_config, prompt_condition_type):
  1513. allowed_cond_types = ["first-segment", "all-segments"]
  1514. # default to "first-segment"
  1515. prompt_condition_type = prompt_condition_type or allowed_cond_types[0]
  1516. if prompt_condition_type not in allowed_cond_types:
  1517. raise ValueError(
  1518. f"`prompt_condition_type={prompt_condition_type} does not exist. Make sure to set `prompt_condition_type` to one of {', '.join(allowed_cond_types)}"
  1519. )
  1520. if generation_config.condition_on_prev_tokens is not True and prompt_condition_type == "all-segments":
  1521. raise ValueError(
  1522. "Make sure to set `condition_on_prev_tokens=True` when setting `prompt_condition_type='all-segments'`."
  1523. )
  1524. generation_config.prompt_condition_type = prompt_condition_type
  1525. @staticmethod
  1526. def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config):
  1527. condition_on_prev_tokens = (
  1528. condition_on_prev_tokens
  1529. if condition_on_prev_tokens is not None
  1530. else getattr(generation_config, "condition_on_prev_tokens", False)
  1531. )
  1532. generation_config.condition_on_prev_tokens = condition_on_prev_tokens
  1533. @staticmethod
  1534. def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames, is_shortform):
  1535. if batch_size > 1 and not is_shortform and attention_mask is None:
  1536. raise ValueError(
  1537. "When doing batched long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
  1538. )
  1539. elif batch_size > 1 and not is_shortform:
  1540. max_frames = attention_mask.sum(-1).cpu().to(torch.long)
  1541. seek = torch.zeros((batch_size,), dtype=torch.long)
  1542. else:
  1543. max_frames = torch.ones((batch_size,), dtype=torch.long) * total_input_frames
  1544. seek = torch.zeros((batch_size,), dtype=torch.long)
  1545. return max_frames, seek
  1546. def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, num_beams, device):
  1547. if generation_config.return_timestamps is True:
  1548. timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
  1549. logits_processor = (
  1550. [timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor
  1551. )
  1552. if generation_config.suppress_tokens is not None:
  1553. suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
  1554. logits_processor = (
  1555. [suppress_tokens_processor]
  1556. if logits_processor is None
  1557. else [suppress_tokens_processor] + logits_processor
  1558. )
  1559. generation_config.suppress_tokens = None
  1560. if generation_config.begin_suppress_tokens is not None:
  1561. begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
  1562. generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
  1563. )
  1564. logits_processor = (
  1565. [begin_suppress_processor]
  1566. if logits_processor is None
  1567. else [begin_suppress_processor] + logits_processor
  1568. )
  1569. generation_config.begin_suppress_tokens = None
  1570. if generation_config.no_speech_threshold is not None:
  1571. no_speech_detector = WhisperNoSpeechDetection(
  1572. no_speech_token=generation_config.no_timestamps_token_id - 1,
  1573. begin_index=begin_index,
  1574. scores_is_logprobs=num_beams > 1,
  1575. )
  1576. logits_processor = (
  1577. [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor
  1578. )
  1579. no_speech_detector.set_model(self)
  1580. return logits_processor
  1581. @staticmethod
  1582. def _maybe_reduce_batch(input_features, seek, max_frames, cur_bsz, batch_idx_map):
  1583. prev_bsz = cur_bsz
  1584. new_batch_idx_map = []
  1585. for i in range(prev_bsz):
  1586. prev_i = batch_idx_map[i]
  1587. if seek[prev_i] >= max_frames[prev_i]:
  1588. cut_index = i + (cur_bsz - prev_bsz)
  1589. cur_bsz -= 1
  1590. input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0)
  1591. else:
  1592. # cut out index that goes away
  1593. new_batch_idx_map.append(prev_i)
  1594. return input_features, cur_bsz, new_batch_idx_map
  1595. @staticmethod
  1596. def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames, cur_bsz, batch_idx_map):
  1597. if input_features is None:
  1598. return None
  1599. segment_input = []
  1600. for i in range(cur_bsz):
  1601. prev_i = batch_idx_map[i]
  1602. segment_input_slice = input_features[i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i]]
  1603. if segment_input_slice.shape[-1] < num_segment_frames:
  1604. # pad to 3000 if necessary
  1605. segment_input_slice = F.pad(
  1606. segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1])
  1607. )
  1608. segment_input.append(segment_input_slice)
  1609. segment_input = torch.cat(segment_input, dim=0)
  1610. return segment_input
  1611. @staticmethod
  1612. def _prepare_decoder_input_ids(
  1613. cur_bsz,
  1614. init_tokens,
  1615. current_segments,
  1616. batch_idx_map,
  1617. do_condition_on_prev_tokens,
  1618. prompt_ids,
  1619. generation_config,
  1620. config,
  1621. device,
  1622. suppress_tokens,
  1623. timestamp_begin,
  1624. kwargs,
  1625. ):
  1626. if "decoder_input_ids" in kwargs:
  1627. decoder_input_ids = kwargs.pop("decoder_input_ids")
  1628. return decoder_input_ids, kwargs
  1629. cut_off_length = config.max_target_positions // 2 - 1
  1630. decoder_input_ids = init_tokens[batch_idx_map]
  1631. prev_start_of_text = getattr(generation_config, "prev_sot_token_id", None)
  1632. if prev_start_of_text is None:
  1633. if suppress_tokens is not None and len(suppress_tokens) >= 2:
  1634. prev_start_of_text = suppress_tokens[-2]
  1635. else:
  1636. prev_start_of_text = None
  1637. if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0:
  1638. # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
  1639. active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
  1640. if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
  1641. prev_ids = prompt_ids
  1642. else:
  1643. one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long)
  1644. prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None
  1645. padding = "max_length" if generation_config.cache_implementation == "static" else "longest"
  1646. prev_tokens = _pad_to_max_length(
  1647. active_segments,
  1648. generation_config.pad_token_id,
  1649. device=device,
  1650. padding_side="left",
  1651. padding=padding,
  1652. bos_token_tensor=prev_ids,
  1653. cut_off_length=cut_off_length,
  1654. skip_ending_double_timestamps=True,
  1655. timestamp_begin=timestamp_begin,
  1656. )
  1657. decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
  1658. kwargs["decoder_attention_mask"] = decoder_input_ids != generation_config.pad_token_id
  1659. elif prompt_ids is not None:
  1660. prev_tokens = prompt_ids[None].repeat(decoder_input_ids.shape[0], 1)
  1661. decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
  1662. # make sure `"decoder_attention_mask"` is not passed to forward
  1663. kwargs.pop("decoder_attention_mask", None)
  1664. else:
  1665. # make sure `"decoder_attention_mask"` is not passed to forward
  1666. kwargs.pop("decoder_attention_mask", None)
  1667. return decoder_input_ids, kwargs
  1668. def _set_max_new_tokens_and_length(self, config, decoder_input_ids, generation_config):
  1669. max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
  1670. if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
  1671. raise ValueError(
  1672. f"The length of `decoder_input_ids`, including special start tokens, prompt tokens, and previous tokens, is {decoder_input_ids.shape[-1]}, "
  1673. f" and `max_new_tokens` is {max_new_tokens}. Thus, the combined length of "
  1674. f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
  1675. f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
  1676. "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
  1677. f"so that their combined length is less than {self.config.max_target_positions}."
  1678. )
  1679. num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1)
  1680. # Make sure we don't get larger than `max_length`
  1681. if generation_config.max_length is not None and generation_config.max_new_tokens is None:
  1682. max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions)
  1683. logger.info(
  1684. f"Increase max_length from {generation_config.max_length} to {max_length} since input is conditioned on previous segment."
  1685. )
  1686. elif (
  1687. generation_config.max_new_tokens is not None
  1688. and generation_config.max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
  1689. ):
  1690. max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
  1691. generation_config.max_new_tokens = max_new_tokens
  1692. @staticmethod
  1693. def _retrieve_compression_ratio(tokens, vocab_size):
  1694. """Compute byte length of zlib compressed token bytes vs. byte length of raw token bytes"""
  1695. length = int(math.log2(vocab_size) / 8) + 1
  1696. token_bytes = b"".join([t.to_bytes(length, "little") for t in tokens.tolist()])
  1697. compression_ratio = len(token_bytes) / len(zlib.compress(token_bytes))
  1698. return compression_ratio
  1699. @staticmethod
  1700. def _retrieve_avg_logprobs(scores, tokens, temperature):
  1701. rescale_temperature = temperature if temperature > 0.0 else 1
  1702. scores = torch.stack(scores).to(tokens.device)
  1703. if scores.shape[0] > tokens.shape[0]:
  1704. scores = scores[: tokens.shape[0]]
  1705. else:
  1706. tokens = tokens[-scores.shape[0] :]
  1707. logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype)
  1708. # retrieve logprob of selected tokens and sum
  1709. # don't remove the eos token logprob! it counts in avg_logprob calculation in the original implementation
  1710. sum_logprobs = sum(logprobs[i][tokens[i]] for i in range(logprobs.shape[0]))
  1711. avg_logprobs = sum_logprobs / len(tokens)
  1712. return avg_logprobs
  1713. @staticmethod
  1714. def _retrieve_segment(
  1715. seek_sequence,
  1716. seek_outputs,
  1717. time_offset,
  1718. timestamp_begin,
  1719. seek_num_frames,
  1720. time_precision,
  1721. time_precision_features,
  1722. input_stride,
  1723. prev_idx,
  1724. idx,
  1725. return_token_timestamps,
  1726. decoder_input_ids,
  1727. ):
  1728. # find the predicted "end of segment" predictions of Whisper
  1729. # "end of segment" predictions occur whenever Whisper predicts a timestamp token
  1730. timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
  1731. single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
  1732. timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
  1733. timestamp_segment_indices.add_(1)
  1734. token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
  1735. idx_offset = decoder_input_ids.shape[-1]
  1736. device = seek_sequence.device
  1737. # If whisper predicted a "end of segment" via a timestep token, let's go ever each
  1738. # "end of segment" prediction and slice the decoding into segments accordingly
  1739. if len(timestamp_segment_indices) > 0:
  1740. # if the output contains two consecutive timestamp tokens
  1741. slices = timestamp_segment_indices.tolist()
  1742. segments = []
  1743. if single_timestamp_ending:
  1744. slices.append(len(seek_sequence))
  1745. else:
  1746. # we want to include the last timestamp token in the last segment to know it was no single ending
  1747. slices[-1] += 1
  1748. last_slice = 0
  1749. # Add each segment to list of all segments
  1750. for i, current_slice in enumerate(slices):
  1751. is_last_slice = i == len(slices) - 1
  1752. sliced_tokens = seek_sequence[last_slice:current_slice]
  1753. start_timestamp_pos = sliced_tokens[0] - timestamp_begin
  1754. idx_sliced_tokens = -1 if not is_last_slice or single_timestamp_ending else -2
  1755. end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin
  1756. segments.append(
  1757. {
  1758. "start": time_offset[prev_idx]
  1759. + start_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
  1760. * time_precision,
  1761. "end": time_offset[prev_idx]
  1762. + end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
  1763. * time_precision,
  1764. "tokens": sliced_tokens,
  1765. "idxs": (idx_offset + last_slice, idx_offset + current_slice),
  1766. "result": seek_outputs[idx],
  1767. }
  1768. )
  1769. if return_token_timestamps:
  1770. segments[-1]["token_timestamps"] = (
  1771. token_timestamps[idx_offset + last_slice : idx_offset + current_slice] + time_offset[prev_idx]
  1772. )
  1773. last_slice = current_slice
  1774. if single_timestamp_ending:
  1775. # single timestamp at the end means no speech after the last timestamp.
  1776. segment_offset = seek_num_frames[prev_idx]
  1777. else:
  1778. # otherwise, ignore the unfinished segment and seek to the last timestamp
  1779. # here we throw away all predictions after the last predicted "end of segment"
  1780. # since we are cutting right in the middle of an audio
  1781. last_timestamp_pos = seek_sequence[last_slice - 2].item() - timestamp_begin
  1782. segment_offset = last_timestamp_pos * input_stride
  1783. else:
  1784. # If whisper does not predict any "end of segment" token, then
  1785. # the whole decoding is considered a segment and we add it to the list of segments
  1786. timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
  1787. last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision)
  1788. if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin:
  1789. # no consecutive timestamps but it has a timestamp; use the last one.
  1790. last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(
  1791. torch.float32 if device.type == "mps" else torch.float64
  1792. )
  1793. segments = [
  1794. {
  1795. "start": time_offset[prev_idx],
  1796. "end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
  1797. "tokens": seek_sequence,
  1798. "idxs": (idx_offset, idx_offset + len(seek_sequence)),
  1799. "result": seek_outputs[idx],
  1800. }
  1801. ]
  1802. if return_token_timestamps:
  1803. segments[-1]["token_timestamps"] = (
  1804. token_timestamps[idx_offset : idx_offset + len(seek_sequence)] + time_offset[prev_idx]
  1805. )
  1806. segment_offset = seek_num_frames[prev_idx]
  1807. return segments, segment_offset