tokenization_lasr.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/lasr/modular_lasr.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_lasr.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import itertools
  21. import re
  22. from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
  23. from tokenizers.models import Unigram
  24. from ...tokenization_utils_tokenizers import TokenizersBackend
  25. VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
  26. class LasrTokenizer(TokenizersBackend):
  27. """
  28. Construct a LASR tokenizer (backed by HuggingFace's *tokenizers* library). Based on
  29. [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models).
  30. This tokenizer inherits from [`TokenizersBackend`] which contains most of the main methods. Users should
  31. refer to this superclass for more information regarding those methods.
  32. Args:
  33. vocab_file (`str`, *optional*):
  34. [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
  35. contains the vocabulary necessary to instantiate a tokenizer.
  36. eos_token (`str`, *optional*, defaults to `"</s>"`):
  37. The end of sequence token.
  38. <Tip>
  39. When building a sequence using special tokens, this is not the token that is used for the end of sequence.
  40. The token used is the `sep_token`.
  41. </Tip>
  42. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  43. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  44. token instead.
  45. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  46. The token used for padding, for example when batching sequences of different lengths.
  47. extra_ids (`int`, *optional*, defaults to 100):
  48. Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are accessible as
  49. "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be retrieved by
  50. calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids method
  51. additional_special_tokens (`list[str]`, *optional*):
  52. Additional special tokens used by the tokenizer.
  53. vocab (`str`, `dict` or `list`, *optional*):
  54. Custom vocabulary dict. If not provided, a minimal vocabulary is created using the special tokens.
  55. """
  56. vocab_files_names = VOCAB_FILES_NAMES
  57. model_input_names = ["input_ids", "attention_mask"]
  58. model = Unigram
  59. def __init__(
  60. self,
  61. eos_token="</s>",
  62. unk_token="<unk>",
  63. pad_token="<pad>",
  64. _spm_precompiled_charsmap=None,
  65. extra_ids=100,
  66. additional_special_tokens=None,
  67. vocab=None,
  68. vocab_file=None,
  69. **kwargs,
  70. ):
  71. self._extra_ids = extra_ids
  72. # Handle extra_ids and additional_special_tokens
  73. if additional_special_tokens is not None:
  74. extra_tokens = [x for x in additional_special_tokens if "<extra_id_" in str(x)]
  75. if len(extra_tokens) < 1:
  76. additional_special_tokens += [f"<extra_id_{i}>" for i in range(extra_ids)]
  77. elif extra_ids > 0 and extra_ids != len(extra_tokens):
  78. raise ValueError(
  79. f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
  80. " provided to LasrTokenizer. In this case the additional_special_tokens must include the extra_ids"
  81. " tokens"
  82. )
  83. else:
  84. extra_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
  85. additional_special_tokens = extra_tokens
  86. # LASR vocab structure: <pad>=0, </s>=1, <unk>=2, then regular vocab, then extra_ids in reverse
  87. if vocab is not None:
  88. self._vocab_scores = vocab
  89. else:
  90. self._vocab_scores = [
  91. (str(pad_token), 0.0),
  92. (str(eos_token), 0.0),
  93. (str(unk_token), 0.0),
  94. ("▁", -2.0), # Space token
  95. ]
  96. for i in range(extra_ids - 1, -1, -1):
  97. self._vocab_scores.append((f"<extra_id_{i}>", 0.0))
  98. self._tokenizer = Tokenizer(
  99. Unigram(
  100. self._vocab_scores,
  101. unk_id=3,
  102. byte_fallback=False,
  103. )
  104. )
  105. if _spm_precompiled_charsmap is not None:
  106. self._tokenizer.normalizer = normalizers.Precompiled(_spm_precompiled_charsmap)
  107. self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  108. [
  109. pre_tokenizers.WhitespaceSplit(),
  110. pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True),
  111. ]
  112. )
  113. self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
  114. super().__init__(
  115. eos_token=eos_token,
  116. unk_token=unk_token,
  117. pad_token=pad_token,
  118. extra_ids=extra_ids,
  119. additional_special_tokens=additional_special_tokens,
  120. **kwargs,
  121. )
  122. self._tokenizer.post_processor = processors.TemplateProcessing(
  123. single=["$A", "</s>"],
  124. pair=["$A", "</s>", "$B", "</s>"],
  125. special_tokens=[
  126. ("</s>", self.eos_token_id),
  127. ],
  128. )
  129. def get_sentinel_tokens(self):
  130. """Get the list of sentinel tokens (extra_id tokens) from additional_special_tokens."""
  131. return list(
  132. set(filter(lambda x: bool(re.search(r"<extra_id_\d+>", x)) is not None, self.additional_special_tokens))
  133. )
  134. def get_sentinel_token_ids(self):
  135. """Get the token IDs for sentinel tokens."""
  136. return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()]
  137. def _decode(
  138. self,
  139. token_ids: int | list[int],
  140. skip_special_tokens: bool = False,
  141. clean_up_tokenization_spaces: bool | None = None,
  142. group_tokens: bool = True,
  143. **kwargs,
  144. ) -> str:
  145. if isinstance(token_ids, int):
  146. token_ids = [token_ids]
  147. if group_tokens:
  148. token_ids = [token_group[0] for token_group in itertools.groupby(token_ids)]
  149. # for CTC we filter out the blank token, which is the pad token
  150. token_ids = [token for token in token_ids if token != self.pad_token_id]
  151. return super()._decode(
  152. token_ids=token_ids,
  153. skip_special_tokens=skip_special_tokens,
  154. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  155. **kwargs,
  156. )
  157. __all__ = ["LasrTokenizer"]