tokenization_bartpho.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Copyright 2021 VinAI Research and the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License
  14. """Tokenization classes for BARTpho-syllable model."""
  15. import os
  16. from shutil import copyfile
  17. from typing import Any
  18. from ...tokenization_python import AddedToken
  19. from ...tokenization_utils_sentencepiece import SentencePieceBackend
  20. from ...utils import logging
  21. from ...utils.import_utils import requires
  22. logger = logging.get_logger(__name__)
  23. VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "monolingual_vocab_file": "dict.txt"}
  24. @requires(backends=("sentencepiece",))
  25. class BartphoTokenizer(SentencePieceBackend):
  26. """
  27. Adapted from [`XLMRobertaTokenizer`]. Based on [SentencePiece](https://github.com/google/sentencepiece).
  28. This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
  29. this superclass for more information regarding those methods.
  30. Args:
  31. vocab_file (`str`):
  32. Path to the vocabulary file. This vocabulary is the pre-trained SentencePiece model available from the
  33. multilingual XLM-RoBERTa, also used in mBART, consisting of 250K types.
  34. monolingual_vocab_file (`str`):
  35. Path to the monolingual vocabulary file. This monolingual vocabulary consists of Vietnamese-specialized
  36. types extracted from the multilingual vocabulary vocab_file of 250K types.
  37. bos_token (`str`, *optional*, defaults to `"<s>"`):
  38. The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
  39. <Tip>
  40. When building a sequence using special tokens, this is not the token that is used for the beginning of
  41. sequence. The token used is the `cls_token`.
  42. </Tip>
  43. eos_token (`str`, *optional*, defaults to `"</s>"`):
  44. The end of sequence token.
  45. <Tip>
  46. When building a sequence using special tokens, this is not the token that is used for the end of sequence.
  47. The token used is the `sep_token`.
  48. </Tip>
  49. sep_token (`str`, *optional*, defaults to `"</s>"`):
  50. The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
  51. sequence classification or for a text and a question for question answering. It is also used as the last
  52. token of a sequence built with special tokens.
  53. cls_token (`str`, *optional*, defaults to `"<s>"`):
  54. The classifier token which is used when doing sequence classification (classification of the whole sequence
  55. instead of per-token classification). It is the first token of the sequence when built with special tokens.
  56. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  57. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  58. token instead.
  59. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  60. The token used for padding, for example when batching sequences of different lengths.
  61. mask_token (`str`, *optional*, defaults to `"<mask>"`):
  62. The token used for masking values. This is the token used when training this model with masked language
  63. modeling. This is the token which the model will try to predict.
  64. sp_model_kwargs (`dict`, *optional*):
  65. Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
  66. SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
  67. to set:
  68. - `enable_sampling`: Enable subword regularization.
  69. - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
  70. - `nbest_size = {0,1}`: No sampling is performed.
  71. - `nbest_size > 1`: samples from the nbest_size results.
  72. - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
  73. using forward-filtering-and-backward-sampling algorithm.
  74. - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
  75. BPE-dropout.
  76. Attributes:
  77. sp_model (`SentencePieceProcessor`):
  78. The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
  79. """
  80. vocab_files_names = VOCAB_FILES_NAMES
  81. model_input_names = ["input_ids", "attention_mask"]
  82. is_fast = False
  83. def __init__(
  84. self,
  85. vocab_file,
  86. monolingual_vocab_file,
  87. bos_token="<s>",
  88. eos_token="</s>",
  89. sep_token="</s>",
  90. cls_token="<s>",
  91. unk_token="<unk>",
  92. pad_token="<pad>",
  93. mask_token="<mask>",
  94. sp_model_kwargs: dict[str, Any] | None = None,
  95. **kwargs,
  96. ) -> None:
  97. # Mask token behave like a normal word, i.e. include the space before it
  98. mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
  99. self.monolingual_vocab_file = monolingual_vocab_file
  100. # Load the reduced vocab
  101. # Keep order of special tokens for backward compatibility
  102. self.fairseq_tokens_to_ids = {}
  103. cnt = 0
  104. for token in [bos_token, pad_token, eos_token, unk_token, sep_token, cls_token]:
  105. if str(token) not in self.fairseq_tokens_to_ids:
  106. self.fairseq_tokens_to_ids[str(token)] = cnt
  107. cnt += 1
  108. with open(monolingual_vocab_file, "r", encoding="utf-8") as f:
  109. for line in f:
  110. token = line.strip().split()[0]
  111. self.fairseq_tokens_to_ids[token] = len(self.fairseq_tokens_to_ids)
  112. if str(mask_token) not in self.fairseq_tokens_to_ids:
  113. self.fairseq_tokens_to_ids[str(mask_token)] = len(self.fairseq_tokens_to_ids)
  114. self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
  115. # Prepare sp_model_kwargs for parent class
  116. if sp_model_kwargs is not None:
  117. kwargs["sp_model_kwargs"] = sp_model_kwargs
  118. # Call parent init (which will load sp_model)
  119. super().__init__(
  120. vocab_file=vocab_file,
  121. bos_token=bos_token,
  122. eos_token=eos_token,
  123. unk_token=unk_token,
  124. sep_token=sep_token,
  125. cls_token=cls_token,
  126. pad_token=pad_token,
  127. mask_token=mask_token,
  128. **kwargs,
  129. )
  130. self._align_added_tokens_with_fairseq_vocab()
  131. def build_inputs_with_special_tokens(
  132. self, token_ids_0: list[int], token_ids_1: list[int] | None = None
  133. ) -> list[int]:
  134. """
  135. Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
  136. adding special tokens. An BARTPho sequence has the following format:
  137. - single sequence: `<s> X </s>`
  138. - pair of sequences: `<s> A </s></s> B </s>`
  139. Args:
  140. token_ids_0 (`list[int]`):
  141. List of IDs to which the special tokens will be added.
  142. token_ids_1 (`list[int]`, *optional*):
  143. Optional second list of IDs for sequence pairs.
  144. Returns:
  145. `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
  146. """
  147. if token_ids_1 is None:
  148. return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
  149. cls = [self.cls_token_id]
  150. sep = [self.sep_token_id]
  151. return cls + token_ids_0 + sep + sep + token_ids_1 + sep
  152. def get_special_tokens_mask(
  153. self, token_ids_0: list[int], token_ids_1: list[int] | None = None, already_has_special_tokens: bool = False
  154. ) -> list[int]:
  155. """
  156. Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
  157. special tokens using the tokenizer `prepare_for_model` method.
  158. Args:
  159. token_ids_0 (`list[int]`):
  160. List of IDs.
  161. token_ids_1 (`list[int]`, *optional*):
  162. Optional second list of IDs for sequence pairs.
  163. already_has_special_tokens (`bool`, *optional*, defaults to `False`):
  164. Whether or not the token list is already formatted with special tokens for the model.
  165. Returns:
  166. `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
  167. """
  168. if already_has_special_tokens:
  169. return super().get_special_tokens_mask(
  170. token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
  171. )
  172. if token_ids_1 is None:
  173. return [1] + ([0] * len(token_ids_0)) + [1]
  174. return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
  175. def create_token_type_ids_from_sequences(
  176. self, token_ids_0: list[int], token_ids_1: list[int] | None = None
  177. ) -> list[int]:
  178. """
  179. Create a mask from the two sequences passed to be used in a sequence-pair classification task. BARTPho does not
  180. make use of token type ids, therefore a list of zeros is returned.
  181. Args:
  182. token_ids_0 (`list[int]`):
  183. List of IDs.
  184. token_ids_1 (`list[int]`, *optional*):
  185. Optional second list of IDs for sequence pairs.
  186. Returns:
  187. `list[int]`: List of zeros.
  188. """
  189. sep = [self.sep_token_id]
  190. cls = [self.cls_token_id]
  191. if token_ids_1 is None:
  192. return len(cls + token_ids_0 + sep) * [0]
  193. return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
  194. @property
  195. def vocab_size(self):
  196. """Override to return fairseq vocab size instead of sp_model vocab size"""
  197. return len(self.fairseq_ids_to_tokens)
  198. def get_vocab(self):
  199. """Override to use fairseq vocabulary"""
  200. vocab = dict(self.fairseq_tokens_to_ids)
  201. if hasattr(self, "_added_tokens_encoder"):
  202. for token, idx in self._added_tokens_encoder.items():
  203. if token not in vocab:
  204. vocab[token] = idx
  205. return vocab
  206. def _convert_token_to_id(self, token):
  207. """Converts a token (str) in an id using the fairseq vocab."""
  208. if token in self.fairseq_tokens_to_ids:
  209. return self.fairseq_tokens_to_ids[token]
  210. else:
  211. return self.unk_token_id
  212. def _convert_token_to_id_with_added_voc(self, token):
  213. """Override to use fairseq vocab instead of sp_model vocab."""
  214. if token is None:
  215. return None
  216. if token in self._added_tokens_encoder:
  217. return self._added_tokens_encoder[token]
  218. return self._convert_token_to_id(token)
  219. def _convert_id_to_token(self, index):
  220. """Converts an index (integer) in a token (str) using the fairseq vocab."""
  221. return self.fairseq_ids_to_tokens[index]
  222. def _align_added_tokens_with_fairseq_vocab(self):
  223. """
  224. The slow tokenizer base class populates `_added_tokens_*` using SentencePiece ids. Remap those entries so that
  225. every token present in the reduced fairseq dictionary uses the same ids everywhere, otherwise conversions and
  226. special-token setters observe two different vocabularies.
  227. """
  228. if not hasattr(self, "_added_tokens_decoder") or not hasattr(self, "_added_tokens_encoder"):
  229. return
  230. remapped_decoder: dict[int, AddedToken] = {}
  231. for original_id, token_obj in self._added_tokens_decoder.items():
  232. token = token_obj.content
  233. new_id = self.fairseq_tokens_to_ids.get(token, original_id)
  234. remapped_decoder[new_id] = token_obj
  235. self._added_tokens_decoder = remapped_decoder
  236. self._added_tokens_encoder = {token.content: idx for idx, token in remapped_decoder.items()}
  237. def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
  238. if not os.path.isdir(save_directory):
  239. logger.error(f"Vocabulary path ({save_directory}) should be a directory")
  240. return
  241. out_vocab_file = os.path.join(
  242. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
  243. )
  244. out_monolingual_vocab_file = os.path.join(
  245. save_directory,
  246. (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["monolingual_vocab_file"],
  247. )
  248. if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
  249. copyfile(self.vocab_file, out_vocab_file)
  250. elif not os.path.isfile(self.vocab_file):
  251. with open(out_vocab_file, "wb") as fi:
  252. content_spiece_model = self.sp_model.serialized_model_proto()
  253. fi.write(content_spiece_model)
  254. if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(
  255. out_monolingual_vocab_file
  256. ) and os.path.isfile(self.monolingual_vocab_file):
  257. copyfile(self.monolingual_vocab_file, out_monolingual_vocab_file)
  258. elif not os.path.isfile(self.monolingual_vocab_file):
  259. with open(out_monolingual_vocab_file, "w", encoding="utf-8") as fp:
  260. for token in self.fairseq_tokens_to_ids:
  261. if token not in self.all_special_tokens:
  262. fp.write(f"{str(token)} \n")
  263. return out_vocab_file, out_monolingual_vocab_file
  264. __all__ = ["BartphoTokenizer"]