tokenization_layoutlmv3.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862
  1. # Copyright The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tokenization class for LayoutLMv3. Same as LayoutLMv2, but RoBERTa-like BPE tokenization instead of WordPiece."""
  15. from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors
  16. from ...tokenization_utils_base import (
  17. BatchEncoding,
  18. EncodedInput,
  19. PaddingStrategy,
  20. PreTokenizedInput,
  21. TensorType,
  22. TextInput,
  23. TextInputPair,
  24. TruncationStrategy,
  25. )
  26. from ...tokenization_utils_tokenizers import TokenizersBackend
  27. from ...utils import add_end_docstrings, logging
  28. logger = logging.get_logger(__name__)
  29. VOCAB_FILES_NAMES = {
  30. "vocab_file": "vocab.json",
  31. "merges_file": "merges.txt",
  32. "tokenizer_file": "tokenizer.json",
  33. }
  34. # Docstring constants for encode methods
  35. LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING = r"""
  36. add_special_tokens (`bool`, *optional*, defaults to `True`):
  37. Whether or not to encode the sequences with the special tokens relative to their model.
  38. padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
  39. Activates and controls padding. Accepts the following values:
  40. - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
  41. sequence if provided).
  42. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  43. acceptable input length for the model if that argument is not provided.
  44. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
  45. lengths).
  46. truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
  47. Activates and controls truncation. Accepts the following values:
  48. - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
  49. to the maximum acceptable input length for the model if that argument is not provided. This will
  50. truncate token by token, removing a token from the longest sequence in the pair if a pair of
  51. sequences (or a batch of pairs) is provided.
  52. - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
  53. maximum acceptable input length for the model if that argument is not provided. This will only
  54. truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
  55. - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
  56. maximum acceptable input length for the model if that argument is not provided. This will only
  57. truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
  58. - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
  59. greater than the model maximum admissible input size).
  60. max_length (`int`, *optional*):
  61. Controls the maximum length to use by one of the truncation/padding parameters.
  62. If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
  63. is required by one of the truncation/padding parameters. If the model has no specific maximum input
  64. length (like XLNet) truncation/padding to a maximum length will be deactivated.
  65. stride (`int`, *optional*, defaults to 0):
  66. If set to a number along with `max_length`, the overflowing tokens returned when
  67. `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
  68. returned to provide some overlap between truncated and overflowing sequences. The value of this
  69. argument defines the number of overlapping tokens.
  70. pad_to_multiple_of (`int`, *optional*):
  71. If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
  72. the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
  73. """
  74. LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
  75. return_token_type_ids (`bool`, *optional*):
  76. Whether to return token type IDs. If left to the default, will return the token type IDs according to
  77. the specific tokenizer's default, defined by the `return_outputs` attribute.
  78. [What are token type IDs?](../glossary#token-type-ids)
  79. return_attention_mask (`bool`, *optional*):
  80. Whether to return the attention mask. If left to the default, will return the attention mask according
  81. to the specific tokenizer's default, defined by the `return_outputs` attribute.
  82. [What are attention masks?](../glossary#attention-mask)
  83. return_overflowing_tokens (`bool`, *optional*, defaults to `False`):
  84. Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
  85. of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead
  86. of returning overflowing tokens.
  87. return_special_tokens_mask (`bool`, *optional*, defaults to `False`):
  88. Whether or not to return special tokens mask information.
  89. return_offsets_mapping (`bool`, *optional*, defaults to `False`):
  90. Whether or not to return `(char_start, char_end)` for each token.
  91. This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using
  92. Python's tokenizer, this method will raise `NotImplementedError`.
  93. return_length (`bool`, *optional*, defaults to `False`):
  94. Whether or not to return the lengths of the encoded inputs.
  95. verbose (`bool`, *optional*, defaults to `True`):
  96. Whether or not to print more information and warnings.
  97. **kwargs: passed to the `self.tokenize()` method
  98. """
  99. class LayoutLMv3Tokenizer(TokenizersBackend):
  100. r"""
  101. Construct a LayoutLMv3 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level BPE.
  102. This tokenizer inherits from [`TokenizersBackend`] which contains most of the main methods. Users should
  103. refer to this superclass for more information regarding those methods.
  104. Args:
  105. errors (`str`, *optional*, defaults to `"replace"`):
  106. Paradigm to follow when decoding bytes to UTF-8. See
  107. [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
  108. bos_token (`str`, *optional*, defaults to `"<s>"`):
  109. The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
  110. eos_token (`str`, *optional*, defaults to `"</s>"`):
  111. The end of sequence token.
  112. sep_token (`str`, *optional*, defaults to `"</s>"`):
  113. The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
  114. sequence classification or for a text and a question for question answering. It is also used as the last
  115. token of a sequence built with special tokens.
  116. cls_token (`str`, *optional*, defaults to `"<s>"`):
  117. The classifier token which is used when doing sequence classification (classification of the whole sequence
  118. instead of per-token classification). It is the first token of the sequence when built with special tokens.
  119. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  120. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  121. token instead.
  122. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  123. The token used for padding, for example when batching sequences of different lengths.
  124. mask_token (`str`, *optional*, defaults to `"<mask>"`):
  125. The token used for masking values. This is the token used when training this model with masked language
  126. modeling. This is the token which the model will try to predict.
  127. add_prefix_space (`bool`, *optional*, defaults to `True`):
  128. Whether or not to add an initial space to the input. This allows to treat the leading word just as any
  129. other word.
  130. cls_token_box (`list[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
  131. The bounding box to use for the special [CLS] token.
  132. sep_token_box (`list[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
  133. The bounding box to use for the special [SEP] token.
  134. pad_token_box (`list[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
  135. The bounding box to use for the special [PAD] token.
  136. pad_token_label (`int`, *optional*, defaults to -100):
  137. The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's
  138. CrossEntropyLoss.
  139. only_label_first_subword (`bool`, *optional*, defaults to `True`):
  140. Whether or not to only label the first subword, in case word labels are provided.
  141. vocab (`str` or `dict[str, int]`, *optional*):
  142. Custom vocabulary dictionary. If not provided, vocabulary is loaded from `vocab_file` when using
  143. `from_pretrained`.
  144. merges (`str` or `list[str]`, *optional*):
  145. Custom merges list. If not provided, merges are loaded from `merges_file` when using `from_pretrained`.
  146. """
  147. vocab_files_names = VOCAB_FILES_NAMES
  148. model_input_names = ["input_ids", "attention_mask", "bbox"]
  149. model = models.BPE
  150. def __init__(
  151. self,
  152. errors="replace",
  153. bos_token="<s>",
  154. eos_token="</s>",
  155. sep_token="</s>",
  156. cls_token="<s>",
  157. unk_token="<unk>",
  158. pad_token="<pad>",
  159. mask_token="<mask>",
  160. add_prefix_space=True,
  161. cls_token_box=[0, 0, 0, 0],
  162. sep_token_box=[0, 0, 0, 0],
  163. pad_token_box=[0, 0, 0, 0],
  164. pad_token_label=-100,
  165. only_label_first_subword=True,
  166. vocab: str | dict[str, int] | None = None,
  167. merges: str | list[str] | None = None,
  168. **kwargs,
  169. ):
  170. self.add_prefix_space = add_prefix_space
  171. self._vocab = vocab or {}
  172. self._merges = merges or []
  173. self._tokenizer = Tokenizer(
  174. models.BPE(
  175. vocab=self._vocab,
  176. merges=self._merges,
  177. dropout=None,
  178. continuing_subword_prefix="",
  179. end_of_word_suffix="",
  180. fuse_unk=False,
  181. )
  182. )
  183. self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
  184. self._tokenizer.decoder = decoders.ByteLevel()
  185. super().__init__(
  186. errors=errors,
  187. bos_token=bos_token,
  188. eos_token=eos_token,
  189. sep_token=sep_token,
  190. cls_token=cls_token,
  191. unk_token=unk_token,
  192. pad_token=pad_token,
  193. mask_token=mask_token,
  194. add_prefix_space=add_prefix_space,
  195. cls_token_box=cls_token_box,
  196. sep_token_box=sep_token_box,
  197. pad_token_box=pad_token_box,
  198. pad_token_label=pad_token_label,
  199. only_label_first_subword=only_label_first_subword,
  200. **kwargs,
  201. )
  202. # Now set post_processor with actual token IDs (RoBERTa-style)
  203. cls = str(self.cls_token)
  204. sep = str(self.sep_token)
  205. cls_token_id = self.cls_token_id
  206. sep_token_id = self.sep_token_id
  207. self._tokenizer.post_processor = processors.RobertaProcessing(
  208. sep=(sep, sep_token_id),
  209. cls=(cls, cls_token_id),
  210. add_prefix_space=add_prefix_space,
  211. trim_offsets=True,
  212. )
  213. self.cls_token_box = cls_token_box
  214. self.sep_token_box = sep_token_box
  215. self.pad_token_box = pad_token_box
  216. self.pad_token_label = pad_token_label
  217. self.only_label_first_subword = only_label_first_subword
  218. @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
  219. def __call__(
  220. self,
  221. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput],
  222. text_pair: PreTokenizedInput | list[PreTokenizedInput] | None = None,
  223. boxes: list[list[int]] | list[list[list[int]]] | None = None,
  224. word_labels: list[int] | list[list[int]] | None = None,
  225. add_special_tokens: bool = True,
  226. padding: bool | str | PaddingStrategy = False,
  227. truncation: bool | str | TruncationStrategy = None,
  228. max_length: int | None = None,
  229. stride: int = 0,
  230. pad_to_multiple_of: int | None = None,
  231. padding_side: str | None = None,
  232. return_tensors: str | TensorType | None = None,
  233. return_token_type_ids: bool | None = None,
  234. return_attention_mask: bool | None = None,
  235. return_overflowing_tokens: bool = False,
  236. return_special_tokens_mask: bool = False,
  237. return_offsets_mapping: bool = False,
  238. return_length: bool = False,
  239. verbose: bool = True,
  240. **kwargs,
  241. ) -> BatchEncoding:
  242. """
  243. Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
  244. sequences with word-level normalized bounding boxes and optional labels.
  245. Args:
  246. text (`str`, `List[str]`, `List[List[str]]`):
  247. The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings
  248. (words of a single example or questions of a batch of examples) or a list of list of strings (batch of
  249. words).
  250. text_pair (`List[str]`, `List[List[str]]`):
  251. The sequence or batch of sequences to be encoded. Each sequence should be a list of strings
  252. (pretokenized string).
  253. boxes (`List[List[int]]`, `List[List[List[int]]]`):
  254. Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale.
  255. word_labels (`List[int]`, `List[List[int]]`, *optional*):
  256. Word-level integer labels (for token classification tasks such as FUNSD, CORD).
  257. """
  258. # Input type checking for clearer error
  259. def _is_valid_text_input(t):
  260. if isinstance(t, str):
  261. # Strings are fine
  262. return True
  263. elif isinstance(t, (list, tuple)):
  264. # List are fine as long as they are...
  265. if len(t) == 0:
  266. # ... empty
  267. return True
  268. elif isinstance(t[0], str):
  269. # ... list of strings
  270. return True
  271. elif isinstance(t[0], (list, tuple)):
  272. # ... list with an empty list or with a list of strings
  273. return len(t[0]) == 0 or isinstance(t[0][0], str)
  274. else:
  275. return False
  276. else:
  277. return False
  278. if text_pair is not None:
  279. # in case text + text_pair are provided, text = questions, text_pair = words
  280. if not _is_valid_text_input(text):
  281. raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ")
  282. if not isinstance(text_pair, (list, tuple)):
  283. raise ValueError(
  284. "Words must be of type `List[str]` (single pretokenized example), "
  285. "or `List[List[str]]` (batch of pretokenized examples)."
  286. )
  287. else:
  288. # in case only text is provided => must be words
  289. if not isinstance(text, (list, tuple)):
  290. raise ValueError(
  291. "Words must be of type `List[str]` (single pretokenized example), "
  292. "or `List[List[str]]` (batch of pretokenized examples)."
  293. )
  294. if text_pair is not None:
  295. is_batched = isinstance(text, (list, tuple))
  296. else:
  297. is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
  298. words = text if text_pair is None else text_pair
  299. if boxes is None:
  300. raise ValueError("You must provide corresponding bounding boxes")
  301. if is_batched:
  302. if len(words) != len(boxes):
  303. raise ValueError("You must provide words and boxes for an equal amount of examples")
  304. for words_example, boxes_example in zip(words, boxes):
  305. if len(words_example) != len(boxes_example):
  306. raise ValueError("You must provide as many words as there are bounding boxes")
  307. else:
  308. if len(words) != len(boxes):
  309. raise ValueError("You must provide as many words as there are bounding boxes")
  310. if is_batched:
  311. if text_pair is not None and len(text) != len(text_pair):
  312. raise ValueError(
  313. f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
  314. f" {len(text_pair)}."
  315. )
  316. batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
  317. is_pair = bool(text_pair is not None)
  318. return self.batch_encode_plus(
  319. batch_text_or_text_pairs=batch_text_or_text_pairs,
  320. is_pair=is_pair,
  321. boxes=boxes,
  322. word_labels=word_labels,
  323. add_special_tokens=add_special_tokens,
  324. padding=padding,
  325. truncation=truncation,
  326. max_length=max_length,
  327. stride=stride,
  328. pad_to_multiple_of=pad_to_multiple_of,
  329. padding_side=padding_side,
  330. return_tensors=return_tensors,
  331. return_token_type_ids=return_token_type_ids,
  332. return_attention_mask=return_attention_mask,
  333. return_overflowing_tokens=return_overflowing_tokens,
  334. return_special_tokens_mask=return_special_tokens_mask,
  335. return_offsets_mapping=return_offsets_mapping,
  336. return_length=return_length,
  337. verbose=verbose,
  338. **kwargs,
  339. )
  340. else:
  341. return self.encode_plus(
  342. text=text,
  343. text_pair=text_pair,
  344. boxes=boxes,
  345. word_labels=word_labels,
  346. add_special_tokens=add_special_tokens,
  347. padding=padding,
  348. truncation=truncation,
  349. max_length=max_length,
  350. stride=stride,
  351. pad_to_multiple_of=pad_to_multiple_of,
  352. padding_side=padding_side,
  353. return_tensors=return_tensors,
  354. return_token_type_ids=return_token_type_ids,
  355. return_attention_mask=return_attention_mask,
  356. return_overflowing_tokens=return_overflowing_tokens,
  357. return_special_tokens_mask=return_special_tokens_mask,
  358. return_offsets_mapping=return_offsets_mapping,
  359. return_length=return_length,
  360. verbose=verbose,
  361. **kwargs,
  362. )
  363. @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
  364. def batch_encode_plus(
  365. self,
  366. batch_text_or_text_pairs: list[TextInput] | list[TextInputPair] | list[PreTokenizedInput],
  367. is_pair: bool | None = None,
  368. boxes: list[list[list[int]]] | None = None,
  369. word_labels: list[int] | list[list[int]] | None = None,
  370. add_special_tokens: bool = True,
  371. padding: bool | str | PaddingStrategy = False,
  372. truncation: bool | str | TruncationStrategy = None,
  373. max_length: int | None = None,
  374. stride: int = 0,
  375. pad_to_multiple_of: int | None = None,
  376. padding_side: str | None = None,
  377. return_tensors: str | TensorType | None = None,
  378. return_token_type_ids: bool | None = None,
  379. return_attention_mask: bool | None = None,
  380. return_overflowing_tokens: bool = False,
  381. return_special_tokens_mask: bool = False,
  382. return_offsets_mapping: bool = False,
  383. return_length: bool = False,
  384. verbose: bool = True,
  385. **kwargs,
  386. ) -> BatchEncoding:
  387. # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
  388. padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
  389. padding=padding,
  390. truncation=truncation,
  391. max_length=max_length,
  392. pad_to_multiple_of=pad_to_multiple_of,
  393. verbose=verbose,
  394. **kwargs,
  395. )
  396. return self._batch_encode_plus(
  397. batch_text_or_text_pairs=batch_text_or_text_pairs,
  398. is_pair=is_pair,
  399. boxes=boxes,
  400. word_labels=word_labels,
  401. add_special_tokens=add_special_tokens,
  402. padding_strategy=padding_strategy,
  403. truncation_strategy=truncation_strategy,
  404. max_length=max_length,
  405. stride=stride,
  406. pad_to_multiple_of=pad_to_multiple_of,
  407. padding_side=padding_side,
  408. return_tensors=return_tensors,
  409. return_token_type_ids=return_token_type_ids,
  410. return_attention_mask=return_attention_mask,
  411. return_overflowing_tokens=return_overflowing_tokens,
  412. return_special_tokens_mask=return_special_tokens_mask,
  413. return_offsets_mapping=return_offsets_mapping,
  414. return_length=return_length,
  415. verbose=verbose,
  416. **kwargs,
  417. )
  418. def tokenize(self, text: str, pair: str | None = None, add_special_tokens: bool = False, **kwargs) -> list[str]:
  419. batched_input = [(text, pair)] if pair else [text]
  420. encodings = self._tokenizer.encode_batch(
  421. batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs
  422. )
  423. return encodings[0].tokens if encodings else []
  424. @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
  425. def encode_plus(
  426. self,
  427. text: TextInput | PreTokenizedInput,
  428. text_pair: PreTokenizedInput | None = None,
  429. boxes: list[list[int]] | None = None,
  430. word_labels: list[int] | None = None,
  431. add_special_tokens: bool = True,
  432. padding: bool | str | PaddingStrategy = False,
  433. truncation: bool | str | TruncationStrategy = None,
  434. max_length: int | None = None,
  435. stride: int = 0,
  436. pad_to_multiple_of: int | None = None,
  437. padding_side: str | None = None,
  438. return_tensors: str | TensorType | None = None,
  439. return_token_type_ids: bool | None = None,
  440. return_attention_mask: bool | None = None,
  441. return_overflowing_tokens: bool = False,
  442. return_special_tokens_mask: bool = False,
  443. return_offsets_mapping: bool = False,
  444. return_length: bool = False,
  445. verbose: bool = True,
  446. **kwargs,
  447. ) -> BatchEncoding:
  448. """
  449. Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated,
  450. `__call__` should be used instead.
  451. Args:
  452. text (`str`, `List[str]`, `List[List[str]]`):
  453. The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
  454. text_pair (`List[str]` or `List[int]`, *optional*):
  455. Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
  456. list of list of strings (words of a batch of examples).
  457. """
  458. # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
  459. padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
  460. padding=padding,
  461. truncation=truncation,
  462. max_length=max_length,
  463. pad_to_multiple_of=pad_to_multiple_of,
  464. verbose=verbose,
  465. **kwargs,
  466. )
  467. return self._encode_plus(
  468. text=text,
  469. boxes=boxes,
  470. text_pair=text_pair,
  471. word_labels=word_labels,
  472. add_special_tokens=add_special_tokens,
  473. padding_strategy=padding_strategy,
  474. truncation_strategy=truncation_strategy,
  475. max_length=max_length,
  476. stride=stride,
  477. pad_to_multiple_of=pad_to_multiple_of,
  478. padding_side=padding_side,
  479. return_tensors=return_tensors,
  480. return_token_type_ids=return_token_type_ids,
  481. return_attention_mask=return_attention_mask,
  482. return_overflowing_tokens=return_overflowing_tokens,
  483. return_special_tokens_mask=return_special_tokens_mask,
  484. return_offsets_mapping=return_offsets_mapping,
  485. return_length=return_length,
  486. verbose=verbose,
  487. **kwargs,
  488. )
  489. def _batch_encode_plus(
  490. self,
  491. batch_text_or_text_pairs: list[TextInput] | list[TextInputPair] | list[PreTokenizedInput],
  492. is_pair: bool | None = None,
  493. boxes: list[list[list[int]]] | None = None,
  494. word_labels: list[list[int]] | None = None,
  495. add_special_tokens: bool = True,
  496. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  497. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  498. max_length: int | None = None,
  499. stride: int = 0,
  500. pad_to_multiple_of: int | None = None,
  501. padding_side: str | None = None,
  502. return_tensors: str | None = None,
  503. return_token_type_ids: bool | None = None,
  504. return_attention_mask: bool | None = None,
  505. return_overflowing_tokens: bool = False,
  506. return_special_tokens_mask: bool = False,
  507. return_offsets_mapping: bool = False,
  508. return_length: bool = False,
  509. verbose: bool = True,
  510. ) -> BatchEncoding:
  511. if not isinstance(batch_text_or_text_pairs, list):
  512. raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})")
  513. # Set the truncation and padding strategy and restore the initial configuration
  514. self.set_truncation_and_padding(
  515. padding_strategy=padding_strategy,
  516. truncation_strategy=truncation_strategy,
  517. max_length=max_length,
  518. stride=stride,
  519. pad_to_multiple_of=pad_to_multiple_of,
  520. padding_side=padding_side,
  521. )
  522. if is_pair:
  523. batch_text_or_text_pairs = [
  524. (text.split() if isinstance(text, str) else text, text_pair)
  525. for text, text_pair in batch_text_or_text_pairs
  526. ]
  527. encodings = self._tokenizer.encode_batch(
  528. batch_text_or_text_pairs,
  529. add_special_tokens=add_special_tokens,
  530. is_pretokenized=True, # we set this to True as LayoutLMv3 always expects pretokenized inputs
  531. )
  532. # Convert encoding to dict
  533. tokens_and_encodings = [
  534. self._convert_encoding(
  535. encoding=encoding,
  536. return_token_type_ids=return_token_type_ids,
  537. return_attention_mask=return_attention_mask,
  538. return_overflowing_tokens=return_overflowing_tokens,
  539. return_special_tokens_mask=return_special_tokens_mask,
  540. return_offsets_mapping=True
  541. if word_labels is not None
  542. else return_offsets_mapping, # we use offsets to create the labels
  543. return_length=return_length,
  544. verbose=verbose,
  545. )
  546. for encoding in encodings
  547. ]
  548. # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
  549. sanitized_tokens = {}
  550. for key in tokens_and_encodings[0][0]:
  551. stack = [e for item, _ in tokens_and_encodings for e in item[key]]
  552. sanitized_tokens[key] = stack
  553. sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
  554. # If returning overflowing tokens, we need to return a mapping
  555. # from the batch idx to the original sample
  556. if return_overflowing_tokens:
  557. overflow_to_sample_mapping = []
  558. for i, (toks, _) in enumerate(tokens_and_encodings):
  559. overflow_to_sample_mapping += [i] * len(toks["input_ids"])
  560. sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
  561. for input_ids in sanitized_tokens["input_ids"]:
  562. self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
  563. # create the token boxes
  564. token_boxes = []
  565. for batch_index in range(len(sanitized_tokens["input_ids"])):
  566. if return_overflowing_tokens:
  567. original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
  568. else:
  569. original_index = batch_index
  570. token_boxes_example = []
  571. for id, sequence_id, word_id in zip(
  572. sanitized_tokens["input_ids"][batch_index],
  573. sanitized_encodings[batch_index].sequence_ids,
  574. sanitized_encodings[batch_index].word_ids,
  575. ):
  576. if word_id is not None:
  577. if is_pair and sequence_id == 0:
  578. token_boxes_example.append(self.pad_token_box)
  579. else:
  580. token_boxes_example.append(boxes[original_index][word_id])
  581. else:
  582. if id == self.cls_token_id:
  583. token_boxes_example.append(self.cls_token_box)
  584. elif id == self.sep_token_id:
  585. token_boxes_example.append(self.sep_token_box)
  586. elif id == self.pad_token_id:
  587. token_boxes_example.append(self.pad_token_box)
  588. else:
  589. # For RoBERTa, there might be additional sep tokens
  590. token_boxes_example.append(self.sep_token_box)
  591. token_boxes.append(token_boxes_example)
  592. sanitized_tokens["bbox"] = token_boxes
  593. # optionally, create the labels
  594. if word_labels is not None:
  595. labels = []
  596. for batch_index in range(len(sanitized_tokens["input_ids"])):
  597. if return_overflowing_tokens:
  598. original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
  599. else:
  600. original_index = batch_index
  601. labels_example = []
  602. previous_word_id = None
  603. for id, offset, word_id in zip(
  604. sanitized_tokens["input_ids"][batch_index],
  605. sanitized_tokens["offset_mapping"][batch_index],
  606. sanitized_encodings[batch_index].word_ids,
  607. ):
  608. if word_id is not None:
  609. if self.only_label_first_subword:
  610. # Check if this is the first token of the word (word_id changed or is first occurrence)
  611. if word_id != previous_word_id:
  612. # Use the real label id for the first token of the word, and padding ids for the remaining tokens
  613. labels_example.append(word_labels[original_index][word_id])
  614. else:
  615. labels_example.append(self.pad_token_label)
  616. else:
  617. labels_example.append(word_labels[original_index][word_id])
  618. previous_word_id = word_id
  619. else:
  620. labels_example.append(self.pad_token_label)
  621. previous_word_id = None
  622. labels.append(labels_example)
  623. sanitized_tokens["labels"] = labels
  624. # finally, remove offsets if the user didn't want them
  625. if not return_offsets_mapping:
  626. del sanitized_tokens["offset_mapping"]
  627. return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
  628. def _encode_plus(
  629. self,
  630. text: TextInput | PreTokenizedInput,
  631. text_pair: PreTokenizedInput | None = None,
  632. boxes: list[list[int]] | None = None,
  633. word_labels: list[int] | None = None,
  634. add_special_tokens: bool = True,
  635. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  636. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  637. max_length: int | None = None,
  638. stride: int = 0,
  639. pad_to_multiple_of: int | None = None,
  640. padding_side: str | None = None,
  641. return_tensors: bool | None = None,
  642. return_token_type_ids: bool | None = None,
  643. return_attention_mask: bool | None = None,
  644. return_overflowing_tokens: bool = False,
  645. return_special_tokens_mask: bool = False,
  646. return_offsets_mapping: bool = False,
  647. return_length: bool = False,
  648. verbose: bool = True,
  649. **kwargs,
  650. ) -> BatchEncoding:
  651. # make it a batched input
  652. # 2 options:
  653. # 1) only text, in case text must be a list of str
  654. # 2) text + text_pair, in which case text = str and text_pair a list of str
  655. batched_input = [(text, text_pair)] if text_pair else [text]
  656. batched_boxes = [boxes]
  657. batched_word_labels = [word_labels] if word_labels is not None else None
  658. batched_output = self._batch_encode_plus(
  659. batched_input,
  660. is_pair=bool(text_pair is not None),
  661. boxes=batched_boxes,
  662. word_labels=batched_word_labels,
  663. add_special_tokens=add_special_tokens,
  664. padding_strategy=padding_strategy,
  665. truncation_strategy=truncation_strategy,
  666. max_length=max_length,
  667. stride=stride,
  668. pad_to_multiple_of=pad_to_multiple_of,
  669. padding_side=padding_side,
  670. return_tensors=return_tensors,
  671. return_token_type_ids=return_token_type_ids,
  672. return_attention_mask=return_attention_mask,
  673. return_overflowing_tokens=return_overflowing_tokens,
  674. return_special_tokens_mask=return_special_tokens_mask,
  675. return_offsets_mapping=return_offsets_mapping,
  676. return_length=return_length,
  677. verbose=verbose,
  678. **kwargs,
  679. )
  680. # Return tensor is None, then we can remove the leading batch axis
  681. # Overflowing tokens are returned as a batch of output so we keep them in this case
  682. if return_tensors is None and not return_overflowing_tokens:
  683. batched_output = BatchEncoding(
  684. {
  685. key: value[0] if len(value) > 0 and isinstance(value[0], list) else value
  686. for key, value in batched_output.items()
  687. },
  688. batched_output.encodings,
  689. )
  690. self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
  691. return batched_output
  692. def _pad(
  693. self,
  694. encoded_inputs: dict[str, EncodedInput] | BatchEncoding,
  695. max_length: int | None = None,
  696. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  697. pad_to_multiple_of: int | None = None,
  698. padding_side: str | None = None,
  699. return_attention_mask: bool | None = None,
  700. ) -> dict:
  701. """
  702. Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
  703. Args:
  704. encoded_inputs:
  705. Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
  706. max_length: maximum length of the returned list and optionally padding length (see below).
  707. Will truncate by taking into account the special tokens.
  708. padding_strategy: PaddingStrategy to use for padding.
  709. - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
  710. - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
  711. - PaddingStrategy.DO_NOT_PAD: Do not pad
  712. The tokenizer padding sides are defined in self.padding_side:
  713. - 'left': pads on the left of the sequences
  714. - 'right': pads on the right of the sequences
  715. pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
  716. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
  717. `>= 7.5` (Volta).
  718. padding_side:
  719. The side on which the model should have padding applied. Should be selected between ['right', 'left'].
  720. Default value is picked from the class attribute of the same name.
  721. return_attention_mask:
  722. (optional) Set to False to avoid returning attention mask (default: set to model specifics)
  723. """
  724. # Load from model defaults
  725. if return_attention_mask is None:
  726. return_attention_mask = "attention_mask" in self.model_input_names
  727. required_input = encoded_inputs[self.model_input_names[0]]
  728. if padding_strategy == PaddingStrategy.LONGEST:
  729. max_length = len(required_input)
  730. if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
  731. max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
  732. needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
  733. # Initialize attention mask if not present.
  734. if return_attention_mask and "attention_mask" not in encoded_inputs:
  735. encoded_inputs["attention_mask"] = [1] * len(required_input)
  736. if needs_to_be_padded:
  737. difference = max_length - len(required_input)
  738. padding_side = padding_side if padding_side is not None else self.padding_side
  739. if padding_side == "right":
  740. if return_attention_mask:
  741. encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
  742. if "token_type_ids" in encoded_inputs:
  743. encoded_inputs["token_type_ids"] = (
  744. encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
  745. )
  746. if "bbox" in encoded_inputs:
  747. encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference
  748. if "labels" in encoded_inputs:
  749. encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference
  750. if "special_tokens_mask" in encoded_inputs:
  751. encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
  752. encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
  753. elif padding_side == "left":
  754. if return_attention_mask:
  755. encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
  756. if "token_type_ids" in encoded_inputs:
  757. encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
  758. "token_type_ids"
  759. ]
  760. if "bbox" in encoded_inputs:
  761. encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"]
  762. if "labels" in encoded_inputs:
  763. encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"]
  764. if "special_tokens_mask" in encoded_inputs:
  765. encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
  766. encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
  767. else:
  768. raise ValueError("Invalid padding strategy:" + str(padding_side))
  769. return encoded_inputs
  770. def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
  771. """
  772. Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
  773. adding special tokens. A RoBERTa sequence has the following format:
  774. - single sequence: `<s> X </s>`
  775. - pair of sequences: `<s> A </s></s> B </s>`
  776. Args:
  777. token_ids_0 (`List[int]`):
  778. List of IDs to which the special tokens will be added.
  779. token_ids_1 (`List[int]`, *optional*):
  780. Optional second list of IDs for sequence pairs.
  781. Returns:
  782. `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
  783. """
  784. if token_ids_1 is None:
  785. return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
  786. cls = [self.cls_token_id]
  787. sep = [self.sep_token_id]
  788. return cls + token_ids_0 + sep + sep + token_ids_1 + sep
  789. __all__ = ["LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast"]
  790. # Backward alias
  791. LayoutLMv3TokenizerFast = LayoutLMv3Tokenizer