bert_wordpiece.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from typing import Dict, Iterator, List, Optional, Union
  2. from tokenizers import AddedToken, Tokenizer, decoders, trainers
  3. from tokenizers.models import WordPiece
  4. from tokenizers.normalizers import BertNormalizer
  5. from tokenizers.pre_tokenizers import BertPreTokenizer
  6. from tokenizers.processors import BertProcessing
  7. from .base_tokenizer import BaseTokenizer
  8. class BertWordPieceTokenizer(BaseTokenizer):
  9. """Bert WordPiece Tokenizer"""
  10. def __init__(
  11. self,
  12. vocab: Optional[Union[str, Dict[str, int]]] = None,
  13. unk_token: Union[str, AddedToken] = "[UNK]",
  14. sep_token: Union[str, AddedToken] = "[SEP]",
  15. cls_token: Union[str, AddedToken] = "[CLS]",
  16. pad_token: Union[str, AddedToken] = "[PAD]",
  17. mask_token: Union[str, AddedToken] = "[MASK]",
  18. clean_text: bool = True,
  19. handle_chinese_chars: bool = True,
  20. strip_accents: Optional[bool] = None,
  21. lowercase: bool = True,
  22. wordpieces_prefix: str = "##",
  23. ):
  24. if vocab is not None:
  25. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(unk_token)))
  26. else:
  27. tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token)))
  28. # Let the tokenizer know about special tokens if they are part of the vocab
  29. if tokenizer.token_to_id(str(unk_token)) is not None:
  30. tokenizer.add_special_tokens([str(unk_token)])
  31. if tokenizer.token_to_id(str(sep_token)) is not None:
  32. tokenizer.add_special_tokens([str(sep_token)])
  33. if tokenizer.token_to_id(str(cls_token)) is not None:
  34. tokenizer.add_special_tokens([str(cls_token)])
  35. if tokenizer.token_to_id(str(pad_token)) is not None:
  36. tokenizer.add_special_tokens([str(pad_token)])
  37. if tokenizer.token_to_id(str(mask_token)) is not None:
  38. tokenizer.add_special_tokens([str(mask_token)])
  39. tokenizer.normalizer = BertNormalizer(
  40. clean_text=clean_text,
  41. handle_chinese_chars=handle_chinese_chars,
  42. strip_accents=strip_accents,
  43. lowercase=lowercase,
  44. )
  45. tokenizer.pre_tokenizer = BertPreTokenizer()
  46. if vocab is not None:
  47. sep_token_id = tokenizer.token_to_id(str(sep_token))
  48. if sep_token_id is None:
  49. raise TypeError("sep_token not found in the vocabulary")
  50. cls_token_id = tokenizer.token_to_id(str(cls_token))
  51. if cls_token_id is None:
  52. raise TypeError("cls_token not found in the vocabulary")
  53. tokenizer.post_processor = BertProcessing((str(sep_token), sep_token_id), (str(cls_token), cls_token_id))
  54. tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)
  55. parameters = {
  56. "model": "BertWordPiece",
  57. "unk_token": unk_token,
  58. "sep_token": sep_token,
  59. "cls_token": cls_token,
  60. "pad_token": pad_token,
  61. "mask_token": mask_token,
  62. "clean_text": clean_text,
  63. "handle_chinese_chars": handle_chinese_chars,
  64. "strip_accents": strip_accents,
  65. "lowercase": lowercase,
  66. "wordpieces_prefix": wordpieces_prefix,
  67. }
  68. super().__init__(tokenizer, parameters)
  69. @staticmethod
  70. def from_file(vocab: str, **kwargs):
  71. vocab = WordPiece.read_file(vocab)
  72. return BertWordPieceTokenizer(vocab, **kwargs)
  73. def train(
  74. self,
  75. files: Union[str, List[str]],
  76. vocab_size: int = 30000,
  77. min_frequency: int = 2,
  78. limit_alphabet: int = 1000,
  79. initial_alphabet: List[str] = [],
  80. special_tokens: List[Union[str, AddedToken]] = [
  81. "[PAD]",
  82. "[UNK]",
  83. "[CLS]",
  84. "[SEP]",
  85. "[MASK]",
  86. ],
  87. show_progress: bool = True,
  88. wordpieces_prefix: str = "##",
  89. ):
  90. """Train the model using the given files"""
  91. trainer = trainers.WordPieceTrainer(
  92. vocab_size=vocab_size,
  93. min_frequency=min_frequency,
  94. limit_alphabet=limit_alphabet,
  95. initial_alphabet=initial_alphabet,
  96. special_tokens=special_tokens,
  97. show_progress=show_progress,
  98. continuing_subword_prefix=wordpieces_prefix,
  99. )
  100. if isinstance(files, str):
  101. files = [files]
  102. self._tokenizer.train(files, trainer=trainer)
  103. def train_from_iterator(
  104. self,
  105. iterator: Union[Iterator[str], Iterator[Iterator[str]]],
  106. vocab_size: int = 30000,
  107. min_frequency: int = 2,
  108. limit_alphabet: int = 1000,
  109. initial_alphabet: List[str] = [],
  110. special_tokens: List[Union[str, AddedToken]] = [
  111. "[PAD]",
  112. "[UNK]",
  113. "[CLS]",
  114. "[SEP]",
  115. "[MASK]",
  116. ],
  117. show_progress: bool = True,
  118. wordpieces_prefix: str = "##",
  119. length: Optional[int] = None,
  120. ):
  121. """Train the model using the given iterator"""
  122. trainer = trainers.WordPieceTrainer(
  123. vocab_size=vocab_size,
  124. min_frequency=min_frequency,
  125. limit_alphabet=limit_alphabet,
  126. initial_alphabet=initial_alphabet,
  127. special_tokens=special_tokens,
  128. show_progress=show_progress,
  129. continuing_subword_prefix=wordpieces_prefix,
  130. )
  131. self._tokenizer.train_from_iterator(
  132. iterator,
  133. trainer=trainer,
  134. length=length,
  135. )