tokenization_markuplm.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005
  1. # Copyright 2022 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from tokenizers import Tokenizer, decoders, pre_tokenizers, processors
  15. from tokenizers.models import BPE
  16. from ...tokenization_utils_base import (
  17. ENCODE_KWARGS_DOCSTRING,
  18. AddedToken,
  19. BatchEncoding,
  20. EncodedInput,
  21. PaddingStrategy,
  22. PreTokenizedInput,
  23. TensorType,
  24. TextInput,
  25. TextInputPair,
  26. TruncationStrategy,
  27. )
  28. from ...tokenization_utils_tokenizers import TokenizersBackend
  29. from ...utils import add_end_docstrings, logging
  30. logger = logging.get_logger(__name__)
  31. VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
  32. MARKUPLM_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
  33. add_special_tokens (`bool`, *optional*, defaults to `True`):
  34. Whether or not to encode the sequences with the special tokens relative to their model.
  35. padding (`bool`, `str` or [`~tokenization_utils_base.PaddingStrategy`], *optional*, defaults to `False`):
  36. Activates and controls padding. Accepts the following values:
  37. - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
  38. sequence if provided).
  39. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  40. acceptable input length for the model if that argument is not provided.
  41. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
  42. lengths).
  43. truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
  44. Activates and controls truncation. Accepts the following values:
  45. - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
  46. to the maximum acceptable input length for the model if that argument is not provided. This will
  47. truncate token by token, removing a token from the longest sequence in the pair if a pair of
  48. sequences (or a batch of pairs) is provided.
  49. - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
  50. maximum acceptable input length for the model if that argument is not provided. This will only
  51. truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
  52. - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
  53. maximum acceptable input length for the model if that argument is not provided. This will only
  54. truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
  55. - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
  56. greater than the model maximum admissible input size).
  57. max_length (`int`, *optional*):
  58. Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to
  59. `None`, this will use the predefined model maximum length if a maximum length is required by one of the
  60. truncation/padding parameters. If the model has no specific maximum input length (like XLNet)
  61. truncation/padding to a maximum length will be deactivated.
  62. stride (`int`, *optional*, defaults to 0):
  63. If set to a number along with `max_length`, the overflowing tokens returned when
  64. `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
  65. returned to provide some overlap between truncated and overflowing sequences. The value of this
  66. argument defines the number of overlapping tokens.
  67. is_split_into_words (`bool`, *optional*, defaults to `False`):
  68. Whether or not the input is already pretokenized (e.g. split into words). Set this to `True` if you are
  69. passing pretokenized inputs to avoid additional tokenization.
  70. pad_to_multiple_of (`int`, *optional*):
  71. If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
  72. the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
  73. return_tensors (`str` or [`~tokenization_utils_base.TensorType`], *optional*):
  74. If set, will return tensors instead of list of python integers. Acceptable values are:
  75. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  76. - `'np'`: Return Numpy `np.ndarray` objects.
  77. """
  78. class MarkupLMTokenizer(TokenizersBackend):
  79. r"""
  80. Construct a MarkupLM tokenizer. Based on byte-level Byte-Pair-Encoding (BPE).
  81. [`MarkupLMTokenizer`] can be used to turn HTML strings into to token-level `input_ids`, `attention_mask`,
  82. `token_type_ids`, `xpath_tags_seq` and `xpath_tags_seq`. This tokenizer inherits from [`TokenizersBackend`] which
  83. contains most of the main methods and ensures a `tokenizers` backend is always instantiated.
  84. Users should refer to this superclass for more information regarding those methods.
  85. Args:
  86. vocab (`str` or `dict[str, int]`, *optional*):
  87. Custom vocabulary dictionary. If not provided, the vocabulary is loaded from `vocab_file`.
  88. merges (`str` or `list[str]`, *optional*):
  89. Custom merges list. If not provided, merges are loaded from `merges_file`.
  90. errors (`str`, *optional*, defaults to `"replace"`):
  91. Paradigm to follow when decoding bytes to UTF-8. See
  92. [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
  93. bos_token (`str`, *optional*, defaults to `"<s>"`):
  94. The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
  95. <Tip>
  96. When building a sequence using special tokens, this is not the token that is used for the beginning of
  97. sequence. The token used is the `cls_token`.
  98. </Tip>
  99. eos_token (`str`, *optional*, defaults to `"</s>"`):
  100. The end of sequence token.
  101. <Tip>
  102. When building a sequence using special tokens, this is not the token that is used for the end of sequence.
  103. The token used is the `sep_token`.
  104. </Tip>
  105. sep_token (`str`, *optional*, defaults to `"</s>"`):
  106. The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
  107. sequence classification or for a text and a question for question answering. It is also used as the last
  108. token of a sequence built with special tokens.
  109. cls_token (`str`, *optional*, defaults to `"<s>"`):
  110. The classifier token which is used when doing sequence classification (classification of the whole sequence
  111. instead of per-token classification). It is the first token of the sequence when built with special tokens.
  112. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  113. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  114. token instead.
  115. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  116. The token used for padding, for example when batching sequences of different lengths.
  117. mask_token (`str`, *optional*, defaults to `"<mask>"`):
  118. The token used for masking values. This is the token used when training this model with masked language
  119. modeling. This is the token which the model will try to predict.
  120. add_prefix_space (`bool`, *optional*, defaults to `False`):
  121. Whether or not to add an initial space to the input. This allows to treat the leading word just as any
  122. other word. (RoBERTa tokenizer detect beginning of words by the preceding space).
  123. """
  124. vocab_files_names = VOCAB_FILES_NAMES
  125. model_input_names = ["input_ids", "token_type_ids", "attention_mask"]
  126. model = BPE
  127. def __init__(
  128. self,
  129. tags_dict,
  130. vocab: str | dict[str, int] | list[tuple[str, float]] | None = None,
  131. merges: str | list[str] | None = None,
  132. errors="replace",
  133. bos_token="<s>",
  134. eos_token="</s>",
  135. sep_token="</s>",
  136. cls_token="<s>",
  137. unk_token="<unk>",
  138. pad_token="<pad>",
  139. mask_token="<mask>",
  140. add_prefix_space=False,
  141. max_depth=50,
  142. max_width=1000,
  143. pad_width=1001,
  144. pad_token_label=-100,
  145. only_label_first_subword=True,
  146. trim_offsets=False,
  147. **kwargs,
  148. ):
  149. bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
  150. eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
  151. sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
  152. cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
  153. unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
  154. pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
  155. # Mask token behave like a normal word, i.e. include the space before it
  156. mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
  157. if vocab is None:
  158. vocab = {
  159. str(pad_token): 0,
  160. str(unk_token): 1,
  161. str(cls_token): 2,
  162. str(sep_token): 3,
  163. str(mask_token): 4,
  164. }
  165. merges = merges or []
  166. tokenizer = Tokenizer(
  167. BPE(
  168. vocab=vocab,
  169. merges=merges,
  170. dropout=None,
  171. continuing_subword_prefix="",
  172. end_of_word_suffix="",
  173. fuse_unk=False,
  174. )
  175. )
  176. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
  177. tokenizer.decoder = decoders.ByteLevel()
  178. self._vocab = vocab
  179. self._merges = merges
  180. self._tokenizer = tokenizer
  181. super().__init__(
  182. tags_dict=tags_dict,
  183. errors=errors,
  184. bos_token=bos_token,
  185. eos_token=eos_token,
  186. unk_token=unk_token,
  187. sep_token=sep_token,
  188. cls_token=cls_token,
  189. pad_token=pad_token,
  190. mask_token=mask_token,
  191. add_prefix_space=add_prefix_space,
  192. trim_offsets=trim_offsets,
  193. max_depth=max_depth,
  194. max_width=max_width,
  195. pad_width=pad_width,
  196. pad_token_label=pad_token_label,
  197. only_label_first_subword=only_label_first_subword,
  198. **kwargs,
  199. )
  200. sep_token_str = str(sep_token)
  201. cls_token_str = str(cls_token)
  202. cls_token_id = self.cls_token_id
  203. sep_token_id = self.sep_token_id
  204. self._tokenizer.post_processor = processors.TemplateProcessing(
  205. single=f"{cls_token_str} $A {sep_token_str}",
  206. pair=f"{cls_token_str} $A {sep_token_str} $B {sep_token_str}",
  207. special_tokens=[
  208. (cls_token_str, cls_token_id),
  209. (sep_token_str, sep_token_id),
  210. ],
  211. )
  212. self.tags_dict = tags_dict
  213. # additional properties
  214. self.max_depth = max_depth
  215. self.max_width = max_width
  216. self.pad_width = pad_width
  217. self.unk_tag_id = len(self.tags_dict)
  218. self.pad_tag_id = self.unk_tag_id + 1
  219. self.pad_xpath_tags_seq = [self.pad_tag_id] * self.max_depth
  220. self.pad_xpath_subs_seq = [self.pad_width] * self.max_depth
  221. self.pad_token_label = pad_token_label
  222. self.only_label_first_subword = only_label_first_subword
  223. def get_xpath_seq(self, xpath):
  224. """
  225. Given the xpath expression of one particular node (like "/html/body/div/li[1]/div/span[2]"), return a list of
  226. tag IDs and corresponding subscripts, taking into account max depth.
  227. """
  228. xpath_tags_list = []
  229. xpath_subs_list = []
  230. xpath_units = xpath.split("/")
  231. for unit in xpath_units:
  232. if not unit.strip():
  233. continue
  234. name_subs = unit.strip().split("[")
  235. tag_name = name_subs[0]
  236. sub = 0 if len(name_subs) == 1 else int(name_subs[1][:-1])
  237. xpath_tags_list.append(self.tags_dict.get(tag_name, self.unk_tag_id))
  238. xpath_subs_list.append(min(self.max_width, sub))
  239. xpath_tags_list = xpath_tags_list[: self.max_depth]
  240. xpath_subs_list = xpath_subs_list[: self.max_depth]
  241. xpath_tags_list += [self.pad_tag_id] * (self.max_depth - len(xpath_tags_list))
  242. xpath_subs_list += [self.pad_width] * (self.max_depth - len(xpath_subs_list))
  243. return xpath_tags_list, xpath_subs_list
  244. @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, MARKUPLM_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
  245. def __call__(
  246. self,
  247. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput],
  248. text_pair: PreTokenizedInput | list[PreTokenizedInput] | None = None,
  249. xpaths: list[list[int]] | list[list[list[int]]] | None = None,
  250. node_labels: list[int] | list[list[int]] | None = None,
  251. add_special_tokens: bool = True,
  252. padding: bool | str | PaddingStrategy = False,
  253. truncation: bool | str | TruncationStrategy = None,
  254. max_length: int | None = None,
  255. stride: int = 0,
  256. is_split_into_words: bool = False,
  257. pad_to_multiple_of: int | None = None,
  258. padding_side: str | None = None,
  259. return_tensors: str | TensorType | None = None,
  260. return_token_type_ids: bool | None = None,
  261. return_attention_mask: bool | None = None,
  262. return_overflowing_tokens: bool = False,
  263. return_special_tokens_mask: bool = False,
  264. return_offsets_mapping: bool = False,
  265. return_length: bool = False,
  266. verbose: bool = True,
  267. **kwargs,
  268. ) -> BatchEncoding:
  269. """
  270. Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
  271. sequences with nodes, xpaths and optional labels.
  272. Args:
  273. text (`str`, `list[str]`, `list[list[str]]`):
  274. The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings
  275. (words of a single example or questions of a batch of examples) or a list of list of strings (batch of
  276. words).
  277. text_pair (`list[str]`, `list[list[str]]`):
  278. The sequence or batch of sequences to be encoded. Each sequence should be a list of strings
  279. (pretokenized string).
  280. xpaths (`list[list[int]]`, `list[list[list[int]]]`):
  281. Node-level xpaths. Each bounding box should be normalized to be on a 0-1000 scale.
  282. node_labels (`list[int]`, `list[list[int]]`, *optional*):
  283. Node-level integer labels (for token classification tasks).
  284. is_split_into_words (`bool`, *optional*):
  285. Set to `True` if the inputs are already provided as pretokenized word lists.
  286. """
  287. placeholder_xpath = "/document/node"
  288. if isinstance(text, tuple):
  289. text = list(text)
  290. if text_pair is not None and isinstance(text_pair, tuple):
  291. text_pair = list(text_pair)
  292. if xpaths is None and not is_split_into_words:
  293. nodes_source = text if text_pair is None else text_pair
  294. if isinstance(nodes_source, tuple):
  295. nodes_source = list(nodes_source)
  296. processed_nodes = nodes_source
  297. if isinstance(nodes_source, str):
  298. processed_nodes = nodes_source.split()
  299. elif isinstance(nodes_source, list):
  300. if nodes_source and isinstance(nodes_source[0], str):
  301. requires_split = any(" " in entry for entry in nodes_source)
  302. if requires_split:
  303. processed_nodes = [entry.split() for entry in nodes_source]
  304. else:
  305. processed_nodes = nodes_source
  306. elif nodes_source and isinstance(nodes_source[0], tuple):
  307. processed_nodes = [list(sample) for sample in nodes_source]
  308. if text_pair is None:
  309. text = processed_nodes
  310. else:
  311. text_pair = processed_nodes
  312. if isinstance(processed_nodes, list) and processed_nodes and isinstance(processed_nodes[0], (list, tuple)):
  313. xpaths = [[placeholder_xpath] * len(sample) for sample in processed_nodes]
  314. else:
  315. length = len(processed_nodes) if hasattr(processed_nodes, "__len__") else 0
  316. xpaths = [placeholder_xpath] * length
  317. def _is_valid_text_input(t):
  318. if isinstance(t, str):
  319. return True
  320. if isinstance(t, (list, tuple)):
  321. if len(t) == 0:
  322. return True
  323. if isinstance(t[0], str):
  324. return True
  325. if isinstance(t[0], (list, tuple)):
  326. return len(t[0]) == 0 or isinstance(t[0][0], str)
  327. return False
  328. if text_pair is not None:
  329. # in case text + text_pair are provided, text = questions, text_pair = nodes
  330. if not _is_valid_text_input(text):
  331. raise ValueError("text input must of type `str` (single example) or `list[str]` (batch of examples). ")
  332. if not isinstance(text_pair, (list, tuple)):
  333. raise ValueError(
  334. "Nodes must be of type `list[str]` (single pretokenized example), "
  335. "or `list[list[str]]` (batch of pretokenized examples)."
  336. )
  337. is_batched = isinstance(text, (list, tuple))
  338. else:
  339. # in case only text is provided => must be nodes
  340. if not isinstance(text, (list, tuple)):
  341. raise ValueError(
  342. "Nodes must be of type `list[str]` (single pretokenized example), "
  343. "or `list[list[str]]` (batch of pretokenized examples)."
  344. )
  345. is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
  346. nodes = text if text_pair is None else text_pair
  347. assert xpaths is not None, "You must provide corresponding xpaths"
  348. if is_batched:
  349. assert len(nodes) == len(xpaths), "You must provide nodes and xpaths for an equal amount of examples"
  350. for nodes_example, xpaths_example in zip(nodes, xpaths):
  351. assert len(nodes_example) == len(xpaths_example), "You must provide as many nodes as there are xpaths"
  352. else:
  353. assert len(nodes) == len(xpaths), "You must provide as many nodes as there are xpaths"
  354. if is_batched:
  355. if text_pair is not None and len(text) != len(text_pair):
  356. raise ValueError(
  357. f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
  358. f" {len(text_pair)}."
  359. )
  360. batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
  361. is_pair = bool(text_pair is not None)
  362. return self.batch_encode_plus(
  363. batch_text_or_text_pairs=batch_text_or_text_pairs,
  364. is_pair=is_pair,
  365. xpaths=xpaths,
  366. node_labels=node_labels,
  367. add_special_tokens=add_special_tokens,
  368. padding=padding,
  369. truncation=truncation,
  370. max_length=max_length,
  371. stride=stride,
  372. pad_to_multiple_of=pad_to_multiple_of,
  373. padding_side=padding_side,
  374. return_tensors=return_tensors,
  375. return_token_type_ids=return_token_type_ids,
  376. return_attention_mask=return_attention_mask,
  377. return_overflowing_tokens=return_overflowing_tokens,
  378. return_special_tokens_mask=return_special_tokens_mask,
  379. return_offsets_mapping=return_offsets_mapping,
  380. return_length=return_length,
  381. verbose=verbose,
  382. **kwargs,
  383. )
  384. else:
  385. return self.encode_plus(
  386. text=text,
  387. text_pair=text_pair,
  388. xpaths=xpaths,
  389. node_labels=node_labels,
  390. add_special_tokens=add_special_tokens,
  391. padding=padding,
  392. truncation=truncation,
  393. max_length=max_length,
  394. stride=stride,
  395. pad_to_multiple_of=pad_to_multiple_of,
  396. padding_side=padding_side,
  397. return_tensors=return_tensors,
  398. return_token_type_ids=return_token_type_ids,
  399. return_attention_mask=return_attention_mask,
  400. return_overflowing_tokens=return_overflowing_tokens,
  401. return_special_tokens_mask=return_special_tokens_mask,
  402. return_offsets_mapping=return_offsets_mapping,
  403. return_length=return_length,
  404. verbose=verbose,
  405. **kwargs,
  406. )
  407. @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, MARKUPLM_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
  408. def batch_encode_plus(
  409. self,
  410. batch_text_or_text_pairs: list[TextInput] | list[TextInputPair] | list[PreTokenizedInput],
  411. is_pair: bool | None = None,
  412. xpaths: list[list[list[int]]] | None = None,
  413. node_labels: list[int] | list[list[int]] | None = None,
  414. add_special_tokens: bool = True,
  415. padding: bool | str | PaddingStrategy = False,
  416. truncation: bool | str | TruncationStrategy = None,
  417. max_length: int | None = None,
  418. stride: int = 0,
  419. pad_to_multiple_of: int | None = None,
  420. padding_side: str | None = None,
  421. return_tensors: str | TensorType | None = None,
  422. return_token_type_ids: bool | None = None,
  423. return_attention_mask: bool | None = None,
  424. return_overflowing_tokens: bool = False,
  425. return_special_tokens_mask: bool = False,
  426. return_offsets_mapping: bool = False,
  427. return_length: bool = False,
  428. verbose: bool = True,
  429. **kwargs,
  430. ) -> BatchEncoding:
  431. # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
  432. padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
  433. padding=padding,
  434. truncation=truncation,
  435. max_length=max_length,
  436. pad_to_multiple_of=pad_to_multiple_of,
  437. verbose=verbose,
  438. **kwargs,
  439. )
  440. return self._batch_encode_plus(
  441. batch_text_or_text_pairs=batch_text_or_text_pairs,
  442. is_pair=is_pair,
  443. xpaths=xpaths,
  444. node_labels=node_labels,
  445. add_special_tokens=add_special_tokens,
  446. padding_strategy=padding_strategy,
  447. truncation_strategy=truncation_strategy,
  448. max_length=max_length,
  449. stride=stride,
  450. pad_to_multiple_of=pad_to_multiple_of,
  451. padding_side=padding_side,
  452. return_tensors=return_tensors,
  453. return_token_type_ids=return_token_type_ids,
  454. return_attention_mask=return_attention_mask,
  455. return_overflowing_tokens=return_overflowing_tokens,
  456. return_special_tokens_mask=return_special_tokens_mask,
  457. return_offsets_mapping=return_offsets_mapping,
  458. return_length=return_length,
  459. verbose=verbose,
  460. **kwargs,
  461. )
  462. def tokenize(self, text: str, pair: str | None = None, add_special_tokens: bool = False, **kwargs) -> list[str]:
  463. batched_input = [(text, pair)] if pair else [text]
  464. encodings = self._tokenizer.encode_batch(
  465. batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs
  466. )
  467. return encodings[0].tokens
  468. @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, MARKUPLM_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
  469. def encode_plus(
  470. self,
  471. text: TextInput | PreTokenizedInput,
  472. text_pair: PreTokenizedInput | None = None,
  473. xpaths: list[list[int]] | None = None,
  474. node_labels: list[int] | None = None,
  475. add_special_tokens: bool = True,
  476. padding: bool | str | PaddingStrategy = False,
  477. truncation: bool | str | TruncationStrategy = None,
  478. max_length: int | None = None,
  479. stride: int = 0,
  480. pad_to_multiple_of: int | None = None,
  481. padding_side: str | None = None,
  482. return_tensors: str | TensorType | None = None,
  483. return_token_type_ids: bool | None = None,
  484. return_attention_mask: bool | None = None,
  485. return_overflowing_tokens: bool = False,
  486. return_special_tokens_mask: bool = False,
  487. return_offsets_mapping: bool = False,
  488. return_length: bool = False,
  489. verbose: bool = True,
  490. **kwargs,
  491. ) -> BatchEncoding:
  492. """
  493. Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated,
  494. `__call__` should be used instead.
  495. Args:
  496. text (`str`, `list[str]`, `list[list[str]]`):
  497. The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
  498. text_pair (`list[str]` or `list[int]`, *optional*):
  499. Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
  500. list of list of strings (words of a batch of examples).
  501. """
  502. # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
  503. padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
  504. padding=padding,
  505. truncation=truncation,
  506. max_length=max_length,
  507. pad_to_multiple_of=pad_to_multiple_of,
  508. verbose=verbose,
  509. **kwargs,
  510. )
  511. return self._encode_plus(
  512. text=text,
  513. xpaths=xpaths,
  514. text_pair=text_pair,
  515. node_labels=node_labels,
  516. add_special_tokens=add_special_tokens,
  517. padding_strategy=padding_strategy,
  518. truncation_strategy=truncation_strategy,
  519. max_length=max_length,
  520. stride=stride,
  521. pad_to_multiple_of=pad_to_multiple_of,
  522. padding_side=padding_side,
  523. return_tensors=return_tensors,
  524. return_token_type_ids=return_token_type_ids,
  525. return_attention_mask=return_attention_mask,
  526. return_overflowing_tokens=return_overflowing_tokens,
  527. return_special_tokens_mask=return_special_tokens_mask,
  528. return_offsets_mapping=return_offsets_mapping,
  529. return_length=return_length,
  530. verbose=verbose,
  531. **kwargs,
  532. )
  533. def _batch_encode_plus(
  534. self,
  535. batch_text_or_text_pairs: list[TextInput] | list[TextInputPair] | list[PreTokenizedInput],
  536. is_pair: bool | None = None,
  537. xpaths: list[list[list[int]]] | None = None,
  538. node_labels: list[list[int]] | None = None,
  539. add_special_tokens: bool = True,
  540. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  541. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  542. max_length: int | None = None,
  543. stride: int = 0,
  544. pad_to_multiple_of: int | None = None,
  545. padding_side: str | None = None,
  546. return_tensors: str | None = None,
  547. return_token_type_ids: bool | None = None,
  548. return_attention_mask: bool | None = None,
  549. return_overflowing_tokens: bool = False,
  550. return_special_tokens_mask: bool = False,
  551. return_offsets_mapping: bool = False,
  552. return_length: bool = False,
  553. verbose: bool = True,
  554. ) -> BatchEncoding:
  555. if not isinstance(batch_text_or_text_pairs, list):
  556. raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})")
  557. # Set the truncation and padding strategy and restore the initial configuration
  558. self.set_truncation_and_padding(
  559. padding_strategy=padding_strategy,
  560. truncation_strategy=truncation_strategy,
  561. max_length=max_length,
  562. stride=stride,
  563. pad_to_multiple_of=pad_to_multiple_of,
  564. padding_side=padding_side,
  565. )
  566. if is_pair:
  567. processed_inputs = []
  568. for text, text_pair in batch_text_or_text_pairs:
  569. if isinstance(text, tuple):
  570. text = list(text)
  571. if isinstance(text, str):
  572. text = [text]
  573. if isinstance(text_pair, tuple):
  574. text_pair = list(text_pair)
  575. if isinstance(text_pair, str):
  576. text_pair = [text_pair]
  577. processed_inputs.append((text, text_pair))
  578. batch_text_or_text_pairs = processed_inputs
  579. else:
  580. processed_inputs = []
  581. for text in batch_text_or_text_pairs:
  582. if isinstance(text, tuple):
  583. text = list(text)
  584. if isinstance(text, str):
  585. text = [text]
  586. processed_inputs.append(text)
  587. batch_text_or_text_pairs = processed_inputs
  588. encodings = self._tokenizer.encode_batch(
  589. batch_text_or_text_pairs,
  590. add_special_tokens=add_special_tokens,
  591. is_pretokenized=True, # we set this to True as MarkupLM always expects pretokenized inputs
  592. )
  593. # Convert encoding to dict
  594. # `Tokens` is a tuple of (list[dict[str, list[list[int]]]] or list[dict[str, 2D-Tensor]],
  595. # list[EncodingFast]) with nested dimensions corresponding to batch, overflows, sequence length
  596. tokens_and_encodings = [
  597. self._convert_encoding(
  598. encoding=encoding,
  599. return_token_type_ids=return_token_type_ids,
  600. return_attention_mask=return_attention_mask,
  601. return_overflowing_tokens=return_overflowing_tokens,
  602. return_special_tokens_mask=return_special_tokens_mask,
  603. return_offsets_mapping=True
  604. if node_labels is not None
  605. else return_offsets_mapping, # we use offsets to create the labels
  606. return_length=return_length,
  607. verbose=verbose,
  608. )
  609. for encoding in encodings
  610. ]
  611. # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
  612. # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
  613. # (we say ~ because the number of overflow varies with the example in the batch)
  614. #
  615. # To match each overflowing sample with the original sample in the batch
  616. # we add an overflow_to_sample_mapping array (see below)
  617. sanitized_tokens = {}
  618. for key in tokens_and_encodings[0][0]:
  619. stack = [e for item, _ in tokens_and_encodings for e in item[key]]
  620. sanitized_tokens[key] = stack
  621. sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
  622. # If returning overflowing tokens, we need to return a mapping
  623. # from the batch idx to the original sample
  624. if return_overflowing_tokens:
  625. overflow_to_sample_mapping = []
  626. for i, (toks, _) in enumerate(tokens_and_encodings):
  627. overflow_to_sample_mapping += [i] * len(toks["input_ids"])
  628. sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
  629. for input_ids in sanitized_tokens["input_ids"]:
  630. self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
  631. # create the token-level xpaths tags and subscripts
  632. xpath_tags_seq = []
  633. xpath_subs_seq = []
  634. for batch_index in range(len(sanitized_tokens["input_ids"])):
  635. if return_overflowing_tokens:
  636. original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
  637. else:
  638. original_index = batch_index
  639. xpath_tags_seq_example = []
  640. xpath_subs_seq_example = []
  641. for id, sequence_id, word_id in zip(
  642. sanitized_tokens["input_ids"][batch_index],
  643. sanitized_encodings[batch_index].sequence_ids,
  644. sanitized_encodings[batch_index].word_ids,
  645. ):
  646. if word_id is not None:
  647. if is_pair and sequence_id == 0:
  648. xpath_tags_seq_example.append(self.pad_xpath_tags_seq)
  649. xpath_subs_seq_example.append(self.pad_xpath_subs_seq)
  650. else:
  651. xpath_tags_list, xpath_subs_list = self.get_xpath_seq(xpaths[original_index][word_id])
  652. xpath_tags_seq_example.extend([xpath_tags_list])
  653. xpath_subs_seq_example.extend([xpath_subs_list])
  654. else:
  655. if id in [self.cls_token_id, self.sep_token_id, self.pad_token_id]:
  656. xpath_tags_seq_example.append(self.pad_xpath_tags_seq)
  657. xpath_subs_seq_example.append(self.pad_xpath_subs_seq)
  658. else:
  659. raise ValueError("Id not recognized")
  660. xpath_tags_seq.append(xpath_tags_seq_example)
  661. xpath_subs_seq.append(xpath_subs_seq_example)
  662. sanitized_tokens["xpath_tags_seq"] = xpath_tags_seq
  663. sanitized_tokens["xpath_subs_seq"] = xpath_subs_seq
  664. # optionally, create the labels
  665. if node_labels is not None:
  666. labels = []
  667. for batch_index in range(len(sanitized_tokens["input_ids"])):
  668. if return_overflowing_tokens:
  669. original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
  670. else:
  671. original_index = batch_index
  672. labels_example = []
  673. for id, offset, word_id in zip(
  674. sanitized_tokens["input_ids"][batch_index],
  675. sanitized_tokens["offset_mapping"][batch_index],
  676. sanitized_encodings[batch_index].word_ids,
  677. ):
  678. if word_id is not None:
  679. if self.only_label_first_subword:
  680. if offset[0] == 0:
  681. # Use the real label id for the first token of the word, and padding ids for the remaining tokens
  682. labels_example.append(node_labels[original_index][word_id])
  683. else:
  684. labels_example.append(self.pad_token_label)
  685. else:
  686. labels_example.append(node_labels[original_index][word_id])
  687. else:
  688. labels_example.append(self.pad_token_label)
  689. labels.append(labels_example)
  690. sanitized_tokens["labels"] = labels
  691. # finally, remove offsets if the user didn't want them
  692. if not return_offsets_mapping:
  693. del sanitized_tokens["offset_mapping"]
  694. return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
  695. def _encode_plus(
  696. self,
  697. text: TextInput | PreTokenizedInput,
  698. text_pair: PreTokenizedInput | None = None,
  699. xpaths: list[list[int]] | None = None,
  700. node_labels: list[int] | None = None,
  701. add_special_tokens: bool = True,
  702. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  703. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  704. max_length: int | None = None,
  705. stride: int = 0,
  706. pad_to_multiple_of: int | None = None,
  707. padding_side: str | None = None,
  708. return_tensors: bool | None = None,
  709. return_token_type_ids: bool | None = None,
  710. return_attention_mask: bool | None = None,
  711. return_overflowing_tokens: bool = False,
  712. return_special_tokens_mask: bool = False,
  713. return_offsets_mapping: bool = False,
  714. return_length: bool = False,
  715. verbose: bool = True,
  716. **kwargs,
  717. ) -> BatchEncoding:
  718. placeholder_xpath = "/document/node"
  719. if isinstance(text, tuple):
  720. text = list(text)
  721. if text_pair is not None and isinstance(text_pair, tuple):
  722. text_pair = list(text_pair)
  723. nodes_single = text if text_pair is None else text_pair
  724. processed_nodes = nodes_single
  725. if isinstance(nodes_single, str):
  726. processed_nodes = nodes_single.split()
  727. elif isinstance(nodes_single, list) and nodes_single and isinstance(nodes_single[0], str):
  728. processed_nodes = nodes_single
  729. if text_pair is None:
  730. text = processed_nodes
  731. else:
  732. text_pair = processed_nodes
  733. if xpaths is None:
  734. length = len(processed_nodes) if hasattr(processed_nodes, "__len__") else 0
  735. xpaths = [placeholder_xpath] * length
  736. # make it a batched input
  737. # 2 options:
  738. # 1) only text, in case text must be a list of str
  739. # 2) text + text_pair, in which case text = str and text_pair a list of str
  740. batched_input = [(text, text_pair)] if text_pair else [text]
  741. batched_xpaths = [xpaths]
  742. batched_node_labels = [node_labels] if node_labels is not None else None
  743. batched_output = self._batch_encode_plus(
  744. batched_input,
  745. is_pair=bool(text_pair is not None),
  746. xpaths=batched_xpaths,
  747. node_labels=batched_node_labels,
  748. add_special_tokens=add_special_tokens,
  749. padding_strategy=padding_strategy,
  750. truncation_strategy=truncation_strategy,
  751. max_length=max_length,
  752. stride=stride,
  753. pad_to_multiple_of=pad_to_multiple_of,
  754. padding_side=padding_side,
  755. return_tensors=return_tensors,
  756. return_token_type_ids=return_token_type_ids,
  757. return_attention_mask=return_attention_mask,
  758. return_overflowing_tokens=return_overflowing_tokens,
  759. return_special_tokens_mask=return_special_tokens_mask,
  760. return_offsets_mapping=return_offsets_mapping,
  761. return_length=return_length,
  762. verbose=verbose,
  763. **kwargs,
  764. )
  765. # Return tensor is None, then we can remove the leading batch axis
  766. # Overflowing tokens are returned as a batch of output so we keep them in this case
  767. if return_tensors is None and not return_overflowing_tokens:
  768. batched_output = BatchEncoding(
  769. {
  770. key: value[0] if len(value) > 0 and isinstance(value[0], list) else value
  771. for key, value in batched_output.items()
  772. },
  773. batched_output.encodings,
  774. )
  775. self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
  776. return batched_output
  777. def _pad(
  778. self,
  779. encoded_inputs: dict[str, EncodedInput] | BatchEncoding,
  780. max_length: int | None = None,
  781. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  782. pad_to_multiple_of: int | None = None,
  783. padding_side: str | None = None,
  784. return_attention_mask: bool | None = None,
  785. ) -> dict:
  786. """
  787. Args:
  788. Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
  789. encoded_inputs:
  790. Dictionary of tokenized inputs (`list[int]`) or batch of tokenized inputs (`list[list[int]]`).
  791. max_length: maximum length of the returned list and optionally padding length (see below).
  792. Will truncate by taking into account the special tokens.
  793. padding_strategy: PaddingStrategy to use for padding.
  794. - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
  795. - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
  796. - PaddingStrategy.DO_NOT_PAD: Do not pad
  797. The tokenizer padding sides are defined in self.padding_side:
  798. - 'left': pads on the left of the sequences
  799. - 'right': pads on the right of the sequences
  800. pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
  801. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
  802. `>= 7.5` (Volta).
  803. padding_side:
  804. The side on which the model should have padding applied. Should be selected between ['right', 'left'].
  805. Default value is picked from the class attribute of the same name.
  806. return_attention_mask:
  807. (optional) Set to False to avoid returning attention mask (default: set to model specifics)
  808. """
  809. # Load from model defaults
  810. if return_attention_mask is None:
  811. return_attention_mask = "attention_mask" in self.model_input_names
  812. required_input = encoded_inputs[self.model_input_names[0]]
  813. if padding_strategy == PaddingStrategy.LONGEST:
  814. max_length = len(required_input)
  815. if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
  816. max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
  817. needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
  818. # Initialize attention mask if not present.
  819. if return_attention_mask and "attention_mask" not in encoded_inputs:
  820. encoded_inputs["attention_mask"] = [1] * len(required_input)
  821. if needs_to_be_padded:
  822. difference = max_length - len(required_input)
  823. padding_side = padding_side if padding_side is not None else self.padding_side
  824. if padding_side == "right":
  825. if return_attention_mask:
  826. encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
  827. if "token_type_ids" in encoded_inputs:
  828. encoded_inputs["token_type_ids"] = (
  829. encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
  830. )
  831. if "xpath_tags_seq" in encoded_inputs:
  832. encoded_inputs["xpath_tags_seq"] = (
  833. encoded_inputs["xpath_tags_seq"] + [self.pad_xpath_tags_seq] * difference
  834. )
  835. if "xpath_subs_seq" in encoded_inputs:
  836. encoded_inputs["xpath_subs_seq"] = (
  837. encoded_inputs["xpath_subs_seq"] + [self.pad_xpath_subs_seq] * difference
  838. )
  839. if "labels" in encoded_inputs:
  840. encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference
  841. if "special_tokens_mask" in encoded_inputs:
  842. encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
  843. encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
  844. elif padding_side == "left":
  845. if return_attention_mask:
  846. encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
  847. if "token_type_ids" in encoded_inputs:
  848. encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
  849. "token_type_ids"
  850. ]
  851. if "xpath_tags_seq" in encoded_inputs:
  852. encoded_inputs["xpath_tags_seq"] = [self.pad_xpath_tags_seq] * difference + encoded_inputs[
  853. "xpath_tags_seq"
  854. ]
  855. if "xpath_subs_seq" in encoded_inputs:
  856. encoded_inputs["xpath_subs_seq"] = [self.pad_xpath_subs_seq] * difference + encoded_inputs[
  857. "xpath_subs_seq"
  858. ]
  859. if "labels" in encoded_inputs:
  860. encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"]
  861. if "special_tokens_mask" in encoded_inputs:
  862. encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
  863. encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
  864. else:
  865. raise ValueError("Invalid padding strategy:" + str(padding_side))
  866. return encoded_inputs
  867. def build_inputs_with_special_tokens(
  868. self, token_ids_0: list[int], token_ids_1: list[int] | None = None
  869. ) -> list[int]:
  870. """
  871. Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
  872. adding special tokens. A RoBERTa sequence has the following format:
  873. - single sequence: `<s> X </s>`
  874. - pair of sequences: `<s> A </s></s> B </s>`
  875. Args:
  876. token_ids_0 (`list[int]`):
  877. List of IDs to which the special tokens will be added.
  878. token_ids_1 (`list[int]`, *optional*):
  879. Optional second list of IDs for sequence pairs.
  880. Returns:
  881. `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
  882. """
  883. if token_ids_1 is None:
  884. return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
  885. cls = [self.cls_token_id]
  886. sep = [self.sep_token_id]
  887. return cls + token_ids_0 + sep + token_ids_1 + sep
  888. def create_token_type_ids_from_sequences(
  889. self, token_ids_0: list[int], token_ids_1: list[int] | None = None
  890. ) -> list[int]:
  891. """
  892. Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not
  893. make use of token type ids, therefore a list of zeros is returned.
  894. Args:
  895. token_ids_0 (`list[int]`):
  896. List of IDs.
  897. token_ids_1 (`list[int]`, *optional*):
  898. Optional second list of IDs for sequence pairs.
  899. Returns:
  900. `list[int]`: List of zeros.
  901. """
  902. sep = [self.sep_token_id]
  903. cls = [self.cls_token_id]
  904. if token_ids_1 is None:
  905. return len(cls + token_ids_0 + sep) * [0]
  906. return len(cls + token_ids_0 + sep + token_ids_1 + sep) * [0]
  907. def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
  908. files = self._tokenizer.model.save(save_directory, name=filename_prefix)
  909. return tuple(files)
  910. MarkupLMTokenizerFast = MarkupLMTokenizer
  911. __all__ = ["MarkupLMTokenizer", "MarkupLMTokenizerFast"]