processing_llama4.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
  15. from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
  16. from ...image_processing_utils import BatchFeature
  17. from ...image_utils import ImageInput, make_flat_list_of_images
  18. from ...utils import auto_docstring
  19. class Llama4ProcessorKwargs(ProcessingKwargs, total=False):
  20. _defaults = {
  21. "text_kwargs": {
  22. "padding_side": "left",
  23. },
  24. }
  25. chat_template = "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %} \n {%- if messages[0]['content'] is string %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- else %}\n {#- FIXME: The processor requires an array, always. #}\n {%- set system_message = messages[0]['content'][0]['text']|trim %}\n {%- endif %}\n {%- set messages = messages[1:] %}\n {%- set user_supplied_system_message = true %}\n{%- else %}\n {%- set system_message = \"\" %}\n {%- set user_supplied_system_message = false %}\n{%- endif %}\n\n{#- System message if the user supplied one #}\n{%- if user_supplied_system_message %}\n {{- \"<|header_start|>system<|header_end|>\n\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\n\" }}\n {%- endif %}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\n\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\n\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|header_start|>user<|header_end|>\n\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\n\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\n\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\n\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}\n {{- '<|header_start|>assistant<|header_end|>\n\n' -}}\n {{- '<|python_start|>' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|python_end|>' }}\n {%- for tool_call in message.tool_calls %}\n {{- '{\"name\": \"' + tool_call.function.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.function.arguments | tojson }}\n {{- \"}\" }}\n {%- endfor %}\n {{- \"<|eot|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|header_start|>ipython<|header_end|>\n\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|header_start|>assistant<|header_end|>\n\n' }}\n{%- endif %}\n"
  26. @auto_docstring
  27. class Llama4Processor(ProcessorMixin):
  28. def __init__(
  29. self,
  30. image_processor=None,
  31. tokenizer=None,
  32. patch_size: int = 14,
  33. pixel_shuffle_ratio: float = 0.5,
  34. fake_image_token="<|image|>",
  35. image_token="<|image|>",
  36. start_of_image_token="<|image_start|>",
  37. end_of_image_token="<|image_end|>",
  38. patch_token="<|patch|>",
  39. tile_x_separator_token="<|tile_x_separator|>",
  40. tile_y_separator_token="<|tile_y_separator|>",
  41. chat_template=chat_template,
  42. **kwargs,
  43. ):
  44. r"""
  45. patch_size (`int`, *optional*, defaults to 28):
  46. The size of image patches for tokenization.
  47. pixel_shuffle_ratio (`float`, *optional*, defaults to `0.5`):
  48. The ratio used for pixel shuffling when processing images. This controls the downsampling factor
  49. applied to image patches. The actual downsampling ratio is calculated as `1 / (pixel_shuffle_ratio^2)`.
  50. fake_image_token (`str`, *optional*, defaults to `"<|image|>"`):
  51. The placeholder token in the text that will be replaced with actual image tokens. This token serves
  52. as a marker indicating where images should be inserted in the text sequence.
  53. image_token (`str`, *optional*, defaults to `"<|image|>"`):
  54. The token to be used to represent an image in the text.
  55. start_of_image_token (`str`, *optional*, defaults to `"<|image_start|>"`):
  56. The special token that marks the beginning of an image sequence in the text. This token is prepended
  57. to image token sequences to delimit image boundaries.
  58. end_of_image_token (`str`, *optional*, defaults to `"<|image_end|>"`):
  59. The special token that marks the end of an image sequence in the text. This token is appended to
  60. image token sequences to delimit image boundaries.
  61. patch_token (`str`, *optional*, defaults to `"<|patch|>"`):
  62. The token used to represent individual image patches. Multiple patch tokens are used to represent
  63. the full image, with the number depending on the image size and patch configuration.
  64. tile_x_separator_token (`str`, *optional*, defaults to `"<|tile_x_separator|>"`):
  65. The token used to separate tiles (patches) horizontally within an image. This token is inserted
  66. between patches in the same row when images are split into multiple tiles.
  67. tile_y_separator_token (`str`, *optional*, defaults to `"<|tile_y_separator|>"`):
  68. The token used to separate tiles (patches) vertically within an image. This token is inserted
  69. between rows of patches when images are split into multiple tiles.
  70. """
  71. super().__init__(image_processor, tokenizer, chat_template=chat_template)
  72. self.downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
  73. self.patch_size = patch_size
  74. self.fake_image_token = fake_image_token
  75. self.image_token = image_token
  76. self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
  77. self.start_of_img_token = start_of_image_token
  78. self.end_of_img_token = end_of_image_token
  79. self.img_patch_token = patch_token
  80. self.tile_token = tile_x_separator_token
  81. self.tile_global_token = tile_y_separator_token
  82. def _prompt_split_image(self, aspect_ratio, num_patches_per_chunk):
  83. """
  84. Create a structured string representation of image tokens
  85. Args:
  86. num_patches: Number of patches in the image
  87. Returns:
  88. String with appropriate image tokens
  89. """
  90. img_string = "<|image_start|>"
  91. ratio_h, ratio_w = aspect_ratio
  92. if ratio_h * ratio_w > 1:
  93. for yy in range(ratio_h):
  94. for xx in range(ratio_w):
  95. img_string += "<|patch|>" * num_patches_per_chunk
  96. if xx < ratio_w - 1:
  97. img_string += "<|tile_x_separator|>"
  98. img_string += "<|tile_y_separator|>"
  99. img_string += "<|image|>"
  100. img_string += "<|patch|>" * num_patches_per_chunk
  101. img_string += "<|image_end|>"
  102. return img_string
  103. @auto_docstring
  104. def __call__(
  105. self,
  106. images: ImageInput | None = None,
  107. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
  108. **kwargs: Unpack[Llama4ProcessorKwargs],
  109. ) -> BatchFeature:
  110. r"""
  111. Returns:
  112. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  113. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  114. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  115. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  116. `None`).
  117. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  118. """
  119. if text is None:
  120. raise ValueError("You have to specify text.")
  121. output_kwargs = self._merge_kwargs(
  122. Llama4ProcessorKwargs,
  123. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  124. **kwargs,
  125. )
  126. if not isinstance(text, (list, tuple)):
  127. text = [text]
  128. # Process images
  129. image_inputs = {}
  130. if images is not None:
  131. images = self.image_processor.fetch_images(images)
  132. images = make_flat_list_of_images(images)
  133. image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
  134. image_height, image_width = image_inputs["pixel_values"][0].shape[-2:]
  135. num_patches_per_chunk = int(
  136. (image_height // self.patch_size) * (image_width // self.patch_size) // self.downsample_ratio
  137. )
  138. aspect_ratios = image_inputs.pop("aspect_ratios")
  139. total_placeholders = sum(prompt.count(self.fake_image_token) for prompt in text)
  140. if total_placeholders != len(images):
  141. raise ValueError(
  142. f"Found {total_placeholders} placeholders across the batch, "
  143. f"but have {len(images)} flattened images."
  144. )
  145. image_index = 0
  146. processed_text = []
  147. for prompt in text:
  148. placeholder_count = prompt.count(self.fake_image_token)
  149. if placeholder_count == 0:
  150. # do nothing if there is no image
  151. processed_text.append(prompt)
  152. continue
  153. prompt_splits = prompt.split(self.fake_image_token)
  154. new_prompt = []
  155. for local_image_index, split_part in enumerate(prompt_splits):
  156. new_prompt.append(split_part)
  157. if local_image_index < placeholder_count:
  158. tokens_for_this_image = self._prompt_split_image(
  159. aspect_ratios[image_index], num_patches_per_chunk
  160. )
  161. image_index += 1
  162. new_prompt.append(tokens_for_this_image)
  163. processed_text.append("".join(new_prompt))
  164. if image_index != len(images):
  165. raise ValueError("Number of image placeholders in the prompt does not match the number of images.")
  166. text = processed_text
  167. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  168. text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
  169. self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
  170. return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
  171. __all__ = ["Llama4Processor"]