| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151 |
- # Copyright 2018 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Utilities to convert slow tokenizers in their fast tokenizers counterparts.
- All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
- allow to make our dependency on SentencePiece optional.
- """
- import warnings
- from collections.abc import Collection
- from functools import lru_cache
- from packaging import version
- from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
- from tokenizers.models import BPE, Unigram, WordPiece
- from tqdm import tqdm
- from .utils import is_protobuf_available, is_sentencepiece_available, logging, requires_backends
- from .utils.import_utils import PROTOBUF_IMPORT_ERROR
- logger = logging.get_logger(__name__)
- MBART_LANGUAGES = [
- "ar_AR",
- "cs_CZ",
- "de_DE",
- "en_XX",
- "es_XX",
- "et_EE",
- "fi_FI",
- "fr_XX",
- "gu_IN",
- "hi_IN",
- "it_IT",
- "ja_XX",
- "kk_KZ",
- "ko_KR",
- "lt_LT",
- "lv_LV",
- "my_MM",
- "ne_NP",
- "nl_XX",
- "ro_RO",
- "ru_RU",
- "si_LK",
- "tr_TR",
- "vi_VN",
- "zh_CN",
- ]
- MBART50_LANGUAGES = MBART_LANGUAGES + [
- "af_ZA",
- "az_AZ",
- "bn_IN",
- "fa_IR",
- "he_IL",
- "hr_HR",
- "id_ID",
- "ka_GE",
- "km_KH",
- "mk_MK",
- "ml_IN",
- "mn_MN",
- "mr_IN",
- "pl_PL",
- "ps_AF",
- "pt_XX",
- "sv_SE",
- "sw_KE",
- "ta_IN",
- "te_IN",
- "th_TH",
- "tl_XX",
- "uk_UA",
- "ur_PK",
- "xh_ZA",
- "gl_ES",
- "sl_SI",
- ]
- def import_protobuf(error_message=""):
- if is_sentencepiece_available():
- from sentencepiece import sentencepiece_model_pb2
- return sentencepiece_model_pb2
- if is_protobuf_available():
- import google.protobuf
- if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
- from transformers.utils import sentencepiece_model_pb2
- else:
- from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
- return sentencepiece_model_pb2
- else:
- raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
- def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
- if add_prefix_space:
- prepend_scheme = "always"
- if not getattr(original_tokenizer, "legacy", True):
- prepend_scheme = "first"
- else:
- prepend_scheme = "never"
- return prepend_scheme
- def generate_merges(vocab, vocab_scores, skip_tokens: Collection[str] | None = None):
- skip_tokens = set(skip_tokens) if skip_tokens is not None else set()
- reverse = vocab_scores is not None
- vocab_scores = dict(vocab_scores) if reverse else vocab
- merges = []
- for merge, piece_score in vocab_scores.items():
- if merge in skip_tokens:
- continue
- local = []
- for index in range(1, len(merge)):
- piece_l, piece_r = merge[:index], merge[index:]
- if piece_l in skip_tokens or piece_r in skip_tokens:
- continue
- if piece_l in vocab and piece_r in vocab:
- local.append((piece_l, piece_r, piece_score))
- local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
- merges.extend(local)
- merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse)
- merges = [(val[0], val[1]) for val in merges]
- return merges
- class SentencePieceExtractor:
- """
- Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
- """
- def __init__(self, model: str):
- requires_backends(self, "sentencepiece")
- requires_backends(self, "protobuf")
- # from .utils import sentencepiece_model_pb2 as model_pb2
- model_pb2 = import_protobuf()
- m = model_pb2.ModelProto()
- with open(model, "rb") as f:
- m.ParseFromString(f.read())
- self.proto = m
- def extract(self, model_type, **kwargs) -> tuple[dict[str, int], list[tuple]]:
- """
- By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
- order the merges with respect to the piece scores instead.
- """
- self.proto.trainer_spec.unk_id
- if model_type is None:
- from tokenizers.models import BPE, Unigram
- model_type = Unigram if self.proto.trainer_spec.model_type == 1 else BPE
- vocab = [(piece.piece, piece.score) for piece in self.proto.pieces]
- if model_type.__name__ != "BPE":
- kwargs["unk_id"] = self.proto.trainer_spec.unk_id
- kwargs["vocab"] = vocab
- else:
- from .tokenization_utils_base import generate_merges
- vocab = {word: i for i, (word, score) in enumerate(vocab)}
- merges = generate_merges(vocab)
- kwargs["vocab"] = vocab
- kwargs["merges"] = merges
- # control tokens are special
- # user defined symbols are not
- # both user and control tokens are AddedTokens
- # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
- spm_added_tokens = [(id, p.piece, p.type == 3) for id, p in enumerate(self.proto.pieces) if p.type in [3, 4]]
- kwargs["additional_special_tokens"] = [
- AddedToken(token, normalized=False, special=special)
- for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
- ]
- kwargs["_spm_precompiled_charsmap"] = getattr(self.proto.normalizer_spec, "precompiled_charsmap", None)
- return kwargs
- class GemmaSentencePieceExtractor(SentencePieceExtractor):
- def extract(self, vocab_scores=None) -> tuple[dict[str, int], list[tuple]]:
- """
- By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
- order the merges with respect to the piece scores instead.
- """
- sp = self.sp
- vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
- # If "\t" is missing in the vocab, we have to do this to support merges
- # "<0x09>" is the bytefallback for `\t`
- if "\t" not in vocab:
- vocab["\t"] = vocab.get("<0x09>")
- merges = generate_merges(vocab, vocab_scores)
- return vocab, merges
- def check_number_comma(piece: str) -> bool:
- return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
- class Converter:
- def __init__(self, original_tokenizer):
- self.original_tokenizer = original_tokenizer
- def converted(self) -> Tokenizer:
- raise NotImplementedError()
- class BertConverter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.vocab
- tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
- tokenize_chinese_chars = False
- strip_accents = False
- do_lower_case = False
- if hasattr(self.original_tokenizer, "basic_tokenizer"):
- tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
- strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
- do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
- tokenizer.normalizer = normalizers.BertNormalizer(
- clean_text=True,
- handle_chinese_chars=tokenize_chinese_chars,
- strip_accents=strip_accents,
- lowercase=do_lower_case,
- )
- tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
- cls = str(self.original_tokenizer.cls_token)
- sep = str(self.original_tokenizer.sep_token)
- cls_token_id = self.original_tokenizer.cls_token_id
- sep_token_id = self.original_tokenizer.sep_token_id
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{cls}:0 $A:0 {sep}:0",
- pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
- special_tokens=[
- (cls, cls_token_id),
- (sep, sep_token_id),
- ],
- )
- tokenizer.decoder = decoders.WordPiece(prefix="##")
- return tokenizer
- class SplinterConverter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.vocab
- tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
- tokenize_chinese_chars = False
- strip_accents = False
- do_lower_case = False
- if hasattr(self.original_tokenizer, "basic_tokenizer"):
- tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
- strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
- do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
- tokenizer.normalizer = normalizers.BertNormalizer(
- clean_text=True,
- handle_chinese_chars=tokenize_chinese_chars,
- strip_accents=strip_accents,
- lowercase=do_lower_case,
- )
- tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
- cls = str(self.original_tokenizer.cls_token)
- sep = str(self.original_tokenizer.sep_token)
- question = str(self.original_tokenizer.question_token)
- dot = "."
- cls_token_id = self.original_tokenizer.cls_token_id
- sep_token_id = self.original_tokenizer.sep_token_id
- question_token_id = self.original_tokenizer.question_token_id
- dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
- if self.original_tokenizer.padding_side == "right":
- pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
- else:
- pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{cls}:0 $A:0 {sep}:0",
- pair=pair,
- special_tokens=[
- (cls, cls_token_id),
- (sep, sep_token_id),
- (question, question_token_id),
- (dot, dot_token_id),
- ],
- )
- tokenizer.decoder = decoders.WordPiece(prefix="##")
- return tokenizer
- class FunnelConverter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.vocab
- tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
- tokenize_chinese_chars = False
- strip_accents = False
- do_lower_case = False
- if hasattr(self.original_tokenizer, "basic_tokenizer"):
- tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
- strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
- do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
- tokenizer.normalizer = normalizers.BertNormalizer(
- clean_text=True,
- handle_chinese_chars=tokenize_chinese_chars,
- strip_accents=strip_accents,
- lowercase=do_lower_case,
- )
- tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
- cls = str(self.original_tokenizer.cls_token)
- sep = str(self.original_tokenizer.sep_token)
- cls_token_id = self.original_tokenizer.cls_token_id
- sep_token_id = self.original_tokenizer.sep_token_id
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer
- pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
- special_tokens=[
- (cls, cls_token_id),
- (sep, sep_token_id),
- ],
- )
- tokenizer.decoder = decoders.WordPiece(prefix="##")
- return tokenizer
- class MPNetConverter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.vocab
- tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
- tokenize_chinese_chars = False
- strip_accents = False
- do_lower_case = False
- if hasattr(self.original_tokenizer, "basic_tokenizer"):
- tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
- strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
- do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
- tokenizer.normalizer = normalizers.BertNormalizer(
- clean_text=True,
- handle_chinese_chars=tokenize_chinese_chars,
- strip_accents=strip_accents,
- lowercase=do_lower_case,
- )
- tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
- cls = str(self.original_tokenizer.cls_token)
- sep = str(self.original_tokenizer.sep_token)
- cls_token_id = self.original_tokenizer.cls_token_id
- sep_token_id = self.original_tokenizer.sep_token_id
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{cls}:0 $A:0 {sep}:0",
- pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
- special_tokens=[
- (cls, cls_token_id),
- (sep, sep_token_id),
- ],
- )
- tokenizer.decoder = decoders.WordPiece(prefix="##")
- return tokenizer
- class OpenAIGPTConverter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.encoder
- merges = list(self.original_tokenizer.bpe_ranks.keys())
- unk_token = self.original_tokenizer.unk_token
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- unk_token=str(unk_token),
- end_of_word_suffix="</w>",
- fuse_unk=False,
- )
- )
- if tokenizer.token_to_id(str(unk_token)) is not None:
- tokenizer.add_special_tokens([str(unk_token)])
- tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
- tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
- tokenizer.decoder = decoders.BPEDecoder(suffix="</w>")
- return tokenizer
- class GPT2Converter(Converter):
- def converted(self, vocab: dict[str, int] | None = None, merges: list[tuple[str, str]] | None = None) -> Tokenizer:
- if not vocab:
- vocab = self.original_tokenizer.encoder
- if not merges:
- merges = list(self.original_tokenizer.bpe_ranks)
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- )
- )
- add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False)
- tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
- tokenizer.decoder = decoders.ByteLevel()
- if getattr(self.original_tokenizer, "add_bos_token", False):
- bos = self.original_tokenizer.bos_token
- bos_token_id = self.original_tokenizer.bos_token_id
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{bos}:0 $A:0",
- pair=f"{bos}:0 $A:0 $B:1",
- special_tokens=[
- (bos, bos_token_id),
- ],
- )
- else:
- # XXX trim_offsets=False actually means this post_processor doesn't
- # really do anything.
- tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
- return tokenizer
- class HerbertConverter(Converter):
- def converted(self) -> Tokenizer:
- tokenizer_info_str = "#version:"
- token_suffix = "</w>"
- vocab = self.original_tokenizer.encoder
- merges = list(self.original_tokenizer.bpe_ranks.keys())
- if tokenizer_info_str in merges[0][0]:
- merges = merges[1:]
- tokenizer = Tokenizer(
- BPE(
- vocab,
- merges,
- dropout=None,
- unk_token=self.original_tokenizer.unk_token,
- end_of_word_suffix=token_suffix,
- )
- )
- tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
- tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
- tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
- tokenizer.post_processor = processors.BertProcessing(
- sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
- cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
- )
- return tokenizer
- class Qwen2Converter(Converter):
- def converted(self, vocab: dict[str, int] | None = None, merges: list[tuple[str, str]] | None = None) -> Tokenizer:
- if not vocab:
- vocab = self.original_tokenizer.encoder
- if not merges:
- merges = list(self.original_tokenizer.bpe_ranks.keys())
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- unk_token=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- byte_fallback=False,
- )
- )
- tokenizer.normalizer = normalizers.NFC()
- tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
- [
- pre_tokenizers.Split(
- Regex(
- r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
- ),
- behavior="isolated",
- invert=False,
- ),
- pre_tokenizers.ByteLevel(
- add_prefix_space=getattr(self.original_tokenizer, "add_prefix_space", False),
- use_regex=False,
- ),
- ]
- )
- tokenizer.decoder = decoders.ByteLevel()
- tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
- return tokenizer
- class RobertaConverter(Converter):
- def converted(self) -> Tokenizer:
- ot = self.original_tokenizer
- vocab = ot.encoder
- merges = list(ot.bpe_ranks.keys())
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- )
- )
- tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
- tokenizer.decoder = decoders.ByteLevel()
- tokenizer.post_processor = processors.RobertaProcessing(
- sep=(ot.sep_token, ot.sep_token_id),
- cls=(ot.cls_token, ot.cls_token_id),
- add_prefix_space=ot.add_prefix_space,
- trim_offsets=True, # True by default on Roberta (historical)
- )
- return tokenizer
- class RoFormerConverter(Converter):
- def converted(self) -> Tokenizer:
- from .models.roformer.tokenization_utils import JiebaPreTokenizer
- vocab = self.original_tokenizer.vocab
- tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
- strip_accents = False
- do_lower_case = False
- if hasattr(self.original_tokenizer, "basic_tokenizer"):
- strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
- do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
- tokenizer.normalizer = normalizers.BertNormalizer(
- clean_text=True,
- handle_chinese_chars=False,
- strip_accents=strip_accents,
- lowercase=do_lower_case,
- )
- tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
- cls = str(self.original_tokenizer.cls_token)
- sep = str(self.original_tokenizer.sep_token)
- cls_token_id = self.original_tokenizer.cls_token_id
- sep_token_id = self.original_tokenizer.sep_token_id
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{cls}:0 $A:0 {sep}:0",
- pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
- special_tokens=[
- (cls, cls_token_id),
- (sep, sep_token_id),
- ],
- )
- tokenizer.decoder = decoders.WordPiece(prefix="##")
- return tokenizer
- class DebertaConverter(Converter):
- def converted(self) -> Tokenizer:
- ot = self.original_tokenizer
- vocab = ot.encoder
- merges = list(ot.bpe_ranks.keys())
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- )
- )
- tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
- tokenizer.decoder = decoders.ByteLevel()
- tokenizer.post_processor = processors.TemplateProcessing(
- single="[CLS]:0 $A:0 [SEP]:0",
- pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
- special_tokens=[
- ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
- ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
- ],
- )
- return tokenizer
- class SpmConverter(Converter):
- handle_byte_fallback = False
- SpmExtractor = SentencePieceExtractor
- special_tokens = {}
- @staticmethod
- def build_tokenizer_from_spm_proto(proto, vocab, merges=None):
- """
- Similar to convert_from_spm method, but used only when there is no `model_type` class, i.e. there is no matching class in `TOKENIZERS_MAPPING` and we just create a tokenizer instead of extracting stuff from the sentencepiece file
- """
- byte_fallback = proto.trainer_spec.byte_fallback
- unk_piece = proto.trainer_spec.unk_piece
- precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
- # model
- if isinstance(vocab, dict):
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges or [],
- unk_token=unk_piece,
- fuse_unk=True,
- byte_fallback=byte_fallback,
- dropout=None,
- )
- )
- elif isinstance(vocab, list) and vocab and isinstance(vocab[0], (tuple, list)):
- tokenizer = Tokenizer(
- Unigram(
- vocab=vocab,
- unk_id=proto.trainer_spec.unk_id,
- byte_fallback=byte_fallback,
- )
- )
- else:
- return None
- # normalizer
- _normalizers = [normalizers.Replace(" ", "▁")]
- if precompiled_charsmap:
- _normalizers.insert(0, normalizers.Precompiled(precompiled_charsmap))
- tokenizer.normalizer = normalizers.Sequence(_normalizers)
- # decoder
- if byte_fallback:
- tokenizer.decoder = decoders.Sequence(
- [decoders.Replace("▁", " "), decoders.ByteFallback(), decoders.Fuse()]
- )
- else:
- tokenizer.decoder = decoders.Sequence([decoders.Replace("▁", " ")])
- return tokenizer
- @classmethod
- def convert_from_spm(cls, vocab=None, **kwargs):
- """
- Hook used when converting directly from a SentencePiece model without a slow tokenizer instance.
- By default, return kwargs unchanged.
- """
- if vocab is not None:
- kwargs["vocab"] = vocab
- return kwargs
- def __init__(self, *args):
- requires_backends(self, "protobuf")
- super().__init__(*args)
- # from .utils import sentencepiece_model_pb2 as model_pb2
- model_pb2 = import_protobuf()
- m = model_pb2.ModelProto()
- with open(self.original_tokenizer.vocab_file, "rb") as f:
- m.ParseFromString(f.read())
- self.proto = m
- if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback:
- warnings.warn(
- "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
- " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
- " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
- "unknown tokens into a sequence of byte tokens matching the original piece of text."
- )
- def vocab(self, proto):
- return [(piece.piece, piece.score) for piece in proto.pieces]
- def unk_id(self, proto):
- return proto.trainer_spec.unk_id
- def tokenizer(self, proto):
- model_type = proto.trainer_spec.model_type
- vocab_scores = self.vocab(proto)
- if model_type == 1:
- tokenizer = Tokenizer(
- Unigram(
- vocab_scores,
- unk_id=self.unk_id(proto),
- byte_fallback=self.handle_byte_fallback,
- )
- )
- elif model_type == 2:
- _, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
- bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
- tokenizer = Tokenizer(
- BPE(
- bpe_vocab,
- merges,
- unk_token=proto.trainer_spec.unk_piece,
- fuse_unk=True,
- byte_fallback=self.handle_byte_fallback,
- dropout=None,
- )
- )
- else:
- raise Exception(
- "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
- )
- # control tokens are special
- # user defined symbols are not
- # both user and control tokens are AddedTokens
- # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
- spm_added_tokens = [
- (id, p.piece, p.type == 3 or p.piece in self.special_tokens)
- for id, p in enumerate(proto.pieces)
- if p.type in [3, 4]
- ]
- tokenizer.add_tokens(
- [
- AddedToken(token, normalized=False, special=special)
- for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
- ]
- )
- return tokenizer
- def normalizer(self, proto):
- precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
- _normalizers = [
- normalizers.Strip(left=False, right=True), # stripping is important
- normalizers.Replace(Regex(" {2,}"), "▁"),
- ]
- if not precompiled_charsmap:
- return normalizers.Sequence(_normalizers)
- else:
- return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
- def pre_tokenizer(self, replacement, add_prefix_space):
- prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
- return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
- def post_processor(self):
- return None
- def decoder(self, replacement, add_prefix_space):
- prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
- return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
- def converted(self) -> Tokenizer:
- tokenizer = self.tokenizer(self.proto)
- # Tokenizer assemble
- normalizer = self.normalizer(self.proto)
- if normalizer is not None:
- tokenizer.normalizer = normalizer
- replacement = "▁"
- add_prefix_space = True
- if hasattr(self.original_tokenizer, "add_prefix_space"):
- add_prefix_space = self.original_tokenizer.add_prefix_space
- pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
- if pre_tokenizer is not None:
- tokenizer.pre_tokenizer = pre_tokenizer
- tokenizer.decoder = self.decoder(replacement, add_prefix_space)
- post_processor = self.post_processor()
- if post_processor:
- tokenizer.post_processor = post_processor
- return tokenizer
- class AlbertConverter(SpmConverter):
- def vocab(self, proto):
- return [
- (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
- for piece in proto.pieces
- ]
- def normalizer(self, proto):
- list_normalizers = [
- normalizers.Replace("``", '"'),
- normalizers.Replace("''", '"'),
- ]
- if not self.original_tokenizer.keep_accents:
- list_normalizers.append(normalizers.NFKD())
- list_normalizers.append(normalizers.StripAccents())
- if self.original_tokenizer.do_lower_case:
- list_normalizers.append(normalizers.Lowercase())
- precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
- if precompiled_charsmap:
- list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
- list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
- return normalizers.Sequence(list_normalizers)
- def post_processor(self):
- return processors.TemplateProcessing(
- single="[CLS]:0 $A:0 [SEP]:0",
- pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
- special_tokens=[
- ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
- ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
- ],
- )
- class BarthezConverter(SpmConverter):
- def unk_id(self, proto):
- unk_id = 3
- return unk_id
- def post_processor(self):
- return processors.TemplateProcessing(
- single="<s> $A </s>",
- pair="<s> $A </s> </s> $B </s>",
- special_tokens=[
- ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
- ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
- ],
- )
- class CamembertConverter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- ("<s>NOTUSED", 0.0),
- ("<pad>", 0.0),
- ("</s>NOTUSED", 0.0),
- ("<unk>", 0.0),
- ("<unk>NOTUSED", -100),
- ]
- # We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
- vocab += [("<mask>", 0.0)]
- return vocab
- def unk_id(self, proto):
- # See vocab unk position
- return 3
- def post_processor(self):
- return processors.TemplateProcessing(
- single="<s> $A </s>",
- pair="<s> $A </s> </s> $B </s>",
- special_tokens=[
- ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
- ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
- ],
- )
- @classmethod
- def convert_from_spm(cls, vocab=None, **kwargs):
- pad_token = str(kwargs.get("pad_token", "<pad>"))
- unk_token = str(kwargs.get("unk_token", "<unk>"))
- mask_token = str(kwargs.get("mask_token", "<mask>"))
- vocab_list = [
- ("<s>NOTUSED", 0.0),
- (pad_token, 0.0),
- ("</s>NOTUSED", 0.0),
- (unk_token, 0.0),
- ("<unk>NOTUSED", -100.0),
- ]
- if vocab is not None:
- vocab_list.extend(list(vocab)[1:])
- vocab_list.append((mask_token, 0.0))
- kwargs["vocab"] = vocab_list
- return kwargs
- class DebertaV2Converter(SpmConverter):
- def pre_tokenizer(self, replacement, add_prefix_space):
- list_pretokenizers = []
- if self.original_tokenizer.split_by_punct:
- list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
- prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
- list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
- return pre_tokenizers.Sequence(list_pretokenizers)
- def normalizer(self, proto):
- list_normalizers = []
- if self.original_tokenizer.do_lower_case:
- list_normalizers.append(normalizers.Lowercase())
- list_normalizers.append(normalizers.Strip())
- precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
- if precompiled_charsmap:
- list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
- list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
- return normalizers.Sequence(list_normalizers)
- def post_processor(self):
- return processors.TemplateProcessing(
- single="[CLS]:0 $A:0 [SEP]:0",
- pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
- special_tokens=[
- ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
- ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
- ],
- )
- class MBartConverter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- ("<s>", 0.0),
- ("<pad>", 0.0),
- ("</s>", 0.0),
- ("<unk>", 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- vocab += [
- ("ar_AR", 0.0),
- ("cs_CZ", 0.0),
- ("de_DE", 0.0),
- ("en_XX", 0.0),
- ("es_XX", 0.0),
- ("et_EE", 0.0),
- ("fi_FI", 0.0),
- ("fr_XX", 0.0),
- ("gu_IN", 0.0),
- ("hi_IN", 0.0),
- ("it_IT", 0.0),
- ("ja_XX", 0.0),
- ("kk_KZ", 0.0),
- ("ko_KR", 0.0),
- ("lt_LT", 0.0),
- ("lv_LV", 0.0),
- ("my_MM", 0.0),
- ("ne_NP", 0.0),
- ("nl_XX", 0.0),
- ("ro_RO", 0.0),
- ("ru_RU", 0.0),
- ("si_LK", 0.0),
- ("tr_TR", 0.0),
- ("vi_VN", 0.0),
- ("zh_CN", 0.0),
- ]
- vocab += [("<mask>", 0.0)]
- return vocab
- def unk_id(self, proto):
- return 3
- def post_processor(self):
- return processors.TemplateProcessing(
- single="$A </s> en_XX",
- pair="$A $B </s> en_XX",
- special_tokens=[
- ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
- ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
- ],
- )
- @classmethod
- def convert_from_spm(cls, vocab=None, **kwargs):
- bos_token = str(kwargs.get("bos_token", "<s>"))
- pad_token = str(kwargs.get("pad_token", "<pad>"))
- eos_token = str(kwargs.get("eos_token", "</s>"))
- unk_token = str(kwargs.get("unk_token", "<unk>"))
- mask_token = str(kwargs.get("mask_token", "<mask>"))
- vocab_list = [
- (bos_token, 0.0),
- (pad_token, 0.0),
- (eos_token, 0.0),
- (unk_token, 0.0),
- ]
- if vocab is not None:
- vocab_list.extend(list(vocab)[3:])
- vocab_list.extend((lang_code, 0.0) for lang_code in MBART_LANGUAGES)
- vocab_list.append((mask_token, 0.0))
- kwargs["vocab"] = vocab_list
- return kwargs
- class MBart50Converter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- ("<s>", 0.0),
- ("<pad>", 0.0),
- ("</s>", 0.0),
- ("<unk>", 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)] # fmt: skip
- vocab += [("<mask>", 0.0)]
- return vocab
- def unk_id(self, proto):
- return 3
- def post_processor(self):
- return processors.TemplateProcessing(
- single="en_XX $A </s>",
- pair="en_XX $A $B </s>",
- special_tokens=[
- ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
- ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
- ],
- )
- @classmethod
- def convert_from_spm(cls, vocab=None, **kwargs):
- cls_token = str(kwargs.get("cls_token", "<s>"))
- pad_token = str(kwargs.get("pad_token", "<pad>"))
- eos_token = str(kwargs.get("eos_token", "</s>"))
- unk_token = str(kwargs.get("unk_token", "<unk>"))
- mask_token = str(kwargs.get("mask_token", "<mask>"))
- vocab_list = [
- (cls_token, 0.0),
- (pad_token, 0.0),
- (eos_token, 0.0),
- (unk_token, 0.0),
- ]
- if vocab is not None:
- vocab_list.extend(list(vocab)[3:])
- vocab_list.extend((lang_code, 0.0) for lang_code in MBART50_LANGUAGES)
- vocab_list.append((mask_token, 0.0))
- kwargs["vocab"] = vocab_list
- return kwargs
- class NllbConverter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- ("<s>", 0.0),
- ("<pad>", 0.0),
- ("</s>", 0.0),
- ("<unk>", 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- return vocab
- def unk_id(self, proto):
- return 3
- def post_processor(self):
- return processors.TemplateProcessing(
- single="eng_Latn $A </s>",
- pair="eng_Latn $A $B </s>",
- special_tokens=[
- ("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")),
- ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
- ],
- )
- @classmethod
- def convert_from_spm(cls, vocab=None, **kwargs):
- bos_token = str(kwargs.get("bos_token", "<s>"))
- pad_token = str(kwargs.get("pad_token", "<pad>"))
- eos_token = str(kwargs.get("eos_token", "</s>"))
- unk_token = str(kwargs.get("unk_token", "<unk>"))
- reordered_vocab = {
- bos_token: 0,
- pad_token: 1,
- eos_token: 2,
- unk_token: 3,
- }
- if vocab is not None:
- tokens = vocab.keys() if isinstance(vocab, dict) else [tok for tok, _ in vocab]
- for token in tokens:
- if token in reordered_vocab:
- continue
- reordered_vocab[token] = len(reordered_vocab)
- kwargs["vocab"] = reordered_vocab
- return kwargs
- class SeamlessM4TConverter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- ("<pad>", 0.0),
- ("<unk>", 0.0),
- ("<s>", 0.0),
- ("</s>", 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- return vocab
- def unk_id(self, proto):
- return self.original_tokenizer.unk_token_id
- def post_processor(self):
- return processors.TemplateProcessing(
- single="__eng__ $A </s>",
- pair="__eng__ $A $B </s>",
- special_tokens=[
- ("__eng__", self.original_tokenizer.convert_tokens_to_ids("__eng__")),
- ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
- ],
- )
- class XLMRobertaConverter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- ("<s>", 0.0),
- ("<pad>", 0.0),
- ("</s>", 0.0),
- ("<unk>", 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- vocab += [("<mask>", 0.0)]
- return vocab
- def unk_id(self, proto):
- unk_id = 3
- return unk_id
- def post_processor(self):
- return processors.TemplateProcessing(
- single="<s> $A </s>",
- pair="<s> $A </s> </s> $B </s>",
- special_tokens=[
- ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
- ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
- ],
- )
- @classmethod
- def convert_from_spm(cls, vocab=None, **kwargs):
- bos_token = str(kwargs.get("bos_token", "<s>"))
- pad_token = str(kwargs.get("pad_token", "<pad>"))
- eos_token = str(kwargs.get("eos_token", "</s>"))
- unk_token = str(kwargs.get("unk_token", "<unk>"))
- mask_token = str(kwargs.get("mask_token", "<mask>"))
- vocab_list = [
- (bos_token, 0.0),
- (pad_token, 0.0),
- (eos_token, 0.0),
- (unk_token, 0.0),
- ]
- if vocab is not None:
- vocab_list.extend(list(vocab)[3:])
- vocab_list.append((mask_token, 0.0))
- kwargs["vocab"] = vocab_list
- return kwargs
- class XLNetConverter(SpmConverter):
- def vocab(self, proto):
- return [
- (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
- for piece in proto.pieces
- ]
- def normalizer(self, proto):
- list_normalizers = [
- normalizers.Replace("``", '"'),
- normalizers.Replace("''", '"'),
- ]
- if not self.original_tokenizer.keep_accents:
- list_normalizers.append(normalizers.NFKD())
- list_normalizers.append(normalizers.StripAccents())
- if self.original_tokenizer.do_lower_case:
- list_normalizers.append(normalizers.Lowercase())
- precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
- if precompiled_charsmap:
- list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
- list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
- return normalizers.Sequence(list_normalizers)
- def post_processor(self):
- return processors.TemplateProcessing(
- single="$A:0 <sep>:0 <cls>:2",
- pair="$A:0 <sep>:0 $B:1 <sep>:1 <cls>:2",
- special_tokens=[
- ("<sep>", self.original_tokenizer.convert_tokens_to_ids("<sep>")),
- ("<cls>", self.original_tokenizer.convert_tokens_to_ids("<cls>")),
- ],
- )
- class ReformerConverter(SpmConverter):
- pass
- class RemBertConverter(SpmConverter):
- # Inspired from AlbertConverter
- def normalizer(self, proto):
- list_normalizers = [
- normalizers.Replace("``", '"'),
- normalizers.Replace("''", '"'),
- normalizers.Replace(Regex(" {2,}"), " "),
- ]
- if not self.original_tokenizer.keep_accents:
- list_normalizers.append(normalizers.NFKD())
- list_normalizers.append(normalizers.StripAccents())
- if self.original_tokenizer.do_lower_case:
- list_normalizers.append(normalizers.Lowercase())
- precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
- if precompiled_charsmap:
- list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
- return normalizers.Sequence(list_normalizers)
- def post_processor(self):
- return processors.TemplateProcessing(
- single="[CLS]:0 $A:0 [SEP]:0",
- pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
- special_tokens=[
- ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
- ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
- ],
- )
- class BertGenerationConverter(SpmConverter):
- pass
- class PegasusConverter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- (self.original_tokenizer.pad_token, 0.0),
- (self.original_tokenizer.eos_token, 0.0),
- ]
- if self.original_tokenizer.mask_token_sent is not None:
- vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
- if (
- self.original_tokenizer.mask_token is not None
- and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
- ):
- vocab += [(self.original_tokenizer.mask_token, 0.0)]
- vocab += [(f"<unk_{i}>", -100.0) for i in range(2, self.original_tokenizer.offset)]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
- return vocab
- @classmethod
- def convert_from_spm(cls, vocab=None, **kwargs):
- pad_token = str(kwargs.get("pad_token", "<pad>"))
- eos_token = str(kwargs.get("eos_token", "</s>"))
- mask_token = str(kwargs.get("mask_token", "<mask_1>"))
- mask_token_sent = str(kwargs.get("mask_token_sent", "<mask_2>"))
- vocab_list = [
- (pad_token, 0.0),
- (eos_token, 0.0),
- ]
- if mask_token != "None":
- vocab_list.append((mask_token, 0.0))
- if mask_token_sent != "None" and mask_token_sent != mask_token:
- vocab_list.append((mask_token_sent, 0.0))
- vocab_list.extend([(f"<unk_{i}>", -100.0) for i in range(2, kwargs.get("offset", 103))])
- if vocab is not None:
- vocab_list.extend(list(vocab)[2:])
- kwargs["vocab"] = vocab_list
- return kwargs
- def unk_id(self, proto):
- return proto.trainer_spec.unk_id + self.original_tokenizer.offset
- def pre_tokenizer(self, replacement, add_prefix_space):
- prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
- return pre_tokenizers.Sequence(
- [
- pre_tokenizers.WhitespaceSplit(),
- pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
- ]
- )
- def post_processor(self):
- eos = self.original_tokenizer.eos_token
- special_tokens = [
- (eos, self.original_tokenizer.eos_token_id),
- ]
- return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)
- class T5Converter(SpmConverter):
- def vocab(self, proto):
- num_extra_ids = self.original_tokenizer._extra_ids
- vocab = [(piece.piece, piece.score) for piece in proto.pieces]
- vocab += [(f"<extra_id_{i}>", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
- return vocab
- def post_processor(self):
- return processors.TemplateProcessing(
- single=["$A", "</s>"],
- pair=["$A", "</s>", "$B", "</s>"],
- special_tokens=[
- ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
- ],
- )
- @classmethod
- def convert_from_spm(cls, vocab=None, **kwargs):
- extra_ids = kwargs.get("extra_ids", 100)
- extra_tokens = [f"<extra_id_{i}>" for i in range(extra_ids - 1, -1, -1)]
- vocab_list = list(vocab) if vocab is not None else []
- vocab_list.extend((token, 0.0) for token in extra_tokens)
- kwargs.setdefault("additional_special_tokens", extra_tokens)
- kwargs["vocab"] = vocab_list
- return kwargs
- class UdopConverter(SpmConverter):
- def post_processor(self):
- return processors.TemplateProcessing(
- single=["$A", "</s>"],
- pair=["$A", "</s>", "$B", "</s>"],
- special_tokens=[
- ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
- ],
- )
- class WhisperConverter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.encoder
- merges = list(self.original_tokenizer.bpe_ranks.keys())
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- )
- )
- tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
- tokenizer.decoder = decoders.ByteLevel()
- prefix_token_ids = self.original_tokenizer.prefix_tokens
- prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
- eos = self.original_tokenizer.eos_token
- eos_token_id = self.original_tokenizer.eos_token_id
- prefix_template = " ".join([f"{token}:0" for token in prefixes])
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{prefix_template} $A:0 {eos}:0",
- pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
- special_tokens=[
- (eos, eos_token_id),
- *zip(prefixes, prefix_token_ids),
- ],
- )
- return tokenizer
- class BigBirdConverter(SpmConverter):
- def post_processor(self):
- return processors.TemplateProcessing(
- single="[CLS]:0 $A:0 [SEP]:0",
- pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
- special_tokens=[
- ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
- ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
- ],
- )
- class CLIPConverter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.encoder
- merges = list(self.original_tokenizer.bpe_ranks.keys())
- unk_token = self.original_tokenizer.unk_token
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="</w>",
- fuse_unk=False,
- unk_token=str(unk_token),
- )
- )
- tokenizer.normalizer = normalizers.Sequence(
- [normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
- )
- tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
- [
- pre_tokenizers.Split(
- Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
- behavior="removed",
- invert=True,
- ),
- pre_tokenizers.ByteLevel(add_prefix_space=False),
- ]
- )
- tokenizer.decoder = decoders.ByteLevel()
- # Hack to have a ByteLevel and TemplateProcessor
- tokenizer.post_processor = processors.RobertaProcessing(
- sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),
- cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),
- add_prefix_space=False,
- trim_offsets=False,
- )
- return tokenizer
- class LayoutLMv2Converter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.vocab
- tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
- tokenize_chinese_chars = False
- strip_accents = False
- do_lower_case = True
- if hasattr(self.original_tokenizer, "basic_tokenizer"):
- tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
- strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
- do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
- tokenizer.normalizer = normalizers.BertNormalizer(
- clean_text=True,
- handle_chinese_chars=tokenize_chinese_chars,
- strip_accents=strip_accents,
- lowercase=do_lower_case,
- )
- tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
- cls = str(self.original_tokenizer.cls_token)
- sep = str(self.original_tokenizer.sep_token)
- cls_token_id = self.original_tokenizer.cls_token_id
- sep_token_id = self.original_tokenizer.sep_token_id
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{cls}:0 $A:0 {sep}:0",
- pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
- special_tokens=[
- (cls, cls_token_id),
- (sep, sep_token_id),
- ],
- )
- tokenizer.decoder = decoders.WordPiece(prefix="##")
- return tokenizer
- class BlenderbotConverter(Converter):
- def converted(self) -> Tokenizer:
- ot = self.original_tokenizer
- vocab = ot.encoder
- merges = list(ot.bpe_ranks.keys())
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- )
- )
- tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
- tokenizer.decoder = decoders.ByteLevel()
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"$A:0 {ot.eos_token}:0",
- special_tokens=[
- (ot.eos_token, ot.eos_token_id),
- ],
- )
- return tokenizer
- class XGLMConverter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- ("<s>", 0.0),
- ("<pad>", 0.0),
- ("</s>", 0.0),
- ("<unk>", 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- vocab += [("<madeupword0>", 0.0), ("<madeupword1>", 0.0), ("<madeupword2>", 0.0), ("<madeupword3>", 0.0), ("<madeupword4>", 0.0), ("<madeupword5>", 0.0), ("<madeupword6>", 0.0)] # fmt: skip
- return vocab
- def unk_id(self, proto):
- unk_id = 3
- return unk_id
- def post_processor(self):
- return processors.TemplateProcessing(
- single="</s> $A",
- pair="</s> $A </s> </s> $B",
- special_tokens=[
- ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
- ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
- ],
- )
- class GemmaConverter(SpmConverter):
- handle_byte_fallback = True
- SpmExtractor = GemmaSentencePieceExtractor
- # start and end of turn tokens must be marked as special
- special_tokens = {"<start_of_turn>", "<end_of_turn>"}
- """"
- split_by_unicode_script: true
- split_by_number: true
- split_by_whitespace: true
- treat_whitespace_as_suffix: false
- allow_whitespace_only_pieces: true
- split_digits: true
- byte_fallback: true
- """
- def normalizer(self, proto):
- return normalizers.Replace(" ", "▁")
- def vocab(self, proto):
- vocab = [
- (self.original_tokenizer.pad_token, 0.0),
- (self.original_tokenizer.eos_token, 0.0),
- (self.original_tokenizer.bos_token, 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- # Older gemma tokenizers had a missing tab token, so we fix that here
- if not any(x[0] == "\t" for x in vocab):
- override_index = next((i for i, x in enumerate(vocab) if x[0] == "<0x09>"), None)
- if override_index is not None:
- vocab[override_index] = ("\t", 0.0)
- return vocab
- def pre_tokenizer(self, replacement, add_prefix_space):
- return pre_tokenizers.Split(" ", "merged_with_previous")
- def unk_id(self, proto):
- unk_id = 3
- return unk_id
- def decoder(self, replacement, add_prefix_space):
- return decoders.Sequence(
- [
- decoders.Replace("▁", " "),
- decoders.ByteFallback(),
- decoders.Fuse(),
- ]
- )
- class LlamaConverter(SpmConverter):
- handle_byte_fallback = True
- def vocab(self, proto):
- vocab = [
- (self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
- (self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
- (self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- return vocab
- def unk_id(self, proto):
- unk_id = 0
- return unk_id
- def decoder(self, replacement, add_prefix_space):
- sequence = [
- decoders.Replace("▁", " "),
- decoders.ByteFallback(),
- decoders.Fuse(),
- ]
- if add_prefix_space:
- sequence += [decoders.Strip(content=" ", left=1)]
- return decoders.Sequence(sequence)
- def normalizer(self, proto):
- if getattr(self.original_tokenizer, "legacy", True):
- sequence = []
- if getattr(self.original_tokenizer, "add_prefix_space", True):
- sequence += [normalizers.Prepend(prepend="▁")]
- sequence += [normalizers.Replace(pattern=" ", content="▁")]
- return normalizers.Sequence(sequence)
- return None # non-legacy, no normalizer
- def pre_tokenizer(self, replacement, add_prefix_space):
- if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace
- prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
- return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
- return None
- def post_processor(self):
- # the processor is defined in the LlamaTokenizerFast class.
- return None
- class MarkupLMConverter(Converter):
- def converted(self) -> Tokenizer:
- ot = self.original_tokenizer
- vocab = ot.encoder
- merges = list(ot.bpe_ranks.keys())
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- unk_token=self.original_tokenizer.unk_token,
- )
- )
- tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
- tokenizer.decoder = decoders.ByteLevel()
- cls = str(self.original_tokenizer.cls_token)
- sep = str(self.original_tokenizer.sep_token)
- cls_token_id = self.original_tokenizer.cls_token_id
- sep_token_id = self.original_tokenizer.sep_token_id
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{cls} $A {sep}",
- pair=f"{cls} $A {sep} $B {sep}",
- special_tokens=[
- (cls, cls_token_id),
- (sep, sep_token_id),
- ],
- )
- return tokenizer
- class MoshiConverter(SpmConverter):
- handle_byte_fallback = True
- def __init__(self, vocab_file, **kwargs):
- requires_backends(self, "protobuf")
- Converter.__init__(self, vocab_file)
- # from .utils import sentencepiece_model_pb2 as model_pb2
- model_pb2 = import_protobuf()
- m = model_pb2.ModelProto()
- with open(vocab_file, "rb") as f:
- m.ParseFromString(f.read())
- self.proto = m
- def normalizer(self, proto):
- precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
- _normalizers = [
- normalizers.Replace(" ", "▁"),
- ]
- if not precompiled_charsmap:
- return normalizers.Sequence(_normalizers)
- else:
- return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
- def decoder(self, replacement, add_prefix_space):
- sequence = [
- decoders.Replace("▁", " "),
- decoders.ByteFallback(),
- decoders.Fuse(),
- ]
- if add_prefix_space:
- sequence += [decoders.Strip(content=" ", left=1)]
- return decoders.Sequence(sequence)
- def pre_tokenizer(self, replacement, add_prefix_space):
- prepend_scheme = "first"
- return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
- class HeliumConverter(SpmConverter):
- handle_byte_fallback = True
- def __init__(self, vocab_file=None, **kwargs):
- requires_backends(self, "protobuf")
- Converter.__init__(self, vocab_file)
- model_pb2 = import_protobuf()
- m = model_pb2.ModelProto()
- with open(vocab_file, "rb") as f:
- m.ParseFromString(f.read())
- self.proto = m
- def tokenizer(self, proto):
- vocab_scores = self.vocab(proto)
- tokenizer = Tokenizer(
- Unigram(
- vocab_scores,
- unk_id=self.unk_id(proto),
- byte_fallback=self.handle_byte_fallback,
- )
- )
- # control tokens are special
- # user defined symbols are not
- # both user and control tokens are AddedTokens
- # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
- spm_added_tokens = [
- (id, p.piece, p.type == 3 or p.piece in self.special_tokens)
- for id, p in enumerate(proto.pieces)
- if p.type in [3, 4]
- ]
- tokenizer.add_tokens(
- [
- AddedToken(token, normalized=False, special=special, single_word=True)
- for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
- ]
- )
- tokenizer.add_tokens([AddedToken("\n", normalized=False, special=False)])
- tokenizer.enable_padding(pad_token="<pad>", pad_id=3)
- return tokenizer
- def vocab(self, proto):
- vocab = []
- for piece in proto.pieces:
- if piece.piece == "<0x0A>":
- vocab += [("\n", piece.score)]
- else:
- vocab += [(piece.piece, piece.score)]
- return vocab
- def unk_id(self, proto):
- unk_id = 0
- return unk_id
- def decoder(self, replacement, add_prefix_space):
- sequence = [
- decoders.Replace("▁", " "),
- decoders.ByteFallback(),
- decoders.Fuse(),
- ]
- sequence += [decoders.Strip(content=" ", left=1)]
- return decoders.Sequence(sequence)
- def normalizer(self, proto):
- return normalizers.Sequence([normalizers.Prepend(" "), normalizers.Replace(r" ", "▁")])
- def pre_tokenizer(self, replacement, add_prefix_space):
- return pre_tokenizers.Sequence([pre_tokenizers.Split("\n", "contiguous")])
- def post_processor(self):
- return processors.TemplateProcessing(
- single=[
- "<s>",
- "$A",
- ],
- pair=[
- "<s>",
- "$A",
- "<s>",
- "$B",
- ],
- special_tokens=[
- ("<s>", 1),
- ],
- )
- class ParakeetConverter(SpmConverter):
- handle_byte_fallback = True
- def __init__(self, vocab_file=None, *args):
- self.vocab_file = vocab_file
- requires_backends(self, "protobuf")
- Converter.__init__(self, vocab_file)
- model_pb2 = import_protobuf()
- m = model_pb2.ModelProto()
- with open(vocab_file, "rb") as f:
- m.ParseFromString(f.read())
- self.proto = m
- def tokenizer(self, proto):
- vocab_scores = self.vocab(proto)
- _, merges = self.SpmExtractor(self.vocab_file).extract(vocab_scores)
- bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
- tokenizer = Tokenizer(
- BPE(
- bpe_vocab,
- merges,
- unk_token=proto.trainer_spec.unk_piece,
- fuse_unk=True,
- byte_fallback=self.handle_byte_fallback,
- dropout=None,
- )
- )
- # Add user defined symbols and control tokens from sentencepiece model
- spm_added_tokens = [
- (id, p.piece, p.type == 3 or p.piece in self.special_tokens)
- for id, p in enumerate(proto.pieces)
- if p.type in [3, 4]
- ]
- tokenizer.add_tokens(
- [
- AddedToken(token, normalized=False, special=special)
- for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
- ]
- )
- return tokenizer
- def bytes_to_unicode():
- """
- Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
- characters the bpe code barfs on.
- The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
- if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
- decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
- tables between utf-8 bytes and unicode strings.
- """
- bs = (
- list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
- )
- cs = bs[:]
- n = 0
- for b in range(2**8):
- if b not in bs:
- bs.append(b)
- cs.append(2**8 + n)
- n += 1
- cs = [chr(n) for n in cs]
- return dict(zip(bs, cs))
- class TikTokenConverter:
- """
- A general tiktoken converter.
- """
- def __init__(
- self,
- vocab_file=None,
- pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
- add_prefix_space=False,
- extra_special_tokens=None,
- **kwargs,
- ):
- self.vocab_file = vocab_file
- self.pattern = pattern
- self.add_prefix_space = add_prefix_space
- self.extra_special_tokens = (
- extra_special_tokens.keys() if isinstance(extra_special_tokens, dict) else extra_special_tokens
- )
- def extract_vocab_merges_from_model(self, tiktoken_url: str):
- try:
- from tiktoken.load import load_tiktoken_bpe
- except Exception:
- raise ValueError(
- "`tiktoken` is required to read a `tiktoken` file. Install it with `pip install tiktoken`."
- )
- bpe_ranks = load_tiktoken_bpe(tiktoken_url)
- byte_encoder = bytes_to_unicode()
- def token_bytes_to_string(b):
- return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
- merges = []
- vocab = {}
- for token, rank in bpe_ranks.items():
- vocab[token_bytes_to_string(token)] = rank
- if len(token) == 1:
- continue
- local = []
- for index in range(1, len(token)):
- piece_l, piece_r = token[:index], token[index:]
- if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
- local.append((piece_l, piece_r, rank))
- local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
- merges.extend(local)
- merges = sorted(merges, key=lambda val: val[2], reverse=False)
- merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
- return vocab, merges
- def tokenizer(self):
- vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
- tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
- if hasattr(tokenizer.model, "ignore_merges"):
- tokenizer.model.ignore_merges = True
- return tokenizer
- def converted(self) -> Tokenizer:
- tokenizer = self.tokenizer()
- tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
- [
- pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
- pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
- ]
- )
- tokenizer.decoder = decoders.ByteLevel()
- if self.extra_special_tokens is not None:
- tokenizer.add_special_tokens(
- [AddedToken(token, normalized=False, special=True) for token in self.extra_special_tokens]
- )
- tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
- return tokenizer
- class MistralConverter:
- def __init__(
- self,
- vocab_file=None,
- pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
- add_prefix_space=False,
- additional_special_tokens=None,
- **kwargs,
- ):
- self.vocab_file = vocab_file
- self.pattern = pattern
- self.add_prefix_space = add_prefix_space
- self.additional_special_tokens = (
- additional_special_tokens.keys()
- if isinstance(additional_special_tokens, dict)
- else additional_special_tokens
- )
- def extract_vocab_merges_from_model(self, tiktoken_url: str):
- import base64
- import json
- with open(self.vocab_file, "r", encoding="utf-8") as f:
- untyped = json.load(f)
- self.pattern = untyped["config"]["pattern"]
- self.additional_special_tokens = [
- AddedToken(k["token_str"], special=k["is_control"]) for k in untyped["special_tokens"]
- ]
- bpe_ranks = untyped["vocab"]
- byte_encoder = bytes_to_unicode()
- @lru_cache
- def token_bytes_to_string(b):
- return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
- merges = []
- vocab = {}
- for idx, token in enumerate(self.additional_special_tokens):
- vocab[token.content] = idx
- bpe_ranks = [base64.b64decode(k["token_bytes"]) for k in bpe_ranks]
- rank_set = set(bpe_ranks)
- token_to_rank = {token: rank for rank, token in enumerate(bpe_ranks)}
- for rank, token in enumerate(tqdm(bpe_ranks, desc="Converting tekken.json to tokenizer.json")):
- vocab[token_bytes_to_string(token)] = rank
- if len(token) == 1:
- continue
- local = []
- for index in range(1, len(token)):
- piece_l, piece_r = token[:index], token[index:]
- if piece_l in rank_set and piece_r in rank_set and (piece_l + piece_r) in rank_set:
- local.append((piece_l, piece_r, rank))
- local = sorted(local, key=lambda x: (token_to_rank[x[0]], token_to_rank[x[1]]), reverse=False)
- merges.extend(local)
- merges = sorted(merges, key=lambda val: val[2], reverse=False)
- merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
- return vocab, merges
- def tokenizer(self):
- vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
- tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
- if hasattr(tokenizer.model, "ignore_merges"):
- tokenizer.model.ignore_merges = True
- return tokenizer
- def converted(self) -> Tokenizer:
- tokenizer = self.tokenizer()
- tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
- [
- pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
- pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
- ]
- )
- tokenizer.decoder = decoders.ByteLevel()
- tokenizer.add_tokens(self.additional_special_tokens)
- tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
- return tokenizer
- SLOW_TO_FAST_CONVERTERS = {
- "AlbertTokenizer": AlbertConverter,
- "BartTokenizer": RobertaConverter,
- "BarthezTokenizer": BarthezConverter,
- "BertTokenizer": BertConverter,
- "BigBirdTokenizer": BigBirdConverter,
- "BlenderbotTokenizer": BlenderbotConverter,
- "CamembertTokenizer": CamembertConverter,
- "CLIPTokenizer": CLIPConverter,
- "CodeGenTokenizer": GPT2Converter,
- "ConvBertTokenizer": BertConverter,
- "DebertaTokenizer": DebertaConverter,
- "DebertaV2Tokenizer": DebertaV2Converter,
- "DistilBertTokenizer": BertConverter,
- "DPRReaderTokenizer": BertConverter,
- "DPRQuestionEncoderTokenizer": BertConverter,
- "DPRContextEncoderTokenizer": BertConverter,
- "ElectraTokenizer": BertConverter,
- "FNetTokenizer": AlbertConverter,
- "FunnelTokenizer": FunnelConverter,
- "GPT2Tokenizer": GPT2Converter,
- "HerbertTokenizer": HerbertConverter,
- "LayoutLMTokenizer": BertConverter,
- "LayoutLMv2Tokenizer": BertConverter,
- "LayoutLMv3Tokenizer": RobertaConverter,
- "LayoutXLMTokenizer": XLMRobertaConverter,
- "LongformerTokenizer": RobertaConverter,
- "LEDTokenizer": RobertaConverter,
- "LxmertTokenizer": BertConverter,
- "MarkupLMTokenizer": MarkupLMConverter,
- "MBartTokenizer": MBartConverter,
- "MBart50Tokenizer": MBart50Converter,
- "MPNetTokenizer": MPNetConverter,
- "MobileBertTokenizer": BertConverter,
- "MvpTokenizer": RobertaConverter,
- "NllbTokenizer": NllbConverter,
- "OpenAIGPTTokenizer": OpenAIGPTConverter,
- "PegasusTokenizer": PegasusConverter,
- "Qwen2Tokenizer": Qwen2Converter,
- "ReformerTokenizer": ReformerConverter,
- "RemBertTokenizer": RemBertConverter,
- "RobertaTokenizer": RobertaConverter,
- "RoFormerTokenizer": RoFormerConverter,
- "SeamlessM4TTokenizer": SeamlessM4TConverter,
- "SqueezeBertTokenizer": BertConverter,
- "T5Tokenizer": T5Converter,
- "UdopTokenizer": UdopConverter,
- "WhisperTokenizer": WhisperConverter,
- "XLMRobertaTokenizer": XLMRobertaConverter,
- "XLNetTokenizer": XLNetConverter,
- "SplinterTokenizer": SplinterConverter,
- "XGLMTokenizer": XGLMConverter,
- "LlamaTokenizer": LlamaConverter,
- "CodeLlamaTokenizer": LlamaConverter,
- "GemmaTokenizer": GemmaConverter,
- "Phi3Tokenizer": LlamaConverter,
- }
- def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokenizer:
- """
- Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
- Args:
- transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
- Instance of a slow tokenizer to convert in the backend tokenizer for
- [`~tokenization_utils_base.PreTrainedTokenizerFast`].
- from_tiktoken (bool, optional): Whether to use the `tiktoken` library to convert the tokenizer instead of sentencepiece.
- Defaults to False.
- Return:
- A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
- [`~tokenization_utils_base.PreTrainedTokenizerFast`]
- """
- tokenizer_class_name = transformer_tokenizer.__class__.__name__
- if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken:
- converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
- return converter_class(transformer_tokenizer).converted()
- elif transformer_tokenizer.vocab_file.endswith("tekken.json"):
- transformer_tokenizer.original_tokenizer = transformer_tokenizer
- logger.info("Converting from Mistral tekken.json")
- return MistralConverter(transformer_tokenizer.vocab_file).converted()
- else:
- try:
- logger.info("Converting from Tiktoken")
- return TikTokenConverter(
- vocab_file=transformer_tokenizer.vocab_file,
- extra_special_tokens=transformer_tokenizer.extra_special_tokens,
- ).converted()
- except Exception:
- raise ValueError(
- f"Converting from SentencePiece and Tiktoken failed, if a converter for SentencePiece is available, provide a model path "
- f"with a SentencePiece tokenizer.model file."
- f"Currently available slow->fast converters: {list(SLOW_TO_FAST_CONVERTERS.keys())}"
- )
|