processing_mllama.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. # Copyright 2024 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Processor class for Mllama."""
  15. import numpy as np
  16. from ...feature_extraction_utils import BatchFeature
  17. from ...image_utils import ImageInput, make_nested_list_of_images
  18. from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
  19. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  20. from ...utils import auto_docstring
  21. class MllamaProcessorKwargs(ProcessingKwargs, total=False):
  22. _defaults = {
  23. "image_kwargs": {
  24. "max_image_tiles": 4,
  25. },
  26. }
  27. def get_cross_attention_token_mask(input_ids: list[int], image_token_id: int) -> list[list[int]]:
  28. """
  29. Generate a cross-attention token mask for image tokens in the input sequence.
  30. This function identifies the positions of image tokens in the input sequence and creates
  31. a mask that defines which subsequent tokens each image token should attend to.
  32. Args:
  33. input_ids (list[int]): A list of token ids representing the input sequence.
  34. image_token_id (int): The id of the token used to represent images in the sequence.
  35. Returns:
  36. list[list[int]]: A list of [start, end] pairs, where each pair represents the range
  37. of tokens an image token should attend to.
  38. Notes:
  39. - If no image tokens are present, an empty list is returned.
  40. - For a single image token, it attends to all subsequent tokens until the end of the sequence.
  41. - For multiple image tokens, each attends to tokens up to the next image token or the end of the sequence.
  42. - Consecutive image tokens are treated as a group and attend to all subsequent tokens together.
  43. """
  44. image_token_locations = [i for i, token in enumerate(input_ids) if token == image_token_id]
  45. if len(image_token_locations) == 0:
  46. return []
  47. # only one image present, unmask until end of sequence
  48. if len(image_token_locations) == 1:
  49. return [[image_token_locations[0], -1]]
  50. vision_masks = [[loc1, loc2] for loc1, loc2 in zip(image_token_locations[:-1], image_token_locations[1:])]
  51. # last image will attend to all subsequent text
  52. vision_masks.append([image_token_locations[-1], len(input_ids)])
  53. # if there are two or more consecutive vision tokens,
  54. # they should all attend to all subsequent
  55. # text present
  56. last_mask_end = vision_masks[-1][1]
  57. for vision_mask in vision_masks[::-1]:
  58. if vision_mask[0] == vision_mask[1] - 1:
  59. vision_mask[1] = last_mask_end
  60. last_mask_end = vision_mask[1]
  61. return vision_masks
  62. def convert_sparse_cross_attention_mask_to_dense(
  63. cross_attention_token_mask: list[list[list[int]]],
  64. num_tiles: list[list[int]],
  65. max_num_tiles: int,
  66. length: int,
  67. ) -> np.ndarray:
  68. """
  69. Convert the cross attention mask indices to a cross attention mask 4D array.
  70. This function takes a sparse representation of cross attention masks and converts it to a dense 4D numpy array.
  71. The sparse representation is a nested list structure that defines attention ranges for each image in each batch item.
  72. Args:
  73. cross_attention_token_mask (list[list[list[int]]]): A nested list structure where:
  74. - The outer list represents the batch dimension.
  75. - The middle list represents different images within each batch item.
  76. - The inner list contains pairs of integers [start, end] representing token ranges for each image.
  77. num_tiles (list[list[int]]): A nested list structure specifying the number of tiles for each image in each batch item.
  78. max_num_tiles (int): The maximum possible number of tiles.
  79. length (int): The total sequence length of the input.
  80. Returns:
  81. np.ndarray: A 4D numpy array of shape (batch_size, length, max_num_images, max_num_tiles)
  82. The array contains `1` where attention is allowed and `0` where it is not.
  83. Note:
  84. - Special handling is done for cases where the end token is -1, which is interpreted as attending to the end of the sequence.
  85. """
  86. batch_size = len(cross_attention_token_mask)
  87. max_num_images = max(len(masks) for masks in cross_attention_token_mask)
  88. cross_attention_mask = np.zeros(
  89. shape=(batch_size, length, max_num_images, max_num_tiles),
  90. dtype=np.int64,
  91. )
  92. for sample_idx, (sample_masks, sample_num_tiles) in enumerate(zip(cross_attention_token_mask, num_tiles)):
  93. for mask_idx, (locations, mask_num_tiles) in enumerate(zip(sample_masks, sample_num_tiles)):
  94. if len(locations) == 2:
  95. start, end = locations
  96. end = min(end, length)
  97. if end == -1:
  98. end = length
  99. cross_attention_mask[sample_idx, start:end, mask_idx, :mask_num_tiles] = 1
  100. return cross_attention_mask
  101. def build_string_from_input(prompt: str, bos_token: str, image_token: str) -> str:
  102. """
  103. Builds a string from the input prompt by adding `bos_token` if not already present.
  104. Args:
  105. prompt (`str`):
  106. The input prompt string.
  107. bos_token (`str`):
  108. The beginning of sentence token to be added.
  109. image_token (`str`):
  110. The image token used to identify the start of an image sequence.
  111. Returns:
  112. str: The modified prompt string with the `bos_token` added if necessary.
  113. Examples:
  114. >>> build_string_from_input("Hello world", "<begin_of_text>", "<|image|>")
  115. '<begin_of_text>Hello world'
  116. >>> build_string_from_input("<|image|>Hello world", "<begin_of_text>", "<|image|>")
  117. '<|image|><begin_of_text>Hello world'
  118. >>> build_string_from_input("<begin_of_text>Hello world", "<begin_of_text>", "<|image|>")
  119. '<begin_of_text>Hello world'
  120. """
  121. if bos_token in prompt:
  122. return prompt
  123. num_image_tokens_on_start = 0
  124. while prompt.startswith(image_token):
  125. prompt = prompt[len(image_token) :]
  126. num_image_tokens_on_start += 1
  127. return f"{image_token * num_image_tokens_on_start}{bos_token}{prompt}"
  128. @auto_docstring
  129. class MllamaProcessor(ProcessorMixin):
  130. def __init__(self, image_processor, tokenizer, chat_template=None):
  131. if not hasattr(tokenizer, "image_token"):
  132. self.image_token = "<|image|>"
  133. self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
  134. else:
  135. self.image_token = tokenizer.image_token
  136. self.image_token_id = tokenizer.image_token_id
  137. self.python_token = "<|python_tag|>"
  138. self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
  139. self.bos_token = tokenizer.bos_token
  140. super().__init__(image_processor, tokenizer, chat_template=chat_template)
  141. @auto_docstring
  142. def __call__(
  143. self,
  144. images: ImageInput | None = None,
  145. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
  146. **kwargs: Unpack[MllamaProcessorKwargs],
  147. ) -> BatchFeature:
  148. r"""
  149. Returns:
  150. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  151. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  152. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  153. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  154. `None`).
  155. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  156. TODO: add aspect_ratio_ids and aspect_ratio_mask and cross_attention_mask
  157. """
  158. if text is None and images is None:
  159. raise ValueError("You must specify either text or images.")
  160. output_kwargs = self._merge_kwargs(
  161. MllamaProcessorKwargs,
  162. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  163. **kwargs,
  164. )
  165. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  166. data = {}
  167. if text is not None:
  168. if isinstance(text, str):
  169. text = [text]
  170. elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
  171. raise ValueError("Invalid input text. Please provide a string, or a list of strings")
  172. n_images_in_text = [t.count(self.image_token) for t in text]
  173. text = [build_string_from_input(text_item, self.bos_token, self.image_token) for text_item in text]
  174. encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
  175. self._check_special_mm_tokens(text, encoding, modalities=["image"])
  176. n_images_in_ids = [token_ids.count(self.image_token_id) for token_ids in encoding["input_ids"]]
  177. data.update(encoding)
  178. n_images_in_images = [0]
  179. if images is not None:
  180. images = self.image_processor.fetch_images(images)
  181. images = make_nested_list_of_images(images)
  182. n_images_in_images = [len(sample) for sample in images]
  183. if text is not None:
  184. if any(batch_img == 0 for batch_img in n_images_in_text) and not all(
  185. batch_img == 0 for batch_img in n_images_in_text
  186. ):
  187. raise ValueError(
  188. "If a batch of text is provided, there should be either no images or at least one image per sample"
  189. )
  190. if sum(n_images_in_text) > 0 and (
  191. n_images_in_images != n_images_in_text or n_images_in_ids != n_images_in_images
  192. ):
  193. if images is None:
  194. raise ValueError("No image were provided, but there are image tokens in the prompt")
  195. else:
  196. add_message = ""
  197. if sum(n_images_in_images) == sum(n_images_in_text) and n_images_in_images != n_images_in_text:
  198. add_message = "Make sure to pass your images as a nested list, where each sub-list holds images per batch"
  199. elif n_images_in_ids != n_images_in_images:
  200. add_message = "If you activated truncation with `max_length`, increase the `max_length` so image tokens aren't cropped."
  201. raise ValueError(
  202. f"The number of image tokens in each text ({n_images_in_text}) should be the same as the "
  203. f"number of provided images per batch ({n_images_in_images}). {add_message}"
  204. )
  205. if images is not None:
  206. image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
  207. num_tiles = image_features.pop("num_tiles")
  208. data.update(image_features)
  209. # Create cross attention mask
  210. if images is not None and text is not None:
  211. cross_attention_token_mask = [
  212. get_cross_attention_token_mask(token_ids, self.image_token_id) for token_ids in encoding["input_ids"]
  213. ]
  214. cross_attention_mask = convert_sparse_cross_attention_mask_to_dense(
  215. cross_attention_token_mask,
  216. num_tiles=num_tiles,
  217. max_num_tiles=self.image_processor.max_image_tiles,
  218. length=max(len(input_ids) for input_ids in encoding["input_ids"]),
  219. )
  220. data["cross_attention_mask"] = cross_attention_mask
  221. return BatchFeature(data=data, tensor_type=return_tensors)
  222. def post_process_image_text_to_text(
  223. self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
  224. ):
  225. """
  226. Post-process the output of the model to decode the text.
  227. Args:
  228. generated_outputs (`torch.Tensor` or `np.ndarray`):
  229. The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
  230. or `(sequence_length,)`.
  231. skip_special_tokens (`bool`, *optional*, defaults to `True`):
  232. Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
  233. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
  234. Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
  235. **kwargs:
  236. Additional arguments to be passed to the tokenizer's `batch_decode method`.
  237. Returns:
  238. `list[str]`: The decoded text.
  239. """
  240. return self.tokenizer.batch_decode(
  241. generated_outputs,
  242. skip_special_tokens=skip_special_tokens,
  243. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  244. **kwargs,
  245. )
  246. @property
  247. def model_input_names(self):
  248. tokenizer_input_names = self.tokenizer.model_input_names
  249. image_processor_input_names = self.image_processor.model_input_names
  250. # Remove `num_tiles`, it is popped and used only when processing. Make a copy of list when removing
  251. # otherwise `self.image_processor.model_input_names` is also modified
  252. image_processor_input_names = [name for name in image_processor_input_names if name != "num_tiles"]
  253. return list(tokenizer_input_names + image_processor_input_names + ["cross_attention_mask"])
  254. __all__ = ["MllamaProcessor"]