processing_nougat.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright 2023 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Processor class for Nougat.
  16. """
  17. from typing import Optional, Union
  18. from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy
  19. from ...processing_utils import ProcessorMixin
  20. from ...utils import PaddingStrategy, TensorType, auto_docstring
  21. @auto_docstring
  22. class NougatProcessor(ProcessorMixin):
  23. def __init__(self, image_processor, tokenizer):
  24. super().__init__(image_processor, tokenizer)
  25. @auto_docstring
  26. def __call__(
  27. self,
  28. images=None,
  29. text=None,
  30. do_crop_margin: bool | None = None,
  31. do_resize: bool | None = None,
  32. size: dict[str, int] | None = None,
  33. resample: "PILImageResampling" = None, # noqa: F821
  34. do_thumbnail: bool | None = None,
  35. do_align_long_axis: bool | None = None,
  36. do_pad: bool | None = None,
  37. do_rescale: bool | None = None,
  38. rescale_factor: int | float | None = None,
  39. do_normalize: bool | None = None,
  40. image_mean: float | list[float] | None = None,
  41. image_std: float | list[float] | None = None,
  42. data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
  43. input_data_format: Union[str, "ChannelDimension"] | None = None, # noqa: F821
  44. text_pair: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
  45. text_target: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
  46. text_pair_target: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
  47. add_special_tokens: bool = True,
  48. padding: bool | str | PaddingStrategy = False,
  49. truncation: bool | str | TruncationStrategy | None = None,
  50. max_length: int | None = None,
  51. stride: int = 0,
  52. is_split_into_words: bool = False,
  53. pad_to_multiple_of: int | None = None,
  54. return_tensors: str | TensorType | None = None,
  55. return_token_type_ids: bool | None = None,
  56. return_attention_mask: bool | None = None,
  57. return_overflowing_tokens: bool = False,
  58. return_special_tokens_mask: bool = False,
  59. return_offsets_mapping: bool = False,
  60. return_length: bool = False,
  61. verbose: bool = True,
  62. ):
  63. r"""
  64. do_crop_margin (`bool`, *optional*):
  65. Whether to automatically crop white margins from document images. When enabled, the processor detects
  66. and removes white space around the edges of document pages, which is useful for processing scanned
  67. documents or PDFs with large margins.
  68. do_thumbnail (`bool`, *optional*):
  69. Whether to create a thumbnail version of the image. When enabled, a smaller version of the image is
  70. generated alongside the main processed image, which can be useful for preview or faster processing.
  71. do_align_long_axis (`bool`, *optional*):
  72. Whether to automatically align images so that the longer axis is horizontal. When enabled, portrait
  73. images are rotated to landscape orientation, which is typically better for document processing tasks.
  74. """
  75. if images is None and text is None:
  76. raise ValueError("You need to specify either an `images` or `text` input to process.")
  77. if images is not None:
  78. inputs = self.image_processor(
  79. images,
  80. do_crop_margin=do_crop_margin,
  81. do_resize=do_resize,
  82. size=size,
  83. resample=resample,
  84. do_thumbnail=do_thumbnail,
  85. do_align_long_axis=do_align_long_axis,
  86. do_pad=do_pad,
  87. do_rescale=do_rescale,
  88. rescale_factor=rescale_factor,
  89. do_normalize=do_normalize,
  90. image_mean=image_mean,
  91. image_std=image_std,
  92. return_tensors=return_tensors,
  93. data_format=data_format,
  94. input_data_format=input_data_format,
  95. )
  96. if text is not None:
  97. encodings = self.tokenizer(
  98. text,
  99. text_pair=text_pair,
  100. text_target=text_target,
  101. text_pair_target=text_pair_target,
  102. add_special_tokens=add_special_tokens,
  103. padding=padding,
  104. truncation=truncation,
  105. max_length=max_length,
  106. stride=stride,
  107. is_split_into_words=is_split_into_words,
  108. pad_to_multiple_of=pad_to_multiple_of,
  109. return_tensors=return_tensors,
  110. return_token_type_ids=return_token_type_ids,
  111. return_attention_mask=return_attention_mask,
  112. return_overflowing_tokens=return_overflowing_tokens,
  113. return_special_tokens_mask=return_special_tokens_mask,
  114. return_offsets_mapping=return_offsets_mapping,
  115. return_length=return_length,
  116. verbose=verbose,
  117. )
  118. if text is None:
  119. return inputs
  120. elif images is None:
  121. return encodings
  122. else:
  123. inputs["labels"] = encodings["input_ids"]
  124. return inputs
  125. def post_process_generation(self, *args, **kwargs):
  126. """
  127. This method forwards all its arguments to NougatTokenizer's [`~PreTrainedTokenizer.post_process_generation`].
  128. Please refer to the docstring of this method for more information.
  129. """
  130. return self.tokenizer.post_process_generation(*args, **kwargs)
  131. __all__ = ["NougatProcessor"]