| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- from typing import Dict, Iterator, List, Optional, Union
- from tokenizers import AddedToken, Tokenizer, decoders, trainers
- from tokenizers.models import WordPiece
- from tokenizers.normalizers import BertNormalizer
- from tokenizers.pre_tokenizers import BertPreTokenizer
- from tokenizers.processors import BertProcessing
- from .base_tokenizer import BaseTokenizer
- class BertWordPieceTokenizer(BaseTokenizer):
- """Bert WordPiece Tokenizer"""
- def __init__(
- self,
- vocab: Optional[Union[str, Dict[str, int]]] = None,
- unk_token: Union[str, AddedToken] = "[UNK]",
- sep_token: Union[str, AddedToken] = "[SEP]",
- cls_token: Union[str, AddedToken] = "[CLS]",
- pad_token: Union[str, AddedToken] = "[PAD]",
- mask_token: Union[str, AddedToken] = "[MASK]",
- clean_text: bool = True,
- handle_chinese_chars: bool = True,
- strip_accents: Optional[bool] = None,
- lowercase: bool = True,
- wordpieces_prefix: str = "##",
- ):
- if vocab is not None:
- tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(unk_token)))
- else:
- tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token)))
- # Let the tokenizer know about special tokens if they are part of the vocab
- if tokenizer.token_to_id(str(unk_token)) is not None:
- tokenizer.add_special_tokens([str(unk_token)])
- if tokenizer.token_to_id(str(sep_token)) is not None:
- tokenizer.add_special_tokens([str(sep_token)])
- if tokenizer.token_to_id(str(cls_token)) is not None:
- tokenizer.add_special_tokens([str(cls_token)])
- if tokenizer.token_to_id(str(pad_token)) is not None:
- tokenizer.add_special_tokens([str(pad_token)])
- if tokenizer.token_to_id(str(mask_token)) is not None:
- tokenizer.add_special_tokens([str(mask_token)])
- tokenizer.normalizer = BertNormalizer(
- clean_text=clean_text,
- handle_chinese_chars=handle_chinese_chars,
- strip_accents=strip_accents,
- lowercase=lowercase,
- )
- tokenizer.pre_tokenizer = BertPreTokenizer()
- if vocab is not None:
- sep_token_id = tokenizer.token_to_id(str(sep_token))
- if sep_token_id is None:
- raise TypeError("sep_token not found in the vocabulary")
- cls_token_id = tokenizer.token_to_id(str(cls_token))
- if cls_token_id is None:
- raise TypeError("cls_token not found in the vocabulary")
- tokenizer.post_processor = BertProcessing((str(sep_token), sep_token_id), (str(cls_token), cls_token_id))
- tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)
- parameters = {
- "model": "BertWordPiece",
- "unk_token": unk_token,
- "sep_token": sep_token,
- "cls_token": cls_token,
- "pad_token": pad_token,
- "mask_token": mask_token,
- "clean_text": clean_text,
- "handle_chinese_chars": handle_chinese_chars,
- "strip_accents": strip_accents,
- "lowercase": lowercase,
- "wordpieces_prefix": wordpieces_prefix,
- }
- super().__init__(tokenizer, parameters)
- @staticmethod
- def from_file(vocab: str, **kwargs):
- vocab = WordPiece.read_file(vocab)
- return BertWordPieceTokenizer(vocab, **kwargs)
- def train(
- self,
- files: Union[str, List[str]],
- vocab_size: int = 30000,
- min_frequency: int = 2,
- limit_alphabet: int = 1000,
- initial_alphabet: List[str] = [],
- special_tokens: List[Union[str, AddedToken]] = [
- "[PAD]",
- "[UNK]",
- "[CLS]",
- "[SEP]",
- "[MASK]",
- ],
- show_progress: bool = True,
- wordpieces_prefix: str = "##",
- ):
- """Train the model using the given files"""
- trainer = trainers.WordPieceTrainer(
- vocab_size=vocab_size,
- min_frequency=min_frequency,
- limit_alphabet=limit_alphabet,
- initial_alphabet=initial_alphabet,
- special_tokens=special_tokens,
- show_progress=show_progress,
- continuing_subword_prefix=wordpieces_prefix,
- )
- if isinstance(files, str):
- files = [files]
- self._tokenizer.train(files, trainer=trainer)
- def train_from_iterator(
- self,
- iterator: Union[Iterator[str], Iterator[Iterator[str]]],
- vocab_size: int = 30000,
- min_frequency: int = 2,
- limit_alphabet: int = 1000,
- initial_alphabet: List[str] = [],
- special_tokens: List[Union[str, AddedToken]] = [
- "[PAD]",
- "[UNK]",
- "[CLS]",
- "[SEP]",
- "[MASK]",
- ],
- show_progress: bool = True,
- wordpieces_prefix: str = "##",
- length: Optional[int] = None,
- ):
- """Train the model using the given iterator"""
- trainer = trainers.WordPieceTrainer(
- vocab_size=vocab_size,
- min_frequency=min_frequency,
- limit_alphabet=limit_alphabet,
- initial_alphabet=initial_alphabet,
- special_tokens=special_tokens,
- show_progress=show_progress,
- continuing_subword_prefix=wordpieces_prefix,
- )
- self._tokenizer.train_from_iterator(
- iterator,
- trainer=trainer,
- length=length,
- )
|