processing_pixtral.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # Copyright 2024 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Processor class for Pixtral.
  16. """
  17. import numpy as np
  18. from ...feature_extraction_utils import BatchFeature
  19. from ...image_utils import ImageInput, is_valid_image
  20. from ...processing_utils import (
  21. MultiModalData,
  22. ProcessingKwargs,
  23. ProcessorMixin,
  24. Unpack,
  25. )
  26. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  27. from ...utils import auto_docstring, is_vision_available, logging
  28. from ...utils.import_utils import requires
  29. if is_vision_available():
  30. from .image_processing_pixtral import get_resize_output_image_size
  31. logger = logging.get_logger(__name__)
  32. class PixtralProcessorKwargs(ProcessingKwargs, total=False):
  33. _defaults = {
  34. "text_kwargs": {
  35. "padding": False,
  36. "return_mm_token_type_ids": False,
  37. },
  38. "common_kwargs": {
  39. "return_tensors": "pt",
  40. },
  41. }
  42. # Copied from transformers.models.idefics2.processing_idefics2.is_url
  43. def is_url(val) -> bool:
  44. return isinstance(val, str) and val.startswith("http")
  45. # Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
  46. def is_image_or_image_url(elem):
  47. return is_url(elem) or is_valid_image(elem)
  48. @auto_docstring
  49. @requires(backends=("torchvision", "torch"))
  50. class PixtralProcessor(ProcessorMixin):
  51. def __init__(
  52. self,
  53. image_processor=None,
  54. tokenizer=None,
  55. patch_size: int = 16,
  56. spatial_merge_size: int = 1,
  57. chat_template=None,
  58. image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases
  59. image_break_token="[IMG_BREAK]",
  60. image_end_token="[IMG_END]",
  61. **kwargs,
  62. ):
  63. r"""
  64. patch_size (`int`, *optional*, defaults to 16):
  65. Patch size from the vision tower.
  66. spatial_merge_size (`int`, *optional*, defaults to 1):
  67. The downsampling factor for the spatial merge operation.
  68. image_token (`str`, *optional*, defaults to `"[IMG]"`):
  69. Special token used to denote image location.
  70. image_break_token (`str`, *optional*, defaults to `"[IMG_BREAK]"`):
  71. Special token used to denote the end of a line of pixels in an image.
  72. image_end_token (`str`, *optional*, defaults to `"[IMG_END]"`):
  73. Special token used to denote the end of an image input.
  74. """
  75. super().__init__(image_processor, tokenizer, chat_template=chat_template)
  76. self.patch_size = patch_size
  77. self.spatial_merge_size = spatial_merge_size
  78. self.image_token = image_token
  79. self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
  80. self.image_break_token = image_break_token
  81. self.image_end_token = image_end_token
  82. self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
  83. self.image_break_token_id = tokenizer.convert_tokens_to_ids(self.image_break_token)
  84. self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token)
  85. self.image_ids = [self.image_token_id, self.image_break_token_id, self.image_end_token_id]
  86. @auto_docstring
  87. def __call__(
  88. self,
  89. images: ImageInput | None = None,
  90. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
  91. **kwargs: Unpack[PixtralProcessorKwargs],
  92. ) -> BatchFeature:
  93. r"""
  94. Returns:
  95. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  96. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  97. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  98. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  99. `None`).
  100. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  101. """
  102. output_kwargs = self._merge_kwargs(
  103. PixtralProcessorKwargs,
  104. tokenizer_init_kwargs=getattr(self.tokenizer, "init_kwargs", {}),
  105. **kwargs,
  106. )
  107. patch_size = self.patch_size * self.spatial_merge_size
  108. if images is not None:
  109. output_kwargs["images_kwargs"]["patch_size"] = patch_size
  110. image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
  111. else:
  112. image_inputs = {}
  113. if isinstance(text, str):
  114. text = [text]
  115. elif not isinstance(text, list) and not isinstance(text[0], str):
  116. raise TypeError("Invalid input text. Please provide a string, or a list of strings")
  117. # try to expand inputs in processing if we have the necessary parts
  118. prompt_strings = text
  119. if image_inputs.get("pixel_values") is not None:
  120. # Replace the image token with the expanded image token sequence
  121. image_sizes = iter(image_inputs["image_sizes"])
  122. prompt_strings = []
  123. replace_strings = []
  124. for sample in text:
  125. while self.image_token in sample:
  126. height, width = next(image_sizes)
  127. num_height_tokens = height // patch_size
  128. num_width_tokens = width // patch_size
  129. replace_tokens = [
  130. [self.image_token] * num_width_tokens + [self.image_break_token]
  131. ] * num_height_tokens
  132. # Flatten list
  133. replace_tokens = [item for sublist in replace_tokens for item in sublist]
  134. replace_tokens[-1] = self.image_end_token
  135. replace_str = "".join(replace_tokens)
  136. replace_strings.append(replace_str)
  137. sample = sample.replace(self.image_token, "<placeholder>", 1)
  138. while "<placeholder>" in sample:
  139. replace_str = replace_strings.pop(0)
  140. sample = sample.replace("<placeholder>", replace_str, 1)
  141. prompt_strings.append(sample)
  142. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  143. return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
  144. # Remove return_token_type_ids as MistralCommonBackend doesn't support it
  145. output_kwargs["text_kwargs"].pop("return_token_type_ids", None)
  146. text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
  147. self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
  148. if return_mm_token_type_ids:
  149. text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
  150. return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
  151. def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
  152. """
  153. Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
  154. Args:
  155. image_sizes (`list[list[int]]`, *optional*):
  156. The input sizes formatted as (height, width) per each image.
  157. Returns:
  158. `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
  159. input modalities, along with other useful data.
  160. """
  161. vision_data = {}
  162. if image_sizes is not None:
  163. images_kwargs = PixtralProcessorKwargs._defaults.get("images_kwargs", {})
  164. images_kwargs.update(kwargs)
  165. size = images_kwargs.get("size", None) or self.image_processor.size
  166. patch_size = self.patch_size * self.spatial_merge_size
  167. num_image_tokens = []
  168. for height, width in image_sizes:
  169. resized_height, resized_width = get_resize_output_image_size(
  170. np.zeros((height, width, 3)),
  171. size=(size["longest_edge"], size["longest_edge"]),
  172. patch_size=(patch_size, patch_size),
  173. )
  174. num_height_tokens = resized_height // patch_size
  175. num_width_tokens = resized_width // patch_size
  176. num_image_tokens.append((num_width_tokens + 1) * num_height_tokens)
  177. num_image_patches = [1] * len(image_sizes)
  178. vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
  179. return MultiModalData(**vision_data)
  180. @property
  181. def model_input_names(self):
  182. tokenizer_input_names = self.tokenizer.model_input_names
  183. image_processor_input_names = self.image_processor.model_input_names
  184. return tokenizer_input_names + image_processor_input_names + ["image_sizes"]
  185. __all__ = ["PixtralProcessor"]