processing_aria.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/aria/modular_aria.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_aria.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from ...image_processing_utils import BatchFeature
  21. from ...image_utils import ImageInput
  22. from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
  23. from ...tokenization_python import PreTokenizedInput, TextInput
  24. from ...utils import TensorType, auto_docstring
  25. from ..auto import AutoTokenizer
  26. class AriaImagesKwargs(ImagesKwargs, total=False):
  27. """
  28. split_image (`bool`, *optional*, defaults to `False`):
  29. Whether to split large images into multiple crops. When enabled, images exceeding the maximum size are
  30. divided into overlapping crops that are processed separately and then combined. This allows processing
  31. of very high-resolution images that exceed the model's input size limits.
  32. max_image_size (`int`, *optional*, defaults to `980`):
  33. Maximum image size (in pixels) for a single image crop. Images larger than this will be split into
  34. multiple crops when `split_image=True`, or resized if splitting is disabled. This parameter controls
  35. the maximum resolution of individual image patches processed by the model.
  36. min_image_size (`int`, *optional*):
  37. Minimum image size (in pixels) for a single image crop. Images smaller than this will be upscaled to
  38. meet the minimum requirement. If not specified, images are processed at their original size (subject
  39. to the maximum size constraint).
  40. """
  41. split_image: bool
  42. max_image_size: int
  43. min_image_size: int
  44. class AriaProcessorKwargs(ProcessingKwargs, total=False):
  45. images_kwargs: AriaImagesKwargs
  46. _defaults = {
  47. "text_kwargs": {
  48. "padding": False,
  49. "return_mm_token_type_ids": False,
  50. },
  51. "images_kwargs": {
  52. "max_image_size": 980,
  53. "split_image": False,
  54. },
  55. "return_tensors": TensorType.PYTORCH,
  56. }
  57. @auto_docstring
  58. class AriaProcessor(ProcessorMixin):
  59. def __init__(
  60. self,
  61. image_processor=None,
  62. tokenizer: AutoTokenizer | str = None,
  63. chat_template: str | None = None,
  64. size_conversion: dict[float | int, int] | None = None,
  65. ):
  66. r"""
  67. size_conversion (`Dict`, *optional*):
  68. A dictionary indicating size conversions for images.
  69. """
  70. if size_conversion is None:
  71. size_conversion = {490: 128, 980: 256}
  72. self.size_conversion = {int(k): v for k, v in size_conversion.items()}
  73. self.image_token = tokenizer.image_token
  74. self.image_token_id = tokenizer.image_token_id
  75. if tokenizer is not None and tokenizer.pad_token is None:
  76. tokenizer.pad_token = tokenizer.unk_token
  77. super().__init__(image_processor, tokenizer, chat_template=chat_template)
  78. @auto_docstring
  79. def __call__(
  80. self,
  81. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput],
  82. images: ImageInput | None = None,
  83. **kwargs: Unpack[AriaProcessorKwargs],
  84. ) -> BatchFeature:
  85. r"""
  86. Returns:
  87. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  88. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  89. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  90. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  91. `None`).
  92. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  93. - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
  94. """
  95. output_kwargs = self._merge_kwargs(
  96. AriaProcessorKwargs,
  97. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  98. **kwargs,
  99. )
  100. if isinstance(text, str):
  101. text = [text]
  102. elif not isinstance(text, list) and not isinstance(text[0], str):
  103. raise TypeError("Invalid input text. Please provide a string, or a list of strings")
  104. if images is not None:
  105. image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
  106. # expand the image_token according to the num_crops and tokens per image
  107. tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
  108. prompt_strings = []
  109. num_crops = image_inputs.pop("num_crops") * tokens_per_image
  110. for sample in text:
  111. sample = sample.replace(self.tokenizer.image_token, self.tokenizer.image_token * num_crops)
  112. prompt_strings.append(sample)
  113. else:
  114. image_inputs = {}
  115. prompt_strings = text
  116. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  117. return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
  118. text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
  119. self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
  120. if return_mm_token_type_ids:
  121. text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
  122. return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
  123. def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
  124. """
  125. Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
  126. Args:
  127. image_sizes (`list[list[int]]`, *optional*):
  128. The input sizes formatted as (height, width) per each image.
  129. Returns:
  130. `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
  131. input modalities, along with other useful data.
  132. """
  133. vision_data = {}
  134. if image_sizes is not None:
  135. images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
  136. images_kwargs.update(kwargs)
  137. max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
  138. num_image_patches = [
  139. self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
  140. for image_size in image_sizes
  141. ]
  142. num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
  143. vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
  144. return MultiModalData(**vision_data)
  145. @property
  146. def model_input_names(self):
  147. tokenizer_input_names = self.tokenizer.model_input_names
  148. image_processor_input_names = self.image_processor.model_input_names
  149. # Remove `num_crops`, it is popped and used only when processing. Make a copy of list when removing
  150. # otherwise `self.image_processor.model_input_names` is also modified
  151. image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
  152. return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
  153. __all__ = ["AriaProcessor"]