convert_slow_tokenizer.py 76 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151
  1. # Copyright 2018 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Utilities to convert slow tokenizers in their fast tokenizers counterparts.
  16. All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
  17. allow to make our dependency on SentencePiece optional.
  18. """
  19. import warnings
  20. from collections.abc import Collection
  21. from functools import lru_cache
  22. from packaging import version
  23. from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
  24. from tokenizers.models import BPE, Unigram, WordPiece
  25. from tqdm import tqdm
  26. from .utils import is_protobuf_available, is_sentencepiece_available, logging, requires_backends
  27. from .utils.import_utils import PROTOBUF_IMPORT_ERROR
  28. logger = logging.get_logger(__name__)
  29. MBART_LANGUAGES = [
  30. "ar_AR",
  31. "cs_CZ",
  32. "de_DE",
  33. "en_XX",
  34. "es_XX",
  35. "et_EE",
  36. "fi_FI",
  37. "fr_XX",
  38. "gu_IN",
  39. "hi_IN",
  40. "it_IT",
  41. "ja_XX",
  42. "kk_KZ",
  43. "ko_KR",
  44. "lt_LT",
  45. "lv_LV",
  46. "my_MM",
  47. "ne_NP",
  48. "nl_XX",
  49. "ro_RO",
  50. "ru_RU",
  51. "si_LK",
  52. "tr_TR",
  53. "vi_VN",
  54. "zh_CN",
  55. ]
  56. MBART50_LANGUAGES = MBART_LANGUAGES + [
  57. "af_ZA",
  58. "az_AZ",
  59. "bn_IN",
  60. "fa_IR",
  61. "he_IL",
  62. "hr_HR",
  63. "id_ID",
  64. "ka_GE",
  65. "km_KH",
  66. "mk_MK",
  67. "ml_IN",
  68. "mn_MN",
  69. "mr_IN",
  70. "pl_PL",
  71. "ps_AF",
  72. "pt_XX",
  73. "sv_SE",
  74. "sw_KE",
  75. "ta_IN",
  76. "te_IN",
  77. "th_TH",
  78. "tl_XX",
  79. "uk_UA",
  80. "ur_PK",
  81. "xh_ZA",
  82. "gl_ES",
  83. "sl_SI",
  84. ]
  85. def import_protobuf(error_message=""):
  86. if is_sentencepiece_available():
  87. from sentencepiece import sentencepiece_model_pb2
  88. return sentencepiece_model_pb2
  89. if is_protobuf_available():
  90. import google.protobuf
  91. if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
  92. from transformers.utils import sentencepiece_model_pb2
  93. else:
  94. from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
  95. return sentencepiece_model_pb2
  96. else:
  97. raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
  98. def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
  99. if add_prefix_space:
  100. prepend_scheme = "always"
  101. if not getattr(original_tokenizer, "legacy", True):
  102. prepend_scheme = "first"
  103. else:
  104. prepend_scheme = "never"
  105. return prepend_scheme
  106. def generate_merges(vocab, vocab_scores, skip_tokens: Collection[str] | None = None):
  107. skip_tokens = set(skip_tokens) if skip_tokens is not None else set()
  108. reverse = vocab_scores is not None
  109. vocab_scores = dict(vocab_scores) if reverse else vocab
  110. merges = []
  111. for merge, piece_score in vocab_scores.items():
  112. if merge in skip_tokens:
  113. continue
  114. local = []
  115. for index in range(1, len(merge)):
  116. piece_l, piece_r = merge[:index], merge[index:]
  117. if piece_l in skip_tokens or piece_r in skip_tokens:
  118. continue
  119. if piece_l in vocab and piece_r in vocab:
  120. local.append((piece_l, piece_r, piece_score))
  121. local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
  122. merges.extend(local)
  123. merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse)
  124. merges = [(val[0], val[1]) for val in merges]
  125. return merges
  126. class SentencePieceExtractor:
  127. """
  128. Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
  129. """
  130. def __init__(self, model: str):
  131. requires_backends(self, "sentencepiece")
  132. requires_backends(self, "protobuf")
  133. # from .utils import sentencepiece_model_pb2 as model_pb2
  134. model_pb2 = import_protobuf()
  135. m = model_pb2.ModelProto()
  136. with open(model, "rb") as f:
  137. m.ParseFromString(f.read())
  138. self.proto = m
  139. def extract(self, model_type, **kwargs) -> tuple[dict[str, int], list[tuple]]:
  140. """
  141. By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
  142. order the merges with respect to the piece scores instead.
  143. """
  144. self.proto.trainer_spec.unk_id
  145. if model_type is None:
  146. from tokenizers.models import BPE, Unigram
  147. model_type = Unigram if self.proto.trainer_spec.model_type == 1 else BPE
  148. vocab = [(piece.piece, piece.score) for piece in self.proto.pieces]
  149. if model_type.__name__ != "BPE":
  150. kwargs["unk_id"] = self.proto.trainer_spec.unk_id
  151. kwargs["vocab"] = vocab
  152. else:
  153. from .tokenization_utils_base import generate_merges
  154. vocab = {word: i for i, (word, score) in enumerate(vocab)}
  155. merges = generate_merges(vocab)
  156. kwargs["vocab"] = vocab
  157. kwargs["merges"] = merges
  158. # control tokens are special
  159. # user defined symbols are not
  160. # both user and control tokens are AddedTokens
  161. # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
  162. spm_added_tokens = [(id, p.piece, p.type == 3) for id, p in enumerate(self.proto.pieces) if p.type in [3, 4]]
  163. kwargs["additional_special_tokens"] = [
  164. AddedToken(token, normalized=False, special=special)
  165. for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
  166. ]
  167. kwargs["_spm_precompiled_charsmap"] = getattr(self.proto.normalizer_spec, "precompiled_charsmap", None)
  168. return kwargs
  169. class GemmaSentencePieceExtractor(SentencePieceExtractor):
  170. def extract(self, vocab_scores=None) -> tuple[dict[str, int], list[tuple]]:
  171. """
  172. By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
  173. order the merges with respect to the piece scores instead.
  174. """
  175. sp = self.sp
  176. vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
  177. # If "\t" is missing in the vocab, we have to do this to support merges
  178. # "<0x09>" is the bytefallback for `\t`
  179. if "\t" not in vocab:
  180. vocab["\t"] = vocab.get("<0x09>")
  181. merges = generate_merges(vocab, vocab_scores)
  182. return vocab, merges
  183. def check_number_comma(piece: str) -> bool:
  184. return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
  185. class Converter:
  186. def __init__(self, original_tokenizer):
  187. self.original_tokenizer = original_tokenizer
  188. def converted(self) -> Tokenizer:
  189. raise NotImplementedError()
  190. class BertConverter(Converter):
  191. def converted(self) -> Tokenizer:
  192. vocab = self.original_tokenizer.vocab
  193. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  194. tokenize_chinese_chars = False
  195. strip_accents = False
  196. do_lower_case = False
  197. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  198. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  199. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  200. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  201. tokenizer.normalizer = normalizers.BertNormalizer(
  202. clean_text=True,
  203. handle_chinese_chars=tokenize_chinese_chars,
  204. strip_accents=strip_accents,
  205. lowercase=do_lower_case,
  206. )
  207. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  208. cls = str(self.original_tokenizer.cls_token)
  209. sep = str(self.original_tokenizer.sep_token)
  210. cls_token_id = self.original_tokenizer.cls_token_id
  211. sep_token_id = self.original_tokenizer.sep_token_id
  212. tokenizer.post_processor = processors.TemplateProcessing(
  213. single=f"{cls}:0 $A:0 {sep}:0",
  214. pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
  215. special_tokens=[
  216. (cls, cls_token_id),
  217. (sep, sep_token_id),
  218. ],
  219. )
  220. tokenizer.decoder = decoders.WordPiece(prefix="##")
  221. return tokenizer
  222. class SplinterConverter(Converter):
  223. def converted(self) -> Tokenizer:
  224. vocab = self.original_tokenizer.vocab
  225. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  226. tokenize_chinese_chars = False
  227. strip_accents = False
  228. do_lower_case = False
  229. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  230. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  231. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  232. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  233. tokenizer.normalizer = normalizers.BertNormalizer(
  234. clean_text=True,
  235. handle_chinese_chars=tokenize_chinese_chars,
  236. strip_accents=strip_accents,
  237. lowercase=do_lower_case,
  238. )
  239. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  240. cls = str(self.original_tokenizer.cls_token)
  241. sep = str(self.original_tokenizer.sep_token)
  242. question = str(self.original_tokenizer.question_token)
  243. dot = "."
  244. cls_token_id = self.original_tokenizer.cls_token_id
  245. sep_token_id = self.original_tokenizer.sep_token_id
  246. question_token_id = self.original_tokenizer.question_token_id
  247. dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
  248. if self.original_tokenizer.padding_side == "right":
  249. pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
  250. else:
  251. pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
  252. tokenizer.post_processor = processors.TemplateProcessing(
  253. single=f"{cls}:0 $A:0 {sep}:0",
  254. pair=pair,
  255. special_tokens=[
  256. (cls, cls_token_id),
  257. (sep, sep_token_id),
  258. (question, question_token_id),
  259. (dot, dot_token_id),
  260. ],
  261. )
  262. tokenizer.decoder = decoders.WordPiece(prefix="##")
  263. return tokenizer
  264. class FunnelConverter(Converter):
  265. def converted(self) -> Tokenizer:
  266. vocab = self.original_tokenizer.vocab
  267. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  268. tokenize_chinese_chars = False
  269. strip_accents = False
  270. do_lower_case = False
  271. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  272. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  273. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  274. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  275. tokenizer.normalizer = normalizers.BertNormalizer(
  276. clean_text=True,
  277. handle_chinese_chars=tokenize_chinese_chars,
  278. strip_accents=strip_accents,
  279. lowercase=do_lower_case,
  280. )
  281. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  282. cls = str(self.original_tokenizer.cls_token)
  283. sep = str(self.original_tokenizer.sep_token)
  284. cls_token_id = self.original_tokenizer.cls_token_id
  285. sep_token_id = self.original_tokenizer.sep_token_id
  286. tokenizer.post_processor = processors.TemplateProcessing(
  287. single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer
  288. pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
  289. special_tokens=[
  290. (cls, cls_token_id),
  291. (sep, sep_token_id),
  292. ],
  293. )
  294. tokenizer.decoder = decoders.WordPiece(prefix="##")
  295. return tokenizer
  296. class MPNetConverter(Converter):
  297. def converted(self) -> Tokenizer:
  298. vocab = self.original_tokenizer.vocab
  299. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  300. tokenize_chinese_chars = False
  301. strip_accents = False
  302. do_lower_case = False
  303. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  304. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  305. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  306. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  307. tokenizer.normalizer = normalizers.BertNormalizer(
  308. clean_text=True,
  309. handle_chinese_chars=tokenize_chinese_chars,
  310. strip_accents=strip_accents,
  311. lowercase=do_lower_case,
  312. )
  313. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  314. cls = str(self.original_tokenizer.cls_token)
  315. sep = str(self.original_tokenizer.sep_token)
  316. cls_token_id = self.original_tokenizer.cls_token_id
  317. sep_token_id = self.original_tokenizer.sep_token_id
  318. tokenizer.post_processor = processors.TemplateProcessing(
  319. single=f"{cls}:0 $A:0 {sep}:0",
  320. pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
  321. special_tokens=[
  322. (cls, cls_token_id),
  323. (sep, sep_token_id),
  324. ],
  325. )
  326. tokenizer.decoder = decoders.WordPiece(prefix="##")
  327. return tokenizer
  328. class OpenAIGPTConverter(Converter):
  329. def converted(self) -> Tokenizer:
  330. vocab = self.original_tokenizer.encoder
  331. merges = list(self.original_tokenizer.bpe_ranks.keys())
  332. unk_token = self.original_tokenizer.unk_token
  333. tokenizer = Tokenizer(
  334. BPE(
  335. vocab=vocab,
  336. merges=merges,
  337. dropout=None,
  338. unk_token=str(unk_token),
  339. end_of_word_suffix="</w>",
  340. fuse_unk=False,
  341. )
  342. )
  343. if tokenizer.token_to_id(str(unk_token)) is not None:
  344. tokenizer.add_special_tokens([str(unk_token)])
  345. tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
  346. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  347. tokenizer.decoder = decoders.BPEDecoder(suffix="</w>")
  348. return tokenizer
  349. class GPT2Converter(Converter):
  350. def converted(self, vocab: dict[str, int] | None = None, merges: list[tuple[str, str]] | None = None) -> Tokenizer:
  351. if not vocab:
  352. vocab = self.original_tokenizer.encoder
  353. if not merges:
  354. merges = list(self.original_tokenizer.bpe_ranks)
  355. tokenizer = Tokenizer(
  356. BPE(
  357. vocab=vocab,
  358. merges=merges,
  359. dropout=None,
  360. continuing_subword_prefix="",
  361. end_of_word_suffix="",
  362. fuse_unk=False,
  363. )
  364. )
  365. add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False)
  366. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
  367. tokenizer.decoder = decoders.ByteLevel()
  368. if getattr(self.original_tokenizer, "add_bos_token", False):
  369. bos = self.original_tokenizer.bos_token
  370. bos_token_id = self.original_tokenizer.bos_token_id
  371. tokenizer.post_processor = processors.TemplateProcessing(
  372. single=f"{bos}:0 $A:0",
  373. pair=f"{bos}:0 $A:0 $B:1",
  374. special_tokens=[
  375. (bos, bos_token_id),
  376. ],
  377. )
  378. else:
  379. # XXX trim_offsets=False actually means this post_processor doesn't
  380. # really do anything.
  381. tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
  382. return tokenizer
  383. class HerbertConverter(Converter):
  384. def converted(self) -> Tokenizer:
  385. tokenizer_info_str = "#version:"
  386. token_suffix = "</w>"
  387. vocab = self.original_tokenizer.encoder
  388. merges = list(self.original_tokenizer.bpe_ranks.keys())
  389. if tokenizer_info_str in merges[0][0]:
  390. merges = merges[1:]
  391. tokenizer = Tokenizer(
  392. BPE(
  393. vocab,
  394. merges,
  395. dropout=None,
  396. unk_token=self.original_tokenizer.unk_token,
  397. end_of_word_suffix=token_suffix,
  398. )
  399. )
  400. tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
  401. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  402. tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
  403. tokenizer.post_processor = processors.BertProcessing(
  404. sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
  405. cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
  406. )
  407. return tokenizer
  408. class Qwen2Converter(Converter):
  409. def converted(self, vocab: dict[str, int] | None = None, merges: list[tuple[str, str]] | None = None) -> Tokenizer:
  410. if not vocab:
  411. vocab = self.original_tokenizer.encoder
  412. if not merges:
  413. merges = list(self.original_tokenizer.bpe_ranks.keys())
  414. tokenizer = Tokenizer(
  415. BPE(
  416. vocab=vocab,
  417. merges=merges,
  418. dropout=None,
  419. unk_token=None,
  420. continuing_subword_prefix="",
  421. end_of_word_suffix="",
  422. fuse_unk=False,
  423. byte_fallback=False,
  424. )
  425. )
  426. tokenizer.normalizer = normalizers.NFC()
  427. tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  428. [
  429. pre_tokenizers.Split(
  430. Regex(
  431. r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
  432. ),
  433. behavior="isolated",
  434. invert=False,
  435. ),
  436. pre_tokenizers.ByteLevel(
  437. add_prefix_space=getattr(self.original_tokenizer, "add_prefix_space", False),
  438. use_regex=False,
  439. ),
  440. ]
  441. )
  442. tokenizer.decoder = decoders.ByteLevel()
  443. tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
  444. return tokenizer
  445. class RobertaConverter(Converter):
  446. def converted(self) -> Tokenizer:
  447. ot = self.original_tokenizer
  448. vocab = ot.encoder
  449. merges = list(ot.bpe_ranks.keys())
  450. tokenizer = Tokenizer(
  451. BPE(
  452. vocab=vocab,
  453. merges=merges,
  454. dropout=None,
  455. continuing_subword_prefix="",
  456. end_of_word_suffix="",
  457. fuse_unk=False,
  458. )
  459. )
  460. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
  461. tokenizer.decoder = decoders.ByteLevel()
  462. tokenizer.post_processor = processors.RobertaProcessing(
  463. sep=(ot.sep_token, ot.sep_token_id),
  464. cls=(ot.cls_token, ot.cls_token_id),
  465. add_prefix_space=ot.add_prefix_space,
  466. trim_offsets=True, # True by default on Roberta (historical)
  467. )
  468. return tokenizer
  469. class RoFormerConverter(Converter):
  470. def converted(self) -> Tokenizer:
  471. from .models.roformer.tokenization_utils import JiebaPreTokenizer
  472. vocab = self.original_tokenizer.vocab
  473. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  474. strip_accents = False
  475. do_lower_case = False
  476. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  477. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  478. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  479. tokenizer.normalizer = normalizers.BertNormalizer(
  480. clean_text=True,
  481. handle_chinese_chars=False,
  482. strip_accents=strip_accents,
  483. lowercase=do_lower_case,
  484. )
  485. tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
  486. cls = str(self.original_tokenizer.cls_token)
  487. sep = str(self.original_tokenizer.sep_token)
  488. cls_token_id = self.original_tokenizer.cls_token_id
  489. sep_token_id = self.original_tokenizer.sep_token_id
  490. tokenizer.post_processor = processors.TemplateProcessing(
  491. single=f"{cls}:0 $A:0 {sep}:0",
  492. pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
  493. special_tokens=[
  494. (cls, cls_token_id),
  495. (sep, sep_token_id),
  496. ],
  497. )
  498. tokenizer.decoder = decoders.WordPiece(prefix="##")
  499. return tokenizer
  500. class DebertaConverter(Converter):
  501. def converted(self) -> Tokenizer:
  502. ot = self.original_tokenizer
  503. vocab = ot.encoder
  504. merges = list(ot.bpe_ranks.keys())
  505. tokenizer = Tokenizer(
  506. BPE(
  507. vocab=vocab,
  508. merges=merges,
  509. dropout=None,
  510. continuing_subword_prefix="",
  511. end_of_word_suffix="",
  512. fuse_unk=False,
  513. )
  514. )
  515. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
  516. tokenizer.decoder = decoders.ByteLevel()
  517. tokenizer.post_processor = processors.TemplateProcessing(
  518. single="[CLS]:0 $A:0 [SEP]:0",
  519. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  520. special_tokens=[
  521. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  522. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  523. ],
  524. )
  525. return tokenizer
  526. class SpmConverter(Converter):
  527. handle_byte_fallback = False
  528. SpmExtractor = SentencePieceExtractor
  529. special_tokens = {}
  530. @staticmethod
  531. def build_tokenizer_from_spm_proto(proto, vocab, merges=None):
  532. """
  533. Similar to convert_from_spm method, but used only when there is no `model_type` class, i.e. there is no matching class in `TOKENIZERS_MAPPING` and we just create a tokenizer instead of extracting stuff from the sentencepiece file
  534. """
  535. byte_fallback = proto.trainer_spec.byte_fallback
  536. unk_piece = proto.trainer_spec.unk_piece
  537. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  538. # model
  539. if isinstance(vocab, dict):
  540. tokenizer = Tokenizer(
  541. BPE(
  542. vocab=vocab,
  543. merges=merges or [],
  544. unk_token=unk_piece,
  545. fuse_unk=True,
  546. byte_fallback=byte_fallback,
  547. dropout=None,
  548. )
  549. )
  550. elif isinstance(vocab, list) and vocab and isinstance(vocab[0], (tuple, list)):
  551. tokenizer = Tokenizer(
  552. Unigram(
  553. vocab=vocab,
  554. unk_id=proto.trainer_spec.unk_id,
  555. byte_fallback=byte_fallback,
  556. )
  557. )
  558. else:
  559. return None
  560. # normalizer
  561. _normalizers = [normalizers.Replace(" ", "▁")]
  562. if precompiled_charsmap:
  563. _normalizers.insert(0, normalizers.Precompiled(precompiled_charsmap))
  564. tokenizer.normalizer = normalizers.Sequence(_normalizers)
  565. # decoder
  566. if byte_fallback:
  567. tokenizer.decoder = decoders.Sequence(
  568. [decoders.Replace("▁", " "), decoders.ByteFallback(), decoders.Fuse()]
  569. )
  570. else:
  571. tokenizer.decoder = decoders.Sequence([decoders.Replace("▁", " ")])
  572. return tokenizer
  573. @classmethod
  574. def convert_from_spm(cls, vocab=None, **kwargs):
  575. """
  576. Hook used when converting directly from a SentencePiece model without a slow tokenizer instance.
  577. By default, return kwargs unchanged.
  578. """
  579. if vocab is not None:
  580. kwargs["vocab"] = vocab
  581. return kwargs
  582. def __init__(self, *args):
  583. requires_backends(self, "protobuf")
  584. super().__init__(*args)
  585. # from .utils import sentencepiece_model_pb2 as model_pb2
  586. model_pb2 = import_protobuf()
  587. m = model_pb2.ModelProto()
  588. with open(self.original_tokenizer.vocab_file, "rb") as f:
  589. m.ParseFromString(f.read())
  590. self.proto = m
  591. if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback:
  592. warnings.warn(
  593. "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
  594. " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
  595. " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
  596. "unknown tokens into a sequence of byte tokens matching the original piece of text."
  597. )
  598. def vocab(self, proto):
  599. return [(piece.piece, piece.score) for piece in proto.pieces]
  600. def unk_id(self, proto):
  601. return proto.trainer_spec.unk_id
  602. def tokenizer(self, proto):
  603. model_type = proto.trainer_spec.model_type
  604. vocab_scores = self.vocab(proto)
  605. if model_type == 1:
  606. tokenizer = Tokenizer(
  607. Unigram(
  608. vocab_scores,
  609. unk_id=self.unk_id(proto),
  610. byte_fallback=self.handle_byte_fallback,
  611. )
  612. )
  613. elif model_type == 2:
  614. _, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
  615. bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
  616. tokenizer = Tokenizer(
  617. BPE(
  618. bpe_vocab,
  619. merges,
  620. unk_token=proto.trainer_spec.unk_piece,
  621. fuse_unk=True,
  622. byte_fallback=self.handle_byte_fallback,
  623. dropout=None,
  624. )
  625. )
  626. else:
  627. raise Exception(
  628. "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
  629. )
  630. # control tokens are special
  631. # user defined symbols are not
  632. # both user and control tokens are AddedTokens
  633. # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
  634. spm_added_tokens = [
  635. (id, p.piece, p.type == 3 or p.piece in self.special_tokens)
  636. for id, p in enumerate(proto.pieces)
  637. if p.type in [3, 4]
  638. ]
  639. tokenizer.add_tokens(
  640. [
  641. AddedToken(token, normalized=False, special=special)
  642. for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
  643. ]
  644. )
  645. return tokenizer
  646. def normalizer(self, proto):
  647. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  648. _normalizers = [
  649. normalizers.Strip(left=False, right=True), # stripping is important
  650. normalizers.Replace(Regex(" {2,}"), "▁"),
  651. ]
  652. if not precompiled_charsmap:
  653. return normalizers.Sequence(_normalizers)
  654. else:
  655. return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
  656. def pre_tokenizer(self, replacement, add_prefix_space):
  657. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  658. return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
  659. def post_processor(self):
  660. return None
  661. def decoder(self, replacement, add_prefix_space):
  662. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  663. return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
  664. def converted(self) -> Tokenizer:
  665. tokenizer = self.tokenizer(self.proto)
  666. # Tokenizer assemble
  667. normalizer = self.normalizer(self.proto)
  668. if normalizer is not None:
  669. tokenizer.normalizer = normalizer
  670. replacement = "▁"
  671. add_prefix_space = True
  672. if hasattr(self.original_tokenizer, "add_prefix_space"):
  673. add_prefix_space = self.original_tokenizer.add_prefix_space
  674. pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
  675. if pre_tokenizer is not None:
  676. tokenizer.pre_tokenizer = pre_tokenizer
  677. tokenizer.decoder = self.decoder(replacement, add_prefix_space)
  678. post_processor = self.post_processor()
  679. if post_processor:
  680. tokenizer.post_processor = post_processor
  681. return tokenizer
  682. class AlbertConverter(SpmConverter):
  683. def vocab(self, proto):
  684. return [
  685. (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
  686. for piece in proto.pieces
  687. ]
  688. def normalizer(self, proto):
  689. list_normalizers = [
  690. normalizers.Replace("``", '"'),
  691. normalizers.Replace("''", '"'),
  692. ]
  693. if not self.original_tokenizer.keep_accents:
  694. list_normalizers.append(normalizers.NFKD())
  695. list_normalizers.append(normalizers.StripAccents())
  696. if self.original_tokenizer.do_lower_case:
  697. list_normalizers.append(normalizers.Lowercase())
  698. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  699. if precompiled_charsmap:
  700. list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
  701. list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
  702. return normalizers.Sequence(list_normalizers)
  703. def post_processor(self):
  704. return processors.TemplateProcessing(
  705. single="[CLS]:0 $A:0 [SEP]:0",
  706. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  707. special_tokens=[
  708. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  709. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  710. ],
  711. )
  712. class BarthezConverter(SpmConverter):
  713. def unk_id(self, proto):
  714. unk_id = 3
  715. return unk_id
  716. def post_processor(self):
  717. return processors.TemplateProcessing(
  718. single="<s> $A </s>",
  719. pair="<s> $A </s> </s> $B </s>",
  720. special_tokens=[
  721. ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
  722. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  723. ],
  724. )
  725. class CamembertConverter(SpmConverter):
  726. def vocab(self, proto):
  727. vocab = [
  728. ("<s>NOTUSED", 0.0),
  729. ("<pad>", 0.0),
  730. ("</s>NOTUSED", 0.0),
  731. ("<unk>", 0.0),
  732. ("<unk>NOTUSED", -100),
  733. ]
  734. # We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead
  735. vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
  736. vocab += [("<mask>", 0.0)]
  737. return vocab
  738. def unk_id(self, proto):
  739. # See vocab unk position
  740. return 3
  741. def post_processor(self):
  742. return processors.TemplateProcessing(
  743. single="<s> $A </s>",
  744. pair="<s> $A </s> </s> $B </s>",
  745. special_tokens=[
  746. ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
  747. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  748. ],
  749. )
  750. @classmethod
  751. def convert_from_spm(cls, vocab=None, **kwargs):
  752. pad_token = str(kwargs.get("pad_token", "<pad>"))
  753. unk_token = str(kwargs.get("unk_token", "<unk>"))
  754. mask_token = str(kwargs.get("mask_token", "<mask>"))
  755. vocab_list = [
  756. ("<s>NOTUSED", 0.0),
  757. (pad_token, 0.0),
  758. ("</s>NOTUSED", 0.0),
  759. (unk_token, 0.0),
  760. ("<unk>NOTUSED", -100.0),
  761. ]
  762. if vocab is not None:
  763. vocab_list.extend(list(vocab)[1:])
  764. vocab_list.append((mask_token, 0.0))
  765. kwargs["vocab"] = vocab_list
  766. return kwargs
  767. class DebertaV2Converter(SpmConverter):
  768. def pre_tokenizer(self, replacement, add_prefix_space):
  769. list_pretokenizers = []
  770. if self.original_tokenizer.split_by_punct:
  771. list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
  772. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  773. list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
  774. return pre_tokenizers.Sequence(list_pretokenizers)
  775. def normalizer(self, proto):
  776. list_normalizers = []
  777. if self.original_tokenizer.do_lower_case:
  778. list_normalizers.append(normalizers.Lowercase())
  779. list_normalizers.append(normalizers.Strip())
  780. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  781. if precompiled_charsmap:
  782. list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
  783. list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
  784. return normalizers.Sequence(list_normalizers)
  785. def post_processor(self):
  786. return processors.TemplateProcessing(
  787. single="[CLS]:0 $A:0 [SEP]:0",
  788. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  789. special_tokens=[
  790. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  791. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  792. ],
  793. )
  794. class MBartConverter(SpmConverter):
  795. def vocab(self, proto):
  796. vocab = [
  797. ("<s>", 0.0),
  798. ("<pad>", 0.0),
  799. ("</s>", 0.0),
  800. ("<unk>", 0.0),
  801. ]
  802. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  803. vocab += [
  804. ("ar_AR", 0.0),
  805. ("cs_CZ", 0.0),
  806. ("de_DE", 0.0),
  807. ("en_XX", 0.0),
  808. ("es_XX", 0.0),
  809. ("et_EE", 0.0),
  810. ("fi_FI", 0.0),
  811. ("fr_XX", 0.0),
  812. ("gu_IN", 0.0),
  813. ("hi_IN", 0.0),
  814. ("it_IT", 0.0),
  815. ("ja_XX", 0.0),
  816. ("kk_KZ", 0.0),
  817. ("ko_KR", 0.0),
  818. ("lt_LT", 0.0),
  819. ("lv_LV", 0.0),
  820. ("my_MM", 0.0),
  821. ("ne_NP", 0.0),
  822. ("nl_XX", 0.0),
  823. ("ro_RO", 0.0),
  824. ("ru_RU", 0.0),
  825. ("si_LK", 0.0),
  826. ("tr_TR", 0.0),
  827. ("vi_VN", 0.0),
  828. ("zh_CN", 0.0),
  829. ]
  830. vocab += [("<mask>", 0.0)]
  831. return vocab
  832. def unk_id(self, proto):
  833. return 3
  834. def post_processor(self):
  835. return processors.TemplateProcessing(
  836. single="$A </s> en_XX",
  837. pair="$A $B </s> en_XX",
  838. special_tokens=[
  839. ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
  840. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  841. ],
  842. )
  843. @classmethod
  844. def convert_from_spm(cls, vocab=None, **kwargs):
  845. bos_token = str(kwargs.get("bos_token", "<s>"))
  846. pad_token = str(kwargs.get("pad_token", "<pad>"))
  847. eos_token = str(kwargs.get("eos_token", "</s>"))
  848. unk_token = str(kwargs.get("unk_token", "<unk>"))
  849. mask_token = str(kwargs.get("mask_token", "<mask>"))
  850. vocab_list = [
  851. (bos_token, 0.0),
  852. (pad_token, 0.0),
  853. (eos_token, 0.0),
  854. (unk_token, 0.0),
  855. ]
  856. if vocab is not None:
  857. vocab_list.extend(list(vocab)[3:])
  858. vocab_list.extend((lang_code, 0.0) for lang_code in MBART_LANGUAGES)
  859. vocab_list.append((mask_token, 0.0))
  860. kwargs["vocab"] = vocab_list
  861. return kwargs
  862. class MBart50Converter(SpmConverter):
  863. def vocab(self, proto):
  864. vocab = [
  865. ("<s>", 0.0),
  866. ("<pad>", 0.0),
  867. ("</s>", 0.0),
  868. ("<unk>", 0.0),
  869. ]
  870. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  871. vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)] # fmt: skip
  872. vocab += [("<mask>", 0.0)]
  873. return vocab
  874. def unk_id(self, proto):
  875. return 3
  876. def post_processor(self):
  877. return processors.TemplateProcessing(
  878. single="en_XX $A </s>",
  879. pair="en_XX $A $B </s>",
  880. special_tokens=[
  881. ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
  882. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  883. ],
  884. )
  885. @classmethod
  886. def convert_from_spm(cls, vocab=None, **kwargs):
  887. cls_token = str(kwargs.get("cls_token", "<s>"))
  888. pad_token = str(kwargs.get("pad_token", "<pad>"))
  889. eos_token = str(kwargs.get("eos_token", "</s>"))
  890. unk_token = str(kwargs.get("unk_token", "<unk>"))
  891. mask_token = str(kwargs.get("mask_token", "<mask>"))
  892. vocab_list = [
  893. (cls_token, 0.0),
  894. (pad_token, 0.0),
  895. (eos_token, 0.0),
  896. (unk_token, 0.0),
  897. ]
  898. if vocab is not None:
  899. vocab_list.extend(list(vocab)[3:])
  900. vocab_list.extend((lang_code, 0.0) for lang_code in MBART50_LANGUAGES)
  901. vocab_list.append((mask_token, 0.0))
  902. kwargs["vocab"] = vocab_list
  903. return kwargs
  904. class NllbConverter(SpmConverter):
  905. def vocab(self, proto):
  906. vocab = [
  907. ("<s>", 0.0),
  908. ("<pad>", 0.0),
  909. ("</s>", 0.0),
  910. ("<unk>", 0.0),
  911. ]
  912. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  913. return vocab
  914. def unk_id(self, proto):
  915. return 3
  916. def post_processor(self):
  917. return processors.TemplateProcessing(
  918. single="eng_Latn $A </s>",
  919. pair="eng_Latn $A $B </s>",
  920. special_tokens=[
  921. ("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")),
  922. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  923. ],
  924. )
  925. @classmethod
  926. def convert_from_spm(cls, vocab=None, **kwargs):
  927. bos_token = str(kwargs.get("bos_token", "<s>"))
  928. pad_token = str(kwargs.get("pad_token", "<pad>"))
  929. eos_token = str(kwargs.get("eos_token", "</s>"))
  930. unk_token = str(kwargs.get("unk_token", "<unk>"))
  931. reordered_vocab = {
  932. bos_token: 0,
  933. pad_token: 1,
  934. eos_token: 2,
  935. unk_token: 3,
  936. }
  937. if vocab is not None:
  938. tokens = vocab.keys() if isinstance(vocab, dict) else [tok for tok, _ in vocab]
  939. for token in tokens:
  940. if token in reordered_vocab:
  941. continue
  942. reordered_vocab[token] = len(reordered_vocab)
  943. kwargs["vocab"] = reordered_vocab
  944. return kwargs
  945. class SeamlessM4TConverter(SpmConverter):
  946. def vocab(self, proto):
  947. vocab = [
  948. ("<pad>", 0.0),
  949. ("<unk>", 0.0),
  950. ("<s>", 0.0),
  951. ("</s>", 0.0),
  952. ]
  953. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  954. return vocab
  955. def unk_id(self, proto):
  956. return self.original_tokenizer.unk_token_id
  957. def post_processor(self):
  958. return processors.TemplateProcessing(
  959. single="__eng__ $A </s>",
  960. pair="__eng__ $A $B </s>",
  961. special_tokens=[
  962. ("__eng__", self.original_tokenizer.convert_tokens_to_ids("__eng__")),
  963. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  964. ],
  965. )
  966. class XLMRobertaConverter(SpmConverter):
  967. def vocab(self, proto):
  968. vocab = [
  969. ("<s>", 0.0),
  970. ("<pad>", 0.0),
  971. ("</s>", 0.0),
  972. ("<unk>", 0.0),
  973. ]
  974. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  975. vocab += [("<mask>", 0.0)]
  976. return vocab
  977. def unk_id(self, proto):
  978. unk_id = 3
  979. return unk_id
  980. def post_processor(self):
  981. return processors.TemplateProcessing(
  982. single="<s> $A </s>",
  983. pair="<s> $A </s> </s> $B </s>",
  984. special_tokens=[
  985. ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
  986. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  987. ],
  988. )
  989. @classmethod
  990. def convert_from_spm(cls, vocab=None, **kwargs):
  991. bos_token = str(kwargs.get("bos_token", "<s>"))
  992. pad_token = str(kwargs.get("pad_token", "<pad>"))
  993. eos_token = str(kwargs.get("eos_token", "</s>"))
  994. unk_token = str(kwargs.get("unk_token", "<unk>"))
  995. mask_token = str(kwargs.get("mask_token", "<mask>"))
  996. vocab_list = [
  997. (bos_token, 0.0),
  998. (pad_token, 0.0),
  999. (eos_token, 0.0),
  1000. (unk_token, 0.0),
  1001. ]
  1002. if vocab is not None:
  1003. vocab_list.extend(list(vocab)[3:])
  1004. vocab_list.append((mask_token, 0.0))
  1005. kwargs["vocab"] = vocab_list
  1006. return kwargs
  1007. class XLNetConverter(SpmConverter):
  1008. def vocab(self, proto):
  1009. return [
  1010. (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
  1011. for piece in proto.pieces
  1012. ]
  1013. def normalizer(self, proto):
  1014. list_normalizers = [
  1015. normalizers.Replace("``", '"'),
  1016. normalizers.Replace("''", '"'),
  1017. ]
  1018. if not self.original_tokenizer.keep_accents:
  1019. list_normalizers.append(normalizers.NFKD())
  1020. list_normalizers.append(normalizers.StripAccents())
  1021. if self.original_tokenizer.do_lower_case:
  1022. list_normalizers.append(normalizers.Lowercase())
  1023. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  1024. if precompiled_charsmap:
  1025. list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
  1026. list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
  1027. return normalizers.Sequence(list_normalizers)
  1028. def post_processor(self):
  1029. return processors.TemplateProcessing(
  1030. single="$A:0 <sep>:0 <cls>:2",
  1031. pair="$A:0 <sep>:0 $B:1 <sep>:1 <cls>:2",
  1032. special_tokens=[
  1033. ("<sep>", self.original_tokenizer.convert_tokens_to_ids("<sep>")),
  1034. ("<cls>", self.original_tokenizer.convert_tokens_to_ids("<cls>")),
  1035. ],
  1036. )
  1037. class ReformerConverter(SpmConverter):
  1038. pass
  1039. class RemBertConverter(SpmConverter):
  1040. # Inspired from AlbertConverter
  1041. def normalizer(self, proto):
  1042. list_normalizers = [
  1043. normalizers.Replace("``", '"'),
  1044. normalizers.Replace("''", '"'),
  1045. normalizers.Replace(Regex(" {2,}"), " "),
  1046. ]
  1047. if not self.original_tokenizer.keep_accents:
  1048. list_normalizers.append(normalizers.NFKD())
  1049. list_normalizers.append(normalizers.StripAccents())
  1050. if self.original_tokenizer.do_lower_case:
  1051. list_normalizers.append(normalizers.Lowercase())
  1052. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  1053. if precompiled_charsmap:
  1054. list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
  1055. return normalizers.Sequence(list_normalizers)
  1056. def post_processor(self):
  1057. return processors.TemplateProcessing(
  1058. single="[CLS]:0 $A:0 [SEP]:0",
  1059. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  1060. special_tokens=[
  1061. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  1062. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  1063. ],
  1064. )
  1065. class BertGenerationConverter(SpmConverter):
  1066. pass
  1067. class PegasusConverter(SpmConverter):
  1068. def vocab(self, proto):
  1069. vocab = [
  1070. (self.original_tokenizer.pad_token, 0.0),
  1071. (self.original_tokenizer.eos_token, 0.0),
  1072. ]
  1073. if self.original_tokenizer.mask_token_sent is not None:
  1074. vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
  1075. if (
  1076. self.original_tokenizer.mask_token is not None
  1077. and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
  1078. ):
  1079. vocab += [(self.original_tokenizer.mask_token, 0.0)]
  1080. vocab += [(f"<unk_{i}>", -100.0) for i in range(2, self.original_tokenizer.offset)]
  1081. vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
  1082. return vocab
  1083. @classmethod
  1084. def convert_from_spm(cls, vocab=None, **kwargs):
  1085. pad_token = str(kwargs.get("pad_token", "<pad>"))
  1086. eos_token = str(kwargs.get("eos_token", "</s>"))
  1087. mask_token = str(kwargs.get("mask_token", "<mask_1>"))
  1088. mask_token_sent = str(kwargs.get("mask_token_sent", "<mask_2>"))
  1089. vocab_list = [
  1090. (pad_token, 0.0),
  1091. (eos_token, 0.0),
  1092. ]
  1093. if mask_token != "None":
  1094. vocab_list.append((mask_token, 0.0))
  1095. if mask_token_sent != "None" and mask_token_sent != mask_token:
  1096. vocab_list.append((mask_token_sent, 0.0))
  1097. vocab_list.extend([(f"<unk_{i}>", -100.0) for i in range(2, kwargs.get("offset", 103))])
  1098. if vocab is not None:
  1099. vocab_list.extend(list(vocab)[2:])
  1100. kwargs["vocab"] = vocab_list
  1101. return kwargs
  1102. def unk_id(self, proto):
  1103. return proto.trainer_spec.unk_id + self.original_tokenizer.offset
  1104. def pre_tokenizer(self, replacement, add_prefix_space):
  1105. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  1106. return pre_tokenizers.Sequence(
  1107. [
  1108. pre_tokenizers.WhitespaceSplit(),
  1109. pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
  1110. ]
  1111. )
  1112. def post_processor(self):
  1113. eos = self.original_tokenizer.eos_token
  1114. special_tokens = [
  1115. (eos, self.original_tokenizer.eos_token_id),
  1116. ]
  1117. return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)
  1118. class T5Converter(SpmConverter):
  1119. def vocab(self, proto):
  1120. num_extra_ids = self.original_tokenizer._extra_ids
  1121. vocab = [(piece.piece, piece.score) for piece in proto.pieces]
  1122. vocab += [(f"<extra_id_{i}>", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
  1123. return vocab
  1124. def post_processor(self):
  1125. return processors.TemplateProcessing(
  1126. single=["$A", "</s>"],
  1127. pair=["$A", "</s>", "$B", "</s>"],
  1128. special_tokens=[
  1129. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  1130. ],
  1131. )
  1132. @classmethod
  1133. def convert_from_spm(cls, vocab=None, **kwargs):
  1134. extra_ids = kwargs.get("extra_ids", 100)
  1135. extra_tokens = [f"<extra_id_{i}>" for i in range(extra_ids - 1, -1, -1)]
  1136. vocab_list = list(vocab) if vocab is not None else []
  1137. vocab_list.extend((token, 0.0) for token in extra_tokens)
  1138. kwargs.setdefault("additional_special_tokens", extra_tokens)
  1139. kwargs["vocab"] = vocab_list
  1140. return kwargs
  1141. class UdopConverter(SpmConverter):
  1142. def post_processor(self):
  1143. return processors.TemplateProcessing(
  1144. single=["$A", "</s>"],
  1145. pair=["$A", "</s>", "$B", "</s>"],
  1146. special_tokens=[
  1147. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  1148. ],
  1149. )
  1150. class WhisperConverter(Converter):
  1151. def converted(self) -> Tokenizer:
  1152. vocab = self.original_tokenizer.encoder
  1153. merges = list(self.original_tokenizer.bpe_ranks.keys())
  1154. tokenizer = Tokenizer(
  1155. BPE(
  1156. vocab=vocab,
  1157. merges=merges,
  1158. dropout=None,
  1159. continuing_subword_prefix="",
  1160. end_of_word_suffix="",
  1161. fuse_unk=False,
  1162. )
  1163. )
  1164. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
  1165. tokenizer.decoder = decoders.ByteLevel()
  1166. prefix_token_ids = self.original_tokenizer.prefix_tokens
  1167. prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
  1168. eos = self.original_tokenizer.eos_token
  1169. eos_token_id = self.original_tokenizer.eos_token_id
  1170. prefix_template = " ".join([f"{token}:0" for token in prefixes])
  1171. tokenizer.post_processor = processors.TemplateProcessing(
  1172. single=f"{prefix_template} $A:0 {eos}:0",
  1173. pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
  1174. special_tokens=[
  1175. (eos, eos_token_id),
  1176. *zip(prefixes, prefix_token_ids),
  1177. ],
  1178. )
  1179. return tokenizer
  1180. class BigBirdConverter(SpmConverter):
  1181. def post_processor(self):
  1182. return processors.TemplateProcessing(
  1183. single="[CLS]:0 $A:0 [SEP]:0",
  1184. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  1185. special_tokens=[
  1186. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  1187. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  1188. ],
  1189. )
  1190. class CLIPConverter(Converter):
  1191. def converted(self) -> Tokenizer:
  1192. vocab = self.original_tokenizer.encoder
  1193. merges = list(self.original_tokenizer.bpe_ranks.keys())
  1194. unk_token = self.original_tokenizer.unk_token
  1195. tokenizer = Tokenizer(
  1196. BPE(
  1197. vocab=vocab,
  1198. merges=merges,
  1199. dropout=None,
  1200. continuing_subword_prefix="",
  1201. end_of_word_suffix="</w>",
  1202. fuse_unk=False,
  1203. unk_token=str(unk_token),
  1204. )
  1205. )
  1206. tokenizer.normalizer = normalizers.Sequence(
  1207. [normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
  1208. )
  1209. tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  1210. [
  1211. pre_tokenizers.Split(
  1212. Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
  1213. behavior="removed",
  1214. invert=True,
  1215. ),
  1216. pre_tokenizers.ByteLevel(add_prefix_space=False),
  1217. ]
  1218. )
  1219. tokenizer.decoder = decoders.ByteLevel()
  1220. # Hack to have a ByteLevel and TemplateProcessor
  1221. tokenizer.post_processor = processors.RobertaProcessing(
  1222. sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),
  1223. cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),
  1224. add_prefix_space=False,
  1225. trim_offsets=False,
  1226. )
  1227. return tokenizer
  1228. class LayoutLMv2Converter(Converter):
  1229. def converted(self) -> Tokenizer:
  1230. vocab = self.original_tokenizer.vocab
  1231. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  1232. tokenize_chinese_chars = False
  1233. strip_accents = False
  1234. do_lower_case = True
  1235. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  1236. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  1237. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  1238. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  1239. tokenizer.normalizer = normalizers.BertNormalizer(
  1240. clean_text=True,
  1241. handle_chinese_chars=tokenize_chinese_chars,
  1242. strip_accents=strip_accents,
  1243. lowercase=do_lower_case,
  1244. )
  1245. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  1246. cls = str(self.original_tokenizer.cls_token)
  1247. sep = str(self.original_tokenizer.sep_token)
  1248. cls_token_id = self.original_tokenizer.cls_token_id
  1249. sep_token_id = self.original_tokenizer.sep_token_id
  1250. tokenizer.post_processor = processors.TemplateProcessing(
  1251. single=f"{cls}:0 $A:0 {sep}:0",
  1252. pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
  1253. special_tokens=[
  1254. (cls, cls_token_id),
  1255. (sep, sep_token_id),
  1256. ],
  1257. )
  1258. tokenizer.decoder = decoders.WordPiece(prefix="##")
  1259. return tokenizer
  1260. class BlenderbotConverter(Converter):
  1261. def converted(self) -> Tokenizer:
  1262. ot = self.original_tokenizer
  1263. vocab = ot.encoder
  1264. merges = list(ot.bpe_ranks.keys())
  1265. tokenizer = Tokenizer(
  1266. BPE(
  1267. vocab=vocab,
  1268. merges=merges,
  1269. dropout=None,
  1270. continuing_subword_prefix="",
  1271. end_of_word_suffix="",
  1272. fuse_unk=False,
  1273. )
  1274. )
  1275. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
  1276. tokenizer.decoder = decoders.ByteLevel()
  1277. tokenizer.post_processor = processors.TemplateProcessing(
  1278. single=f"$A:0 {ot.eos_token}:0",
  1279. special_tokens=[
  1280. (ot.eos_token, ot.eos_token_id),
  1281. ],
  1282. )
  1283. return tokenizer
  1284. class XGLMConverter(SpmConverter):
  1285. def vocab(self, proto):
  1286. vocab = [
  1287. ("<s>", 0.0),
  1288. ("<pad>", 0.0),
  1289. ("</s>", 0.0),
  1290. ("<unk>", 0.0),
  1291. ]
  1292. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  1293. vocab += [("<madeupword0>", 0.0), ("<madeupword1>", 0.0), ("<madeupword2>", 0.0), ("<madeupword3>", 0.0), ("<madeupword4>", 0.0), ("<madeupword5>", 0.0), ("<madeupword6>", 0.0)] # fmt: skip
  1294. return vocab
  1295. def unk_id(self, proto):
  1296. unk_id = 3
  1297. return unk_id
  1298. def post_processor(self):
  1299. return processors.TemplateProcessing(
  1300. single="</s> $A",
  1301. pair="</s> $A </s> </s> $B",
  1302. special_tokens=[
  1303. ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
  1304. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  1305. ],
  1306. )
  1307. class GemmaConverter(SpmConverter):
  1308. handle_byte_fallback = True
  1309. SpmExtractor = GemmaSentencePieceExtractor
  1310. # start and end of turn tokens must be marked as special
  1311. special_tokens = {"<start_of_turn>", "<end_of_turn>"}
  1312. """"
  1313. split_by_unicode_script: true
  1314. split_by_number: true
  1315. split_by_whitespace: true
  1316. treat_whitespace_as_suffix: false
  1317. allow_whitespace_only_pieces: true
  1318. split_digits: true
  1319. byte_fallback: true
  1320. """
  1321. def normalizer(self, proto):
  1322. return normalizers.Replace(" ", "▁")
  1323. def vocab(self, proto):
  1324. vocab = [
  1325. (self.original_tokenizer.pad_token, 0.0),
  1326. (self.original_tokenizer.eos_token, 0.0),
  1327. (self.original_tokenizer.bos_token, 0.0),
  1328. ]
  1329. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  1330. # Older gemma tokenizers had a missing tab token, so we fix that here
  1331. if not any(x[0] == "\t" for x in vocab):
  1332. override_index = next((i for i, x in enumerate(vocab) if x[0] == "<0x09>"), None)
  1333. if override_index is not None:
  1334. vocab[override_index] = ("\t", 0.0)
  1335. return vocab
  1336. def pre_tokenizer(self, replacement, add_prefix_space):
  1337. return pre_tokenizers.Split(" ", "merged_with_previous")
  1338. def unk_id(self, proto):
  1339. unk_id = 3
  1340. return unk_id
  1341. def decoder(self, replacement, add_prefix_space):
  1342. return decoders.Sequence(
  1343. [
  1344. decoders.Replace("▁", " "),
  1345. decoders.ByteFallback(),
  1346. decoders.Fuse(),
  1347. ]
  1348. )
  1349. class LlamaConverter(SpmConverter):
  1350. handle_byte_fallback = True
  1351. def vocab(self, proto):
  1352. vocab = [
  1353. (self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
  1354. (self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
  1355. (self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
  1356. ]
  1357. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  1358. return vocab
  1359. def unk_id(self, proto):
  1360. unk_id = 0
  1361. return unk_id
  1362. def decoder(self, replacement, add_prefix_space):
  1363. sequence = [
  1364. decoders.Replace("▁", " "),
  1365. decoders.ByteFallback(),
  1366. decoders.Fuse(),
  1367. ]
  1368. if add_prefix_space:
  1369. sequence += [decoders.Strip(content=" ", left=1)]
  1370. return decoders.Sequence(sequence)
  1371. def normalizer(self, proto):
  1372. if getattr(self.original_tokenizer, "legacy", True):
  1373. sequence = []
  1374. if getattr(self.original_tokenizer, "add_prefix_space", True):
  1375. sequence += [normalizers.Prepend(prepend="▁")]
  1376. sequence += [normalizers.Replace(pattern=" ", content="▁")]
  1377. return normalizers.Sequence(sequence)
  1378. return None # non-legacy, no normalizer
  1379. def pre_tokenizer(self, replacement, add_prefix_space):
  1380. if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace
  1381. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  1382. return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
  1383. return None
  1384. def post_processor(self):
  1385. # the processor is defined in the LlamaTokenizerFast class.
  1386. return None
  1387. class MarkupLMConverter(Converter):
  1388. def converted(self) -> Tokenizer:
  1389. ot = self.original_tokenizer
  1390. vocab = ot.encoder
  1391. merges = list(ot.bpe_ranks.keys())
  1392. tokenizer = Tokenizer(
  1393. BPE(
  1394. vocab=vocab,
  1395. merges=merges,
  1396. dropout=None,
  1397. continuing_subword_prefix="",
  1398. end_of_word_suffix="",
  1399. fuse_unk=False,
  1400. unk_token=self.original_tokenizer.unk_token,
  1401. )
  1402. )
  1403. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
  1404. tokenizer.decoder = decoders.ByteLevel()
  1405. cls = str(self.original_tokenizer.cls_token)
  1406. sep = str(self.original_tokenizer.sep_token)
  1407. cls_token_id = self.original_tokenizer.cls_token_id
  1408. sep_token_id = self.original_tokenizer.sep_token_id
  1409. tokenizer.post_processor = processors.TemplateProcessing(
  1410. single=f"{cls} $A {sep}",
  1411. pair=f"{cls} $A {sep} $B {sep}",
  1412. special_tokens=[
  1413. (cls, cls_token_id),
  1414. (sep, sep_token_id),
  1415. ],
  1416. )
  1417. return tokenizer
  1418. class MoshiConverter(SpmConverter):
  1419. handle_byte_fallback = True
  1420. def __init__(self, vocab_file, **kwargs):
  1421. requires_backends(self, "protobuf")
  1422. Converter.__init__(self, vocab_file)
  1423. # from .utils import sentencepiece_model_pb2 as model_pb2
  1424. model_pb2 = import_protobuf()
  1425. m = model_pb2.ModelProto()
  1426. with open(vocab_file, "rb") as f:
  1427. m.ParseFromString(f.read())
  1428. self.proto = m
  1429. def normalizer(self, proto):
  1430. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  1431. _normalizers = [
  1432. normalizers.Replace(" ", "▁"),
  1433. ]
  1434. if not precompiled_charsmap:
  1435. return normalizers.Sequence(_normalizers)
  1436. else:
  1437. return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
  1438. def decoder(self, replacement, add_prefix_space):
  1439. sequence = [
  1440. decoders.Replace("▁", " "),
  1441. decoders.ByteFallback(),
  1442. decoders.Fuse(),
  1443. ]
  1444. if add_prefix_space:
  1445. sequence += [decoders.Strip(content=" ", left=1)]
  1446. return decoders.Sequence(sequence)
  1447. def pre_tokenizer(self, replacement, add_prefix_space):
  1448. prepend_scheme = "first"
  1449. return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
  1450. class HeliumConverter(SpmConverter):
  1451. handle_byte_fallback = True
  1452. def __init__(self, vocab_file=None, **kwargs):
  1453. requires_backends(self, "protobuf")
  1454. Converter.__init__(self, vocab_file)
  1455. model_pb2 = import_protobuf()
  1456. m = model_pb2.ModelProto()
  1457. with open(vocab_file, "rb") as f:
  1458. m.ParseFromString(f.read())
  1459. self.proto = m
  1460. def tokenizer(self, proto):
  1461. vocab_scores = self.vocab(proto)
  1462. tokenizer = Tokenizer(
  1463. Unigram(
  1464. vocab_scores,
  1465. unk_id=self.unk_id(proto),
  1466. byte_fallback=self.handle_byte_fallback,
  1467. )
  1468. )
  1469. # control tokens are special
  1470. # user defined symbols are not
  1471. # both user and control tokens are AddedTokens
  1472. # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
  1473. spm_added_tokens = [
  1474. (id, p.piece, p.type == 3 or p.piece in self.special_tokens)
  1475. for id, p in enumerate(proto.pieces)
  1476. if p.type in [3, 4]
  1477. ]
  1478. tokenizer.add_tokens(
  1479. [
  1480. AddedToken(token, normalized=False, special=special, single_word=True)
  1481. for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
  1482. ]
  1483. )
  1484. tokenizer.add_tokens([AddedToken("\n", normalized=False, special=False)])
  1485. tokenizer.enable_padding(pad_token="<pad>", pad_id=3)
  1486. return tokenizer
  1487. def vocab(self, proto):
  1488. vocab = []
  1489. for piece in proto.pieces:
  1490. if piece.piece == "<0x0A>":
  1491. vocab += [("\n", piece.score)]
  1492. else:
  1493. vocab += [(piece.piece, piece.score)]
  1494. return vocab
  1495. def unk_id(self, proto):
  1496. unk_id = 0
  1497. return unk_id
  1498. def decoder(self, replacement, add_prefix_space):
  1499. sequence = [
  1500. decoders.Replace("▁", " "),
  1501. decoders.ByteFallback(),
  1502. decoders.Fuse(),
  1503. ]
  1504. sequence += [decoders.Strip(content=" ", left=1)]
  1505. return decoders.Sequence(sequence)
  1506. def normalizer(self, proto):
  1507. return normalizers.Sequence([normalizers.Prepend(" "), normalizers.Replace(r" ", "▁")])
  1508. def pre_tokenizer(self, replacement, add_prefix_space):
  1509. return pre_tokenizers.Sequence([pre_tokenizers.Split("\n", "contiguous")])
  1510. def post_processor(self):
  1511. return processors.TemplateProcessing(
  1512. single=[
  1513. "<s>",
  1514. "$A",
  1515. ],
  1516. pair=[
  1517. "<s>",
  1518. "$A",
  1519. "<s>",
  1520. "$B",
  1521. ],
  1522. special_tokens=[
  1523. ("<s>", 1),
  1524. ],
  1525. )
  1526. class ParakeetConverter(SpmConverter):
  1527. handle_byte_fallback = True
  1528. def __init__(self, vocab_file=None, *args):
  1529. self.vocab_file = vocab_file
  1530. requires_backends(self, "protobuf")
  1531. Converter.__init__(self, vocab_file)
  1532. model_pb2 = import_protobuf()
  1533. m = model_pb2.ModelProto()
  1534. with open(vocab_file, "rb") as f:
  1535. m.ParseFromString(f.read())
  1536. self.proto = m
  1537. def tokenizer(self, proto):
  1538. vocab_scores = self.vocab(proto)
  1539. _, merges = self.SpmExtractor(self.vocab_file).extract(vocab_scores)
  1540. bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
  1541. tokenizer = Tokenizer(
  1542. BPE(
  1543. bpe_vocab,
  1544. merges,
  1545. unk_token=proto.trainer_spec.unk_piece,
  1546. fuse_unk=True,
  1547. byte_fallback=self.handle_byte_fallback,
  1548. dropout=None,
  1549. )
  1550. )
  1551. # Add user defined symbols and control tokens from sentencepiece model
  1552. spm_added_tokens = [
  1553. (id, p.piece, p.type == 3 or p.piece in self.special_tokens)
  1554. for id, p in enumerate(proto.pieces)
  1555. if p.type in [3, 4]
  1556. ]
  1557. tokenizer.add_tokens(
  1558. [
  1559. AddedToken(token, normalized=False, special=special)
  1560. for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
  1561. ]
  1562. )
  1563. return tokenizer
  1564. def bytes_to_unicode():
  1565. """
  1566. Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
  1567. characters the bpe code barfs on.
  1568. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
  1569. if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
  1570. decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
  1571. tables between utf-8 bytes and unicode strings.
  1572. """
  1573. bs = (
  1574. list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
  1575. )
  1576. cs = bs[:]
  1577. n = 0
  1578. for b in range(2**8):
  1579. if b not in bs:
  1580. bs.append(b)
  1581. cs.append(2**8 + n)
  1582. n += 1
  1583. cs = [chr(n) for n in cs]
  1584. return dict(zip(bs, cs))
  1585. class TikTokenConverter:
  1586. """
  1587. A general tiktoken converter.
  1588. """
  1589. def __init__(
  1590. self,
  1591. vocab_file=None,
  1592. pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
  1593. add_prefix_space=False,
  1594. extra_special_tokens=None,
  1595. **kwargs,
  1596. ):
  1597. self.vocab_file = vocab_file
  1598. self.pattern = pattern
  1599. self.add_prefix_space = add_prefix_space
  1600. self.extra_special_tokens = (
  1601. extra_special_tokens.keys() if isinstance(extra_special_tokens, dict) else extra_special_tokens
  1602. )
  1603. def extract_vocab_merges_from_model(self, tiktoken_url: str):
  1604. try:
  1605. from tiktoken.load import load_tiktoken_bpe
  1606. except Exception:
  1607. raise ValueError(
  1608. "`tiktoken` is required to read a `tiktoken` file. Install it with `pip install tiktoken`."
  1609. )
  1610. bpe_ranks = load_tiktoken_bpe(tiktoken_url)
  1611. byte_encoder = bytes_to_unicode()
  1612. def token_bytes_to_string(b):
  1613. return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
  1614. merges = []
  1615. vocab = {}
  1616. for token, rank in bpe_ranks.items():
  1617. vocab[token_bytes_to_string(token)] = rank
  1618. if len(token) == 1:
  1619. continue
  1620. local = []
  1621. for index in range(1, len(token)):
  1622. piece_l, piece_r = token[:index], token[index:]
  1623. if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
  1624. local.append((piece_l, piece_r, rank))
  1625. local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
  1626. merges.extend(local)
  1627. merges = sorted(merges, key=lambda val: val[2], reverse=False)
  1628. merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
  1629. return vocab, merges
  1630. def tokenizer(self):
  1631. vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
  1632. tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
  1633. if hasattr(tokenizer.model, "ignore_merges"):
  1634. tokenizer.model.ignore_merges = True
  1635. return tokenizer
  1636. def converted(self) -> Tokenizer:
  1637. tokenizer = self.tokenizer()
  1638. tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  1639. [
  1640. pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
  1641. pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
  1642. ]
  1643. )
  1644. tokenizer.decoder = decoders.ByteLevel()
  1645. if self.extra_special_tokens is not None:
  1646. tokenizer.add_special_tokens(
  1647. [AddedToken(token, normalized=False, special=True) for token in self.extra_special_tokens]
  1648. )
  1649. tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
  1650. return tokenizer
  1651. class MistralConverter:
  1652. def __init__(
  1653. self,
  1654. vocab_file=None,
  1655. pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
  1656. add_prefix_space=False,
  1657. additional_special_tokens=None,
  1658. **kwargs,
  1659. ):
  1660. self.vocab_file = vocab_file
  1661. self.pattern = pattern
  1662. self.add_prefix_space = add_prefix_space
  1663. self.additional_special_tokens = (
  1664. additional_special_tokens.keys()
  1665. if isinstance(additional_special_tokens, dict)
  1666. else additional_special_tokens
  1667. )
  1668. def extract_vocab_merges_from_model(self, tiktoken_url: str):
  1669. import base64
  1670. import json
  1671. with open(self.vocab_file, "r", encoding="utf-8") as f:
  1672. untyped = json.load(f)
  1673. self.pattern = untyped["config"]["pattern"]
  1674. self.additional_special_tokens = [
  1675. AddedToken(k["token_str"], special=k["is_control"]) for k in untyped["special_tokens"]
  1676. ]
  1677. bpe_ranks = untyped["vocab"]
  1678. byte_encoder = bytes_to_unicode()
  1679. @lru_cache
  1680. def token_bytes_to_string(b):
  1681. return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
  1682. merges = []
  1683. vocab = {}
  1684. for idx, token in enumerate(self.additional_special_tokens):
  1685. vocab[token.content] = idx
  1686. bpe_ranks = [base64.b64decode(k["token_bytes"]) for k in bpe_ranks]
  1687. rank_set = set(bpe_ranks)
  1688. token_to_rank = {token: rank for rank, token in enumerate(bpe_ranks)}
  1689. for rank, token in enumerate(tqdm(bpe_ranks, desc="Converting tekken.json to tokenizer.json")):
  1690. vocab[token_bytes_to_string(token)] = rank
  1691. if len(token) == 1:
  1692. continue
  1693. local = []
  1694. for index in range(1, len(token)):
  1695. piece_l, piece_r = token[:index], token[index:]
  1696. if piece_l in rank_set and piece_r in rank_set and (piece_l + piece_r) in rank_set:
  1697. local.append((piece_l, piece_r, rank))
  1698. local = sorted(local, key=lambda x: (token_to_rank[x[0]], token_to_rank[x[1]]), reverse=False)
  1699. merges.extend(local)
  1700. merges = sorted(merges, key=lambda val: val[2], reverse=False)
  1701. merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
  1702. return vocab, merges
  1703. def tokenizer(self):
  1704. vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
  1705. tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
  1706. if hasattr(tokenizer.model, "ignore_merges"):
  1707. tokenizer.model.ignore_merges = True
  1708. return tokenizer
  1709. def converted(self) -> Tokenizer:
  1710. tokenizer = self.tokenizer()
  1711. tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  1712. [
  1713. pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
  1714. pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
  1715. ]
  1716. )
  1717. tokenizer.decoder = decoders.ByteLevel()
  1718. tokenizer.add_tokens(self.additional_special_tokens)
  1719. tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
  1720. return tokenizer
  1721. SLOW_TO_FAST_CONVERTERS = {
  1722. "AlbertTokenizer": AlbertConverter,
  1723. "BartTokenizer": RobertaConverter,
  1724. "BarthezTokenizer": BarthezConverter,
  1725. "BertTokenizer": BertConverter,
  1726. "BigBirdTokenizer": BigBirdConverter,
  1727. "BlenderbotTokenizer": BlenderbotConverter,
  1728. "CamembertTokenizer": CamembertConverter,
  1729. "CLIPTokenizer": CLIPConverter,
  1730. "CodeGenTokenizer": GPT2Converter,
  1731. "ConvBertTokenizer": BertConverter,
  1732. "DebertaTokenizer": DebertaConverter,
  1733. "DebertaV2Tokenizer": DebertaV2Converter,
  1734. "DistilBertTokenizer": BertConverter,
  1735. "DPRReaderTokenizer": BertConverter,
  1736. "DPRQuestionEncoderTokenizer": BertConverter,
  1737. "DPRContextEncoderTokenizer": BertConverter,
  1738. "ElectraTokenizer": BertConverter,
  1739. "FNetTokenizer": AlbertConverter,
  1740. "FunnelTokenizer": FunnelConverter,
  1741. "GPT2Tokenizer": GPT2Converter,
  1742. "HerbertTokenizer": HerbertConverter,
  1743. "LayoutLMTokenizer": BertConverter,
  1744. "LayoutLMv2Tokenizer": BertConverter,
  1745. "LayoutLMv3Tokenizer": RobertaConverter,
  1746. "LayoutXLMTokenizer": XLMRobertaConverter,
  1747. "LongformerTokenizer": RobertaConverter,
  1748. "LEDTokenizer": RobertaConverter,
  1749. "LxmertTokenizer": BertConverter,
  1750. "MarkupLMTokenizer": MarkupLMConverter,
  1751. "MBartTokenizer": MBartConverter,
  1752. "MBart50Tokenizer": MBart50Converter,
  1753. "MPNetTokenizer": MPNetConverter,
  1754. "MobileBertTokenizer": BertConverter,
  1755. "MvpTokenizer": RobertaConverter,
  1756. "NllbTokenizer": NllbConverter,
  1757. "OpenAIGPTTokenizer": OpenAIGPTConverter,
  1758. "PegasusTokenizer": PegasusConverter,
  1759. "Qwen2Tokenizer": Qwen2Converter,
  1760. "ReformerTokenizer": ReformerConverter,
  1761. "RemBertTokenizer": RemBertConverter,
  1762. "RobertaTokenizer": RobertaConverter,
  1763. "RoFormerTokenizer": RoFormerConverter,
  1764. "SeamlessM4TTokenizer": SeamlessM4TConverter,
  1765. "SqueezeBertTokenizer": BertConverter,
  1766. "T5Tokenizer": T5Converter,
  1767. "UdopTokenizer": UdopConverter,
  1768. "WhisperTokenizer": WhisperConverter,
  1769. "XLMRobertaTokenizer": XLMRobertaConverter,
  1770. "XLNetTokenizer": XLNetConverter,
  1771. "SplinterTokenizer": SplinterConverter,
  1772. "XGLMTokenizer": XGLMConverter,
  1773. "LlamaTokenizer": LlamaConverter,
  1774. "CodeLlamaTokenizer": LlamaConverter,
  1775. "GemmaTokenizer": GemmaConverter,
  1776. "Phi3Tokenizer": LlamaConverter,
  1777. }
  1778. def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokenizer:
  1779. """
  1780. Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
  1781. Args:
  1782. transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
  1783. Instance of a slow tokenizer to convert in the backend tokenizer for
  1784. [`~tokenization_utils_base.PreTrainedTokenizerFast`].
  1785. from_tiktoken (bool, optional): Whether to use the `tiktoken` library to convert the tokenizer instead of sentencepiece.
  1786. Defaults to False.
  1787. Return:
  1788. A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
  1789. [`~tokenization_utils_base.PreTrainedTokenizerFast`]
  1790. """
  1791. tokenizer_class_name = transformer_tokenizer.__class__.__name__
  1792. if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken:
  1793. converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
  1794. return converter_class(transformer_tokenizer).converted()
  1795. elif transformer_tokenizer.vocab_file.endswith("tekken.json"):
  1796. transformer_tokenizer.original_tokenizer = transformer_tokenizer
  1797. logger.info("Converting from Mistral tekken.json")
  1798. return MistralConverter(transformer_tokenizer.vocab_file).converted()
  1799. else:
  1800. try:
  1801. logger.info("Converting from Tiktoken")
  1802. return TikTokenConverter(
  1803. vocab_file=transformer_tokenizer.vocab_file,
  1804. extra_special_tokens=transformer_tokenizer.extra_special_tokens,
  1805. ).converted()
  1806. except Exception:
  1807. raise ValueError(
  1808. f"Converting from SentencePiece and Tiktoken failed, if a converter for SentencePiece is available, provide a model path "
  1809. f"with a SentencePiece tokenizer.model file."
  1810. f"Currently available slow->fast converters: {list(SLOW_TO_FAST_CONVERTERS.keys())}"
  1811. )