| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/glm46v/modular_glm46v.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_glm46v.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2025 the HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import numpy as np
- from ...image_processing_utils import BatchFeature
- from ...image_utils import ImageInput
- from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
- from ...tokenization_utils_base import PreTokenizedInput, TextInput
- from ...utils import auto_docstring, logging
- from ...video_utils import VideoInput
- logger = logging.get_logger(__name__)
- class Glm46VProcessorKwargs(ProcessingKwargs, total=False):
- _defaults = {
- "text_kwargs": {
- "padding": False,
- "return_token_type_ids": False,
- "return_mm_token_type_ids": True,
- },
- "videos_kwargs": {"return_metadata": True},
- }
- @auto_docstring
- class Glm46VProcessor(ProcessorMixin):
- def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
- self.image_token = "<|image|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
- self.video_token = "<|video|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
- self.image_token_id = (
- tokenizer.image_token_id
- if getattr(tokenizer, "image_token_id", None)
- else tokenizer.convert_tokens_to_ids(self.image_token)
- )
- self.video_token_id = (
- tokenizer.video_token_id
- if getattr(tokenizer, "video_token_id", None)
- else tokenizer.convert_tokens_to_ids(self.video_token)
- )
- super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
- self.video_start_id = tokenizer.convert_tokens_to_ids("<|begin_of_video|>")
- self.video_end_id = tokenizer.convert_tokens_to_ids("<|end_of_video|>")
- @auto_docstring
- def __call__(
- self,
- images: ImageInput | None = None,
- text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
- videos: VideoInput | None = None,
- **kwargs: Unpack[Glm46VProcessorKwargs],
- ) -> BatchFeature:
- r"""
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
- - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
- - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
- """
- output_kwargs = self._merge_kwargs(
- Glm46VProcessorKwargs,
- tokenizer_init_kwargs=self.tokenizer.init_kwargs,
- **kwargs,
- )
- if images is not None:
- image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
- image_grid_thw = image_inputs["image_grid_thw"]
- else:
- image_inputs = {}
- image_grid_thw = None
- if videos is not None:
- videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
- # If user has not requested video metadata, pop it
- if not kwargs.get("return_metadata"):
- video_metadata = videos_inputs.pop("video_metadata")
- else:
- video_metadata = videos_inputs["video_metadata"]
- video_grid_thw = videos_inputs["video_grid_thw"]
- else:
- videos_inputs = {}
- video_grid_thw = None
- if not isinstance(text, list):
- text = [text]
- text = text.copy() # below lines change text in-place
- if image_grid_thw is not None:
- merge_length = self.image_processor.merge_size**2
- index = 0
- for i in range(len(text)):
- while self.image_token in text[i]:
- num_image_tokens = image_grid_thw[index].prod() // merge_length
- text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
- index += 1
- text[i] = text[i].replace("<|placeholder|>", self.image_token)
- if video_grid_thw is not None:
- merge_length = self.video_processor.merge_size**2
- video_index = 0
- for i in range(len(text)):
- while self.video_token in text[i]:
- num_frames = video_grid_thw[video_index][0]
- video_structure = ""
- metadata = video_metadata[video_index]
- if metadata.fps is None:
- logger.warning_once(
- "SmolVLM requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
- "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. "
- "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
- )
- metadata.fps = 24 if metadata.fps is None else metadata.fps
- timestamps = metadata.timestamps[::2] # mrope
- unique_timestamps = []
- for idx in range(0, len(timestamps)):
- unique_timestamps.append(timestamps[idx])
- selected_timestamps = unique_timestamps[:num_frames]
- while len(selected_timestamps) < num_frames:
- selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
- for frame_idx in range(num_frames):
- timestamp_sec = selected_timestamps[frame_idx]
- frame_structure = self.replace_frame_token_id(timestamp_sec)
- video_structure += frame_structure
- text[i] = text[i].replace(self.video_token, video_structure, 1)
- num_image_tokens = (
- video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
- )
- for frame_idx in range(num_frames):
- if self.image_token in text[i]:
- text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
- video_index += 1
- text[i] = text[i].replace("<|placeholder|>", self.image_token)
- return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
- return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
- text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
- self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
- if return_mm_token_type_ids:
- text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
- return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
- def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
- """
- Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
- Args:
- image_sizes (`list[list[int]]`, *optional*):
- The input sizes formatted as (height, width) per each image.
- video_sizes (`list[list[int]]`, *optional*):
- The input sizes formatted as (num_frames, height, width) per each video.
- Returns:
- `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
- input modalities, along with other useful data.
- """
- vision_data = {}
- if image_sizes is not None:
- images_kwargs = Glm46VProcessorKwargs._defaults.get("images_kwargs", {})
- images_kwargs.update(kwargs)
- merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
- num_image_patches = [
- self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
- for image_size in image_sizes
- ]
- num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
- vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
- if video_sizes is not None:
- videos_kwargs = Glm46VProcessorKwargs._defaults.get("videos_kwargs", {})
- videos_kwargs.update(kwargs)
- num_video_patches = [
- self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
- for video_size in video_sizes
- ]
- num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
- vision_data["num_video_tokens"] = num_video_tokens
- return MultiModalData(**vision_data)
- def post_process_image_text_to_text(
- self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
- ):
- """
- Post-process the output of the model to decode the text.
- Args:
- generated_outputs (`torch.Tensor` or `np.ndarray`):
- The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
- or `(sequence_length,)`.
- skip_special_tokens (`bool`, *optional*, defaults to `True`):
- Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
- clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
- Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
- **kwargs:
- Additional arguments to be passed to the tokenizer's `batch_decode method`.
- Returns:
- `list[str]`: The decoded text.
- """
- return self.tokenizer.batch_decode(
- generated_outputs,
- skip_special_tokens=skip_special_tokens,
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
- **kwargs,
- )
- @property
- def model_input_names(self):
- model_input_names = super().model_input_names
- model_input_names.append("mm_token_type_ids")
- return model_input_names
- def create_mm_token_type_ids(self, input_ids: list) -> list[list[int]]:
- # We have to iterate for each list separately because inputs
- # might be non-padded lists and we can't cast numpy on that!
- # Then cast numpy as each input for faster indexing
- mm_token_type_ids = []
- for input in input_ids:
- array_ids = np.array(input)
- mm_token_types = np.zeros_like(input)
- # Replace 0 -> 2 only inside video segments because Glm46V
- # uses the same special token to denote images and video
- # Otherwise replace 0 -> 1 for image modality
- starts = np.cumsum(array_ids == self.video_start_id, axis=0)
- ends = np.cumsum(array_ids == self.video_end_id, axis=0)
- is_video_modality = starts > ends
- mm_token_types[(array_ids == self.image_token_id) & is_video_modality] = 2
- mm_token_types[(array_ids == self.image_token_id) & (~is_video_modality)] = 1
- mm_token_type_ids.append(mm_token_types.tolist())
- return mm_token_type_ids
- def replace_frame_token_id(self, timestamp_sec):
- return f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec:.1f} seconds"
- __all__ = ["Glm46VProcessor"]
|