tokenization_nllb.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # Copyright 2022 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
  15. from tokenizers.models import BPE
  16. from ...tokenization_python import AddedToken, BatchEncoding
  17. from ...tokenization_utils_tokenizers import TokenizersBackend
  18. from ...utils import logging
  19. logger = logging.get_logger(__name__)
  20. VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
  21. FAIRSEQ_LANGUAGE_CODES = ['ace_Arab', 'ace_Latn', 'acm_Arab', 'acq_Arab', 'aeb_Arab', 'afr_Latn', 'ajp_Arab', 'aka_Latn', 'amh_Ethi', 'apc_Arab', 'arb_Arab', 'ars_Arab', 'ary_Arab', 'arz_Arab', 'asm_Beng', 'ast_Latn', 'awa_Deva', 'ayr_Latn', 'azb_Arab', 'azj_Latn', 'bak_Cyrl', 'bam_Latn', 'ban_Latn', 'bel_Cyrl', 'bem_Latn', 'ben_Beng', 'bho_Deva', 'bjn_Arab', 'bjn_Latn', 'bod_Tibt', 'bos_Latn', 'bug_Latn', 'bul_Cyrl', 'cat_Latn', 'ceb_Latn', 'ces_Latn', 'cjk_Latn', 'ckb_Arab', 'crh_Latn', 'cym_Latn', 'dan_Latn', 'deu_Latn', 'dik_Latn', 'dyu_Latn', 'dzo_Tibt', 'ell_Grek', 'eng_Latn', 'epo_Latn', 'est_Latn', 'eus_Latn', 'ewe_Latn', 'fao_Latn', 'pes_Arab', 'fij_Latn', 'fin_Latn', 'fon_Latn', 'fra_Latn', 'fur_Latn', 'fuv_Latn', 'gla_Latn', 'gle_Latn', 'glg_Latn', 'grn_Latn', 'guj_Gujr', 'hat_Latn', 'hau_Latn', 'heb_Hebr', 'hin_Deva', 'hne_Deva', 'hrv_Latn', 'hun_Latn', 'hye_Armn', 'ibo_Latn', 'ilo_Latn', 'ind_Latn', 'isl_Latn', 'ita_Latn', 'jav_Latn', 'jpn_Jpan', 'kab_Latn', 'kac_Latn', 'kam_Latn', 'kan_Knda', 'kas_Arab', 'kas_Deva', 'kat_Geor', 'knc_Arab', 'knc_Latn', 'kaz_Cyrl', 'kbp_Latn', 'kea_Latn', 'khm_Khmr', 'kik_Latn', 'kin_Latn', 'kir_Cyrl', 'kmb_Latn', 'kon_Latn', 'kor_Hang', 'kmr_Latn', 'lao_Laoo', 'lvs_Latn', 'lij_Latn', 'lim_Latn', 'lin_Latn', 'lit_Latn', 'lmo_Latn', 'ltg_Latn', 'ltz_Latn', 'lua_Latn', 'lug_Latn', 'luo_Latn', 'lus_Latn', 'mag_Deva', 'mai_Deva', 'mal_Mlym', 'mar_Deva', 'min_Latn', 'mkd_Cyrl', 'plt_Latn', 'mlt_Latn', 'mni_Beng', 'khk_Cyrl', 'mos_Latn', 'mri_Latn', 'zsm_Latn', 'mya_Mymr', 'nld_Latn', 'nno_Latn', 'nob_Latn', 'npi_Deva', 'nso_Latn', 'nus_Latn', 'nya_Latn', 'oci_Latn', 'gaz_Latn', 'ory_Orya', 'pag_Latn', 'pan_Guru', 'pap_Latn', 'pol_Latn', 'por_Latn', 'prs_Arab', 'pbt_Arab', 'quy_Latn', 'ron_Latn', 'run_Latn', 'rus_Cyrl', 'sag_Latn', 'san_Deva', 'sat_Beng', 'scn_Latn', 'shn_Mymr', 'sin_Sinh', 'slk_Latn', 'slv_Latn', 'smo_Latn', 'sna_Latn', 'snd_Arab', 'som_Latn', 'sot_Latn', 'spa_Latn', 'als_Latn', 'srd_Latn', 'srp_Cyrl', 'ssw_Latn', 'sun_Latn', 'swe_Latn', 'swh_Latn', 'szl_Latn', 'tam_Taml', 'tat_Cyrl', 'tel_Telu', 'tgk_Cyrl', 'tgl_Latn', 'tha_Thai', 'tir_Ethi', 'taq_Latn', 'taq_Tfng', 'tpi_Latn', 'tsn_Latn', 'tso_Latn', 'tuk_Latn', 'tum_Latn', 'tur_Latn', 'twi_Latn', 'tzm_Tfng', 'uig_Arab', 'ukr_Cyrl', 'umb_Latn', 'urd_Arab', 'uzn_Latn', 'vec_Latn', 'vie_Latn', 'war_Latn', 'wol_Latn', 'xho_Latn', 'ydd_Hebr', 'yor_Latn', 'yue_Hant', 'zho_Hans', 'zho_Hant', 'zul_Latn'] # fmt: skip
  22. class NllbTokenizer(TokenizersBackend):
  23. """
  24. Construct an NLLB tokenizer (backed by HuggingFace's *tokenizers* library). Based on
  25. [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models).
  26. This tokenizer inherits from [`TokenizersBackend`] which contains most of the main methods. Users should
  27. refer to this superclass for more information regarding those methods.
  28. The tokenization method is `<tokens> <eos> <language code>` for source language documents, and `<language code>
  29. <tokens> <eos>` for target language documents.
  30. Examples:
  31. ```python
  32. >>> from transformers import NllbTokenizer
  33. >>> tokenizer = NllbTokenizer.from_pretrained(
  34. ... "facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="fra_Latn"
  35. ... )
  36. >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
  37. >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
  38. >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
  39. ```
  40. Args:
  41. vocab_file (`str`, *optional*):
  42. Path to the vocabulary file.
  43. bos_token (`str`, *optional*, defaults to `"<s>"`):
  44. The beginning of sequence token that was used during pretraining.
  45. eos_token (`str`, *optional*, defaults to `"</s>"`):
  46. The end of sequence token.
  47. sep_token (`str`, *optional*, defaults to `"</s>"`):
  48. The separator token.
  49. cls_token (`str`, *optional*, defaults to `"<s>"`):
  50. The classifier token.
  51. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  52. The unknown token.
  53. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  54. The token used for padding.
  55. mask_token (`str`, *optional*, defaults to `"<mask>"`):
  56. The token used for masking values.
  57. src_lang (`str`, *optional*):
  58. The language to use as source language for translation.
  59. tgt_lang (`str`, *optional*):
  60. The language to use as target language for translation.
  61. legacy_behaviour (`bool`, *optional*, defaults to `False`):
  62. Whether to use legacy behaviour (suffix pattern) or new behaviour (prefix pattern).
  63. """
  64. vocab_files_names = VOCAB_FILES_NAMES
  65. model_input_names = ["input_ids", "attention_mask"]
  66. model = BPE
  67. prefix_tokens: list[int] = []
  68. suffix_tokens: list[int] = []
  69. def __init__(
  70. self,
  71. vocab: str | dict[str, int] | None = None,
  72. merges: str | list[str] | None = None,
  73. bos_token="<s>",
  74. eos_token="</s>",
  75. sep_token="</s>",
  76. cls_token="<s>",
  77. unk_token="<unk>",
  78. pad_token="<pad>",
  79. mask_token="<mask>",
  80. src_lang=None,
  81. tgt_lang=None,
  82. _spm_precompiled_charsmap: str | None = None,
  83. additional_special_tokens=None,
  84. extra_special_tokens=None,
  85. legacy_behaviour=False,
  86. **kwargs,
  87. ):
  88. # V5: extra_special_tokens takes precedence over additional_special_tokens (deprecated)
  89. # Handle case where both are passed (ie. from config and user override)
  90. if extra_special_tokens is not None:
  91. additional_special_tokens = extra_special_tokens
  92. elif additional_special_tokens is None:
  93. additional_special_tokens = FAIRSEQ_LANGUAGE_CODES
  94. mask_token = (
  95. AddedToken(mask_token, normalized=True, lstrip=True, special=True)
  96. if isinstance(mask_token, str)
  97. else mask_token
  98. )
  99. self.legacy_behaviour = legacy_behaviour
  100. if vocab is None:
  101. vocab = {
  102. str(bos_token): 0,
  103. str(pad_token): 1,
  104. str(eos_token): 2,
  105. str(unk_token): 3,
  106. }
  107. self._vocab = vocab
  108. self._merges = merges or []
  109. self._tokenizer = Tokenizer(
  110. BPE(
  111. vocab=self._vocab,
  112. merges=self._merges,
  113. dropout=None,
  114. unk_token=str(unk_token),
  115. fuse_unk=True,
  116. byte_fallback=False,
  117. )
  118. )
  119. if _spm_precompiled_charsmap is not None:
  120. self._tokenizer.normalizer = normalizers.Sequence(
  121. [
  122. normalizers.Precompiled(_spm_precompiled_charsmap),
  123. normalizers.Replace(Regex(r" {2,}"), " "),
  124. ]
  125. )
  126. self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True)
  127. self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
  128. super().__init__(
  129. bos_token=bos_token,
  130. eos_token=eos_token,
  131. sep_token=sep_token,
  132. cls_token=cls_token,
  133. unk_token=unk_token,
  134. pad_token=pad_token,
  135. src_lang=src_lang,
  136. tgt_lang=tgt_lang,
  137. mask_token=mask_token,
  138. extra_special_tokens=additional_special_tokens,
  139. legacy_behaviour=legacy_behaviour,
  140. **kwargs,
  141. )
  142. # Build fairseq mappings for backward compatibility
  143. self.fairseq_offset = 1
  144. self.fairseq_tokens_to_ids = {
  145. "<s>": 0,
  146. "<pad>": 1,
  147. "</s>": 2,
  148. "<unk>": 3,
  149. }
  150. self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
  151. self._src_lang = src_lang if src_lang is not None else "eng_Latn"
  152. self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang)
  153. self.tgt_lang = tgt_lang
  154. self.set_src_lang_special_tokens(self._src_lang)
  155. @property
  156. def src_lang(self) -> str:
  157. return self._src_lang
  158. @src_lang.setter
  159. def src_lang(self, new_src_lang: str) -> None:
  160. self._src_lang = new_src_lang
  161. self.set_src_lang_special_tokens(self._src_lang)
  162. def _build_translation_inputs(
  163. self, raw_inputs, return_tensors: str, src_lang: str | None, tgt_lang: str | None, **extra_kwargs
  164. ):
  165. """Used by translation pipeline, to prepare inputs for the generate function"""
  166. if src_lang is None or tgt_lang is None:
  167. raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
  168. self.src_lang = src_lang
  169. inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
  170. tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
  171. inputs["forced_bos_token_id"] = tgt_lang_id
  172. return inputs
  173. def prepare_seq2seq_batch(
  174. self,
  175. src_texts: list[str],
  176. src_lang: str = "eng_Latn",
  177. tgt_texts: list[str] | None = None,
  178. tgt_lang: str = "fra_Latn",
  179. max_length: int | None = None,
  180. max_target_length: int | None = None,
  181. padding: str = "longest",
  182. return_tensors: str | None = None,
  183. truncation: bool = True,
  184. **kwargs,
  185. ) -> BatchEncoding:
  186. self.src_lang = src_lang
  187. self.tgt_lang = tgt_lang
  188. if max_length is None:
  189. max_length = self.model_max_length
  190. model_inputs = self(
  191. src_texts,
  192. add_special_tokens=True,
  193. return_tensors=return_tensors,
  194. max_length=max_length,
  195. padding=padding,
  196. truncation=truncation,
  197. **kwargs,
  198. )
  199. if tgt_texts is None:
  200. return model_inputs
  201. # Process tgt_texts
  202. if max_target_length is None:
  203. max_target_length = max_length
  204. # Switch to target mode to set the right special tokens
  205. self._switch_to_target_mode()
  206. labels = self(
  207. tgt_texts,
  208. add_special_tokens=True,
  209. return_tensors=return_tensors,
  210. padding=padding,
  211. max_length=max_target_length,
  212. truncation=truncation,
  213. **kwargs,
  214. )
  215. model_inputs["labels"] = labels["input_ids"]
  216. # Switch back to input mode
  217. self._switch_to_input_mode()
  218. return model_inputs
  219. def _switch_to_input_mode(self):
  220. return self.set_src_lang_special_tokens(self.src_lang)
  221. def _switch_to_target_mode(self):
  222. if self.tgt_lang is None:
  223. self.tgt_lang = self._src_lang
  224. return self.set_tgt_lang_special_tokens(self.tgt_lang)
  225. def set_src_lang_special_tokens(self, src_lang) -> None:
  226. """Reset the special tokens to the source lang setting.
  227. - In legacy mode: No prefix and suffix=[eos, src_lang_code].
  228. - In default mode: Prefix=[src_lang_code], suffix = [eos]
  229. """
  230. self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
  231. lang_code_token = src_lang
  232. if self.legacy_behaviour:
  233. self.prefix_tokens = []
  234. self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
  235. self._tokenizer.post_processor = processors.TemplateProcessing(
  236. single=["$A", self.eos_token, lang_code_token],
  237. pair=["$A", "$B", self.eos_token, lang_code_token],
  238. special_tokens=[(self.eos_token, self.eos_token_id), (lang_code_token, self.cur_lang_code)],
  239. )
  240. else:
  241. self.prefix_tokens = [self.cur_lang_code]
  242. self.suffix_tokens = [self.eos_token_id]
  243. self._tokenizer.post_processor = processors.TemplateProcessing(
  244. single=[lang_code_token, "$A", self.eos_token],
  245. pair=[lang_code_token, "$A", "$B", self.eos_token],
  246. special_tokens=[(self.eos_token, self.eos_token_id), (lang_code_token, self.cur_lang_code)],
  247. )
  248. def set_tgt_lang_special_tokens(self, lang: str) -> None:
  249. """Reset the special tokens to the target lang setting.
  250. - In legacy mode: No prefix and suffix=[eos, tgt_lang_code].
  251. - In default mode: Prefix=[tgt_lang_code], suffix = [eos]
  252. """
  253. self.cur_lang_code = self.convert_tokens_to_ids(lang)
  254. lang_code_token = lang
  255. if self.legacy_behaviour:
  256. self.prefix_tokens = []
  257. self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
  258. self._tokenizer.post_processor = processors.TemplateProcessing(
  259. single=["$A", self.eos_token, lang_code_token],
  260. pair=["$A", "$B", self.eos_token, lang_code_token],
  261. special_tokens=[(self.eos_token, self.eos_token_id), (lang_code_token, self.cur_lang_code)],
  262. )
  263. else:
  264. self.prefix_tokens = [self.cur_lang_code]
  265. self.suffix_tokens = [self.eos_token_id]
  266. self._tokenizer.post_processor = processors.TemplateProcessing(
  267. single=[lang_code_token, "$A", self.eos_token],
  268. pair=[lang_code_token, "$A", "$B", self.eos_token],
  269. special_tokens=[(self.eos_token, self.eos_token_id), (lang_code_token, self.cur_lang_code)],
  270. )
  271. __all__ = ["NllbTokenizer"]