tokenization_layoutlmv2.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868
  1. # Copyright 2021 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Tokenization class for LayoutLMv2. Based on WordPiece.
  16. """
  17. from tokenizers import Tokenizer, decoders, models, normalizers, pre_tokenizers, processors
  18. from ...tokenization_utils_base import (
  19. BatchEncoding,
  20. EncodedInput,
  21. PaddingStrategy,
  22. PreTokenizedInput,
  23. TensorType,
  24. TextInput,
  25. TextInputPair,
  26. TruncationStrategy,
  27. )
  28. from ...tokenization_utils_tokenizers import TokenizersBackend
  29. from ...utils import add_end_docstrings, logging
  30. logger = logging.get_logger(__name__)
  31. VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
  32. # Docstring constants for encode methods
  33. LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING = r"""
  34. add_special_tokens (`bool`, *optional*, defaults to `True`):
  35. Whether or not to encode the sequences with the special tokens relative to their model.
  36. padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
  37. Activates and controls padding. Accepts the following values:
  38. - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
  39. sequence if provided).
  40. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  41. acceptable input length for the model if that argument is not provided.
  42. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
  43. lengths).
  44. truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
  45. Activates and controls truncation. Accepts the following values:
  46. - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
  47. to the maximum acceptable input length for the model if that argument is not provided. This will
  48. truncate token by token, removing a token from the longest sequence in the pair if a pair of
  49. sequences (or a batch of pairs) is provided.
  50. - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
  51. maximum acceptable input length for the model if that argument is not provided. This will only
  52. truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
  53. - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
  54. maximum acceptable input length for the model if that argument is not provided. This will only
  55. truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
  56. - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
  57. greater than the model maximum admissible input size).
  58. max_length (`int`, *optional*):
  59. Controls the maximum length to use by one of the truncation/padding parameters.
  60. If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
  61. is required by one of the truncation/padding parameters. If the model has no specific maximum input
  62. length (like XLNet) truncation/padding to a maximum length will be deactivated.
  63. stride (`int`, *optional*, defaults to 0):
  64. If set to a number along with `max_length`, the overflowing tokens returned when
  65. `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
  66. returned to provide some overlap between truncated and overflowing sequences. The value of this
  67. argument defines the number of overlapping tokens.
  68. pad_to_multiple_of (`int`, *optional*):
  69. If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
  70. the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
  71. """
  72. LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
  73. return_token_type_ids (`bool`, *optional*):
  74. Whether to return token type IDs. If left to the default, will return the token type IDs according to
  75. the specific tokenizer's default, defined by the `return_outputs` attribute.
  76. [What are token type IDs?](../glossary#token-type-ids)
  77. return_attention_mask (`bool`, *optional*):
  78. Whether to return the attention mask. If left to the default, will return the attention mask according
  79. to the specific tokenizer's default, defined by the `return_outputs` attribute.
  80. [What are attention masks?](../glossary#attention-mask)
  81. return_overflowing_tokens (`bool`, *optional*, defaults to `False`):
  82. Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
  83. of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead
  84. of returning overflowing tokens.
  85. return_special_tokens_mask (`bool`, *optional*, defaults to `False`):
  86. Whether or not to return special tokens mask information.
  87. return_offsets_mapping (`bool`, *optional*, defaults to `False`):
  88. Whether or not to return `(char_start, char_end)` for each token.
  89. This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using
  90. Python's tokenizer, this method will raise `NotImplementedError`.
  91. return_length (`bool`, *optional*, defaults to `False`):
  92. Whether or not to return the lengths of the encoded inputs.
  93. verbose (`bool`, *optional*, defaults to `True`):
  94. Whether or not to print more information and warnings.
  95. **kwargs: passed to the `self.tokenize()` method
  96. """
  97. class LayoutLMv2Tokenizer(TokenizersBackend):
  98. r"""
  99. Construct a "fast" LayoutLMv2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.
  100. This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
  101. refer to this superclass for more information regarding those methods.
  102. Args:
  103. vocab_file (`str`):
  104. File containing the vocabulary.
  105. do_lower_case (`bool`, *optional*, defaults to `True`):
  106. Whether or not to lowercase the input when tokenizing.
  107. unk_token (`str`, *optional*, defaults to `"[UNK]"`):
  108. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  109. token instead.
  110. sep_token (`str`, *optional*, defaults to `"[SEP]"`):
  111. The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
  112. sequence classification or for a text and a question for question answering. It is also used as the last
  113. token of a sequence built with special tokens.
  114. pad_token (`str`, *optional*, defaults to `"[PAD]"`):
  115. The token used for padding, for example when batching sequences of different lengths.
  116. cls_token (`str`, *optional*, defaults to `"[CLS]"`):
  117. The classifier token which is used when doing sequence classification (classification of the whole sequence
  118. instead of per-token classification). It is the first token of the sequence when built with special tokens.
  119. mask_token (`str`, *optional*, defaults to `"[MASK]"`):
  120. The token used for masking values. This is the token used when training this model with masked language
  121. modeling. This is the token which the model will try to predict.
  122. cls_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
  123. The bounding box to use for the special [CLS] token.
  124. sep_token_box (`List[int]`, *optional*, defaults to `[1000, 1000, 1000, 1000]`):
  125. The bounding box to use for the special [SEP] token.
  126. pad_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
  127. The bounding box to use for the special [PAD] token.
  128. pad_token_label (`int`, *optional*, defaults to -100):
  129. The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's
  130. CrossEntropyLoss.
  131. only_label_first_subword (`bool`, *optional*, defaults to `True`):
  132. Whether or not to only label the first subword, in case word labels are provided.
  133. tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
  134. Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this
  135. issue](https://github.com/huggingface/transformers/issues/328)).
  136. strip_accents (`bool`, *optional*):
  137. Whether or not to strip all accents. If this option is not specified, then it will be determined by the
  138. value for `lowercase` (as in the original LayoutLMv2).
  139. """
  140. vocab_files_names = VOCAB_FILES_NAMES
  141. model = models.WordPiece
  142. model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
  143. def __init__(
  144. self,
  145. vocab: str | dict[str, int] | None = None,
  146. do_lower_case=True,
  147. unk_token="[UNK]",
  148. sep_token="[SEP]",
  149. pad_token="[PAD]",
  150. cls_token="[CLS]",
  151. mask_token="[MASK]",
  152. cls_token_box=[0, 0, 0, 0],
  153. sep_token_box=[1000, 1000, 1000, 1000],
  154. pad_token_box=[0, 0, 0, 0],
  155. pad_token_label=-100,
  156. only_label_first_subword=True,
  157. tokenize_chinese_chars=True,
  158. strip_accents=None,
  159. model_max_length=512,
  160. **kwargs,
  161. ):
  162. self.do_lower_case = do_lower_case
  163. if vocab is not None:
  164. self._vocab = vocab
  165. else:
  166. self._vocab = {
  167. str(pad_token): 0,
  168. str(unk_token): 1,
  169. str(cls_token): 2,
  170. str(sep_token): 3,
  171. str(mask_token): 4,
  172. }
  173. self._tokenizer = Tokenizer(models.WordPiece(vocab=self._vocab, unk_token=str(unk_token)))
  174. self._tokenizer.normalizer = normalizers.BertNormalizer(
  175. clean_text=True,
  176. handle_chinese_chars=tokenize_chinese_chars,
  177. strip_accents=strip_accents,
  178. lowercase=do_lower_case,
  179. )
  180. self._tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  181. self._tokenizer.decoder = decoders.WordPiece(prefix="##")
  182. super().__init__(
  183. do_lower_case=do_lower_case,
  184. unk_token=unk_token,
  185. sep_token=sep_token,
  186. pad_token=pad_token,
  187. cls_token=cls_token,
  188. mask_token=mask_token,
  189. cls_token_box=cls_token_box,
  190. sep_token_box=sep_token_box,
  191. pad_token_box=pad_token_box,
  192. pad_token_label=pad_token_label,
  193. only_label_first_subword=only_label_first_subword,
  194. tokenize_chinese_chars=tokenize_chinese_chars,
  195. strip_accents=strip_accents,
  196. model_max_length=model_max_length,
  197. **kwargs,
  198. )
  199. self.cls_token_box = cls_token_box
  200. self.sep_token_box = sep_token_box
  201. self.pad_token_box = pad_token_box
  202. self.pad_token_label = pad_token_label
  203. self.only_label_first_subword = only_label_first_subword
  204. # Now set post_processor with actual token IDs
  205. cls = str(self.cls_token)
  206. sep = str(self.sep_token)
  207. cls_token_id = self.cls_token_id
  208. sep_token_id = self.sep_token_id
  209. self._tokenizer.post_processor = processors.TemplateProcessing(
  210. single=f"{cls}:0 $A:0 {sep}:0",
  211. pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
  212. special_tokens=[
  213. (cls, cls_token_id),
  214. (sep, sep_token_id),
  215. ],
  216. )
  217. @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
  218. def __call__(
  219. self,
  220. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput],
  221. text_pair: PreTokenizedInput | list[PreTokenizedInput] | None = None,
  222. boxes: list[list[int]] | list[list[list[int]]] | None = None,
  223. word_labels: list[int] | list[list[int]] | None = None,
  224. add_special_tokens: bool = True,
  225. padding: bool | str | PaddingStrategy = False,
  226. truncation: bool | str | TruncationStrategy = None,
  227. max_length: int | None = None,
  228. stride: int = 0,
  229. pad_to_multiple_of: int | None = None,
  230. padding_side: str | None = None,
  231. return_tensors: str | TensorType | None = None,
  232. return_token_type_ids: bool | None = None,
  233. return_attention_mask: bool | None = None,
  234. return_overflowing_tokens: bool = False,
  235. return_special_tokens_mask: bool = False,
  236. return_offsets_mapping: bool = False,
  237. return_length: bool = False,
  238. verbose: bool = True,
  239. **kwargs,
  240. ) -> BatchEncoding:
  241. """
  242. Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
  243. sequences with word-level normalized bounding boxes and optional labels.
  244. Args:
  245. text (`str`, `List[str]`, `List[List[str]]`):
  246. The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings
  247. (words of a single example or questions of a batch of examples) or a list of list of strings (batch of
  248. words).
  249. text_pair (`List[str]`, `List[List[str]]`):
  250. The sequence or batch of sequences to be encoded. Each sequence should be a list of strings
  251. (pretokenized string).
  252. boxes (`List[List[int]]`, `List[List[List[int]]]`):
  253. Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale.
  254. word_labels (`List[int]`, `List[List[int]]`, *optional*):
  255. Word-level integer labels (for token classification tasks such as FUNSD, CORD).
  256. """
  257. # Input type checking for clearer error
  258. def _is_valid_text_input(t):
  259. if isinstance(t, str):
  260. # Strings are fine
  261. return True
  262. elif isinstance(t, (list, tuple)):
  263. # List are fine as long as they are...
  264. if len(t) == 0:
  265. # ... empty
  266. return True
  267. elif isinstance(t[0], str):
  268. # ... list of strings
  269. return True
  270. elif isinstance(t[0], (list, tuple)):
  271. # ... list with an empty list or with a list of strings
  272. return len(t[0]) == 0 or isinstance(t[0][0], str)
  273. else:
  274. return False
  275. else:
  276. return False
  277. if text_pair is not None:
  278. # in case text + text_pair are provided, text = questions, text_pair = words
  279. if not _is_valid_text_input(text):
  280. raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ")
  281. if not isinstance(text_pair, (list, tuple)):
  282. raise ValueError(
  283. "Words must be of type `List[str]` (single pretokenized example), "
  284. "or `List[List[str]]` (batch of pretokenized examples)."
  285. )
  286. else:
  287. # in case only text is provided => must be words
  288. if not isinstance(text, (list, tuple)):
  289. raise ValueError(
  290. "Words must be of type `List[str]` (single pretokenized example), "
  291. "or `List[List[str]]` (batch of pretokenized examples)."
  292. )
  293. if text_pair is not None:
  294. is_batched = isinstance(text, (list, tuple))
  295. else:
  296. is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
  297. words = text if text_pair is None else text_pair
  298. if boxes is None:
  299. raise ValueError("You must provide corresponding bounding boxes")
  300. if is_batched:
  301. if len(words) != len(boxes):
  302. raise ValueError("You must provide words and boxes for an equal amount of examples")
  303. for words_example, boxes_example in zip(words, boxes):
  304. if len(words_example) != len(boxes_example):
  305. raise ValueError("You must provide as many words as there are bounding boxes")
  306. else:
  307. if len(words) != len(boxes):
  308. raise ValueError("You must provide as many words as there are bounding boxes")
  309. if is_batched:
  310. if text_pair is not None and len(text) != len(text_pair):
  311. raise ValueError(
  312. f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
  313. f" {len(text_pair)}."
  314. )
  315. batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
  316. is_pair = bool(text_pair is not None)
  317. return self.batch_encode_plus(
  318. batch_text_or_text_pairs=batch_text_or_text_pairs,
  319. is_pair=is_pair,
  320. boxes=boxes,
  321. word_labels=word_labels,
  322. add_special_tokens=add_special_tokens,
  323. padding=padding,
  324. truncation=truncation,
  325. max_length=max_length,
  326. stride=stride,
  327. pad_to_multiple_of=pad_to_multiple_of,
  328. padding_side=padding_side,
  329. return_tensors=return_tensors,
  330. return_token_type_ids=return_token_type_ids,
  331. return_attention_mask=return_attention_mask,
  332. return_overflowing_tokens=return_overflowing_tokens,
  333. return_special_tokens_mask=return_special_tokens_mask,
  334. return_offsets_mapping=return_offsets_mapping,
  335. return_length=return_length,
  336. verbose=verbose,
  337. **kwargs,
  338. )
  339. else:
  340. return self.encode_plus(
  341. text=text,
  342. text_pair=text_pair,
  343. boxes=boxes,
  344. word_labels=word_labels,
  345. add_special_tokens=add_special_tokens,
  346. padding=padding,
  347. truncation=truncation,
  348. max_length=max_length,
  349. stride=stride,
  350. pad_to_multiple_of=pad_to_multiple_of,
  351. padding_side=padding_side,
  352. return_tensors=return_tensors,
  353. return_token_type_ids=return_token_type_ids,
  354. return_attention_mask=return_attention_mask,
  355. return_overflowing_tokens=return_overflowing_tokens,
  356. return_special_tokens_mask=return_special_tokens_mask,
  357. return_offsets_mapping=return_offsets_mapping,
  358. return_length=return_length,
  359. verbose=verbose,
  360. **kwargs,
  361. )
  362. @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
  363. def batch_encode_plus(
  364. self,
  365. batch_text_or_text_pairs: list[TextInput] | list[TextInputPair] | list[PreTokenizedInput],
  366. is_pair: bool | None = None,
  367. boxes: list[list[list[int]]] | None = None,
  368. word_labels: list[int] | list[list[int]] | None = None,
  369. add_special_tokens: bool = True,
  370. padding: bool | str | PaddingStrategy = False,
  371. truncation: bool | str | TruncationStrategy = None,
  372. max_length: int | None = None,
  373. stride: int = 0,
  374. pad_to_multiple_of: int | None = None,
  375. padding_side: str | None = None,
  376. return_tensors: str | TensorType | None = None,
  377. return_token_type_ids: bool | None = None,
  378. return_attention_mask: bool | None = None,
  379. return_overflowing_tokens: bool = False,
  380. return_special_tokens_mask: bool = False,
  381. return_offsets_mapping: bool = False,
  382. return_length: bool = False,
  383. verbose: bool = True,
  384. **kwargs,
  385. ) -> BatchEncoding:
  386. # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
  387. padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
  388. padding=padding,
  389. truncation=truncation,
  390. max_length=max_length,
  391. pad_to_multiple_of=pad_to_multiple_of,
  392. verbose=verbose,
  393. **kwargs,
  394. )
  395. return self._batch_encode_plus(
  396. batch_text_or_text_pairs=batch_text_or_text_pairs,
  397. is_pair=is_pair,
  398. boxes=boxes,
  399. word_labels=word_labels,
  400. add_special_tokens=add_special_tokens,
  401. padding_strategy=padding_strategy,
  402. truncation_strategy=truncation_strategy,
  403. max_length=max_length,
  404. stride=stride,
  405. pad_to_multiple_of=pad_to_multiple_of,
  406. padding_side=padding_side,
  407. return_tensors=return_tensors,
  408. return_token_type_ids=return_token_type_ids,
  409. return_attention_mask=return_attention_mask,
  410. return_overflowing_tokens=return_overflowing_tokens,
  411. return_special_tokens_mask=return_special_tokens_mask,
  412. return_offsets_mapping=return_offsets_mapping,
  413. return_length=return_length,
  414. verbose=verbose,
  415. **kwargs,
  416. )
  417. def tokenize(self, text: str, pair: str | None = None, add_special_tokens: bool = False, **kwargs) -> list[str]:
  418. batched_input = [(text, pair)] if pair else [text]
  419. encodings = self._tokenizer.encode_batch(
  420. batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs
  421. )
  422. return encodings[0].tokens if encodings else []
  423. @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
  424. def encode_plus(
  425. self,
  426. text: TextInput | PreTokenizedInput,
  427. text_pair: PreTokenizedInput | None = None,
  428. boxes: list[list[int]] | None = None,
  429. word_labels: list[int] | None = None,
  430. add_special_tokens: bool = True,
  431. padding: bool | str | PaddingStrategy = False,
  432. truncation: bool | str | TruncationStrategy = None,
  433. max_length: int | None = None,
  434. stride: int = 0,
  435. pad_to_multiple_of: int | None = None,
  436. padding_side: str | None = None,
  437. return_tensors: str | TensorType | None = None,
  438. return_token_type_ids: bool | None = None,
  439. return_attention_mask: bool | None = None,
  440. return_overflowing_tokens: bool = False,
  441. return_special_tokens_mask: bool = False,
  442. return_offsets_mapping: bool = False,
  443. return_length: bool = False,
  444. verbose: bool = True,
  445. **kwargs,
  446. ) -> BatchEncoding:
  447. """
  448. Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated,
  449. `__call__` should be used instead.
  450. Args:
  451. text (`str`, `List[str]`, `List[List[str]]`):
  452. The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
  453. text_pair (`List[str]` or `List[int]`, *optional*):
  454. Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
  455. list of list of strings (words of a batch of examples).
  456. """
  457. # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
  458. padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
  459. padding=padding,
  460. truncation=truncation,
  461. max_length=max_length,
  462. pad_to_multiple_of=pad_to_multiple_of,
  463. verbose=verbose,
  464. **kwargs,
  465. )
  466. return self._encode_plus(
  467. text=text,
  468. boxes=boxes,
  469. text_pair=text_pair,
  470. word_labels=word_labels,
  471. add_special_tokens=add_special_tokens,
  472. padding_strategy=padding_strategy,
  473. truncation_strategy=truncation_strategy,
  474. max_length=max_length,
  475. stride=stride,
  476. pad_to_multiple_of=pad_to_multiple_of,
  477. padding_side=padding_side,
  478. return_tensors=return_tensors,
  479. return_token_type_ids=return_token_type_ids,
  480. return_attention_mask=return_attention_mask,
  481. return_overflowing_tokens=return_overflowing_tokens,
  482. return_special_tokens_mask=return_special_tokens_mask,
  483. return_offsets_mapping=return_offsets_mapping,
  484. return_length=return_length,
  485. verbose=verbose,
  486. **kwargs,
  487. )
  488. def _batch_encode_plus(
  489. self,
  490. batch_text_or_text_pairs: list[TextInput] | list[TextInputPair] | list[PreTokenizedInput],
  491. is_pair: bool | None = None,
  492. boxes: list[list[list[int]]] | None = None,
  493. word_labels: list[list[int]] | None = None,
  494. add_special_tokens: bool = True,
  495. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  496. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  497. max_length: int | None = None,
  498. stride: int = 0,
  499. pad_to_multiple_of: int | None = None,
  500. padding_side: str | None = None,
  501. return_tensors: str | None = None,
  502. return_token_type_ids: bool | None = None,
  503. return_attention_mask: bool | None = None,
  504. return_overflowing_tokens: bool = False,
  505. return_special_tokens_mask: bool = False,
  506. return_offsets_mapping: bool = False,
  507. return_length: bool = False,
  508. verbose: bool = True,
  509. ) -> BatchEncoding:
  510. if not isinstance(batch_text_or_text_pairs, list):
  511. raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})")
  512. # Set the truncation and padding strategy and restore the initial configuration
  513. self.set_truncation_and_padding(
  514. padding_strategy=padding_strategy,
  515. truncation_strategy=truncation_strategy,
  516. max_length=max_length,
  517. stride=stride,
  518. pad_to_multiple_of=pad_to_multiple_of,
  519. padding_side=padding_side,
  520. )
  521. if is_pair:
  522. batch_text_or_text_pairs = [(text.split(), text_pair) for text, text_pair in batch_text_or_text_pairs]
  523. encodings = self._tokenizer.encode_batch(
  524. batch_text_or_text_pairs,
  525. add_special_tokens=add_special_tokens,
  526. is_pretokenized=True, # we set this to True as LayoutLMv2 always expects pretokenized inputs
  527. )
  528. # Convert encoding to dict
  529. # `Tokens` has type: Tuple[
  530. # List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
  531. # List[EncodingFast]
  532. # ]
  533. # with nested dimensions corresponding to batch, overflows, sequence length
  534. tokens_and_encodings = [
  535. self._convert_encoding(
  536. encoding=encoding,
  537. return_token_type_ids=return_token_type_ids,
  538. return_attention_mask=return_attention_mask,
  539. return_overflowing_tokens=return_overflowing_tokens,
  540. return_special_tokens_mask=return_special_tokens_mask,
  541. return_offsets_mapping=True
  542. if word_labels is not None
  543. else return_offsets_mapping, # we use offsets to create the labels
  544. return_length=return_length,
  545. verbose=verbose,
  546. )
  547. for encoding in encodings
  548. ]
  549. # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
  550. # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
  551. # (we say ~ because the number of overflow varies with the example in the batch)
  552. #
  553. # To match each overflowing sample with the original sample in the batch
  554. # we add an overflow_to_sample_mapping array (see below)
  555. sanitized_tokens = {}
  556. for key in tokens_and_encodings[0][0]:
  557. stack = [e for item, _ in tokens_and_encodings for e in item[key]]
  558. sanitized_tokens[key] = stack
  559. sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
  560. # If returning overflowing tokens, we need to return a mapping
  561. # from the batch idx to the original sample
  562. if return_overflowing_tokens:
  563. overflow_to_sample_mapping = []
  564. for i, (toks, _) in enumerate(tokens_and_encodings):
  565. overflow_to_sample_mapping += [i] * len(toks["input_ids"])
  566. sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
  567. for input_ids in sanitized_tokens["input_ids"]:
  568. self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
  569. # create the token boxes
  570. token_boxes = []
  571. for batch_index in range(len(sanitized_tokens["input_ids"])):
  572. if return_overflowing_tokens:
  573. original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
  574. else:
  575. original_index = batch_index
  576. token_boxes_example = []
  577. for id, sequence_id, word_id in zip(
  578. sanitized_tokens["input_ids"][batch_index],
  579. sanitized_encodings[batch_index].sequence_ids,
  580. sanitized_encodings[batch_index].word_ids,
  581. ):
  582. if word_id is not None:
  583. if is_pair and sequence_id == 0:
  584. token_boxes_example.append(self.pad_token_box)
  585. else:
  586. token_boxes_example.append(boxes[original_index][word_id])
  587. else:
  588. if id == self.cls_token_id:
  589. token_boxes_example.append(self.cls_token_box)
  590. elif id == self.sep_token_id:
  591. token_boxes_example.append(self.sep_token_box)
  592. elif id == self.pad_token_id:
  593. token_boxes_example.append(self.pad_token_box)
  594. else:
  595. raise ValueError("Id not recognized")
  596. token_boxes.append(token_boxes_example)
  597. sanitized_tokens["bbox"] = token_boxes
  598. # optionally, create the labels
  599. if word_labels is not None:
  600. labels = []
  601. for batch_index in range(len(sanitized_tokens["input_ids"])):
  602. if return_overflowing_tokens:
  603. original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
  604. else:
  605. original_index = batch_index
  606. labels_example = []
  607. for id, offset, word_id in zip(
  608. sanitized_tokens["input_ids"][batch_index],
  609. sanitized_tokens["offset_mapping"][batch_index],
  610. sanitized_encodings[batch_index].word_ids,
  611. ):
  612. if word_id is not None:
  613. if self.only_label_first_subword:
  614. if offset[0] == 0:
  615. # Use the real label id for the first token of the word, and padding ids for the remaining tokens
  616. labels_example.append(word_labels[original_index][word_id])
  617. else:
  618. labels_example.append(self.pad_token_label)
  619. else:
  620. labels_example.append(word_labels[original_index][word_id])
  621. else:
  622. labels_example.append(self.pad_token_label)
  623. labels.append(labels_example)
  624. sanitized_tokens["labels"] = labels
  625. # finally, remove offsets if the user didn't want them
  626. if not return_offsets_mapping:
  627. del sanitized_tokens["offset_mapping"]
  628. return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
  629. def _encode_plus(
  630. self,
  631. text: TextInput | PreTokenizedInput,
  632. text_pair: PreTokenizedInput | None = None,
  633. boxes: list[list[int]] | None = None,
  634. word_labels: list[int] | None = None,
  635. add_special_tokens: bool = True,
  636. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  637. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  638. max_length: int | None = None,
  639. stride: int = 0,
  640. pad_to_multiple_of: int | None = None,
  641. padding_side: str | None = None,
  642. return_tensors: bool | None = None,
  643. return_token_type_ids: bool | None = None,
  644. return_attention_mask: bool | None = None,
  645. return_overflowing_tokens: bool = False,
  646. return_special_tokens_mask: bool = False,
  647. return_offsets_mapping: bool = False,
  648. return_length: bool = False,
  649. verbose: bool = True,
  650. **kwargs,
  651. ) -> BatchEncoding:
  652. # make it a batched input
  653. # 2 options:
  654. # 1) only text, in case text must be a list of str
  655. # 2) text + text_pair, in which case text = str and text_pair a list of str
  656. batched_input = [(text, text_pair)] if text_pair else [text]
  657. batched_boxes = [boxes]
  658. batched_word_labels = [word_labels] if word_labels is not None else None
  659. batched_output = self._batch_encode_plus(
  660. batched_input,
  661. is_pair=bool(text_pair is not None),
  662. boxes=batched_boxes,
  663. word_labels=batched_word_labels,
  664. add_special_tokens=add_special_tokens,
  665. padding_strategy=padding_strategy,
  666. truncation_strategy=truncation_strategy,
  667. max_length=max_length,
  668. stride=stride,
  669. pad_to_multiple_of=pad_to_multiple_of,
  670. padding_side=padding_side,
  671. return_tensors=return_tensors,
  672. return_token_type_ids=return_token_type_ids,
  673. return_attention_mask=return_attention_mask,
  674. return_overflowing_tokens=return_overflowing_tokens,
  675. return_special_tokens_mask=return_special_tokens_mask,
  676. return_offsets_mapping=return_offsets_mapping,
  677. return_length=return_length,
  678. verbose=verbose,
  679. **kwargs,
  680. )
  681. # Return tensor is None, then we can remove the leading batch axis
  682. # Overflowing tokens are returned as a batch of output so we keep them in this case
  683. if return_tensors is None and not return_overflowing_tokens:
  684. batched_output = BatchEncoding(
  685. {
  686. key: value[0] if len(value) > 0 and isinstance(value[0], list) else value
  687. for key, value in batched_output.items()
  688. },
  689. batched_output.encodings,
  690. )
  691. self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
  692. return batched_output
  693. def _pad(
  694. self,
  695. encoded_inputs: dict[str, EncodedInput] | BatchEncoding,
  696. max_length: int | None = None,
  697. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  698. pad_to_multiple_of: int | None = None,
  699. padding_side: str | None = None,
  700. return_attention_mask: bool | None = None,
  701. ) -> dict:
  702. """
  703. Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
  704. Args:
  705. encoded_inputs:
  706. Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
  707. max_length: maximum length of the returned list and optionally padding length (see below).
  708. Will truncate by taking into account the special tokens.
  709. padding_strategy: PaddingStrategy to use for padding.
  710. - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
  711. - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
  712. - PaddingStrategy.DO_NOT_PAD: Do not pad
  713. The tokenizer padding sides are defined in self.padding_side:
  714. - 'left': pads on the left of the sequences
  715. - 'right': pads on the right of the sequences
  716. pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
  717. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
  718. `>= 7.5` (Volta).
  719. padding_side:
  720. The side on which the model should have padding applied. Should be selected between ['right', 'left'].
  721. Default value is picked from the class attribute of the same name.
  722. return_attention_mask:
  723. (optional) Set to False to avoid returning attention mask (default: set to model specifics)
  724. """
  725. # Load from model defaults
  726. if return_attention_mask is None:
  727. return_attention_mask = "attention_mask" in self.model_input_names
  728. required_input = encoded_inputs[self.model_input_names[0]]
  729. if padding_strategy == PaddingStrategy.LONGEST:
  730. max_length = len(required_input)
  731. if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
  732. max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
  733. needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
  734. # Initialize attention mask if not present.
  735. if return_attention_mask and "attention_mask" not in encoded_inputs:
  736. encoded_inputs["attention_mask"] = [1] * len(required_input)
  737. if needs_to_be_padded:
  738. difference = max_length - len(required_input)
  739. padding_side = padding_side if padding_side is not None else self.padding_side
  740. if padding_side == "right":
  741. if return_attention_mask:
  742. encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
  743. if "token_type_ids" in encoded_inputs:
  744. encoded_inputs["token_type_ids"] = (
  745. encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
  746. )
  747. if "bbox" in encoded_inputs:
  748. encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference
  749. if "labels" in encoded_inputs:
  750. encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference
  751. if "special_tokens_mask" in encoded_inputs:
  752. encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
  753. encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
  754. elif padding_side == "left":
  755. if return_attention_mask:
  756. encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
  757. if "token_type_ids" in encoded_inputs:
  758. encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
  759. "token_type_ids"
  760. ]
  761. if "bbox" in encoded_inputs:
  762. encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"]
  763. if "labels" in encoded_inputs:
  764. encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"]
  765. if "special_tokens_mask" in encoded_inputs:
  766. encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
  767. encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
  768. else:
  769. raise ValueError("Invalid padding strategy:" + str(padding_side))
  770. return encoded_inputs
  771. def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
  772. """
  773. Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
  774. adding special tokens. A BERT sequence has the following format:
  775. - single sequence: `[CLS] X [SEP]`
  776. - pair of sequences: `[CLS] A [SEP] B [SEP]`
  777. Args:
  778. token_ids_0 (`List[int]`):
  779. List of IDs to which the special tokens will be added.
  780. token_ids_1 (`List[int]`, *optional*):
  781. Optional second list of IDs for sequence pairs.
  782. Returns:
  783. `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
  784. """
  785. output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
  786. if token_ids_1:
  787. output += token_ids_1 + [self.sep_token_id]
  788. return output
  789. __all__ = ["LayoutLMv2Tokenizer"]
  790. # Backward alias
  791. LayoutLMv2TokenizerFast = LayoutLMv2Tokenizer