processing_internvl.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ...image_processing_utils import BatchFeature
  16. from ...image_utils import ImageInput, concatenate_list, make_flat_list_of_images
  17. from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
  18. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  19. from ...utils import auto_docstring
  20. from ...video_utils import VideoInput
  21. class InternVLProcessorKwargs(ProcessingKwargs, total=False):
  22. _defaults = {
  23. "text_kwargs": {
  24. "padding_side": "left",
  25. "return_mm_token_type_ids": False,
  26. },
  27. "images_kwargs": {
  28. "crop_to_patches": True,
  29. },
  30. "videos_kwargs": {
  31. "return_tensors": "pt",
  32. },
  33. }
  34. @auto_docstring
  35. class InternVLProcessor(ProcessorMixin):
  36. def __init__(
  37. self,
  38. image_processor=None,
  39. tokenizer=None,
  40. video_processor=None,
  41. image_seq_length: int = 256,
  42. chat_template=None,
  43. **kwargs,
  44. ):
  45. r"""
  46. image_seq_length (`int`, *optional*, defaults to 256):
  47. The number of image token to use per image patch. it should be set so that:
  48. image_seq_length = (config.image_size // config.patch_size) ** 2 * (config.scale_factor**2)
  49. """
  50. super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
  51. self.image_seq_length = image_seq_length
  52. self.start_image_token = tokenizer.start_image_token
  53. self.end_image_token = tokenizer.end_image_token
  54. self.start_image_token_id = tokenizer.start_image_token_id
  55. self.end_image_token_id = tokenizer.end_image_token_id
  56. self.image_token = tokenizer.context_image_token
  57. self.video_token = tokenizer.video_token
  58. self.image_token_id = tokenizer.context_image_token_id
  59. self.image_ids = [self.image_token_id, self.start_image_token_id, self.end_image_token_id]
  60. def _insert_media_placeholders(
  61. self,
  62. text: list[str],
  63. image_pixel_values,
  64. video_pixel_values,
  65. image_num_patches: list[int],
  66. video_num_patches: list[int],
  67. image_num_patches_indices: np.ndarray,
  68. video_num_patches_indices: np.ndarray,
  69. video_patch_indices: np.ndarray,
  70. ):
  71. """
  72. Processes interleaved text with <image> and <video> placeholders, replacing them with appropriate
  73. image and video tokens while keeping track of the patches used.
  74. """
  75. image_index = 0
  76. video_index = 0
  77. processed_text = []
  78. image_video_patches = []
  79. replace_strings = []
  80. # Support interleaved image and video in prompts:
  81. # Processed patches of images and videos are inserted in `image_video_patches` in the order they appear in the prompts
  82. for prompt in text:
  83. new_prompt = prompt
  84. while self.image_token in new_prompt or self.video_token in new_prompt:
  85. if self.image_token in new_prompt and (
  86. self.video_token not in new_prompt
  87. or new_prompt.index(self.image_token) < new_prompt.index(self.video_token)
  88. ):
  89. # Get the slice of patches corresponding to the current image
  90. start_index = image_num_patches_indices[image_index - 1] if image_index > 0 else 0
  91. end_index = image_num_patches_indices[image_index]
  92. image_video_patches.append(image_pixel_values[start_index:end_index])
  93. # Replace the corresponding image placeholder with the correct number of image tokens
  94. new_prompt = new_prompt.replace(self.image_token, "<placeholder>", 1)
  95. replace_strings.append(
  96. f"{self.start_image_token}{self.image_token * self.image_seq_length * image_num_patches[image_index]}{self.end_image_token}"
  97. )
  98. image_index += 1
  99. else:
  100. # Get the slice of patches corresponding to the current video
  101. # Here we need to account for both the multiple video frames and the potential multiple patches per frame
  102. # As of now, InternVL only supports one patch per frame, but we keep the code flexible for future updates
  103. current_patch_index = video_patch_indices[video_index]
  104. end_patch_index = video_patch_indices[video_index + 1]
  105. start_index = video_num_patches_indices[current_patch_index]
  106. end_index = video_num_patches_indices[end_patch_index]
  107. image_video_patches.append(video_pixel_values[start_index:end_index])
  108. # Get the number of patches per frame and replace the video placeholder with the correct number of image tokens
  109. num_patches = list(video_num_patches[current_patch_index:end_patch_index])
  110. video_prompt = "\n".join(
  111. f"Frame{i + 1}: {self.start_image_token}{self.image_token * self.image_seq_length * num_patches[i]}{self.end_image_token}"
  112. for i in range(len(num_patches))
  113. )
  114. replace_strings.append(video_prompt)
  115. new_prompt = new_prompt.replace(self.video_token, "<placeholder>", 1)
  116. video_index += 1
  117. while "<placeholder>" in new_prompt:
  118. replace_str = replace_strings.pop(0)
  119. new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
  120. processed_text.append(new_prompt)
  121. return processed_text, image_video_patches, image_index, video_index
  122. @auto_docstring
  123. def __call__(
  124. self,
  125. images: ImageInput | None = None,
  126. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
  127. videos: VideoInput | None = None,
  128. **kwargs: Unpack[InternVLProcessorKwargs],
  129. ) -> BatchFeature:
  130. r"""
  131. Returns:
  132. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  133. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  134. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  135. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  136. `None`).
  137. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  138. """
  139. if text is None:
  140. raise ValueError("You have to specify text.")
  141. output_kwargs = self._merge_kwargs(
  142. InternVLProcessorKwargs,
  143. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  144. **kwargs,
  145. )
  146. if not isinstance(text, (list, tuple)):
  147. text = [text]
  148. # Process images and videos separately, as videos don't support crop_to_patches
  149. image_num_patches = []
  150. image_pixel_values = None
  151. image_num_patches_indices = np.array([0])
  152. if images is not None:
  153. images = self.image_processor.fetch_images(images)
  154. images = make_flat_list_of_images(images)
  155. image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
  156. image_num_patches = image_inputs.pop("num_patches")
  157. image_pixel_values = image_inputs.pop("pixel_values")
  158. image_num_patches_indices = np.cumsum(image_num_patches)
  159. video_num_patches = [] # per frame
  160. video_pixel_values = None
  161. video_patch_indices = np.array([0])
  162. video_num_patches_indices = np.array([0])
  163. if videos is not None:
  164. video_kwargs = output_kwargs["videos_kwargs"]
  165. video_inputs = self.video_processor(videos=videos, **video_kwargs)
  166. video_pixel_values = video_inputs.pop("pixel_values_videos")
  167. batch_size, num_frames, *_ = video_pixel_values.shape
  168. num_frames_per_video = np.full(batch_size, num_frames)
  169. num_frames = sum(num_frames_per_video) # total
  170. video_patch_indices = np.empty(batch_size + 1, int)
  171. video_patch_indices[0] = 0
  172. video_patch_indices[1:] = np.cumsum(num_frames_per_video)
  173. video_num_patches = [1] * num_frames
  174. video_num_patches_indices = np.empty(num_frames + 1, int)
  175. video_num_patches_indices[0] = 0
  176. video_num_patches_indices[1:] = np.cumsum(video_num_patches)
  177. video_pixel_values = video_pixel_values.flatten(0, 1)
  178. image_videos_inputs = {}
  179. if images is not None or videos is not None:
  180. text, image_video_patches, image_index, video_index = self._insert_media_placeholders(
  181. text,
  182. image_pixel_values,
  183. video_pixel_values,
  184. image_num_patches,
  185. video_num_patches,
  186. image_num_patches_indices,
  187. video_num_patches_indices,
  188. video_patch_indices,
  189. )
  190. if images is not None and image_index != len(images):
  191. raise ValueError("Number of image placeholders in the prompt does not match the number of images.")
  192. if videos is not None and video_index != len(num_frames_per_video):
  193. raise ValueError("Number of video placeholders in the prompt does not match the number of videos.")
  194. # Concatenate the interleaved image and video patches (function agnostic to the patches type (list, numpy array, torch tensor))
  195. image_videos_inputs = {"pixel_values": concatenate_list(image_video_patches)}
  196. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  197. return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
  198. text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
  199. self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
  200. if return_mm_token_type_ids:
  201. text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
  202. return BatchFeature(data={**text_inputs, **image_videos_inputs}, tensor_type=return_tensors)
  203. def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
  204. """
  205. Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
  206. Args:
  207. image_sizes (`list[list[int]]`, *optional*):
  208. The input sizes formatted as (height, width) per each image.
  209. Returns:
  210. `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
  211. input modalities, along with other useful data.
  212. """
  213. vision_data = {}
  214. if image_sizes is not None:
  215. images_kwargs = InternVLProcessorKwargs._defaults.get("images_kwargs", {})
  216. images_kwargs.update(kwargs)
  217. num_image_patches = [
  218. self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
  219. for image_size in image_sizes
  220. ]
  221. # Add 2 for BOI and EOI tokens
  222. num_image_tokens = [2 + (self.image_seq_length * num_patches) for num_patches in num_image_patches]
  223. vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
  224. return MultiModalData(**vision_data)
  225. @property
  226. def model_input_names(self):
  227. # Overwritten because InternVL renames video inputs to `pixel_values` before returning
  228. tokenizer_input_names = self.tokenizer.model_input_names
  229. image_processor_input_names = self.image_processor.model_input_names
  230. return tokenizer_input_names + image_processor_input_names
  231. __all__ = ["InternVLProcessor"]