processing_gemma4.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # Copyright 2026 the HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. import numpy as np
  16. from ...audio_utils import AudioInput
  17. from ...image_processing_utils import BatchFeature
  18. from ...image_utils import ImageInput, make_nested_list_of_images
  19. from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
  20. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  21. from ...utils import auto_docstring, is_vision_available, logging
  22. from ...utils.import_utils import requires
  23. from ...video_utils import VideoInput
  24. if is_vision_available():
  25. from .image_processing_pil_gemma4 import Gemma4ImageProcessorKwargs, get_aspect_ratio_preserving_size
  26. logger = logging.get_logger(__name__)
  27. class Gemma4ProcessorKwargs(ProcessingKwargs, total=False):
  28. images_kwargs: Gemma4ImageProcessorKwargs
  29. _defaults = {
  30. "text_kwargs": {
  31. "padding": True,
  32. "return_mm_token_type_ids": True,
  33. },
  34. "images_kwargs": {
  35. "do_convert_rgb": True,
  36. },
  37. "audio_kwargs": {},
  38. "videos_kwargs": {"return_metadata": True},
  39. }
  40. @auto_docstring
  41. @requires(backends=("vision",))
  42. class Gemma4Processor(ProcessorMixin):
  43. def __init__(
  44. self,
  45. feature_extractor,
  46. image_processor,
  47. tokenizer,
  48. video_processor,
  49. chat_template=None,
  50. image_seq_length: int = 280,
  51. audio_seq_length: int = 750,
  52. audio_ms_per_token: int = 40,
  53. **kwargs,
  54. ):
  55. r"""
  56. image_seq_length (`int`, *optional*, defaults to 280):
  57. The number of soft tokens per image used for placeholder expansion.
  58. audio_seq_length (`int`, *optional*, defaults to 750):
  59. The maximum number of audio soft tokens per audio segment. Serves as an
  60. upper-bound cap when dynamic audio token counts are computed.
  61. audio_ms_per_token (`int`, *optional*, defaults to 40):
  62. Milliseconds of audio per output soft token. Used to dynamically compute
  63. the number of audio placeholder tokens as ``ceil(duration_ms / audio_ms_per_token)``.
  64. The default of 40 comes from the SSCP convolution's 4× time reduction on 10ms frames.
  65. """
  66. self.image_seq_length = image_seq_length
  67. self.image_token_id = tokenizer.image_token_id
  68. self.boi_token = tokenizer.boi_token
  69. self.eoi_token = tokenizer.eoi_token
  70. self.image_token = tokenizer.image_token
  71. # FIXME: add the token to config and ask Ryan to re-upload
  72. tokenizer.add_special_tokens({"additional_special_tokens": ["<|video|>"]})
  73. self.video_token = "<|video|>"
  74. self.video_token_id = tokenizer.convert_tokens_to_ids(self.video_token)
  75. # Audio token handling, mirroring the vision pattern.
  76. # audio_seq_length serves as the maximum cap on the number of audio soft tokens
  77. # any single audio segment can produce. With dynamic audio tokens, the actual
  78. # number of placeholders inserted per audio is computed from the audio duration.
  79. self.audio_seq_length = audio_seq_length
  80. # Milliseconds of audio per output soft token. The default of 40 comes from the
  81. # SSCP convolution's 4× time reduction applied to 10ms mel spectrogram frames.
  82. self.audio_ms_per_token = audio_ms_per_token
  83. self.audio_token_id = getattr(tokenizer, "audio_token_id", None)
  84. self.audio_token = getattr(tokenizer, "audio_token", None)
  85. self.boa_token = getattr(tokenizer, "boa_token", None)
  86. self.eoa_token = getattr(tokenizer, "eoa_token", None)
  87. super().__init__(
  88. feature_extractor=feature_extractor,
  89. image_processor=image_processor,
  90. tokenizer=tokenizer,
  91. video_processor=video_processor,
  92. chat_template=chat_template,
  93. **kwargs,
  94. )
  95. @auto_docstring
  96. def __call__(
  97. self,
  98. images: ImageInput | None = None,
  99. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
  100. audio: AudioInput | None = None,
  101. videos: VideoInput | None = None,
  102. **kwargs: Unpack[Gemma4ProcessorKwargs],
  103. ) -> BatchFeature:
  104. if text is None and images is None and audio is None and videos is None:
  105. raise ValueError("Provide at least one of `text`, `images`, `audio`, or `videos`.")
  106. output_kwargs = self._merge_kwargs(
  107. Gemma4ProcessorKwargs,
  108. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  109. **kwargs,
  110. )
  111. if isinstance(text, str):
  112. text = [text]
  113. elif not isinstance(text, list) and not isinstance(text[0], str):
  114. raise TypeError("Invalid input text. Please provide a string, or a list of strings")
  115. image_inputs = {}
  116. if images is not None:
  117. images = self.image_processor.fetch_images(images)
  118. batched_images = make_nested_list_of_images(images)
  119. image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
  120. num_soft_tokens = image_inputs.pop("num_soft_tokens_per_image")
  121. # Create empty text to be replaced with placeholders
  122. if not text:
  123. text = [" ".join([self.image_token] * len(images)) for images in batched_images]
  124. if len(batched_images) != len(text):
  125. raise ValueError(
  126. f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
  127. )
  128. replacements = [f"{self.boi_token}{self.image_token * n}{self.eoi_token}" for n in num_soft_tokens]
  129. replacements_iter = iter(replacements)
  130. # Expand image_token placeholders to per-image soft token sequences.
  131. # re.sub never re-scans replaced text, so it is safe
  132. pattern = re.escape(self.image_token)
  133. text = [re.sub(pattern, lambda _: next(replacements_iter), prompt) for prompt in text]
  134. # Process video inputs in same way
  135. video_inputs = {}
  136. if videos is not None:
  137. video_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
  138. num_video_tokens = video_inputs.pop("num_soft_tokens_per_video")
  139. # If user has not requested video metadata, pop it so it isn't returned
  140. if not kwargs.get("return_metadata"):
  141. video_metadata = video_inputs.pop("video_metadata")
  142. else:
  143. video_metadata = video_inputs["video_metadata"]
  144. video_replacements = []
  145. for metadata, n_tokens in zip(video_metadata, num_video_tokens):
  146. if metadata.fps is None:
  147. logger.warning_once(
  148. "Gemma 4 requires frame timestamps to construct prompts, but the `fps` of the input video "
  149. "could not be inferred. Probably `video_metadata` was missing from inputs and you passed "
  150. "pre-sampled frames. Defaulting to `fps=24`. Please provide `video_metadata` for more "
  151. "accurate results."
  152. )
  153. metadata.fps = 24 if metadata.fps is None else metadata.fps
  154. # mm:ss format for timestamps
  155. timestamp_str = [
  156. f"{int(seconds // 60):02d}:{int(seconds % 60):02d}" for seconds in metadata.timestamps
  157. ]
  158. video_replacements.append(
  159. " ".join(
  160. [f"{t} {self.boi_token}{self.video_token * n_tokens}{self.eoi_token}" for t in timestamp_str]
  161. )
  162. )
  163. video_replacements = iter(video_replacements)
  164. pattern = re.escape(self.video_token)
  165. text = [re.sub(pattern, lambda _: next(video_replacements), prompt) for prompt in text]
  166. # Process audio inputs
  167. audio_inputs = {}
  168. if audio is not None:
  169. if self.audio_token is None or self.boa_token is None or self.eoa_token is None:
  170. raise ValueError(
  171. "Audio inputs were provided, but the tokenizer does not have an `audio_token` defined."
  172. )
  173. # Normalize audio input to list of waveforms
  174. if isinstance(audio, np.ndarray) and audio.ndim == 1:
  175. audio = [audio]
  176. # TODO: Add tests for audio-only processor inputs.
  177. if not text:
  178. text = [self.audio_token] * len(audio)
  179. # Dynamic audio token expansion wihtout padding:
  180. # * Extract audio features with feature extractor;
  181. # * Compute precise per-audio token counts from the waveform duration;
  182. # * Generate full audio token sequence for each computed audio length;
  183. # * Expand text prompts with full audio token sequences.
  184. audio_kwargs = output_kwargs.get("audio_kwargs", {})
  185. audio_inputs = self.feature_extractor(audio, **audio_kwargs)
  186. sampling_rate = self.feature_extractor.sampling_rate
  187. num_audio_tokens = [self._compute_audio_num_tokens(a, sampling_rate) for a in audio]
  188. replacements = [f"{self.boa_token}{self.audio_token * n}{self.eoa_token}" for n in num_audio_tokens]
  189. replacements_iter = iter(replacements)
  190. audio_pattern = re.escape(self.audio_token)
  191. text = [re.sub(audio_pattern, lambda _: next(replacements_iter), prompt) for prompt in text]
  192. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  193. return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
  194. text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
  195. # Check special tokens for all active modalities
  196. active_modalities = []
  197. if images is not None:
  198. active_modalities.append("image")
  199. if videos is not None:
  200. active_modalities.append("video")
  201. if audio is not None:
  202. active_modalities.append("audio")
  203. if active_modalities:
  204. self._check_special_mm_tokens(text, text_inputs, modalities=active_modalities)
  205. if return_mm_token_type_ids:
  206. text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
  207. return BatchFeature(
  208. data={**text_inputs, **image_inputs, **audio_inputs, **video_inputs},
  209. tensor_type=return_tensors,
  210. )
  211. def _compute_audio_num_tokens(self, audio_waveform, sampling_rate: int) -> int:
  212. """Compute the number of audio soft tokens for a single waveform.
  213. Replicates the exact sequence-length arithmetic of the audio encoder
  214. so that the processor inserts the correct number of placeholder tokens.
  215. The computation mirrors:
  216. 1. Mel framing via ``_unfold`` in ``Gemma4AudioFeatureExtractor``
  217. 2. Two ``Conv2d`` subsampling layers in ``Gemma4AudioSubSampleConvProjection``
  218. (each: kernel=3, stride=2, semicausal padding top=1, bottom=1)
  219. The result is capped at ``self.audio_seq_length`` (the configured maximum).
  220. Args:
  221. audio_waveform: A 1-D numpy array or list containing the raw audio samples.
  222. sampling_rate: The sampling rate of the audio waveform in Hz.
  223. Returns:
  224. The number of audio soft tokens to insert as placeholders.
  225. """
  226. num_samples = len(audio_waveform)
  227. # Step 1: Mel frames (matches feature_extraction_gemma4.py _unfold)
  228. frame_length = int(round(sampling_rate * 20.0 / 1000.0)) # 320 @ 16kHz
  229. hop_length = int(round(sampling_rate * 10.0 / 1000.0)) # 160 @ 16kHz
  230. frame_size_for_unfold = frame_length + 1 # 321
  231. # The feature extractor prepends (frame_length // 2) zero samples as
  232. # semicausal time-padding before the unfold. We must include this to
  233. # match the actual number of mel frames it produces.
  234. pad_left = frame_length // 2 # 160 @ 16kHz
  235. padded_samples = num_samples + pad_left
  236. num_mel_frames = (padded_samples - frame_size_for_unfold) // hop_length + 1
  237. if num_mel_frames <= 0:
  238. return 0
  239. # Step 2: Two SSCP conv layers (kernel=3, stride=2, semicausal pad top=1, bottom=1)
  240. # Each layer: T_out = (T_in + pad_top + pad_bottom - kernel) // stride + 1
  241. t = num_mel_frames
  242. for _ in range(2):
  243. t_padded = t + 2 # pad_top=1, pad_bottom=1
  244. t = (t_padded - 3) // 2 + 1
  245. # Cap at the configured maximum
  246. return min(t, self.audio_seq_length)
  247. def _get_num_multimodal_tokens(self, image_sizes=None, audio_lengths=None, **kwargs):
  248. """
  249. Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
  250. Args:
  251. image_sizes (`list[list[int]]`, *optional*):
  252. The input sizes formatted as (height, width) per each image.
  253. audio_lengths (`list[int]`, *optional*):
  254. The lengths of audio inputs in number of samples. Used to dynamically
  255. compute per-audio token counts.
  256. Returns:
  257. `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
  258. input modalities, along with other useful data.
  259. """
  260. images_kwargs = Gemma4ProcessorKwargs._defaults.get("images_kwargs", {})
  261. images_kwargs.update(kwargs)
  262. patch_size = images_kwargs.get("patch_size", None) or self.image_processor.patch_size
  263. pooling_kernel_size = (
  264. images_kwargs.get("pooling_kernel_size", None) or self.image_processor.pooling_kernel_size
  265. )
  266. max_soft_tokens = images_kwargs.get("max_soft_tokens", None) or self.image_processor.max_soft_tokens
  267. max_patches = max_soft_tokens * pooling_kernel_size**2
  268. vision_data = {}
  269. if image_sizes is not None:
  270. num_image_tokens = []
  271. for image_size in image_sizes:
  272. target_h, target_w = get_aspect_ratio_preserving_size(
  273. height=image_size[0],
  274. width=image_size[1],
  275. patch_size=patch_size,
  276. max_patches=max_patches,
  277. pooling_kernel_size=pooling_kernel_size,
  278. )
  279. patch_height = target_h // patch_size
  280. patch_width = target_w // patch_size
  281. num_image_tokens.append(patch_height * patch_width // pooling_kernel_size**2)
  282. num_image_patches = [1] * len(image_sizes)
  283. vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
  284. if audio_lengths is not None:
  285. # Dynamically compute per-audio token counts from sample lengths.
  286. # audio_lengths are in number of samples; assume default sampling rate.
  287. sampling_rate = getattr(self.feature_extractor, "sampling_rate", 16_000)
  288. num_audio_tokens = [
  289. self._compute_audio_num_tokens(np.zeros(length), sampling_rate) for length in audio_lengths
  290. ]
  291. vision_data.update({"num_audio_tokens": num_audio_tokens})
  292. return MultiModalData(**vision_data)
  293. @property
  294. def model_input_names(self):
  295. model_input_names = super().model_input_names
  296. model_input_names = [
  297. name
  298. for name in model_input_names
  299. if name not in ["num_soft_tokens_per_image", "num_soft_tokens_per_video"]
  300. ]
  301. # Include audio feature extractor input names if available
  302. if self.feature_extractor is not None:
  303. feature_extractor_input_names = self.feature_extractor.model_input_names
  304. model_input_names.extend([name for name in feature_extractor_input_names if name not in model_input_names])
  305. return model_input_names + ["mm_token_type_ids"]
  306. __all__ = ["Gemma4Processor"]