tokenization_wav2vec2.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. # Copyright 2021 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tokenization class for Wav2Vec2."""
  15. import json
  16. import os
  17. from dataclasses import dataclass
  18. from itertools import groupby
  19. from typing import TYPE_CHECKING, Union
  20. import numpy as np
  21. from ...tokenization_python import PreTrainedTokenizer
  22. from ...tokenization_utils_base import AddedToken
  23. from ...utils import (
  24. ModelOutput,
  25. logging,
  26. to_py_obj,
  27. )
  28. logger = logging.get_logger(__name__)
  29. if TYPE_CHECKING:
  30. import torch
  31. VOCAB_FILES_NAMES = {
  32. "vocab_file": "vocab.json",
  33. "tokenizer_config_file": "tokenizer_config.json",
  34. }
  35. # Wav2Vec2 has no max input length
  36. WAV2VEC2_KWARGS_DOCSTRING = r"""
  37. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
  38. Activates and controls padding. Accepts the following values:
  39. - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
  40. sequence if provided).
  41. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  42. acceptable input length for the model if that argument is not provided.
  43. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
  44. lengths).
  45. max_length (`int`, *optional*):
  46. Controls the maximum length to use by one of the truncation/padding parameters.
  47. If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
  48. is required by one of the truncation/padding parameters. If the model has no specific maximum input
  49. length (like XLNet) truncation/padding to a maximum length will be deactivated.
  50. pad_to_multiple_of (`int`, *optional*):
  51. If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
  52. the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
  53. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  54. If set, will return tensors instead of list of python integers. Acceptable values are:
  55. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  56. - `'np'`: Return Numpy `np.ndarray` objects.
  57. verbose (`bool`, *optional*, defaults to `True`):
  58. Whether or not to print more information and warnings.
  59. """
  60. ListOfDict = list[dict[str, int | str]]
  61. @dataclass
  62. class Wav2Vec2CTCTokenizerOutput(ModelOutput):
  63. """
  64. Output type of [` Wav2Vec2CTCTokenizer`], with transcription.
  65. Args:
  66. text (list of `str` or `str`):
  67. Decoded logits in text from. Usually the speech transcription.
  68. char_offsets (list of `list[dict[str, Union[int, str]]]` or `list[dict[str, Union[int, str]]]`):
  69. Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
  70. offsets can be used to compute time stamps for each character. Total logit score of the beam associated with
  71. produced text.
  72. word_offsets (list of `list[dict[str, Union[int, str]]]` or `list[dict[str, Union[int, str]]]`):
  73. Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets
  74. can be used to compute time stamps for each word.
  75. """
  76. text: list[str] | str
  77. char_offsets: list[ListOfDict] | ListOfDict = None
  78. word_offsets: list[ListOfDict] | ListOfDict = None
  79. class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
  80. """
  81. Constructs a Wav2Vec2CTC tokenizer.
  82. This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to
  83. the superclass for more information regarding such methods.
  84. Args:
  85. vocab_file (`str`):
  86. File containing the vocabulary.
  87. bos_token (`str`, *optional*, defaults to `"<s>"`):
  88. The beginning of sentence token.
  89. eos_token (`str`, *optional*, defaults to `"</s>"`):
  90. The end of sentence token.
  91. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  92. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  93. token instead.
  94. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  95. The token used for padding, for example when batching sequences of different lengths.
  96. word_delimiter_token (`str`, *optional*, defaults to `"|"`):
  97. The token used for defining the end of a word.
  98. do_lower_case (`bool`, *optional*, defaults to `False`):
  99. Whether or not to accept lowercase input and lowercase the output when decoding.
  100. target_lang (`str`, *optional*):
  101. A target language the tokenizer should set by default. `target_lang` has to be defined for multi-lingual,
  102. nested vocabulary such as [facebook/mms-1b-all](https://huggingface.co/facebook/mms-1b-all).
  103. **kwargs
  104. Additional keyword arguments passed along to [`PreTrainedTokenizer`]
  105. """
  106. vocab_files_names = VOCAB_FILES_NAMES
  107. model_input_names = ["input_ids", "attention_mask"]
  108. def __init__(
  109. self,
  110. vocab_file,
  111. bos_token="<s>",
  112. eos_token="</s>",
  113. unk_token="<unk>",
  114. pad_token="<pad>",
  115. word_delimiter_token="|",
  116. replace_word_delimiter_char=" ",
  117. do_lower_case=False,
  118. target_lang=None,
  119. **kwargs,
  120. ):
  121. self._word_delimiter_token = word_delimiter_token
  122. self.do_lower_case = do_lower_case
  123. self.replace_word_delimiter_char = replace_word_delimiter_char
  124. self.target_lang = target_lang
  125. with open(vocab_file, encoding="utf-8") as vocab_handle:
  126. self.vocab = json.load(vocab_handle)
  127. # if target lang is defined vocab must be a nested dict
  128. # with each target lang being one vocabulary
  129. if target_lang is not None:
  130. self.encoder = self.vocab[target_lang]
  131. else:
  132. self.encoder = self.vocab
  133. self.decoder = {v: k for k, v in self.encoder.items()}
  134. super().__init__(
  135. unk_token=unk_token,
  136. bos_token=bos_token,
  137. eos_token=eos_token,
  138. pad_token=pad_token,
  139. do_lower_case=do_lower_case,
  140. word_delimiter_token=word_delimiter_token,
  141. replace_word_delimiter_char=replace_word_delimiter_char,
  142. target_lang=target_lang,
  143. special_tokens_pattern="none",
  144. **kwargs,
  145. )
  146. # make sure that tokens made of several
  147. # characters are not split at tokenization
  148. for token in self.encoder:
  149. if len(token) > 1:
  150. self.add_tokens(AddedToken(token, rstrip=True, lstrip=True, normalized=False))
  151. def set_target_lang(self, target_lang: str):
  152. """
  153. Set the target language of a nested multi-lingual dictionary
  154. """
  155. if self.vocab == self.encoder:
  156. raise ValueError(f"{self.vocab} is not a multi-lingual, nested tokenizer. Cannot set target language.")
  157. if target_lang not in self.vocab:
  158. raise ValueError(f"{target_lang} does not exist. Choose one of {', '.join(self.vocab.keys())}.")
  159. self.target_lang = target_lang
  160. self.init_kwargs["target_lang"] = target_lang
  161. self.encoder = self.vocab[target_lang]
  162. self.decoder = {v: k for k, v in self.encoder.items()}
  163. # Remove conflicting entries from _added_tokens_decoder so vocabulary tokens take precedence
  164. for token_id in list(self._added_tokens_decoder.keys()):
  165. if token_id in self.decoder:
  166. del self._added_tokens_decoder[token_id]
  167. # make sure that tokens made of several
  168. # characters are not split at tokenization
  169. for token in self.encoder:
  170. if len(token) > 1:
  171. self.add_tokens(AddedToken(token, rstrip=True, lstrip=True, normalized=False))
  172. @property
  173. def word_delimiter_token(self) -> str:
  174. """
  175. `str`: Word delimiter token. Log an error if used while not having been set.
  176. """
  177. if self._word_delimiter_token is None and self.verbose:
  178. logger.error("Using word_delimiter_token, but it is not set yet.")
  179. return None
  180. return str(self._word_delimiter_token)
  181. @property
  182. def word_delimiter_token_id(self) -> int | None:
  183. """
  184. `Optional[int]`: Id of the word_delimiter_token in the vocabulary. Returns `None` if the token has not been
  185. set.
  186. """
  187. if self._word_delimiter_token is None:
  188. return None
  189. return self.convert_tokens_to_ids(self.word_delimiter_token)
  190. @word_delimiter_token.setter
  191. def word_delimiter_token(self, value):
  192. self._word_delimiter_token = value
  193. @word_delimiter_token_id.setter
  194. def word_delimiter_token_id(self, value):
  195. self._word_delimiter_token = self.convert_tokens_to_ids(value)
  196. @property
  197. def vocab_size(self) -> int:
  198. return len(self.decoder)
  199. def get_vocab(self) -> dict:
  200. vocab = dict(self.encoder)
  201. vocab.update(self.added_tokens_encoder)
  202. return vocab
  203. def _add_tokens(self, new_tokens: list[str] | list[AddedToken], special_tokens: bool = False) -> int:
  204. # Overwritten to never strip!
  205. to_add = []
  206. for token in new_tokens:
  207. if isinstance(token, str):
  208. to_add.append(AddedToken(token, rstrip=False, lstrip=False, normalized=False))
  209. else:
  210. to_add.append(token)
  211. return super()._add_tokens(to_add, special_tokens)
  212. def _tokenize(self, text, **kwargs):
  213. """
  214. Converts a string into a sequence of tokens (string), using the tokenizer.
  215. """
  216. if self.do_lower_case:
  217. text = text.upper()
  218. return list(text.replace(" ", self.word_delimiter_token))
  219. def _convert_token_to_id(self, token: str) -> int:
  220. """Converts a token (str) in an index (integer) using the vocab."""
  221. return self.encoder.get(token, self.encoder.get(self.unk_token))
  222. def _convert_id_to_token(self, index: int) -> str:
  223. """Converts an index (integer) in a token (str) using the vocab."""
  224. result = self.decoder.get(index, self.unk_token)
  225. return result
  226. def convert_ids_to_tokens(self, ids: int | list[int], skip_special_tokens: bool = False) -> str | list[str]:
  227. """Overridden to prioritize vocabulary tokens over added tokens for nested vocabularies."""
  228. if isinstance(ids, int):
  229. if ids in self.decoder:
  230. return self.decoder[ids]
  231. return self._added_tokens_decoder[ids].content if ids in self._added_tokens_decoder else self.unk_token
  232. tokens = []
  233. for index in ids:
  234. index = int(index)
  235. if skip_special_tokens and index in self.all_special_ids:
  236. continue
  237. if index in self.decoder:
  238. tokens.append(self.decoder[index])
  239. elif index in self._added_tokens_decoder:
  240. tokens.append(self._added_tokens_decoder[index].content)
  241. else:
  242. tokens.append(self.unk_token)
  243. return tokens
  244. def convert_tokens_to_string(
  245. self,
  246. tokens: list[str],
  247. group_tokens: bool = True,
  248. spaces_between_special_tokens: bool = False,
  249. output_char_offsets: bool = False,
  250. output_word_offsets: bool = False,
  251. ) -> dict[str, str | float]:
  252. """
  253. Converts a connectionist-temporal-classification (CTC) output tokens into a single string.
  254. """
  255. if len(tokens) == 0:
  256. return {"text": "", "char_offsets": [], "word_offsets": []}
  257. # group same tokens into non-repeating tokens in CTC style decoding
  258. if group_tokens:
  259. chars, char_repetitions = zip(*((token, len(list(group_iter))) for token, group_iter in groupby(tokens)))
  260. else:
  261. chars = tokens
  262. char_repetitions = len(tokens) * [1]
  263. # filter self.pad_token which is used as CTC-blank token
  264. processed_chars = list(filter(lambda char: char != self.pad_token, chars))
  265. # replace delimiter token
  266. processed_chars = [
  267. self.replace_word_delimiter_char if char == self.word_delimiter_token else char for char in processed_chars
  268. ]
  269. # retrieve offsets
  270. char_offsets = word_offsets = None
  271. if output_char_offsets or output_word_offsets:
  272. char_offsets = self._compute_offsets(char_repetitions, chars, self.pad_token)
  273. if len(char_offsets) != len(processed_chars):
  274. raise ValueError(
  275. f"`char_offsets`: {char_offsets} and `processed_tokens`: {processed_chars}"
  276. " have to be of the same length, but are: "
  277. f"`len(offsets)`: {len(char_offsets)} and `len(processed_tokens)`:"
  278. f" {len(processed_chars)}"
  279. )
  280. # set tokens to correct processed token
  281. for i, char in enumerate(processed_chars):
  282. char_offsets[i]["char"] = char
  283. # retrieve word offsets from character offsets
  284. word_offsets = None
  285. if output_word_offsets:
  286. word_offsets = self._get_word_offsets(char_offsets, self.replace_word_delimiter_char)
  287. # don't output chars if not set to True
  288. if not output_char_offsets:
  289. char_offsets = None
  290. # join to string
  291. join_char = " " if spaces_between_special_tokens else ""
  292. string = join_char.join(processed_chars).strip()
  293. if self.do_lower_case:
  294. string = string.lower()
  295. return {"text": string, "char_offsets": char_offsets, "word_offsets": word_offsets}
  296. @staticmethod
  297. def _compute_offsets(char_repetitions: list[int], chars: list[str], ctc_token: int) -> list[dict[str, str | int]]:
  298. end_indices = np.asarray(char_repetitions).cumsum()
  299. start_indices = np.concatenate(([0], end_indices[:-1]))
  300. offsets = [
  301. {"char": t, "start_offset": s, "end_offset": e} for t, s, e in zip(chars, start_indices, end_indices)
  302. ]
  303. # filter out CTC token
  304. offsets = list(filter(lambda offsets: offsets["char"] != ctc_token, offsets))
  305. return offsets
  306. @staticmethod
  307. def _get_word_offsets(offsets: dict[str, str | float], word_delimiter_char: str = " ") -> dict[str, str | float]:
  308. word_offsets = []
  309. last_state = "SPACE"
  310. word = ""
  311. start_offset = 0
  312. end_offset = 0
  313. for i, offset in enumerate(offsets):
  314. char = offset["char"]
  315. state = "SPACE" if char == word_delimiter_char else "WORD"
  316. if state == last_state:
  317. # If we are in the same state as before, we simply repeat what we've done before
  318. end_offset = offset["end_offset"]
  319. word += char
  320. else:
  321. # Switching state
  322. if state == "SPACE":
  323. # Finishing a word
  324. word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
  325. else:
  326. # Starting a new word
  327. start_offset = offset["start_offset"]
  328. end_offset = offset["end_offset"]
  329. word = char
  330. last_state = state
  331. if last_state == "WORD":
  332. word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
  333. return word_offsets
  334. def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
  335. if is_split_into_words:
  336. text = " " + text
  337. return (text, kwargs)
  338. def _decode(
  339. self,
  340. token_ids: list[int],
  341. skip_special_tokens: bool = False,
  342. clean_up_tokenization_spaces: bool | None = None,
  343. group_tokens: bool = True,
  344. spaces_between_special_tokens: bool = False,
  345. output_word_offsets: bool | None = False,
  346. output_char_offsets: bool | None = False,
  347. ) -> str:
  348. """
  349. special _decode function is needed because added tokens should be treated exactly the
  350. same as tokens of the base vocabulary and therefore the function `convert_tokens_to_string` has to be called on
  351. the whole token list and not individually on added tokens
  352. """
  353. # Don't skip special tokens in convert_ids_to_tokens so we can handle word_delimiter_token specially
  354. filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=False)
  355. result = []
  356. for token in filtered_tokens:
  357. if skip_special_tokens and token in self.all_special_tokens and token != self.word_delimiter_token:
  358. continue
  359. result.append(token)
  360. string_output = self.convert_tokens_to_string(
  361. result,
  362. group_tokens=group_tokens,
  363. spaces_between_special_tokens=spaces_between_special_tokens,
  364. output_word_offsets=output_word_offsets,
  365. output_char_offsets=output_char_offsets,
  366. )
  367. text = string_output["text"]
  368. clean_up_tokenization_spaces = (
  369. clean_up_tokenization_spaces
  370. if clean_up_tokenization_spaces is not None
  371. else self.clean_up_tokenization_spaces
  372. )
  373. if clean_up_tokenization_spaces:
  374. text = self.clean_up_tokenization(text)
  375. if output_word_offsets or output_char_offsets:
  376. return Wav2Vec2CTCTokenizerOutput(
  377. text=text,
  378. char_offsets=string_output["char_offsets"],
  379. word_offsets=string_output["word_offsets"],
  380. )
  381. else:
  382. return text
  383. # overwritten from `tokenization_utils_base.py` because tokenizer can output
  384. # `ModelOutput` which should not be a list for batched output and
  385. # because we need docs for `output_char_offsets` here
  386. def batch_decode(
  387. self,
  388. sequences: Union[list[int], list[list[int]], np.ndarray, "torch.Tensor"],
  389. skip_special_tokens: bool = False,
  390. clean_up_tokenization_spaces: bool | None = None,
  391. output_char_offsets: bool = False,
  392. output_word_offsets: bool = False,
  393. **kwargs,
  394. ) -> list[str]:
  395. """
  396. Convert a list of lists of token ids into a list of strings by calling decode.
  397. Args:
  398. sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor]`):
  399. List of tokenized input ids. Can be obtained using the `__call__` method.
  400. skip_special_tokens (`bool`, *optional*, defaults to `False`):
  401. Whether or not to remove special tokens in the decoding.
  402. clean_up_tokenization_spaces (`bool`, *optional*):
  403. Whether or not to clean up the tokenization spaces.
  404. output_char_offsets (`bool`, *optional*, defaults to `False`):
  405. Whether or not to output character offsets. Character offsets can be used in combination with the
  406. sampling rate and model downsampling rate to compute the time-stamps of transcribed characters.
  407. <Tip>
  408. Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make
  409. use of `output_char_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched
  410. output.
  411. </Tip>
  412. output_word_offsets (`bool`, *optional*, defaults to `False`):
  413. Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
  414. and model downsampling rate to compute the time-stamps of transcribed words.
  415. <Tip>
  416. Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make
  417. use of `output_word_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched
  418. output.
  419. </Tip>
  420. kwargs (additional keyword arguments, *optional*):
  421. Will be passed to the underlying model specific decode method.
  422. Returns:
  423. `list[str]` or [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`]: The list of decoded
  424. sentences. Will be a [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`] when
  425. `output_char_offsets == True` or `output_word_offsets == True`.
  426. """
  427. batch_decoded = [
  428. self.decode(
  429. seq,
  430. skip_special_tokens=skip_special_tokens,
  431. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  432. output_char_offsets=output_char_offsets,
  433. output_word_offsets=output_word_offsets,
  434. **kwargs,
  435. )
  436. for seq in sequences
  437. ]
  438. if output_char_offsets or output_word_offsets:
  439. # transform list of dicts to dict of lists
  440. return Wav2Vec2CTCTokenizerOutput({k: [d[k] for d in batch_decoded] for k in batch_decoded[0]})
  441. return batch_decoded
  442. # overwritten from `tokenization_utils_base.py` because we need docs for `output_char_offsets`
  443. # and `output_word_offsets` here
  444. def decode(
  445. self,
  446. token_ids: Union[int, list[int], np.ndarray, "torch.Tensor"],
  447. skip_special_tokens: bool = False,
  448. clean_up_tokenization_spaces: bool | None = None,
  449. output_char_offsets: bool = False,
  450. output_word_offsets: bool = False,
  451. **kwargs,
  452. ) -> str:
  453. """
  454. Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
  455. tokens and clean up tokenization spaces.
  456. Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
  457. Args:
  458. token_ids (`Union[int, list[int], np.ndarray, torch.Tensor]`):
  459. List of tokenized input ids. Can be obtained using the `__call__` method.
  460. skip_special_tokens (`bool`, *optional*, defaults to `False`):
  461. Whether or not to remove special tokens in the decoding.
  462. clean_up_tokenization_spaces (`bool`, *optional*):
  463. Whether or not to clean up the tokenization spaces.
  464. output_char_offsets (`bool`, *optional*, defaults to `False`):
  465. Whether or not to output character offsets. Character offsets can be used in combination with the
  466. sampling rate and model downsampling rate to compute the time-stamps of transcribed characters.
  467. <Tip>
  468. Please take a look at the example below to better understand how to make use of `output_char_offsets`.
  469. </Tip>
  470. output_word_offsets (`bool`, *optional*, defaults to `False`):
  471. Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
  472. and model downsampling rate to compute the time-stamps of transcribed words.
  473. <Tip>
  474. Please take a look at the example below to better understand how to make use of `output_word_offsets`.
  475. </Tip>
  476. kwargs (additional keyword arguments, *optional*):
  477. Will be passed to the underlying model specific decode method.
  478. Returns:
  479. `str` or [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`]: The list of decoded
  480. sentences. Will be a [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`] when
  481. `output_char_offsets == True` or `output_word_offsets == True`.
  482. Example:
  483. ```python
  484. >>> # Let's see how to retrieve time steps for a model
  485. >>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
  486. >>> from datasets import load_dataset
  487. >>> import datasets
  488. >>> import torch
  489. >>> # import model, feature extractor, tokenizer
  490. >>> model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
  491. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
  492. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
  493. >>> # load first sample of English common_voice
  494. >>> dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", split="train", streaming=True)
  495. >>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
  496. >>> dataset_iter = iter(dataset)
  497. >>> sample = next(dataset_iter)
  498. >>> # forward sample through model to get greedily predicted transcription ids
  499. >>> input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
  500. >>> logits = model(input_values).logits[0]
  501. >>> pred_ids = torch.argmax(logits, axis=-1)
  502. >>> # retrieve word stamps (analogous commands for `output_char_offsets`)
  503. >>> outputs = tokenizer.decode(pred_ids, output_word_offsets=True)
  504. >>> # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate
  505. >>> time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate
  506. >>> word_offsets = [
  507. ... {
  508. ... "word": d["word"],
  509. ... "start_time": round(d["start_offset"] * time_offset, 2),
  510. ... "end_time": round(d["end_offset"] * time_offset, 2),
  511. ... }
  512. ... for d in outputs.word_offsets
  513. ... ]
  514. >>> # compare word offsets with audio `en_train_0/common_voice_en_19121553.mp3` online on the dataset viewer:
  515. >>> # https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0/viewer/en
  516. >>> word_offsets[:3]
  517. [{'word': 'THE', 'start_time': 0.7, 'end_time': 0.78}, {'word': 'TRICK', 'start_time': 0.88, 'end_time': 1.08}, {'word': 'APPEARS', 'start_time': 1.2, 'end_time': 1.64}]
  518. ```"""
  519. # Convert inputs to python lists
  520. token_ids = to_py_obj(token_ids)
  521. return self._decode(
  522. token_ids=token_ids,
  523. skip_special_tokens=skip_special_tokens,
  524. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  525. output_char_offsets=output_char_offsets,
  526. output_word_offsets=output_word_offsets,
  527. **kwargs,
  528. )
  529. def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
  530. if not os.path.isdir(save_directory):
  531. logger.error(f"Vocabulary path ({save_directory}) should be a directory")
  532. return
  533. vocab_file = os.path.join(
  534. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
  535. )
  536. with open(vocab_file, "w", encoding="utf-8") as f:
  537. f.write(json.dumps(self.vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
  538. return (vocab_file,)
  539. __all__ = ["Wav2Vec2CTCTokenizer"]