tokenization_utils_tokenizers.py 64 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388
  1. # Copyright 2020 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Tokenization classes for fast tokenizers (provided by HuggingFace's tokenizers library). For slow (python) tokenizers
  16. see tokenization_utils.py
  17. """
  18. import copy
  19. import json
  20. import os
  21. from collections import defaultdict
  22. from collections.abc import Iterable
  23. from shutil import copyfile
  24. from typing import Any
  25. import tokenizers.pre_tokenizers as pre_tokenizers_fast
  26. from huggingface_hub import is_offline_mode
  27. from tokenizers import AddedToken, processors
  28. from tokenizers import Encoding as EncodingFast
  29. from tokenizers import Tokenizer as TokenizerFast
  30. from tokenizers.decoders import Decoder as DecoderFast
  31. from tokenizers.models import BPE, Unigram
  32. from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
  33. from transformers.utils.hub import cached_file
  34. from .convert_slow_tokenizer import SpmConverter
  35. from .integrations.ggml import convert_gguf_tokenizer
  36. from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
  37. from .tokenization_utils_base import (
  38. INIT_TOKENIZER_DOCSTRING,
  39. BatchEncoding,
  40. PreTokenizedInput,
  41. PreTrainedTokenizerBase,
  42. TextInput,
  43. TruncationStrategy,
  44. generate_merges,
  45. )
  46. from .utils import PaddingStrategy, add_end_docstrings, logging
  47. logger = logging.get_logger(__name__)
  48. # Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
  49. TOKENIZER_FILE = "tokenizer.json"
  50. SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
  51. TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
  52. TIKTOKEN_VOCAB_FILE = "tokenizer.model"
  53. # Slow tokenizers have an additional added tokens files
  54. ADDED_TOKENS_FILE = "added_tokens.json"
  55. INIT_TOKENIZER_DOCSTRING += """
  56. tokenizer_object ([`tokenizers.Tokenizer`]):
  57. A [`tokenizers.Tokenizer`] object from 🤗 tokenizers to instantiate from. See [Using tokenizers from 🤗
  58. tokenizers](../fast_tokenizers) for more information.
  59. tokenizer_file ([`str`]):
  60. A path to a local JSON file representing a previously serialized [`tokenizers.Tokenizer`] object from 🤗
  61. tokenizers.
  62. """
  63. MODEL_TO_TRAINER_MAPPING = {
  64. "BPE": BpeTrainer,
  65. "Unigram": UnigramTrainer,
  66. "WordLevel": WordLevelTrainer,
  67. "WordPiece": WordPieceTrainer,
  68. }
  69. VOCAB_FILES_NAMES = {"tokenizer_file": TOKENIZER_FILE, "vocab_file": TIKTOKEN_VOCAB_FILE}
  70. @add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
  71. class TokenizersBackend(PreTrainedTokenizerBase):
  72. """
  73. Base class for all fast tokenizers (wrapping HuggingFace tokenizers library).
  74. Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`].
  75. Handles all the shared methods for tokenization and special tokens, as well as methods for
  76. downloading/caching/loading pretrained tokenizers, as well as adding tokens to the vocabulary.
  77. This class also contains the added tokens in a unified way on top of all tokenizers so we don't have to handle the
  78. specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
  79. """
  80. vocab_files_names = VOCAB_FILES_NAMES
  81. model = None
  82. _tokenizer = None
  83. @classmethod
  84. def convert_to_native_format(cls, trust_remote_code=False, **kwargs):
  85. """s
  86. Build a `tokenizers.Tokenizer` backend from the available serialization files (tokenizer.json, sentencepiece
  87. models, tekken.json, vocab/merges).
  88. """
  89. # Preserve kwargs for possible downstream use
  90. local_kwargs = dict(kwargs)
  91. fast_tokenizer_file = local_kwargs.pop("tokenizer_file", None)
  92. if (
  93. fast_tokenizer_file is not None
  94. and os.path.isfile(fast_tokenizer_file)
  95. and (cls is TokenizersBackend or "__init__" not in cls.__dict__ or trust_remote_code)
  96. ):
  97. local_kwargs["tokenizer_object"] = TokenizerFast.from_file(fast_tokenizer_file)
  98. return local_kwargs
  99. elif fast_tokenizer_file is not None and os.path.isfile(fast_tokenizer_file):
  100. # we extract vocab/merges and pass decoder/pre_tokenizer/post_processor
  101. # from the file so the reconstructed tokenizer matches the tokenizer.json
  102. with open(fast_tokenizer_file, encoding="utf-8") as tokenizer_handle:
  103. tokenizer_json = json.load(tokenizer_handle)
  104. # Build a minimal tokenizer (empty vocab/merges) to cheaply extract post_processor,
  105. # padding and truncation as Rust objects — avoids parsing the full vocab via from_file.
  106. # This optimization applies to BPE, WordPiece, and WordLevel only:
  107. # - Unigram (SentencePiece) requires a non-empty vocab to initialize correctly in Rust
  108. # (e.g. AlbertTokenizer, CamembertTokenizer, LlamaTokenizer, T5Tokenizer); passing an
  109. # empty vocab causes "Unable to load vocab EmptyVocabulary". TODO: investigate if keeping
  110. # just the UNK token is sufficient to make Unigram work with a minimal vocab.
  111. # - Older tokenizer.json formats (e.g. XLNetTokenizer, DistilBertTokenizer) omit the
  112. # "type" field in the "model" section, so we cannot determine the model type from JSON.
  113. # In both cases we fall back to the original from_file path (no performance improvement).
  114. model_type = tokenizer_json.get("model", {}).get("type")
  115. if model_type not in (None, "Unigram"):
  116. minimal_tokenizer_json = dict(tokenizer_json)
  117. minimal_model = dict(tokenizer_json["model"])
  118. minimal_model["vocab"] = {}
  119. if model_type == "BPE":
  120. minimal_model["merges"] = []
  121. minimal_tokenizer_json["model"] = minimal_model
  122. minimal_tokenizer_json["added_tokens"] = []
  123. tok_from_file = TokenizerFast.from_str(json.dumps(minimal_tokenizer_json))
  124. else:
  125. tok_from_file = TokenizerFast.from_file(fast_tokenizer_file)
  126. local_kwargs["post_processor"] = tok_from_file.post_processor
  127. local_kwargs["tokenizer_padding"] = tok_from_file.padding
  128. local_kwargs["tokenizer_truncation"] = tok_from_file.truncation
  129. # Preserve truncation and padding baked into tokenizer.json so that classes
  130. # with a custom __init__ that rebuild the backend tokenizer from scratch
  131. # can still access these settings.
  132. if tok_from_file.truncation is not None:
  133. local_kwargs["_json_truncation"] = tok_from_file.truncation
  134. if tok_from_file.padding is not None:
  135. local_kwargs["_json_padding"] = tok_from_file.padding
  136. # Extract precompiled SentencePiece charsmap from tokenizer.json normalizer
  137. # when present (e.g. T5 tokenizers converted with SentencePiece >= 2.x).
  138. normalizer_config = tokenizer_json.get("normalizer")
  139. if normalizer_config:
  140. if normalizer_config.get("type", None) == "Sequence":
  141. normalizer_config = normalizer_config["normalizers"]
  142. elif not isinstance(normalizer_config, list):
  143. normalizer_config = [normalizer_config]
  144. for normalizer in normalizer_config:
  145. if normalizer.get("type") == "Precompiled" and "precompiled_charsmap" in normalizer:
  146. import base64
  147. local_kwargs["_spm_precompiled_charsmap"] = base64.b64decode(
  148. normalizer["precompiled_charsmap"]
  149. )
  150. break
  151. vocab = tokenizer_json.get("model", {}).get("vocab", None)
  152. if cls.model is None:
  153. if isinstance(vocab, list):
  154. vocab = list(map(tuple, vocab)) # TODO just for now
  155. elif cls.model.__name__ == "Unigram":
  156. if isinstance(vocab, list) and vocab and isinstance(vocab[0], (list, tuple)):
  157. vocab = [tuple(item) for item in vocab]
  158. elif cls.model.__name__ == "WordLevel":
  159. vocab = {token: i for i, token in enumerate(vocab)}
  160. elif cls.model.__name__ == "BPE" or cls.model.__name__ == "WordPiece":
  161. if isinstance(vocab, list):
  162. vocab = {token[0] if isinstance(token, list) else token: i for i, token in enumerate(vocab)}
  163. local_kwargs["vocab"] = vocab
  164. model_type = getattr(cls, "model", None)
  165. if "merges" in tokenizer_json.get("model", {}) and (model_type and model_type.__name__ == "BPE"):
  166. merges = tokenizer_json["model"]["merges"]
  167. merges = [tuple(merge.split(" ")) if isinstance(merge, str) else tuple(merge) for merge in merges]
  168. local_kwargs["merges"] = merges
  169. return local_kwargs
  170. vocab_file = local_kwargs.get("vocab_file")
  171. merges_file = local_kwargs.get("merges_file")
  172. vocab = local_kwargs.get("vocab")
  173. merges = local_kwargs.get("merges")
  174. # Tekken converter (Mistral)
  175. if isinstance(vocab_file, str) and vocab_file.endswith("tekken.json") and os.path.isfile(vocab_file):
  176. from .convert_slow_tokenizer import MistralConverter
  177. local_kwargs["vocab"], local_kwargs["merges"] = MistralConverter(
  178. vocab_file=vocab_file
  179. ).extract_vocab_merges_from_model(vocab_file)
  180. return local_kwargs
  181. # SentencePiece model (with TikToken fallback)
  182. if isinstance(vocab_file, str) and os.path.isfile(vocab_file) and vocab_file.endswith(".model"):
  183. try:
  184. from .convert_slow_tokenizer import SentencePieceExtractor
  185. # 1. Extract vocab, merges, and spm_precompiled from the .model proto
  186. extractor = SentencePieceExtractor(vocab_file)
  187. local_kwargs = extractor.extract(cls.model, **local_kwargs)
  188. # 2. If a model-specific converter exists, use it.
  189. try:
  190. from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
  191. converter_class = SLOW_TO_FAST_CONVERTERS.get(cls.__name__)
  192. if converter_class is not None and hasattr(converter_class, "convert_from_spm"):
  193. local_kwargs = converter_class.convert_from_spm(**local_kwargs)
  194. except Exception as e:
  195. logger.warning(
  196. f"Could not reorder vocab using converter for {cls.__name__} due to {e}. Falling back to raw SentencePiece extraction."
  197. )
  198. if hasattr(cls, "convert_from_spm_model"):
  199. local_kwargs = cls.convert_from_spm_model(**local_kwargs)
  200. # 3. For non-model specific tokenizers (e.g. TokenizersBackend used
  201. # for MODELS_WITH_INCORRECT_HUB_TOKENIZER_CLASS), build a _tokenizer
  202. # from the proto so normalizer/decoder are configured correctly.
  203. if "tokenizer_object" not in local_kwargs and (
  204. cls is TokenizersBackend or "__init__" not in cls.__dict__
  205. ):
  206. vocab = local_kwargs.pop("vocab", None)
  207. merges = local_kwargs.pop("merges", None)
  208. # Replace placeholder tokens as specified in added_tokens_decoder
  209. added_tokens_decoder = local_kwargs.get("added_tokens_decoder") or {}
  210. if vocab is not None and added_tokens_decoder:
  211. id_to_token = {token_id: token for token, token_id in vocab.items()}
  212. for token_id, new_token in added_tokens_decoder.items():
  213. token_id = int(token_id)
  214. new_token = str(new_token)
  215. current_token = id_to_token.get(token_id)
  216. if current_token and current_token != new_token and new_token not in vocab:
  217. vocab[new_token] = vocab.pop(current_token)
  218. id_to_token[token_id] = new_token
  219. tokenizer_object = SpmConverter.build_tokenizer_from_spm_proto(
  220. proto=extractor.proto,
  221. vocab=vocab,
  222. merges=merges,
  223. )
  224. if tokenizer_object is not None:
  225. local_kwargs["tokenizer_object"] = tokenizer_object
  226. # Set bos/eos tokens from proto spec if available. This is needed when
  227. # building a tokenizer_object directly from a .model file because the
  228. # tokenizer_object does not have bos/eos set.
  229. proto_spec = extractor.proto.trainer_spec
  230. if proto_spec.bos_id >= 0:
  231. local_kwargs.setdefault("bos_token", proto_spec.bos_piece or "<s>")
  232. if proto_spec.eos_id >= 0:
  233. local_kwargs.setdefault("eos_token", proto_spec.eos_piece or "</s>")
  234. if proto_spec.unk_id >= 0:
  235. local_kwargs.setdefault("unk_token", proto_spec.unk_piece or "<unk>")
  236. except Exception as e: # TODO only catch deserialization error here!
  237. logger.warning(
  238. f"Could not extract SentencePiece model from {vocab_file} using sentencepiece library due to {e}. "
  239. "Falling back to TikToken extractor."
  240. )
  241. from .convert_slow_tokenizer import TikTokenConverter
  242. converter = TikTokenConverter(
  243. vocab_file=vocab_file, extra_special_tokens=local_kwargs.get("extra_special_tokens")
  244. )
  245. local_kwargs["tokenizer_object"] = converter.converted()
  246. return local_kwargs
  247. # Fallback to standard vocab/merges files if they existed!
  248. if vocab is None and isinstance(vocab_file, str) and os.path.isfile(vocab_file):
  249. local_kwargs["vocab"] = vocab_file
  250. vocab = local_kwargs["vocab"]
  251. if merges is None and isinstance(merges_file, str) and os.path.isfile(merges_file):
  252. local_kwargs["merges"] = merges_file
  253. merges = local_kwargs["merges"]
  254. # Generate merges automatically when not provided for BPE tokenizers
  255. if merges is None and cls.model is not None and cls.model.__name__ == "BPE" and isinstance(vocab, dict):
  256. # Gather special tokens from kwargs to skip in merge generation
  257. def _iter_special_tokens(values: Iterable[Any]) -> list[str]:
  258. collected: list[str] = []
  259. for val in values:
  260. if val is None:
  261. continue
  262. if isinstance(val, (list, tuple)):
  263. collected.extend(_iter_special_tokens(val))
  264. else:
  265. collected.append(str(val))
  266. return collected
  267. special_tokens_keys = [
  268. "pad_token",
  269. "unk_token",
  270. "bos_token",
  271. "eos_token",
  272. "sep_token",
  273. "cls_token",
  274. "mask_token",
  275. "additional_special_tokens",
  276. "extra_special_tokens",
  277. ]
  278. skip_tokens: set[str] = set()
  279. for key in special_tokens_keys:
  280. if key in local_kwargs:
  281. skip_tokens.update(_iter_special_tokens([local_kwargs[key]]))
  282. merges = generate_merges(vocab, skip_tokens=skip_tokens)
  283. local_kwargs["merges"] = merges
  284. return local_kwargs
  285. def __init__(self, *args, **kwargs):
  286. # Truncation/padding dicts extracted from tokenizer.json by convert_to_native_format
  287. # when a class with a custom __init__ rebuilds the backend tokenizer from scratch.
  288. _json_truncation = kwargs.pop("_json_truncation", None)
  289. _json_padding = kwargs.pop("_json_padding", None)
  290. # Precompiled SentencePiece charsmap is already used by model-specific tokenizers
  291. # (before calling super().__init__) and should not be stored in `init_kwargs` to keep the tokenizer serializable.
  292. kwargs.pop("_spm_precompiled_charsmap", None)
  293. tokenizer_object = kwargs.pop("tokenizer_object", None)
  294. gguf_file = kwargs.pop("gguf_file", None)
  295. fast_tokenizer_file = kwargs.pop("tokenizer_file", None)
  296. # Note: added_tokens_decoder is NOT popped - it's passed to super().__init__() for processing
  297. added_tokens_decoder = kwargs.get("added_tokens_decoder", {})
  298. # Store add_prefix_space before super().__init__() to ensure it's not overridden
  299. add_prefix_space = kwargs.get("add_prefix_space", False)
  300. vocab_file = kwargs.get("vocab_file")
  301. vocab = kwargs.get("vocab")
  302. merges = kwargs.get("merges")
  303. fast_tokenizer = None
  304. if tokenizer_object is not None:
  305. fast_tokenizer = copy.deepcopy(tokenizer_object)
  306. elif fast_tokenizer_file is not None and os.path.isfile(fast_tokenizer_file):
  307. # We have a serialization from tokenizers which let us directly build the backend
  308. fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
  309. elif gguf_file is not None:
  310. # We need to convert a slow tokenizer to build the backend
  311. gguf_path = cached_file(kwargs.get("name_or_path", ""), gguf_file, **kwargs)
  312. gguf_param = load_gguf_checkpoint(gguf_path)
  313. architecture = gguf_param["config"]["model_type"]
  314. tokenizer_dict = gguf_param["tokenizer"]
  315. tokenizer_config = gguf_param["tokenizer_config"]
  316. fast_tokenizer, additional_kwargs = convert_gguf_tokenizer(architecture, tokenizer_dict)
  317. kwargs.update(tokenizer_config)
  318. if len(additional_kwargs) > 0:
  319. kwargs.update(additional_kwargs)
  320. elif self._tokenizer is None and vocab is not None:
  321. # Build from vocab/merges extracted by convert_to_native_format
  322. if merges is not None:
  323. vocab_dict = vocab if isinstance(vocab, dict) else {w: i for i, (w, _) in enumerate(vocab)}
  324. fast_tokenizer = TokenizerFast(BPE(vocab=vocab_dict, merges=merges, fuse_unk=True, dropout=None))
  325. elif isinstance(vocab, dict):
  326. fast_tokenizer = TokenizerFast(BPE(vocab=vocab, merges=[], fuse_unk=True, dropout=None))
  327. elif isinstance(vocab, list) and vocab and isinstance(vocab[0], (tuple, list)):
  328. fast_tokenizer = TokenizerFast(Unigram(vocab=vocab, unk_id=kwargs.get("unk_id", 0)))
  329. elif self._tokenizer is None:
  330. raise ValueError(
  331. "Couldn't instantiate the backend tokenizer from one of: \n"
  332. "(1) a `tokenizers` library serialization file, \n"
  333. "(2) a slow tokenizer instance to convert or \n"
  334. "(3) an equivalent slow tokenizer class to instantiate and convert. \n"
  335. "You need to have sentencepiece or tiktoken installed to convert a slow tokenizer to a fast one."
  336. )
  337. # Only set defaults when creating TokenizersBackend from scratch
  338. if fast_tokenizer_file is None and tokenizer_object is None and self._tokenizer is None:
  339. kwargs.setdefault("bos_token", "<s>")
  340. kwargs.setdefault("eos_token", "</s>")
  341. if fast_tokenizer is not None:
  342. self._tokenizer = fast_tokenizer
  343. if self._tokenizer is None:
  344. raise ValueError("The backend tokenizer is not correctly initialized.")
  345. _truncation = kwargs.pop("tokenizer_truncation", None) or self._tokenizer.truncation or _json_truncation
  346. if _truncation is not None:
  347. self._tokenizer.enable_truncation(**_truncation)
  348. kwargs.setdefault("max_length", _truncation["max_length"])
  349. kwargs.setdefault("truncation_side", _truncation["direction"])
  350. kwargs.setdefault("stride", _truncation["stride"])
  351. kwargs.setdefault("truncation_strategy", _truncation["strategy"])
  352. else:
  353. self._tokenizer.no_truncation()
  354. _padding = kwargs.pop("tokenizer_padding", None) or self._tokenizer.padding or _json_padding
  355. if _padding is not None:
  356. self._tokenizer.enable_padding(**_padding)
  357. kwargs.setdefault("pad_token", _padding["pad_token"])
  358. kwargs.setdefault("pad_token_type_id", _padding["pad_type_id"])
  359. kwargs.setdefault("padding_side", _padding["direction"])
  360. kwargs.setdefault("max_length", _padding["length"])
  361. kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])
  362. # Set backend to "tokenizers" if not already set
  363. if "backend" not in kwargs:
  364. kwargs["backend"] = "tokenizers"
  365. explicit_bos_eos_in_kwargs = "add_bos_token" in kwargs or "add_eos_token" in kwargs
  366. self._add_bos_token = kwargs.get("add_bos_token", False)
  367. self._add_eos_token = kwargs.get("add_eos_token", False)
  368. if post_processor := kwargs.pop("post_processor", None): # most reliable way to get the post-processor
  369. self._tokenizer.post_processor = post_processor
  370. self._should_update_post_processor = explicit_bos_eos_in_kwargs or self._tokenizer.post_processor is None
  371. # We call this after having initialized the backend tokenizer because we update it.
  372. super().__init__(**kwargs)
  373. if vocab_file is not None:
  374. self.vocab_file = vocab_file
  375. # Ensure add_prefix_space is set correctly after parent init
  376. self.add_prefix_space = add_prefix_space
  377. self._tokenizer.encode_special_tokens = self.split_special_tokens
  378. added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder}
  379. tokens_to_add = [
  380. token
  381. for index, token in sorted(added_tokens_decoder.items(), key=lambda x: x[0])
  382. if hash(repr(token)) not in added_tokens_decoder_hash
  383. ]
  384. encoder = list(self.added_tokens_encoder.keys()) + [str(token) for token in tokens_to_add]
  385. # if some of the special tokens are not already in the tokenizer, add them
  386. # V5: Check both named special tokens and extra special tokens
  387. # Iterate over _special_tokens_map to preserve AddedToken properties (lstrip, rstrip, etc.)
  388. for special_token_value in self._special_tokens_map.values():
  389. if special_token_value is None:
  390. continue
  391. if str(special_token_value) not in encoder and special_token_value not in tokens_to_add:
  392. tokens_to_add.append(special_token_value)
  393. # Also check extra special tokens
  394. for token in self._extra_special_tokens:
  395. if str(token) not in encoder and token not in tokens_to_add:
  396. tokens_to_add.append(token)
  397. if len(tokens_to_add) > 0:
  398. tokens = []
  399. all_named_tokens = [str(t) for t in self._special_tokens_map.values() if t]
  400. for token in tokens_to_add:
  401. if isinstance(token, str):
  402. # Convert string to AddedToken, assuming it's special
  403. token = AddedToken(token, special=True)
  404. elif isinstance(token, AddedToken):
  405. # Ensure the special flag is set correctly for special tokens
  406. if not token.special and str(token) in all_named_tokens:
  407. token.special = True
  408. tokens.append(token)
  409. if tokens:
  410. # These tokens are from the special tokens map
  411. self.add_tokens(tokens)
  412. try:
  413. vocab_size = self._tokenizer.get_vocab_size()
  414. except NotImplementedError:
  415. vocab_size = 0
  416. # Optionally patches mistral tokenizers with wrong regex
  417. if vocab_size > 100000 and getattr(self._tokenizer, "pre_tokenizer", None) is not None:
  418. kwargs.pop("tokenizer", None)
  419. self._tokenizer = self._patch_mistral_regex(
  420. self._tokenizer,
  421. self.init_kwargs.get("name_or_path", None),
  422. init_kwargs=self.init_kwargs,
  423. fix_mistral_regex=kwargs.pop("fix_mistral_regex", None),
  424. **kwargs,
  425. )
  426. self._should_update_post_processor = (
  427. self._should_update_post_processor or self._tokenizer.post_processor is None
  428. )
  429. if self._should_update_post_processor:
  430. self.update_post_processor()
  431. @property
  432. def is_fast(self) -> bool:
  433. return True
  434. @property
  435. def can_save_slow_tokenizer(self) -> bool:
  436. """
  437. `bool`: Whether or not the slow tokenizer can be saved. For a sentencepiece based slow tokenizer, this
  438. can only be `True` if the original `"sentencepiece.model"` was not deleted.
  439. """
  440. if "vocab_file" in self.vocab_files_names and self.vocab_files_names["vocab_file"].endswith(".model"):
  441. if hasattr(self, "vocab_file") and self.vocab_file:
  442. # If the vocab file is a sentencepiece model, we can save it
  443. return os.path.isfile(self.vocab_file)
  444. return False
  445. else:
  446. return True
  447. def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
  448. if not os.path.isdir(save_directory):
  449. logger.error(f"Vocabulary path ({save_directory}) should be a directory")
  450. return
  451. out_vocab_file = os.path.join(
  452. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
  453. )
  454. if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
  455. copyfile(self.vocab_file, out_vocab_file)
  456. return (out_vocab_file,)
  457. def update_post_processor(self):
  458. """
  459. Updates the underlying post processor with the current `bos_token` and `eos_token`.
  460. """
  461. bos = self.bos_token
  462. bos_token_id = self.bos_token_id
  463. if bos is None and self.add_bos_token:
  464. self.add_bos_token = False
  465. eos = self.eos_token
  466. eos_token_id = self.eos_token_id
  467. if eos is None and self.add_eos_token:
  468. self.add_eos_token = False
  469. single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
  470. pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
  471. special_tokens = []
  472. if self.add_bos_token:
  473. special_tokens.append((bos, bos_token_id))
  474. if self.add_eos_token:
  475. special_tokens.append((eos, eos_token_id))
  476. self._tokenizer.post_processor = processors.TemplateProcessing(
  477. single=single, pair=pair, special_tokens=special_tokens
  478. )
  479. @property
  480. def add_eos_token(self):
  481. return getattr(self, "_add_eos_token", False)
  482. @property
  483. def add_bos_token(self):
  484. return getattr(self, "_add_bos_token", False)
  485. @add_eos_token.setter
  486. def add_eos_token(self, value):
  487. object.__setattr__(self, "_add_eos_token", value)
  488. self.update_post_processor()
  489. @add_bos_token.setter
  490. def add_bos_token(self, value):
  491. object.__setattr__(self, "_add_bos_token", value)
  492. self.update_post_processor()
  493. def _post_init(self):
  494. """
  495. Post-initialization hook that runs after the tokenizer is fully set up.
  496. This is called by from_pretrained() after loading the tokenizer, which allows
  497. us to add any special tokens that may have been passed as AddedToken objects.
  498. Child classes should call super()._post_init() if they override this method.
  499. """
  500. tokens_to_add = []
  501. # V5: Check named special tokens
  502. for token_value in self._special_tokens_map.values():
  503. if token_value is None:
  504. continue
  505. if isinstance(token_value, AddedToken):
  506. tokens_to_add.append(token_value)
  507. elif isinstance(token_value, str):
  508. tokens_to_add.append(AddedToken(token_value, special=True, normalized=False))
  509. # V5: Check extra special tokens
  510. for token in self._extra_special_tokens:
  511. if isinstance(token, AddedToken):
  512. tokens_to_add.append(token)
  513. elif isinstance(token, str):
  514. tokens_to_add.append(AddedToken(token, special=True, normalized=False))
  515. if tokens_to_add:
  516. # Ensure special tokens are added as such to the backend
  517. self.add_tokens(tokens_to_add, special_tokens=True)
  518. if getattr(self, "_should_update_post_processor", True) or self._tokenizer.post_processor is None:
  519. self.update_post_processor()
  520. @property
  521. def vocab_size(self) -> int:
  522. """
  523. `int`: Size of the base vocabulary (without the added tokens).
  524. """
  525. return self._tokenizer.get_vocab_size(with_added_tokens=False)
  526. def get_vocab(self) -> dict[str, int]:
  527. return self._tokenizer.get_vocab(with_added_tokens=True)
  528. @property
  529. def vocab(self) -> dict[str, int]:
  530. return self.get_vocab()
  531. @property
  532. def added_tokens_encoder(self) -> dict[str, int]:
  533. """
  534. Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
  535. optimisation in `self._added_tokens_encoder` for the slow tokenizers.
  536. """
  537. return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])}
  538. @property
  539. def added_tokens_decoder(self) -> dict[int, AddedToken]:
  540. """
  541. Returns the added tokens in the vocabulary as a dictionary of index to AddedToken.
  542. Returns:
  543. `dict[str, int]`: The added tokens.
  544. """
  545. return self._tokenizer.get_added_tokens_decoder()
  546. # BC v5: expose ``_added_tokens_encoder`` / ``_added_tokens_decoder`` attrs for custom tokenizers that expect
  547. # them from slow tokenizers. Only supports read, not write (won't sync to Rust backend, use add_tokens() instead
  548. _added_tokens_encoder = added_tokens_encoder
  549. _added_tokens_decoder = added_tokens_decoder
  550. def get_added_vocab(self) -> dict[str, int]:
  551. """
  552. Returns the added tokens in the vocabulary as a dictionary of token to index.
  553. Returns:
  554. `dict[str, int]`: The added tokens.
  555. """
  556. return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])}
  557. def __bool__(self) -> bool:
  558. """
  559. Returns True, to avoid expensive `assert tokenizer` gotchas.
  560. """
  561. return True
  562. def __len__(self) -> int:
  563. """
  564. Size of the full vocabulary with the added tokens.
  565. """
  566. return self._tokenizer.get_vocab_size(with_added_tokens=True)
  567. @property
  568. def backend_tokenizer(self) -> TokenizerFast:
  569. """
  570. `tokenizers.implementations.BaseTokenizer`: The Rust tokenizer used as a backend.
  571. """
  572. return self._tokenizer
  573. @property
  574. def decoder(self) -> DecoderFast:
  575. """
  576. `tokenizers.decoders.Decoder`: The Rust decoder for this tokenizer.
  577. """
  578. return self._tokenizer.decoder
  579. def _convert_encoding(
  580. self,
  581. encoding: EncodingFast,
  582. return_token_type_ids: bool | None = None,
  583. return_attention_mask: bool | None = None,
  584. return_overflowing_tokens: bool = False,
  585. return_special_tokens_mask: bool = False,
  586. return_offsets_mapping: bool = False,
  587. return_length: bool = False,
  588. verbose: bool = True,
  589. ) -> tuple[dict[str, Any], list[EncodingFast]]:
  590. """
  591. Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict and a list
  592. of encodings, take care of building a batch from overflowing tokens.
  593. Overflowing tokens are converted to additional examples (like batches) so the output values of the dict are
  594. lists (overflows) of lists (tokens).
  595. Output shape: (overflows, sequence length)
  596. """
  597. if return_token_type_ids is None:
  598. return_token_type_ids = "token_type_ids" in self.model_input_names
  599. if return_attention_mask is None:
  600. return_attention_mask = "attention_mask" in self.model_input_names
  601. if return_overflowing_tokens and encoding.overflowing is not None:
  602. encodings = [encoding] + encoding.overflowing
  603. else:
  604. encodings = [encoding]
  605. encoding_dict = defaultdict(list)
  606. for e in encodings:
  607. encoding_dict["input_ids"].append(e.ids)
  608. if return_token_type_ids:
  609. encoding_dict["token_type_ids"].append(e.type_ids)
  610. if return_attention_mask:
  611. encoding_dict["attention_mask"].append(e.attention_mask)
  612. if return_special_tokens_mask:
  613. encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
  614. if return_offsets_mapping:
  615. encoding_dict["offset_mapping"].append(e.offsets)
  616. if return_length:
  617. encoding_dict["length"].append(len(e.ids))
  618. return encoding_dict, encodings
  619. def _convert_token_to_id_with_added_voc(self, token: str) -> int:
  620. index = self._tokenizer.token_to_id(token)
  621. if index is None:
  622. return self.unk_token_id
  623. return index
  624. def _convert_id_to_token(self, index: int) -> str | None:
  625. return self._tokenizer.id_to_token(int(index))
  626. def _add_tokens(self, new_tokens: list[str | AddedToken], special_tokens=False) -> int:
  627. if special_tokens:
  628. return self._tokenizer.add_special_tokens(new_tokens)
  629. return self._tokenizer.add_tokens(new_tokens)
  630. def num_special_tokens_to_add(self, pair: bool = False) -> int:
  631. """
  632. Returns the number of added tokens when encoding a sequence with special tokens.
  633. <Tip>
  634. This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put
  635. this inside your training loop.
  636. </Tip>
  637. Args:
  638. pair (`bool`, *optional*, defaults to `False`):
  639. Whether the number of added tokens should be computed in the case of a sequence pair or a single
  640. sequence.
  641. Returns:
  642. `int`: Number of special tokens added to sequences.
  643. """
  644. return self._tokenizer.num_special_tokens_to_add(pair)
  645. def convert_ids_to_tokens(self, ids: int | list[int], skip_special_tokens: bool = False) -> str | list[str]:
  646. """
  647. Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
  648. added tokens.
  649. Args:
  650. ids (`int` or `list[int]`):
  651. The token id (or token ids) to convert to tokens.
  652. skip_special_tokens (`bool`, *optional*, defaults to `False`):
  653. Whether or not to remove special tokens in the decoding.
  654. Returns:
  655. `str` or `list[str]`: The decoded token(s).
  656. """
  657. if isinstance(ids, int):
  658. return self._tokenizer.id_to_token(ids)
  659. tokens = []
  660. # self.all_special_ids is an @property which may be slow, so only compute it once before the loop
  661. ids_to_skip = set(self.all_special_ids) if skip_special_tokens else set()
  662. for index in ids:
  663. index = int(index)
  664. if index in ids_to_skip:
  665. continue
  666. tokens.append(self._tokenizer.id_to_token(index))
  667. return tokens
  668. def tokenize(self, text: str, pair: str | None = None, add_special_tokens: bool = False, **kwargs) -> list[str]:
  669. return self._encode_plus(text=text, text_pair=pair, add_special_tokens=add_special_tokens, **kwargs).tokens()
  670. def set_truncation_and_padding(
  671. self,
  672. padding_strategy: PaddingStrategy,
  673. truncation_strategy: TruncationStrategy,
  674. max_length: int,
  675. stride: int,
  676. pad_to_multiple_of: int | None,
  677. padding_side: str | None,
  678. ):
  679. """
  680. Define the truncation and the padding strategies for fast tokenizers (provided by HuggingFace tokenizers
  681. library) and restore the tokenizer settings afterwards.
  682. The provided tokenizer has no padding / truncation strategy before the managed section. If your tokenizer set a
  683. padding / truncation strategy before, then it will be reset to no padding / truncation when exiting the managed
  684. section.
  685. Args:
  686. padding_strategy ([`~utils.PaddingStrategy`]):
  687. The kind of padding that will be applied to the input
  688. truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`]):
  689. The kind of truncation that will be applied to the input
  690. max_length (`int`):
  691. The maximum size of a sequence.
  692. stride (`int`):
  693. The stride to use when handling overflow.
  694. pad_to_multiple_of (`int`, *optional*):
  695. If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
  696. the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
  697. padding_side (`str`, *optional*):
  698. The side on which the model should have padding applied. Should be selected between ['right', 'left'].
  699. Default value is picked from the class attribute of the same name.
  700. """
  701. _truncation = self._tokenizer.truncation
  702. _padding = self._tokenizer.padding
  703. # Set truncation and padding on the backend tokenizer
  704. if truncation_strategy == TruncationStrategy.DO_NOT_TRUNCATE:
  705. if _truncation is not None:
  706. self._tokenizer.no_truncation()
  707. else:
  708. target = {
  709. "max_length": max_length,
  710. "stride": stride,
  711. "strategy": truncation_strategy.value,
  712. "direction": self.truncation_side,
  713. }
  714. # _truncation might contain more keys that the target `transformers`
  715. # supports. Use only the target keys to trigger `enable_truncation`.
  716. # This should enable this code to works on various `tokenizers`
  717. # targets.
  718. if _truncation is None:
  719. current = None
  720. else:
  721. current = {k: _truncation.get(k, None) for k in target}
  722. if current != target:
  723. self._tokenizer.enable_truncation(**target)
  724. if padding_strategy == PaddingStrategy.DO_NOT_PAD:
  725. if _padding is not None:
  726. self._tokenizer.no_padding()
  727. else:
  728. length = max_length if padding_strategy == PaddingStrategy.MAX_LENGTH else None
  729. target = {
  730. "length": length,
  731. "direction": padding_side if padding_side is not None else self.padding_side,
  732. "pad_id": self.pad_token_id,
  733. "pad_token": self.pad_token,
  734. "pad_type_id": self.pad_token_type_id,
  735. "pad_to_multiple_of": pad_to_multiple_of,
  736. }
  737. if _padding != target:
  738. self._tokenizer.enable_padding(**target)
  739. def _encode_plus(
  740. self,
  741. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput],
  742. text_pair: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
  743. add_special_tokens: bool = True,
  744. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  745. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  746. max_length: int | None = None,
  747. stride: int = 0,
  748. is_split_into_words: bool = False,
  749. pad_to_multiple_of: int | None = None,
  750. padding_side: str | None = None,
  751. return_tensors: bool | None = None,
  752. return_token_type_ids: bool | None = None,
  753. return_attention_mask: bool | None = None,
  754. return_overflowing_tokens: bool = False,
  755. return_special_tokens_mask: bool = False,
  756. return_offsets_mapping: bool = False,
  757. return_length: bool = False,
  758. verbose: bool = True,
  759. split_special_tokens: bool | None = None,
  760. **kwargs,
  761. ) -> BatchEncoding:
  762. # Input validation (from _call_one)
  763. def _is_valid_text_input(t):
  764. if isinstance(t, str):
  765. return True
  766. elif isinstance(t, (list, tuple)):
  767. if len(t) == 0:
  768. return True
  769. elif isinstance(t[0], str):
  770. return True
  771. elif isinstance(t[0], (list, tuple)):
  772. if len(t[0]) == 0 or isinstance(t[0][0], str):
  773. return True
  774. elif isinstance(t[0][0], (list, tuple)):
  775. return len(t[0][0]) == 0 or isinstance(t[0][0][0], str)
  776. else:
  777. return False
  778. else:
  779. return False
  780. else:
  781. return False
  782. if not _is_valid_text_input(text):
  783. raise ValueError(
  784. "text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) "
  785. "or `list[list[str]]` (batch of pretokenized examples) or `list[tuple[list[str], list[str]]]` (batch of pretokenized sequence pairs)."
  786. )
  787. if text_pair is not None and not _is_valid_text_input(text_pair):
  788. raise ValueError(
  789. "text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) "
  790. "or `list[list[str]]` (batch of pretokenized examples) or `list[tuple[list[str], list[str]]]` (batch of pretokenized sequence pairs)."
  791. )
  792. # Batch detection (from _call_one)
  793. if is_split_into_words:
  794. is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
  795. else:
  796. is_batched = isinstance(text, (list, tuple))
  797. if is_batched:
  798. # Batch validation
  799. if isinstance(text_pair, str):
  800. raise TypeError(
  801. "when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as"
  802. " `text`."
  803. )
  804. if text_pair is not None and len(text) != len(text_pair):
  805. raise ValueError(
  806. f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
  807. f" {len(text_pair)}."
  808. )
  809. batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
  810. else:
  811. # Single input - convert to batch format
  812. batch_text_or_text_pairs = [(text, text_pair)] if text_pair else [text]
  813. # Set tokenizer configuration (from _batch_encode_plus)
  814. if not isinstance(batch_text_or_text_pairs, (tuple, list)):
  815. raise TypeError(
  816. f"batch_text_or_text_pairs has to be a list or a tuple (got {type(batch_text_or_text_pairs)})"
  817. )
  818. self.set_truncation_and_padding(
  819. padding_strategy=padding_strategy,
  820. truncation_strategy=truncation_strategy,
  821. max_length=max_length,
  822. stride=stride,
  823. pad_to_multiple_of=pad_to_multiple_of,
  824. padding_side=padding_side,
  825. )
  826. # Use self.split_special_tokens as default if not explicitly provided
  827. if split_special_tokens is None:
  828. split_special_tokens = self.split_special_tokens
  829. if self._tokenizer.encode_special_tokens != split_special_tokens:
  830. self._tokenizer.encode_special_tokens = split_special_tokens
  831. # Direct rust backend call
  832. encodings = self._tokenizer.encode_batch(
  833. batch_text_or_text_pairs,
  834. add_special_tokens=add_special_tokens,
  835. is_pretokenized=is_split_into_words,
  836. )
  837. # Convert encodings to BatchEncoding format
  838. tokens_and_encodings = [
  839. self._convert_encoding(
  840. encoding=encoding,
  841. return_token_type_ids=return_token_type_ids,
  842. return_attention_mask=return_attention_mask,
  843. return_overflowing_tokens=return_overflowing_tokens,
  844. return_special_tokens_mask=return_special_tokens_mask,
  845. return_offsets_mapping=return_offsets_mapping,
  846. return_length=return_length,
  847. verbose=verbose,
  848. )
  849. for encoding in encodings
  850. ]
  851. # Convert the output to have dict[list] from list[dict]
  852. sanitized_tokens = {}
  853. for key in tokens_and_encodings[0][0]:
  854. stack = [e for item, _ in tokens_and_encodings for e in item[key]]
  855. sanitized_tokens[key] = stack
  856. sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
  857. # If returning overflowing tokens, we need to return a mapping
  858. if return_overflowing_tokens:
  859. overflow_to_sample_mapping = []
  860. for i, (toks, _) in enumerate(tokens_and_encodings):
  861. overflow_to_sample_mapping += [i] * len(toks["input_ids"])
  862. sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
  863. for input_ids in sanitized_tokens["input_ids"]:
  864. self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
  865. batched_output = BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
  866. # If single input, remove the batch dimension (unless returning overflowing tokens)
  867. if not is_batched and return_tensors is None and not return_overflowing_tokens:
  868. batched_output = BatchEncoding(
  869. {
  870. key: (value[0] if len(value) > 0 and isinstance(value[0], list) else value)
  871. for key, value in batched_output.items()
  872. },
  873. batched_output.encodings,
  874. )
  875. return batched_output
  876. def convert_tokens_to_string(self, tokens: list[str]) -> str:
  877. return (
  878. self.backend_tokenizer.decoder.decode(tokens)
  879. if self.backend_tokenizer.decoder is not None
  880. else " ".join(tokens)
  881. )
  882. def _decode(
  883. self,
  884. token_ids: int | list[int],
  885. skip_special_tokens: bool = False,
  886. clean_up_tokenization_spaces: bool | None = None,
  887. **kwargs,
  888. ) -> str:
  889. # Removed: use_source_tokenizer parameter (unused)
  890. kwargs.pop("use_source_tokenizer", None) # Pop if present to avoid errors
  891. if isinstance(token_ids, int):
  892. token_ids = [token_ids]
  893. if isinstance(token_ids, dict):
  894. token_ids = token_ids["input_ids"]
  895. text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
  896. clean_up_tokenization_spaces = (
  897. clean_up_tokenization_spaces
  898. if clean_up_tokenization_spaces is not None
  899. else self.clean_up_tokenization_spaces
  900. )
  901. if clean_up_tokenization_spaces:
  902. text = self.clean_up_tokenization(text)
  903. return text
  904. def _save_pretrained(
  905. self,
  906. save_directory: str | os.PathLike,
  907. file_names: tuple[str, ...],
  908. legacy_format: bool | None = None,
  909. filename_prefix: str | None = None,
  910. ) -> tuple[str, ...]:
  911. save_directory = str(save_directory)
  912. tokenizer_file = os.path.join(
  913. save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_FILE
  914. )
  915. self.backend_tokenizer.save(tokenizer_file)
  916. file_names = file_names + (tokenizer_file,)
  917. return file_names
  918. def train_new_from_iterator(
  919. self,
  920. text_iterator,
  921. vocab_size,
  922. length=None,
  923. new_special_tokens=None,
  924. special_tokens_map=None,
  925. **kwargs,
  926. ):
  927. """
  928. Trains a tokenizer on a new corpus with the same defaults (in terms of special tokens or tokenization pipeline)
  929. as the current one.
  930. Args:
  931. text_iterator (generator of `list[str]`):
  932. The training corpus. Should be a generator of batches of texts, for instance a list of lists of texts
  933. if you have everything in memory.
  934. vocab_size (`int`):
  935. The size of the vocabulary you want for your tokenizer.
  936. length (`int`, *optional*):
  937. The total number of sequences in the iterator. This is used to provide meaningful progress tracking
  938. new_special_tokens (list of `str` or `AddedToken`, *optional*):
  939. A list of new special tokens to add to the tokenizer you are training.
  940. special_tokens_map (`dict[str, str]`, *optional*):
  941. If you want to rename some of the special tokens this tokenizer uses, pass along a mapping old special
  942. token name to new special token name in this argument.
  943. kwargs (`dict[str, Any]`, *optional*):
  944. Additional keyword arguments passed along to the trainer from the 🤗 Tokenizers library.
  945. Returns:
  946. [`PreTrainedTokenizerFast`]: A new tokenizer of the same type as the original one, trained on
  947. `text_iterator`.
  948. """
  949. tokenizer_json = json.loads(self._tokenizer.to_str())
  950. # Remove added tokens for now (uses IDs of tokens)
  951. added_tokens = tokenizer_json.pop("added_tokens")
  952. # Remove post processor for now (uses IDs of tokens)
  953. post_processor = tokenizer_json.pop("post_processor")
  954. unk_token = None
  955. # Remove vocab
  956. if tokenizer_json["model"]["type"] == "BPE":
  957. tokenizer_json["model"]["vocab"] = {}
  958. tokenizer_json["model"]["merges"] = []
  959. elif tokenizer_json["model"]["type"] == "Unigram":
  960. if tokenizer_json["model"]["unk_id"] is not None:
  961. unk_id = tokenizer_json["model"]["unk_id"]
  962. unk_token = tokenizer_json["model"]["vocab"][unk_id][0]
  963. if special_tokens_map is not None and unk_token in special_tokens_map:
  964. unk_token = special_tokens_map[unk_token]
  965. tokenizer_json["model"]["unk_id"] = 0
  966. tokenizer_json["model"]["vocab"] = [[unk_token, 0.0]]
  967. elif tokenizer_json["model"]["type"] in ["WordLevel", "WordPiece"]:
  968. tokenizer_json["model"]["vocab"] = {}
  969. else:
  970. raise ValueError(
  971. f"This method does not support this type of tokenizer (found {tokenizer_json['model']['type']}) "
  972. "only BPE, Unigram, WordLevel and WordPiece."
  973. )
  974. if (
  975. special_tokens_map is not None
  976. and "unk_token" in tokenizer_json["model"]
  977. and tokenizer_json["model"]["unk_token"] in special_tokens_map
  978. ):
  979. tokenizer_json["model"]["unk_token"] = special_tokens_map[tokenizer_json["model"]["unk_token"]]
  980. tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json))
  981. # Get the special tokens from the current tokenizer if none are specified.
  982. special_tokens = []
  983. for added_token in added_tokens:
  984. special = added_token.pop("special", None)
  985. _ = added_token.pop("id", None)
  986. if tokenizer_json["model"]["type"] != "Unigram" and not special:
  987. continue
  988. if special_tokens_map is not None and added_token["content"] in special_tokens_map:
  989. added_token["content"] = special_tokens_map[added_token["content"]]
  990. special_tokens.append(AddedToken(**added_token))
  991. if new_special_tokens is not None:
  992. special_tokens.extend(new_special_tokens)
  993. # Trainer needs to know the end of word / continuing subword thingies in BPE
  994. if (
  995. tokenizer_json["model"]["type"] == "BPE"
  996. and "continuing_subword_prefix" not in kwargs
  997. and tokenizer_json["model"]["continuing_subword_prefix"] is not None
  998. ):
  999. kwargs["continuing_subword_prefix"] = tokenizer_json["model"]["continuing_subword_prefix"]
  1000. if (
  1001. tokenizer_json["model"]["type"] == "BPE"
  1002. and "end_of_word_suffix" not in kwargs
  1003. and tokenizer_json["model"]["end_of_word_suffix"] is not None
  1004. ):
  1005. kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"]
  1006. if tokenizer_json["model"]["type"] == "Unigram" and unk_token is not None:
  1007. kwargs["unk_token"] = unk_token
  1008. if tokenizer_json["pre_tokenizer"] is not None:
  1009. if (
  1010. tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel"
  1011. or tokenizer_json["pre_tokenizer"]["type"] == "Sequence"
  1012. and "pretokenizers" in tokenizer_json["pre_tokenizer"]
  1013. and any(
  1014. pretokenizer["type"] == "ByteLevel"
  1015. for pretokenizer in tokenizer_json["pre_tokenizer"]["pretokenizers"]
  1016. )
  1017. ):
  1018. kwargs["initial_alphabet"] = pre_tokenizers_fast.ByteLevel.alphabet()
  1019. trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]]
  1020. trainer = trainer_class(vocab_size=vocab_size, special_tokens=special_tokens, **kwargs)
  1021. tokenizer.train_from_iterator(text_iterator, length=length, trainer=trainer)
  1022. if post_processor is not None:
  1023. trained_tokenizer_json = json.loads(tokenizer.to_str())
  1024. # Almost done, we just have to adjust the token IDs in the post processor
  1025. if "special_tokens" in post_processor:
  1026. for key in post_processor["special_tokens"]:
  1027. tokens = post_processor["special_tokens"][key]["tokens"]
  1028. if special_tokens_map is not None:
  1029. tokens = [special_tokens_map.get(token, token) for token in tokens]
  1030. post_processor["special_tokens"][key]["tokens"] = tokens
  1031. for token in tokens:
  1032. token_id = tokenizer.token_to_id(token)
  1033. if token_id is None:
  1034. raise ValueError(
  1035. "Attempted to set a token in the post processor that does not exist in the mapping"
  1036. )
  1037. post_processor["special_tokens"][key]["ids"] = [tokenizer.token_to_id(token) for token in tokens]
  1038. for special_token in ["cls", "sep"]:
  1039. if special_token in post_processor:
  1040. token, _ = post_processor[special_token]
  1041. if special_tokens_map is not None and token in special_tokens_map:
  1042. token = special_tokens_map[token]
  1043. token_id = tokenizer.token_to_id(token)
  1044. if token_id is None:
  1045. raise ValueError(
  1046. "Attempted to set a token in the post processor that does not exist in the mapping"
  1047. )
  1048. post_processor[special_token] = [token, token_id]
  1049. trained_tokenizer_json["post_processor"] = post_processor
  1050. tokenizer = TokenizerFast.from_str(json.dumps(trained_tokenizer_json))
  1051. kwargs = self.init_kwargs.copy()
  1052. # V5: Map pad/cls/mask token at the Transformers level (named tokens only)
  1053. for token in PreTrainedTokenizerBase.SPECIAL_TOKENS_ATTRIBUTES:
  1054. if getattr(self, token) is not None:
  1055. special_token = getattr(self, token)
  1056. if special_tokens_map is not None and special_token in special_tokens_map:
  1057. special_token = special_tokens_map[special_token]
  1058. special_token_full = self._special_tokens_map.get(token, None)
  1059. if isinstance(special_token_full, AddedToken):
  1060. # Create an added token with the same parameters except the content
  1061. kwargs[token] = AddedToken(
  1062. special_token,
  1063. single_word=special_token_full.single_word,
  1064. lstrip=special_token_full.lstrip,
  1065. rstrip=special_token_full.rstrip,
  1066. normalized=special_token_full.normalized,
  1067. special=True,
  1068. )
  1069. else:
  1070. kwargs[token] = special_token
  1071. # V5: Handle extra special tokens
  1072. extra_special_tokens = self.extra_special_tokens.copy() if self.extra_special_tokens else []
  1073. if new_special_tokens is not None:
  1074. extra_special_tokens.extend(new_special_tokens)
  1075. if len(extra_special_tokens) > 0:
  1076. kwargs["extra_special_tokens"] = extra_special_tokens
  1077. # Always try to pass tokenizer_object in kwargs first (standard TokenizersBackend usage)
  1078. # If the class creates its own tokenizer and passes it explicitly to super().__init__(),
  1079. # this will cause a TypeError, which we catch and handle by removing tokenizer_object
  1080. # from kwargs and setting _tokenizer directly after initialization.
  1081. kwargs["tokenizer_object"] = tokenizer
  1082. try:
  1083. return self.__class__(**kwargs)
  1084. except TypeError as e:
  1085. # Check if the error is due to multiple values for tokenizer_object
  1086. if "multiple values for keyword argument 'tokenizer_object'" in str(e):
  1087. # Class creates its own tokenizer and passes it explicitly (like LayoutLMv3Tokenizer)
  1088. # Remove tokenizer_object from kwargs and set _tokenizer directly
  1089. kwargs.pop("tokenizer_object", None)
  1090. new_tokenizer = self.__class__(**kwargs)
  1091. new_tokenizer._tokenizer = tokenizer
  1092. return new_tokenizer
  1093. else:
  1094. # Some other TypeError, re-raise it
  1095. raise
  1096. @classmethod
  1097. def _patch_mistral_regex(
  1098. cls,
  1099. tokenizer,
  1100. pretrained_model_name_or_path,
  1101. token=None,
  1102. cache_dir=None,
  1103. local_files_only=False,
  1104. _commit_hash=None,
  1105. is_local=False,
  1106. init_kwargs=None,
  1107. fix_mistral_regex=None,
  1108. **kwargs,
  1109. ):
  1110. """
  1111. Patches mistral related tokenizers with incorrect regex if detected
  1112. 1) Local file with an associated config saved next to it
  1113. >> Model type one of the mistral models (on older versions)
  1114. 2) Remote models on the hub from official mistral models
  1115. >> Tags including `base_model:.*mistralai`
  1116. """
  1117. import re
  1118. from huggingface_hub import model_info
  1119. from packaging import version
  1120. from transformers.utils.hub import cached_file
  1121. def is_base_mistral(model_id: str) -> bool:
  1122. model = model_info(model_id)
  1123. if model.tags is not None:
  1124. if re.search("base_model:.*mistralai", "".join(model.tags)):
  1125. return True
  1126. return False
  1127. if is_offline_mode():
  1128. is_local = True
  1129. if pretrained_model_name_or_path is not None and (
  1130. is_local or (not is_local and is_base_mistral(pretrained_model_name_or_path))
  1131. ):
  1132. _config_file = cached_file(
  1133. pretrained_model_name_or_path,
  1134. "config.json",
  1135. cache_dir=cache_dir,
  1136. token=token,
  1137. local_files_only=local_files_only,
  1138. _raise_exceptions_for_missing_entries=False,
  1139. _raise_exceptions_for_connection_errors=False,
  1140. _commit_hash=_commit_hash,
  1141. )
  1142. # Detected using a (local) mistral tokenizer
  1143. mistral_config_detected = False
  1144. if _config_file is not None:
  1145. with open(_config_file, encoding="utf-8") as f:
  1146. _config = json.load(f)
  1147. transformers_version = _config.get("transformers_version")
  1148. transformers_model_type = _config.get("model_type")
  1149. # Detect if we can skip the mistral fix by
  1150. # a) having a non-mistral tokenizer
  1151. # b) fixed version of transformers
  1152. if transformers_version and version.parse(transformers_version) <= version.parse("4.57.2"):
  1153. if (
  1154. is_local
  1155. and transformers_model_type is not None
  1156. and transformers_model_type
  1157. not in [
  1158. "mistral",
  1159. "mistral3",
  1160. "voxtral",
  1161. "ministral",
  1162. "pixtral",
  1163. ]
  1164. ):
  1165. return tokenizer
  1166. elif transformers_version and version.parse(transformers_version) > version.parse("4.57.3"):
  1167. return tokenizer
  1168. mistral_config_detected = True
  1169. if mistral_config_detected or (not is_local and is_base_mistral(pretrained_model_name_or_path)):
  1170. # Expose the `fix_mistral_regex` flag on the tokenizer when provided, even if no correction is applied.
  1171. if init_kwargs and "fix_mistral_regex" in init_kwargs:
  1172. setattr(tokenizer, "fix_mistral_regex", init_kwargs["fix_mistral_regex"])
  1173. # only warn if its not explicitly passed
  1174. if fix_mistral_regex is None and not getattr(tokenizer, "fix_mistral_regex", False):
  1175. setattr(tokenizer, "fix_mistral_regex", False)
  1176. logger.warning(
  1177. f"The tokenizer you are loading from '{pretrained_model_name_or_path}'"
  1178. f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e."
  1179. " This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue."
  1180. )
  1181. elif fix_mistral_regex is True or getattr(tokenizer, "fix_mistral_regex", False):
  1182. setattr(tokenizer, "fix_mistral_regex", True)
  1183. import tokenizers
  1184. split_pretokenizer = tokenizers.pre_tokenizers.Split(
  1185. pattern=tokenizers.Regex(
  1186. r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+"
  1187. ),
  1188. behavior="isolated",
  1189. )
  1190. current_pretokenizer = tokenizer.pre_tokenizer
  1191. # Check if it's already a Sequence
  1192. if isinstance(current_pretokenizer, tokenizers.pre_tokenizers.Sequence):
  1193. # Replace the first element (the Split pattern)
  1194. tokenizer.pre_tokenizer[0] = split_pretokenizer
  1195. else:
  1196. # Replace Metaspace with ByteLevel when adding Split, as Metaspace(split=False) doesn't
  1197. # work correctly with the Split pre-tokenizer and causes spaces to be lost during encoding
  1198. if isinstance(current_pretokenizer, tokenizers.pre_tokenizers.Metaspace):
  1199. current_pretokenizer = tokenizers.pre_tokenizers.ByteLevel(
  1200. add_prefix_space=False, use_regex=False
  1201. )
  1202. # Not a Sequence, so create one with Split + current pretokenizer
  1203. tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Sequence(
  1204. [
  1205. split_pretokenizer,
  1206. current_pretokenizer,
  1207. ]
  1208. )
  1209. return tokenizer
  1210. # Backward-compatible alias: allow referring to TokenizersBackend as PreTrainedTokenizerFast
  1211. PreTrainedTokenizerFast = TokenizersBackend