processing_markuplm.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Copyright 2022 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Processor class for MarkupLM.
  16. """
  17. from ...file_utils import TensorType
  18. from ...processing_utils import ProcessorMixin
  19. from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TruncationStrategy
  20. from ...utils import auto_docstring
  21. @auto_docstring
  22. class MarkupLMProcessor(ProcessorMixin):
  23. parse_html = True
  24. def __init__(self, feature_extractor, tokenizer):
  25. super().__init__(feature_extractor, tokenizer)
  26. @auto_docstring
  27. def __call__(
  28. self,
  29. html_strings=None,
  30. nodes=None,
  31. xpaths=None,
  32. node_labels=None,
  33. questions=None,
  34. add_special_tokens: bool = True,
  35. padding: bool | str | PaddingStrategy = False,
  36. truncation: bool | str | TruncationStrategy = None,
  37. max_length: int | None = None,
  38. stride: int = 0,
  39. pad_to_multiple_of: int | None = None,
  40. return_token_type_ids: bool | None = None,
  41. return_attention_mask: bool | None = None,
  42. return_overflowing_tokens: bool = False,
  43. return_special_tokens_mask: bool = False,
  44. return_offsets_mapping: bool = False,
  45. return_length: bool = False,
  46. verbose: bool = True,
  47. return_tensors: str | TensorType | None = None,
  48. **kwargs,
  49. ) -> BatchEncoding:
  50. # first, create nodes and xpaths
  51. r"""
  52. html_strings (`str` or `list[str]`, *optional*):
  53. Raw HTML strings to parse and process. When `parse_html=True` (default), these strings are parsed
  54. to extract nodes and xpaths automatically. If provided, `nodes`, `xpaths`, and `node_labels` should
  55. not be provided. Required when `parse_html=True`.
  56. nodes (`list[list[str]]`, *optional*):
  57. Pre-extracted HTML nodes as a list of lists, where each inner list contains the text content of nodes
  58. for a single document. Required when `parse_html=False`. Should not be provided when `parse_html=True`.
  59. xpaths (`list[list[str]]`, *optional*):
  60. Pre-extracted XPath expressions corresponding to the nodes. Should be a list of lists with the same
  61. structure as `nodes`, where each XPath identifies the location of the corresponding node in the HTML
  62. tree. Required when `parse_html=False`. Should not be provided when `parse_html=True`.
  63. node_labels (`list[list[int]]`, *optional*):
  64. Labels for the nodes, typically used for training or fine-tuning tasks. Should be a list of lists
  65. with the same structure as `nodes`, where each label corresponds to a node. Optional and only used
  66. when `parse_html=False`.
  67. questions (`str` or `list[str]`, *optional*):
  68. Question strings for question-answering tasks. When provided, the tokenizer processes questions
  69. as the first sequence and nodes as the second sequence (text_pair). If a single string is provided,
  70. it is converted to a list to match the batch dimension of the parsed HTML.
  71. """
  72. if self.parse_html:
  73. if html_strings is None:
  74. raise ValueError("Make sure to pass HTML strings in case `parse_html` is set to `True`")
  75. if nodes is not None or xpaths is not None or node_labels is not None:
  76. raise ValueError(
  77. "Please don't pass nodes, xpaths nor node labels in case `parse_html` is set to `True`"
  78. )
  79. features = self.feature_extractor(html_strings)
  80. nodes = features["nodes"]
  81. xpaths = features["xpaths"]
  82. else:
  83. if html_strings is not None:
  84. raise ValueError("You have passed HTML strings but `parse_html` is set to `False`.")
  85. if nodes is None or xpaths is None:
  86. raise ValueError("Make sure to pass nodes and xpaths in case `parse_html` is set to `False`")
  87. # # second, apply the tokenizer
  88. if questions is not None and self.parse_html:
  89. if isinstance(questions, str):
  90. questions = [questions] # add batch dimension (as the feature extractor always adds a batch dimension)
  91. encoded_inputs = self.tokenizer(
  92. text=questions if questions is not None else nodes,
  93. text_pair=nodes if questions is not None else None,
  94. xpaths=xpaths,
  95. node_labels=node_labels,
  96. add_special_tokens=add_special_tokens,
  97. padding=padding,
  98. truncation=truncation,
  99. max_length=max_length,
  100. stride=stride,
  101. pad_to_multiple_of=pad_to_multiple_of,
  102. return_token_type_ids=return_token_type_ids,
  103. return_attention_mask=return_attention_mask,
  104. return_overflowing_tokens=return_overflowing_tokens,
  105. return_special_tokens_mask=return_special_tokens_mask,
  106. return_offsets_mapping=return_offsets_mapping,
  107. return_length=return_length,
  108. verbose=verbose,
  109. return_tensors=return_tensors,
  110. **kwargs,
  111. )
  112. return encoded_inputs
  113. __all__ = ["MarkupLMProcessor"]