tokenization_plbart.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # Copyright 2022, UCLA NLP, The Facebook AI Research Team Authors and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any
  15. from ...tokenization_python import BatchEncoding
  16. from ...tokenization_utils_base import AddedToken
  17. from ...tokenization_utils_sentencepiece import SentencePieceBackend
  18. from ...utils import logging
  19. from ...utils.import_utils import requires
  20. logger = logging.get_logger(__name__)
  21. SPIECE_UNDERLINE = "▁"
  22. VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
  23. FAIRSEQ_LANGUAGE_CODES = {
  24. "base": ["__java__", "__python__", "__en_XX__"],
  25. "multi": ["__java__", "__python__", "__en_XX__", "__javascript__", "__php__", "__ruby__", "__go__"],
  26. }
  27. FAIRSEQ_LANGUAGE_CODES_MAP = {
  28. "java": "__java__",
  29. "python": "__python__",
  30. "en_XX": "__en_XX__",
  31. "javascript": "__javascript__",
  32. "php": "__php__",
  33. "ruby": "__ruby__",
  34. "go": "__go__",
  35. }
  36. @requires(backends=("sentencepiece",))
  37. class PLBartTokenizer(SentencePieceBackend):
  38. """
  39. Construct an PLBART tokenizer.
  40. Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
  41. [SentencePiece](https://github.com/google/sentencepiece).
  42. The tokenization method is `<tokens> <eos> <language code>` for source language documents, and `<language code>
  43. <tokens> <eos>` for target language documents.
  44. Args:
  45. vocab_file (`str`):
  46. Path to the vocabulary file.
  47. src_lang (`str`, *optional*):
  48. A string representing the source language.
  49. tgt_lang (`str`, *optional*):
  50. A string representing the target language.
  51. bos_token (`str`, *optional*, defaults to `"<s>"`):
  52. The start of sequence token.
  53. eos_token (`str`, *optional*, defaults to `"</s>"`):
  54. The end of sequence token.
  55. sep_token (`str`, *optional*, defaults to `"</s>"`):
  56. The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
  57. sequence classification or for a text and a question for question answering. It is also used as the last
  58. token of a sequence built with special tokens.
  59. cls_token (`str`, *optional*, defaults to `"<s>"`):
  60. The cls token, which is a special token used as the first token for all tasks.
  61. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  62. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  63. token instead.
  64. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  65. The token used for padding, for example when batching sequences of different lengths.
  66. mask_token(`str`, *optional*, defaults to `"<mask>"`):
  67. The token used for masking values. This is the token used when training this model with masking tasks. This
  68. is only used in the `"base"` tokenizer type. For `"multi"` tokenizer, masking is never done for the
  69. downstream tasks.
  70. language_codes (`str`, *optional*, defaults to `"base"`):
  71. What language codes to use. Should be one of `"base"` or `"multi"`.
  72. sp_model_kwargs (`dict`, *optional*):
  73. Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
  74. SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
  75. to set:
  76. - `enable_sampling`: Enable subword regularization.
  77. - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
  78. - `nbest_size = {0,1}`: No sampling is performed.
  79. - `nbest_size > 1`: samples from the nbest_size results.
  80. - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
  81. using forward-filtering-and-backward-sampling algorithm.
  82. - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
  83. BPE-dropout.
  84. Examples:
  85. ```python
  86. >>> from transformers import PLBartTokenizer
  87. >>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX")
  88. >>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
  89. >>> expected_translation_english = "Returns the maximum value of a b c."
  90. >>> inputs = tokenizer(example_python_phrase, text_target=expected_translation_english, return_tensors="pt")
  91. ```"""
  92. vocab_files_names = VOCAB_FILES_NAMES
  93. model_input_names = ["input_ids", "attention_mask"]
  94. prefix_tokens: list[int] = []
  95. suffix_tokens: list[int] = []
  96. def __init__(
  97. self,
  98. vocab_file,
  99. bos_token="<s>",
  100. eos_token="</s>",
  101. sep_token="</s>",
  102. cls_token="<s>",
  103. unk_token="<unk>",
  104. pad_token="<pad>",
  105. mask_token="<mask>",
  106. language_codes="base",
  107. src_lang=None,
  108. tgt_lang=None,
  109. sp_model_kwargs: dict[str, Any] | None = None,
  110. additional_special_tokens=None,
  111. clean_up_tokenization_spaces=True,
  112. **kwargs,
  113. ):
  114. # Mask token behave like a normal word, i.e. include the space before it
  115. mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
  116. self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
  117. src_lang = self._convert_lang_code_special_format(src_lang)
  118. tgt_lang = self._convert_lang_code_special_format(tgt_lang)
  119. self.language_codes = language_codes
  120. fairseq_language_codes = FAIRSEQ_LANGUAGE_CODES[self.language_codes]
  121. # Original fairseq vocab and spm vocab must be "aligned":
  122. # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
  123. # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
  124. # fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
  125. # spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
  126. # Mimic fairseq token-to-id alignment for the first 4 token
  127. self.vocab_file = vocab_file
  128. self.lang_code_to_id = {}
  129. self.id_to_lang_code = {}
  130. self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
  131. self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
  132. self.fairseq_offset = 1
  133. _additional_special_tokens = list(fairseq_language_codes)
  134. if additional_special_tokens is not None:
  135. _additional_special_tokens.extend(
  136. [t for t in additional_special_tokens if t not in _additional_special_tokens]
  137. )
  138. super().__init__(
  139. vocab_file=vocab_file,
  140. bos_token=bos_token,
  141. eos_token=eos_token,
  142. unk_token=unk_token,
  143. sep_token=sep_token,
  144. cls_token=cls_token,
  145. pad_token=pad_token,
  146. mask_token=mask_token,
  147. src_lang=src_lang,
  148. tgt_lang=tgt_lang,
  149. additional_special_tokens=_additional_special_tokens,
  150. sp_model_kwargs=self.sp_model_kwargs,
  151. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  152. language_codes=language_codes,
  153. special_tokens_pattern="prefix_suffix",
  154. token_type_ids_pattern="all_zeros",
  155. **kwargs,
  156. )
  157. # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
  158. self.sp_model_size = len(self.sp_model)
  159. self.lang_code_to_id = {
  160. code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(fairseq_language_codes)
  161. }
  162. self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
  163. self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
  164. if self.language_codes == "base":
  165. self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
  166. self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
  167. self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
  168. reserved_tokens = {"<s>", "<pad>", "</s>", "<unk>", "<mask>"}
  169. reserved_tokens.update(FAIRSEQ_LANGUAGE_CODES[self.language_codes])
  170. removed = False
  171. for token in reserved_tokens:
  172. idx = self._added_tokens_encoder.pop(token, None)
  173. if idx is not None:
  174. self._added_tokens_decoder.pop(idx, None)
  175. removed = True
  176. if removed:
  177. self._update_trie()
  178. self._update_total_vocab_size()
  179. synced = False
  180. for token, idx in self._added_tokens_encoder.items():
  181. if idx in self._added_tokens_decoder:
  182. continue
  183. self._added_tokens_decoder[idx] = AddedToken(
  184. token, special=True, normalized=False, lstrip=False, rstrip=False
  185. )
  186. synced = True
  187. if synced:
  188. self._update_trie()
  189. self._update_total_vocab_size()
  190. if self.language_codes == "base":
  191. self._src_lang = src_lang
  192. self.cur_lang_code_id = (
  193. self.lang_code_to_id[self._src_lang] if self._src_lang is not None else self._src_lang
  194. )
  195. else:
  196. self._src_lang = src_lang if src_lang is not None else "__en_XX__"
  197. self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
  198. self.tgt_lang = tgt_lang
  199. self.set_src_lang_special_tokens(self._src_lang)
  200. @property
  201. def vocab_size(self):
  202. lang_code_count = len(getattr(self, "lang_code_to_id", {}))
  203. fairseq_offset = getattr(self, "fairseq_offset", 1)
  204. base_vocab = len(self.sp_model) if hasattr(self, "sp_model") else 0
  205. if getattr(self, "language_codes", "base") == "base":
  206. return base_vocab + lang_code_count + fairseq_offset + 1 # +1 for mask token
  207. return base_vocab + lang_code_count + fairseq_offset
  208. def get_vocab(self):
  209. """Override to use fairseq vocabulary structure"""
  210. vocab = self.fairseq_tokens_to_ids.copy()
  211. for i in range(self.sp_model.get_piece_size()):
  212. sp_token = self.sp_model.IdToPiece(i)
  213. # Map SP token to fairseq ID: SP ID 0 maps to unk_token_id, others map to SP_ID + fairseq_offset
  214. vocab_id = self.unk_token_id if i == 0 else (i + self.fairseq_offset)
  215. if sp_token not in vocab:
  216. vocab[sp_token] = vocab_id
  217. # Add any additional tokens
  218. vocab.update({token: idx for token, idx in self._added_tokens_encoder.items() if token not in vocab})
  219. return vocab
  220. @property
  221. def src_lang(self) -> str:
  222. return self._src_lang
  223. @src_lang.setter
  224. def src_lang(self, new_src_lang: str) -> None:
  225. new_src_lang = self._convert_lang_code_special_format(new_src_lang)
  226. self._src_lang = new_src_lang
  227. self.set_src_lang_special_tokens(self._src_lang)
  228. def _build_translation_inputs(
  229. self, raw_inputs, return_tensors: str, src_lang: str | None, tgt_lang: str | None, **extra_kwargs
  230. ):
  231. """Used by translation pipeline, to prepare inputs for the generate function"""
  232. if src_lang is None or tgt_lang is None:
  233. raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
  234. self.src_lang = self._convert_lang_code_special_format(src_lang)
  235. self.tgt_lang = self._convert_lang_code_special_format(tgt_lang)
  236. inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
  237. tgt_lang_id = self.convert_tokens_to_ids(self.tgt_lang)
  238. inputs["forced_bos_token_id"] = tgt_lang_id
  239. return inputs
  240. def _convert_token_to_id(self, token):
  241. """Converts a token (str) in an id using the vocab."""
  242. if token in self.fairseq_tokens_to_ids:
  243. return self.fairseq_tokens_to_ids[token]
  244. spm_id = self.sp_model.PieceToId(token)
  245. # Need to return unknown token if the SP model returned 0
  246. return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
  247. def _convert_id_to_token(self, index):
  248. """Converts an index (integer) in a token (str) using the vocab."""
  249. if index in self.fairseq_ids_to_tokens:
  250. return self.fairseq_ids_to_tokens[index]
  251. return self.sp_model.IdToPiece(index - self.fairseq_offset)
  252. def prepare_seq2seq_batch(
  253. self,
  254. src_texts: list[str],
  255. src_lang: str = "en_XX",
  256. tgt_texts: list[str] | None = None,
  257. tgt_lang: str = "python",
  258. **kwargs,
  259. ) -> BatchEncoding:
  260. self.src_lang = self._convert_lang_code_special_format(src_lang)
  261. self.tgt_lang = self._convert_lang_code_special_format(tgt_lang)
  262. return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
  263. def _switch_to_input_mode(self):
  264. return self.set_src_lang_special_tokens(self.src_lang)
  265. def _switch_to_target_mode(self):
  266. return self.set_tgt_lang_special_tokens(self.tgt_lang)
  267. def set_src_lang_special_tokens(self, src_lang) -> None:
  268. """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
  269. src_lang = self._convert_lang_code_special_format(src_lang)
  270. self.cur_lang_code = self.lang_code_to_id[src_lang] if src_lang is not None else None
  271. self.prefix_tokens = []
  272. if self.cur_lang_code is not None:
  273. self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
  274. else:
  275. self.suffix_tokens = [self.eos_token_id]
  276. def set_tgt_lang_special_tokens(self, lang: str) -> None:
  277. """Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
  278. lang = self._convert_lang_code_special_format(lang)
  279. self.cur_lang_code = self.lang_code_to_id[lang] if lang is not None else None
  280. self.prefix_tokens = []
  281. if self.cur_lang_code is not None:
  282. self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
  283. else:
  284. self.suffix_tokens = [self.eos_token_id]
  285. def _convert_lang_code_special_format(self, lang: str) -> str:
  286. """Convert Language Codes to format tokenizer uses if required"""
  287. lang = FAIRSEQ_LANGUAGE_CODES_MAP.get(lang, lang)
  288. return lang
  289. def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=None, **kwargs):
  290. """Override to use self.clean_up_tokenization_spaces as default for batched input."""
  291. return super().decode(
  292. token_ids=token_ids,
  293. skip_special_tokens=skip_special_tokens,
  294. clean_up_tokenization_spaces=self.clean_up_tokenization_spaces,
  295. **kwargs,
  296. )
  297. __all__ = ["PLBartTokenizer"]