tokenization_utils_sentencepiece.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # Copyright 2020 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. SentencePiece-based tokenization class for loading from sentencepiece.model files.
  16. """
  17. import os
  18. from shutil import copyfile
  19. try:
  20. import sentencepiece as spm
  21. except ImportError:
  22. spm = None
  23. from .convert_slow_tokenizer import import_protobuf
  24. from .tokenization_python import PreTrainedTokenizer
  25. from .tokenization_utils_base import (
  26. INIT_TOKENIZER_DOCSTRING,
  27. AddedToken,
  28. generate_merges,
  29. )
  30. from .utils import add_end_docstrings, logging, requires_backends
  31. logger = logging.get_logger(__name__)
  32. VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
  33. SPIECE_UNDERLINE = "▁"
  34. @add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
  35. class SentencePieceBackend(PreTrainedTokenizer):
  36. """
  37. Base class for SentencePiece-based tokenizers that load from sentencepiece.model files.
  38. Inherits from [`~tokenization_utils.PreTrainedTokenizer`].
  39. Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading
  40. pretrained tokenizers as well as adding tokens to the vocabulary.
  41. This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the
  42. specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
  43. """
  44. vocab_files_names = VOCAB_FILES_NAMES
  45. def __init__(self, **kwargs):
  46. # Ensure optional dependency is available before loading
  47. requires_backends(self, "sentencepiece")
  48. # Extract sentencepiece-specific parameters
  49. self.vocab_file = kwargs.get("vocab_file")
  50. self.legacy = kwargs.get("legacy", True)
  51. self.sp_model_kwargs = kwargs.pop("sp_model_kwargs", {})
  52. # Set backend to "sentencepiece" if not already set
  53. if "backend" not in kwargs:
  54. kwargs["backend"] = "sentencepiece"
  55. # Load the SentencePiece model before calling parent __init__
  56. # This is needed because parent __init__ may call methods that depend on sp_model
  57. tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
  58. tokenizer.Load(self.vocab_file)
  59. if not self.legacy:
  60. model_pb2 = import_protobuf()
  61. proto = model_pb2.ModelProto.FromString(tokenizer.serialized_model_proto())
  62. if proto.normalizer_spec.add_dummy_prefix:
  63. proto.normalizer_spec.add_dummy_prefix = False
  64. tokenizer.LoadFromSerializedProto(proto.SerializeToString())
  65. self.sp_model = tokenizer
  66. # Initialize total_vocab_size before parent __init__ (which may call _add_tokens -> len(self))
  67. self.total_vocab_size = self.sp_model.get_piece_size()
  68. # Add sp_model_kwargs back to kwargs so it gets stored in init_kwargs
  69. kwargs["sp_model_kwargs"] = self.sp_model_kwargs
  70. # Call parent class __init__ (PreTrainedTokenizer)
  71. # This handles tokens_trie, _added_tokens_decoder, _added_tokens_encoder,
  72. # token_type_ids_pattern, special_tokens_pattern, and adds special tokens
  73. super().__init__(**kwargs)
  74. self._update_trie()
  75. @property
  76. def vocab_size(self) -> int:
  77. """Returns vocab size"""
  78. return self.sp_model.get_piece_size()
  79. def get_vocab(self):
  80. """Returns vocab as a dict"""
  81. vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
  82. vocab.update(self.added_tokens_encoder)
  83. return vocab
  84. def _add_tokens(self, new_tokens: list[str] | list[AddedToken], special_tokens: bool = False) -> int:
  85. """
  86. Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
  87. it with indices starting from length of the current vocabulary. Special tokens are sometimes already in the
  88. vocab which is why they have to be handled specifically.
  89. Args:
  90. new_tokens (`list[str]`or `list[tokenizers.AddedToken]`):
  91. Token(s) to add in vocabulary. A token is counted as added if it's not already in the vocabulary
  92. (tested by checking if the tokenizer assign the index of the `unk_token` to them). If a token is part
  93. of the vocabulary then we simply mark this token as an `AddedToken` which allows to control the
  94. stripping and normalization of this token. This is NOT possible in `tokenizers`.
  95. special_tokens (`bool`, *optional*, defaults to `False`):
  96. Whether or not the tokens should be added as special tokens.
  97. Returns:
  98. `int`: The number of tokens actually added to the vocabulary.
  99. Examples:
  100. ```python
  101. # Let's see how to increase the vocabulary of Bert model and tokenizer
  102. tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
  103. model = BertModel.from_pretrained("google-bert/bert-base-uncased")
  104. num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
  105. print("We have added", num_added_toks, "tokens")
  106. # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
  107. model.resize_token_embeddings(len(tokenizer))
  108. ```"""
  109. if not new_tokens:
  110. return 0
  111. next_index = len(self) # total size (base + added)
  112. num_added = 0
  113. for token in new_tokens:
  114. if not isinstance(token, (str, AddedToken)):
  115. raise TypeError(f"Token {token} is not a string but a {type(token)}.")
  116. if str(token) == "":
  117. continue
  118. if isinstance(token, str):
  119. if token in self._added_tokens_encoder:
  120. continue
  121. is_special = token in self.all_special_tokens or special_tokens
  122. token = AddedToken(token, rstrip=False, lstrip=False, normalized=not is_special, special=is_special)
  123. elif special_tokens:
  124. # doing token.special=True changes the normalization! will fix in rust
  125. # this is important and the only reason why the AddedTokens in each class are normalized by default
  126. token.__setstate__({"special": True, "normalized": token.normalized})
  127. if token in self._added_tokens_decoder.values():
  128. continue
  129. if not token.special and token.normalized and getattr(self, "do_lower_case", False):
  130. token.content = token.content.lower()
  131. # Check if token already exists in the SentencePiece base vocab
  132. tok_id = self.sp_model.piece_to_id(token.content)
  133. in_base_vocab = (
  134. tok_id < self.sp_model.get_piece_size() and self.sp_model.IdToPiece(tok_id) == token.content
  135. )
  136. if in_base_vocab:
  137. token_index = tok_id
  138. else:
  139. token_index = next_index
  140. next_index += 1
  141. num_added += 1
  142. if token.special and str(token) not in self.all_special_tokens:
  143. self._extra_special_tokens.append(token)
  144. # the setter automatically updates the reverse map
  145. self._added_tokens_decoder[token_index] = token
  146. self._added_tokens_encoder[token.content] = token_index
  147. if self.verbose:
  148. logger.info(f"Adding {token} to the vocabulary")
  149. self._update_trie()
  150. self._update_total_vocab_size()
  151. return num_added
  152. def _update_trie(self, unique_no_split_tokens: list[str] | None = None):
  153. # Add all added tokens
  154. for token in self._added_tokens_decoder.values():
  155. if token.content not in self.tokens_trie._tokens:
  156. self.tokens_trie.add(token.content)
  157. # Also add all special tokens (even if they're in base vocab) so they get split during tokenization
  158. for token in self.all_special_tokens:
  159. if token not in self.tokens_trie._tokens:
  160. self.tokens_trie.add(token)
  161. # Add any additional no-split tokens
  162. for token in unique_no_split_tokens or []:
  163. if token not in self.tokens_trie._tokens:
  164. self.tokens_trie.add(token)
  165. def _tokenize(self, text, **kwargs):
  166. """
  167. Returns a tokenized string.
  168. We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
  169. SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
  170. `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
  171. `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
  172. `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
  173. """
  174. if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
  175. return self.sp_model.encode(text, out_type=str)
  176. # 1. Encode string + prefix ex: "<unk> Hey"
  177. tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
  178. # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
  179. unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
  180. return tokens[unk_token_length:] if len(tokens) >= unk_token_length else tokens
  181. def _convert_token_to_id(self, token):
  182. """Converts a token (str) to an id using the vocab."""
  183. return self.sp_model.piece_to_id(token)
  184. def _convert_id_to_token(self, index):
  185. """Converts an index (integer) in a token (str) using the vocab."""
  186. token = self.sp_model.IdToPiece(index)
  187. return token
  188. def convert_tokens_to_string(self, tokens: list[str]) -> str:
  189. """Converts a sequence of tokens (string) in a single string."""
  190. out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
  191. return out_string
  192. def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
  193. """
  194. Save the sentencepiece vocabulary (copy original file) to a directory.
  195. Args:
  196. save_directory (`str`):
  197. The directory in which to save the vocabulary.
  198. filename_prefix (`str`, *optional*):
  199. An optional prefix to add to the named of the saved files.
  200. Returns:
  201. `tuple(str)`: Paths to the files saved.
  202. """
  203. if not os.path.isdir(save_directory):
  204. logger.error(f"Vocabulary path ({save_directory}) should be a directory")
  205. return
  206. out_vocab_file = os.path.join(
  207. save_directory, (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
  208. )
  209. if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
  210. copyfile(self.vocab_file, out_vocab_file)
  211. elif not os.path.isfile(self.vocab_file):
  212. with open(out_vocab_file, "wb") as fi:
  213. content_spiece_model = self.sp_model.serialized_model_proto()
  214. fi.write(content_spiece_model)
  215. return (out_vocab_file,)
  216. def _decode(
  217. self,
  218. token_ids: int | list[int],
  219. skip_special_tokens: bool = False,
  220. clean_up_tokenization_spaces: bool | None = None,
  221. spaces_between_special_tokens: bool = False,
  222. **kwargs,
  223. ) -> str:
  224. """
  225. Decode token ids to string.
  226. Uses the generic decode path from PreTrainedTokenizer which works for all vocabularies,
  227. including custom vocabularies that override _convert_id_to_token.
  228. """
  229. # Use parent class's generic decode method - it's simpler and works for all cases
  230. return super()._decode(
  231. token_ids=token_ids,
  232. skip_special_tokens=skip_special_tokens,
  233. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  234. **kwargs,
  235. )
  236. class SentencePieceExtractor:
  237. """
  238. Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
  239. """
  240. def __init__(self, model: str):
  241. requires_backends(self, "sentencepiece")
  242. from sentencepiece import SentencePieceProcessor
  243. self.sp = SentencePieceProcessor()
  244. self.sp.Load(model)
  245. def extract(self, vocab_scores=None) -> tuple[dict[str, int], list[tuple[str, float]], list[tuple]]:
  246. """
  247. By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
  248. order the merges with respect to the piece scores instead.
  249. """
  250. sp = self.sp
  251. vocab_ids = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
  252. vocab_scores_dict = {sp.id_to_piece(i): sp.get_score(i) for i in range(sp.GetPieceSize())}
  253. merges = generate_merges(vocab_ids, vocab_scores_dict)
  254. vocab_scores_list = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.GetPieceSize())]
  255. return vocab_ids, vocab_scores_list, merges