processing_glm4v.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/glm4v/modular_glm4v.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_glm4v.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import numpy as np
  21. from ...feature_extraction_utils import BatchFeature
  22. from ...image_utils import ImageInput
  23. from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
  24. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  25. from ...utils import auto_docstring, logging
  26. from ...video_utils import VideoInput
  27. logger = logging.get_logger(__name__)
  28. class Glm4vProcessorKwargs(ProcessingKwargs, total=False):
  29. _defaults = {
  30. "text_kwargs": {
  31. "padding": False,
  32. "return_token_type_ids": False,
  33. "return_mm_token_type_ids": True,
  34. },
  35. "videos_kwargs": {"return_metadata": True},
  36. }
  37. @auto_docstring
  38. class Glm4vProcessor(ProcessorMixin):
  39. def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
  40. self.image_token = "<|image|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
  41. self.video_token = "<|video|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
  42. self.image_token_id = (
  43. tokenizer.image_token_id
  44. if getattr(tokenizer, "image_token_id", None)
  45. else tokenizer.convert_tokens_to_ids(self.image_token)
  46. )
  47. self.video_token_id = (
  48. tokenizer.video_token_id
  49. if getattr(tokenizer, "video_token_id", None)
  50. else tokenizer.convert_tokens_to_ids(self.video_token)
  51. )
  52. super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
  53. self.video_start_id = tokenizer.convert_tokens_to_ids("<|begin_of_video|>")
  54. self.video_end_id = tokenizer.convert_tokens_to_ids("<|end_of_video|>")
  55. @auto_docstring
  56. def __call__(
  57. self,
  58. images: ImageInput | None = None,
  59. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
  60. videos: VideoInput | None = None,
  61. **kwargs: Unpack[Glm4vProcessorKwargs],
  62. ) -> BatchFeature:
  63. r"""
  64. Returns:
  65. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  66. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  67. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  68. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  69. `None`).
  70. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  71. - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
  72. - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
  73. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
  74. """
  75. output_kwargs = self._merge_kwargs(
  76. Glm4vProcessorKwargs,
  77. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  78. **kwargs,
  79. )
  80. if images is not None:
  81. image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
  82. image_grid_thw = image_inputs["image_grid_thw"]
  83. else:
  84. image_inputs = {}
  85. image_grid_thw = None
  86. if videos is not None:
  87. videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
  88. # If user has not requested video metadata, pop it
  89. if not kwargs.get("return_metadata"):
  90. video_metadata = videos_inputs.pop("video_metadata")
  91. else:
  92. video_metadata = videos_inputs["video_metadata"]
  93. video_grid_thw = videos_inputs["video_grid_thw"]
  94. else:
  95. videos_inputs = {}
  96. video_grid_thw = None
  97. if not isinstance(text, list):
  98. text = [text]
  99. text = text.copy() # below lines change text in-place
  100. if image_grid_thw is not None:
  101. merge_length = self.image_processor.merge_size**2
  102. index = 0
  103. for i in range(len(text)):
  104. while self.image_token in text[i]:
  105. num_image_tokens = image_grid_thw[index].prod() // merge_length
  106. text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
  107. index += 1
  108. text[i] = text[i].replace("<|placeholder|>", self.image_token)
  109. if video_grid_thw is not None:
  110. merge_length = self.video_processor.merge_size**2
  111. video_index = 0
  112. for i in range(len(text)):
  113. while self.video_token in text[i]:
  114. num_frames = video_grid_thw[video_index][0]
  115. video_structure = ""
  116. metadata = video_metadata[video_index]
  117. if metadata.fps is None:
  118. logger.warning_once(
  119. "SmolVLM requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
  120. "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. "
  121. "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
  122. )
  123. metadata.fps = 24 if metadata.fps is None else metadata.fps
  124. timestamps = metadata.timestamps[::2] # mrope
  125. unique_timestamps = []
  126. for idx in range(0, len(timestamps)):
  127. unique_timestamps.append(timestamps[idx])
  128. selected_timestamps = unique_timestamps[:num_frames]
  129. while len(selected_timestamps) < num_frames:
  130. selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
  131. for frame_idx in range(num_frames):
  132. timestamp_sec = selected_timestamps[frame_idx]
  133. frame_structure = self.replace_frame_token_id(timestamp_sec)
  134. video_structure += frame_structure
  135. text[i] = text[i].replace(self.video_token, video_structure, 1)
  136. num_image_tokens = (
  137. video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
  138. )
  139. for frame_idx in range(num_frames):
  140. if self.image_token in text[i]:
  141. text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
  142. video_index += 1
  143. text[i] = text[i].replace("<|placeholder|>", self.image_token)
  144. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  145. return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
  146. text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
  147. self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
  148. if return_mm_token_type_ids:
  149. text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
  150. return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
  151. def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
  152. """
  153. Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
  154. Args:
  155. image_sizes (`list[list[int]]`, *optional*):
  156. The input sizes formatted as (height, width) per each image.
  157. video_sizes (`list[list[int]]`, *optional*):
  158. The input sizes formatted as (num_frames, height, width) per each video.
  159. Returns:
  160. `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
  161. input modalities, along with other useful data.
  162. """
  163. vision_data = {}
  164. if image_sizes is not None:
  165. images_kwargs = Glm4vProcessorKwargs._defaults.get("images_kwargs", {})
  166. images_kwargs.update(kwargs)
  167. merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
  168. num_image_patches = [
  169. self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
  170. for image_size in image_sizes
  171. ]
  172. num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
  173. vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
  174. if video_sizes is not None:
  175. videos_kwargs = Glm4vProcessorKwargs._defaults.get("videos_kwargs", {})
  176. videos_kwargs.update(kwargs)
  177. num_video_patches = [
  178. self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
  179. for video_size in video_sizes
  180. ]
  181. num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
  182. vision_data["num_video_tokens"] = num_video_tokens
  183. return MultiModalData(**vision_data)
  184. def post_process_image_text_to_text(
  185. self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
  186. ):
  187. """
  188. Post-process the output of the model to decode the text.
  189. Args:
  190. generated_outputs (`torch.Tensor` or `np.ndarray`):
  191. The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
  192. or `(sequence_length,)`.
  193. skip_special_tokens (`bool`, *optional*, defaults to `True`):
  194. Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
  195. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
  196. Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
  197. **kwargs:
  198. Additional arguments to be passed to the tokenizer's `batch_decode method`.
  199. Returns:
  200. `list[str]`: The decoded text.
  201. """
  202. return self.tokenizer.batch_decode(
  203. generated_outputs,
  204. skip_special_tokens=skip_special_tokens,
  205. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  206. **kwargs,
  207. )
  208. @property
  209. def model_input_names(self):
  210. model_input_names = super().model_input_names
  211. model_input_names.append("mm_token_type_ids")
  212. return model_input_names
  213. def create_mm_token_type_ids(self, input_ids: list) -> list[list[int]]:
  214. # We have to iterate for each list separately because inputs
  215. # might be non-padded lists and we can't cast numpy on that!
  216. # Then cast numpy as each input for faster indexing
  217. mm_token_type_ids = []
  218. for input in input_ids:
  219. array_ids = np.array(input)
  220. mm_token_types = np.zeros_like(input)
  221. # Replace 0 -> 2 only inside video segments because GLM4v
  222. # uses the same special token to denote images and video
  223. # Otherwise replace 0 -> 1 for image modality
  224. starts = np.cumsum(array_ids == self.video_start_id, axis=0)
  225. ends = np.cumsum(array_ids == self.video_end_id, axis=0)
  226. is_video_modality = starts > ends
  227. mm_token_types[(array_ids == self.image_token_id) & is_video_modality] = 2
  228. mm_token_types[(array_ids == self.image_token_id) & (~is_video_modality)] = 1
  229. mm_token_type_ids.append(mm_token_types.tolist())
  230. return mm_token_type_ids
  231. def replace_frame_token_id(self, timestamp_sec):
  232. return f"<|begin_of_image|>{self.image_token}<|end_of_image|>{int(timestamp_sec)}"
  233. __all__ = ["Glm4vProcessor"]