# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/colmodernvbert/modular_colmodernvbert.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_colmodernvbert.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2026 Illuin Technology and contributors, and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re from itertools import accumulate from typing import TYPE_CHECKING, Optional, Union import numpy as np import torch from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image, load_image from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import AddedToken, BatchEncoding, TextInput from ...utils import auto_docstring from ...utils.import_utils import requires if TYPE_CHECKING: from ...tokenization_utils_base import PreTokenizedInput class ColModernVBertProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { "padding": "longest", }, "images_kwargs": { "return_row_col_info": True, "data_format": "channels_first", "do_convert_rgb": True, }, "common_kwargs": {"return_tensors": "pt"}, } def is_url(val) -> bool: return isinstance(val, str) and val.startswith("http") def is_image_or_image_url(elem): return is_url(elem) or is_valid_image(elem) def _prompt_split_image(image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_img_token): """Prompt with expanded image tokens for when the image is split into patches.""" text_split_images = "" for n_h in range(image_rows): for n_w in range(image_cols): text_split_images += ( f"{fake_token_around_image}" + f"" + f"{image_token}" * image_seq_len ) text_split_images += "\n" text_split_images += ( f"\n{fake_token_around_image}" + f"{global_img_token}" + f"{image_token}" * image_seq_len + f"{fake_token_around_image}" ) return text_split_images def _prompt_single_image(image_seq_len, fake_token_around_image, image_token, global_img_token): """Prompt with expanded image tokens for a single image.""" return ( f"{fake_token_around_image}" + f"{global_img_token}" + f"{image_token}" * image_seq_len + f"{fake_token_around_image}" ) def get_image_prompt_string( image_rows, image_cols, image_seq_len, fake_token_around_image, image_token, global_img_token ): if image_rows == 0 and image_cols == 0: return _prompt_single_image( image_seq_len, fake_token_around_image=fake_token_around_image, image_token=image_token, global_img_token=global_img_token, ) return _prompt_split_image( image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_img_token ) @requires(backends=("torch",)) @auto_docstring class ColModernVBertProcessor(ProcessorMixin): r""" Constructs a ColModernVBert processor which wraps a ModernVBertProcessor and special methods to process images and queries, as well as to compute the late-interaction retrieval score. [`ColModernVBertProcessor`] offers all the functionalities of [`ModernVBertProcessor`]. See the [`~ModernVBertProcessor.__call__`] for more information. Args: image_processor ([`Idefics3ImageProcessor`]): An instance of [`Idefics3ImageProcessor`]. The image processor is a required input. tokenizer (`PreTrainedTokenizerFast`, *optional*): An instance of [`PreTrainedTokenizerFast`]. This should correspond with the model's text model. The tokenizer is a required input. image_seq_len (`int`, *optional*, defaults to 64): The length of the image sequence i.e. the number of tokens per image in the input. visual_prompt_prefix (`Optional`, *optional*): A prefix to be prepended to visual prompts. query_prefix (`Optional`, *optional*): A prefix to be prepended to query prompts. """ def __init__( self, image_processor, tokenizer=None, chat_template=None, image_seq_len: int = 64, visual_prompt_prefix: str | None = None, query_prefix: str | None = None, **kwargs, ): r""" image_seq_len (`int`, *optional*, defaults to 64): The length of the image sequence i.e. the number of tokens per image in the input. visual_prompt_prefix (`str`, *optional*): A string that gets tokenized and prepended to the image tokens. query_prefix (`str`, *optional*): A prefix to be used for the query. """ chat_template = None # ColModernVBert does not use chat templates self.fake_image_token = AddedToken("", normalized=False, special=True).content self.image_token = AddedToken("", normalized=False, special=True).content self.end_of_utterance_token = AddedToken("", normalized=False, special=True).content self.global_image_tag = "" # https://github.com/huggingface/transformers/pull/32473/files/8063e5e17362571b693f1db95167f5443a3be1b2#r1734825341 self.image_seq_len = image_seq_len self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) self.fake_image_token_id = tokenizer.convert_tokens_to_ids(self.fake_image_token) self.global_image_token_id = tokenizer.convert_tokens_to_ids(self.global_image_tag) self.row_col_ids = [ tokenizer.convert_tokens_to_ids(f"") for i in range(6) for j in range(6) ] # This regex matches one or more occurrences of tags (optionally surrounded by newline characters) # or tags (where x and y are digits, also optionally surrounded by newline characters). self._regex_to_remove_extra_special_tokens = re.compile(r"(\n?\n?|\n?)+") tokens_to_add = { "additional_special_tokens": [ self.fake_image_token, self.image_token, self.end_of_utterance_token, ] } tokenizer.add_special_tokens(tokens_to_add) self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs) self.visual_prompt_prefix = visual_prompt_prefix or ( f"<|begin_of_text|>User:{self.image_token}Describe the image.\nAssistant:" ) self.query_prefix = query_prefix or "" self.query_augmentation_token = self.end_of_utterance_token def _extract_images_from_prompts(self, prompts): prompt_images = [] for prompt in prompts: images = [] for elem in prompt: if is_valid_image(elem): images.append(elem) elif is_url(elem): images.append(load_image(elem)) prompt_images.append(images) return prompt_images @auto_docstring def __call__( self, images: ImageInput | list[ImageInput] | list[list[ImageInput]] = None, text: Union[TextInput, "PreTokenizedInput", list[TextInput], list["PreTokenizedInput"]] = None, image_seq_len: int | None = None, **kwargs: Unpack[ColModernVBertProcessorKwargs], ) -> BatchEncoding: r""" image_seq_len (`int`, *optional*): The length of the image sequence. If not provided, the default value of self.image_seq_len is used. image_seq_len should be equal to int(((image_size // patch_size) ** 2) / (scale_factor**2)) """ if text is None and images is None: raise ValueError("You must provide either `text` or `images`.") output_kwargs = self._merge_kwargs( ColModernVBertProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) n_images_in_text = [] n_images_in_images = [] inputs = {} if text is not None: if isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") n_images_in_text = [sample.count(self.image_token) for sample in text] if images is not None: if is_image_or_image_url(images): images = [[images]] elif isinstance(images, (list, tuple)) and is_image_or_image_url(images[0]): if text is not None: if sum(n_images_in_text) != len(images): raise ValueError( f"The total number of {self.image_token} tokens in the prompts should be the same as the number of images passed." f" Found {sum(n_images_in_text)} {self.image_token} tokens and {len(images)} images." ) # Reorganize the images to match the prompts cumsum_images_in_text = [0] + list(accumulate(n_images_in_text)) images = [ images[cumsum_images_in_text[i] : cumsum_images_in_text[i + 1]] for i in range(len(n_images_in_text)) ] else: images = [images] elif ( not isinstance(images, (list, tuple)) and not isinstance(images[0], (list, tuple)) and not is_image_or_image_url(images[0][0]) ): raise ValueError( "Invalid input images. Please provide a single image or a list of images or a list of list of images." ) n_images_in_images = [len(sample) for sample in images] # Load images if they are URLs images = [[load_image(im) if is_url(im) else im for im in sample] for sample in images] image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) inputs.update(image_inputs) if text is not None: if n_images_in_images != n_images_in_text: raise ValueError( f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same." ) image_rows = inputs.pop("rows", [[0] * n_images for n_images in n_images_in_text]) image_cols = inputs.pop("cols", [[0] * n_images for n_images in n_images_in_text]) fake_image_token = self.fake_image_token image_token = self.image_token global_img_token = self.global_image_tag prompt_strings = [] batch_image_seq_lengths = [] for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols): # Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len` image_prompt_strings = [] image_seq_lengths = [] for n_rows, n_cols in zip(sample_rows, sample_cols): image_prompt_string = get_image_prompt_string( n_rows, n_cols, image_seq_len, image_token=image_token, fake_token_around_image=fake_image_token, global_img_token=global_img_token, ) # Add +2 and +3 for special BOI/EOI/fake_image_wrapper tokens row_length = (self.image_seq_len + 2) * n_cols + 1 image_seq_lengths.append((self.image_seq_len + 3) + row_length * n_rows) image_prompt_strings.append(image_prompt_string) batch_image_seq_lengths.append(image_seq_lengths) split_sample = sample.split(image_token) if len(split_sample) == 0: raise ValueError("The image token should be present in the text.") # Place in the image prompt strings where the image tokens are sample = split_sample[0] for i, image_prompt_string in enumerate(image_prompt_strings): sample += image_prompt_string + split_sample[i + 1] prompt_strings.append(sample) text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"]) inputs.update(text_inputs) elif text is not None: if any(n_images_in_text): raise ValueError( f"Found {sum(n_images_in_text)} {self.image_token} tokens in the text but no images were passed." ) text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) inputs.update(text_inputs) if return_mm_token_type_ids: inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(inputs["input_ids"], batch_image_seq_lengths) return BatchFeature(data=inputs, tensor_type=return_tensors) def create_mm_token_type_ids(self, input_ids: list, batch_image_seq_lengths: list[int]) -> list[list[int]]: # We have to iterate for each list separately because inputs # might be non-padded lists and we can't cast numpy on that! # Then cast numpy as each input for faster indexing mm_token_type_ids = [] for i, seq_lengths in enumerate(batch_image_seq_lengths): array_ids = np.array(input_ids[i]) mm_token_types = np.zeros_like(array_ids) image_start_positions = np.where(array_ids == self.fake_image_token_id)[0] j = 0 for seq_len in seq_lengths: if j >= len(image_start_positions): break start = image_start_positions[j] end = start + seq_len mm_token_types[start:end] = 1 j = np.searchsorted(image_start_positions, end) mm_token_type_ids.append(mm_token_types.tolist()) return mm_token_type_ids def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): """ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. Args: image_sizes (`list[list[int]]`, *optional*): The input sizes formatted as (height, width) per each image. Returns: `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided input modalities, along with other useful data. """ vision_data = {} if image_sizes is not None: images_kwargs = ColModernVBertProcessorKwargs._defaults.get("images_kwargs", {}) images_kwargs.update(kwargs) num_image_row_cols = [ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) for image_size in image_sizes ] base_image_length = self.image_seq_len + 3 col_length = self.image_seq_len + 2 num_image_tokens = [] num_image_patches = [] for num_patches, num_rows, num_cols in num_image_row_cols: row_length = col_length * num_cols + 1 num_image_tokens.append(base_image_length + (row_length * num_rows)) num_image_patches.append(num_patches) vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) return MultiModalData(**vision_data) def process_images( self, images: ImageInput | None = None, **kwargs: Unpack[ColModernVBertProcessorKwargs], ) -> BatchFeature: """ Prepare for the model one or several image(s). Handles input validation, RGB conversion, and prepends the `visual_prompt_prefix` to each image. Optionally computes labels from `token_type_ids` when a `suffix` is provided in `text_kwargs`. Args: images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ output_kwargs = self._merge_kwargs( ColModernVBertProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) suffix = output_kwargs["text_kwargs"].pop("suffix", None) return_token_type_ids = suffix is not None # Normalize input to a flat list of images if is_valid_image(images): images = [images] elif isinstance(images, list) and is_valid_image(images[0]): pass elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])): raise ValueError("images must be an image, list of images or list of list of images") # Ensure all images are in RGB format images = [image.convert("RGB") for image in images] # Pair each image with the visual prompt prefix for the VLM backbone batch_doc = self.__call__( text=[self.visual_prompt_prefix] * len(images), images=images, images_kwargs=output_kwargs["images_kwargs"], text_kwargs=output_kwargs["text_kwargs"], ) # When suffix is provided, generate labels by masking non-suffix tokens if return_token_type_ids: labels = batch_doc["input_ids"].masked_fill(batch_doc["token_type_ids"] == 0, -100) batch_doc.update({"labels": labels}) return batch_doc def process_queries( self, text: TextInput | list[TextInput], **kwargs: Unpack[ColModernVBertProcessorKwargs], ) -> BatchFeature: """ Prepare for the model one or several text queries. Handles input validation, prepends the `query_prefix`, and appends query augmentation tokens (used to pad query embeddings for better late-interaction retrieval performance). Args: text (`str`, `list[str]`, `list[list[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). """ output_kwargs = self._merge_kwargs( ColModernVBertProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) suffix = output_kwargs["text_kwargs"].pop("suffix", None) if isinstance(text, str): text = [text] elif not (isinstance(text, list) and isinstance(text[0], str)): raise ValueError("Text must be a string or a list of strings") # Default suffix: repeat the augmentation token to pad query embeddings if suffix is None: suffix = self.query_augmentation_token * 10 # Build final queries: prefix + original query + augmentation suffix texts_query: list[str] = [self.query_prefix + query + suffix for query in text] batch_query = self.__call__( text=texts_query, return_token_type_ids=False, text_kwargs=output_kwargs["text_kwargs"], ) return batch_query def score_retrieval( self, query_embeddings: Union["torch.Tensor", list["torch.Tensor"]], passage_embeddings: Union["torch.Tensor", list["torch.Tensor"]], batch_size: int = 128, output_dtype: Optional["torch.dtype"] = None, output_device: Union["torch.device", str] = "cpu", ) -> "torch.Tensor": """ Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector query embeddings (`qs`) and passage embeddings (`ps`). For ColQwen2, a passage is the image of a document page. Because the embedding tensors are multi-vector and can thus have different shapes, they should be fed as: (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim) (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually obtained by padding the list of tensors. Args: query_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Query embeddings. passage_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Passage embeddings. batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores. output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor. If `None`, the dtype of the input embeddings is used. output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor. Returns: `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score tensor is saved on the "cpu" device. """ if len(query_embeddings) == 0: raise ValueError("No queries provided") if len(passage_embeddings) == 0: raise ValueError("No passages provided") if query_embeddings[0].device != passage_embeddings[0].device: raise ValueError("Queries and passages must be on the same device") if query_embeddings[0].dtype != passage_embeddings[0].dtype: raise ValueError("Queries and passages must have the same dtype") if output_dtype is None: output_dtype = query_embeddings[0].dtype scores: list[torch.Tensor] = [] for i in range(0, len(query_embeddings), batch_size): batch_scores: list[torch.Tensor] = [] batch_queries = torch.nn.utils.rnn.pad_sequence( query_embeddings[i : i + batch_size], batch_first=True, padding_value=0 ) for j in range(0, len(passage_embeddings), batch_size): batch_passages = torch.nn.utils.rnn.pad_sequence( passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0 ) batch_scores.append( torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2) ) scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device)) return torch.cat(scores, dim=0) __all__ = ["ColModernVBertProcessor"]