# Copyright 2026 Illuin Technology and contributors, and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Union import torch from huggingface_hub.dataclasses import strict from ...configuration_utils import PreTrainedConfig from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image from ...processing_utils import Unpack from ...tokenization_utils_base import TextInput from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging from ...utils.generic import can_return_tuple from ...utils.import_utils import requires from ..auto import CONFIG_MAPPING from ..auto.modeling_auto import AutoModel from ..colpali.modeling_colpali import ColPaliForRetrieval, ColPaliPreTrainedModel from ..colqwen2.configuration_colqwen2 import ColQwen2Config from ..idefics3.processing_idefics3 import Idefics3Processor, Idefics3ProcessorKwargs logger = logging.get_logger(__name__) @auto_docstring(checkpoint="ModernVBERT/colmodernvbert-merged") @strict class ColModernVBertConfig(ColQwen2Config): r""" Example: ```python from transformers import ColModernVBertConfig, ColModernVBertForRetrieval config = ColModernVBertConfig() model = ColModernVBertForRetrieval(config) ``` """ model_type = "colmodernvbert" sub_configs = {"vlm_config": PreTrainedConfig} vlm_config: dict | PreTrainedConfig | None = None embedding_dim: int = 128 initializer_range: float = 0.02 def __post_init__(self, **kwargs): if self.vlm_config is None: self.vlm_config = CONFIG_MAPPING["modernvbert"]() logger.info( "`vlm_config` is `None`. Initializing `vlm_config` with the `ModernVBertConfig` with default values." ) elif isinstance(self.vlm_config, dict): self.vlm_config = CONFIG_MAPPING[self.vlm_config["model_type"]](**self.vlm_config) if not hasattr(self.vlm_config, "vocab_size"): self.vlm_config.vocab_size = self.vlm_config.get_text_config().vocab_size super().__post_init__(**kwargs) class ColModernVBertProcessorKwargs(Idefics3ProcessorKwargs, total=False): _defaults = { "text_kwargs": { "padding": "longest", }, "images_kwargs": { "return_row_col_info": True, "data_format": "channels_first", "do_convert_rgb": True, }, "common_kwargs": {"return_tensors": "pt"}, } @requires(backends=("torch",)) @auto_docstring class ColModernVBertProcessor(Idefics3Processor): r""" Constructs a ColModernVBert processor which wraps a ModernVBertProcessor and special methods to process images and queries, as well as to compute the late-interaction retrieval score. [`ColModernVBertProcessor`] offers all the functionalities of [`ModernVBertProcessor`]. See the [`~ModernVBertProcessor.__call__`] for more information. Args: image_processor ([`Idefics3ImageProcessor`]): An instance of [`Idefics3ImageProcessor`]. The image processor is a required input. tokenizer (`PreTrainedTokenizerFast`, *optional*): An instance of [`PreTrainedTokenizerFast`]. This should correspond with the model's text model. The tokenizer is a required input. image_seq_len (`int`, *optional*, defaults to 64): The length of the image sequence i.e. the number of tokens per image in the input. visual_prompt_prefix (`Optional`, *optional*): A prefix to be prepended to visual prompts. query_prefix (`Optional`, *optional*): A prefix to be prepended to query prompts. """ def __init__( self, image_processor, tokenizer=None, chat_template=None, image_seq_len: int = 64, visual_prompt_prefix: str | None = None, query_prefix: str | None = None, **kwargs, ): r""" image_seq_len (`int`, *optional*, defaults to 64): The length of the image sequence i.e. the number of tokens per image in the input. visual_prompt_prefix (`str`, *optional*): A string that gets tokenized and prepended to the image tokens. query_prefix (`str`, *optional*): A prefix to be used for the query. """ chat_template = None # ColModernVBert does not use chat templates super().__init__( image_processor, tokenizer, chat_template=chat_template, image_seq_len=image_seq_len, **kwargs, ) self.visual_prompt_prefix = visual_prompt_prefix or ( f"<|begin_of_text|>User:{self.image_token}Describe the image.\nAssistant:" ) self.query_prefix = query_prefix or "" self.query_augmentation_token = self.end_of_utterance_token def process_images( self, images: ImageInput | None = None, **kwargs: Unpack[ColModernVBertProcessorKwargs], ) -> BatchFeature: """ Prepare for the model one or several image(s). Handles input validation, RGB conversion, and prepends the `visual_prompt_prefix` to each image. Optionally computes labels from `token_type_ids` when a `suffix` is provided in `text_kwargs`. Args: images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ output_kwargs = self._merge_kwargs( ColModernVBertProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) suffix = output_kwargs["text_kwargs"].pop("suffix", None) return_token_type_ids = suffix is not None # Normalize input to a flat list of images if is_valid_image(images): images = [images] elif isinstance(images, list) and is_valid_image(images[0]): pass elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])): raise ValueError("images must be an image, list of images or list of list of images") # Ensure all images are in RGB format images = [image.convert("RGB") for image in images] # Pair each image with the visual prompt prefix for the VLM backbone batch_doc = self.__call__( text=[self.visual_prompt_prefix] * len(images), images=images, images_kwargs=output_kwargs["images_kwargs"], text_kwargs=output_kwargs["text_kwargs"], ) # When suffix is provided, generate labels by masking non-suffix tokens if return_token_type_ids: labels = batch_doc["input_ids"].masked_fill(batch_doc["token_type_ids"] == 0, -100) batch_doc.update({"labels": labels}) return batch_doc def process_queries( self, text: TextInput | list[TextInput], **kwargs: Unpack[ColModernVBertProcessorKwargs], ) -> BatchFeature: """ Prepare for the model one or several text queries. Handles input validation, prepends the `query_prefix`, and appends query augmentation tokens (used to pad query embeddings for better late-interaction retrieval performance). Args: text (`str`, `list[str]`, `list[list[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). """ output_kwargs = self._merge_kwargs( ColModernVBertProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) suffix = output_kwargs["text_kwargs"].pop("suffix", None) if isinstance(text, str): text = [text] elif not (isinstance(text, list) and isinstance(text[0], str)): raise ValueError("Text must be a string or a list of strings") # Default suffix: repeat the augmentation token to pad query embeddings if suffix is None: suffix = self.query_augmentation_token * 10 # Build final queries: prefix + original query + augmentation suffix texts_query: list[str] = [self.query_prefix + query + suffix for query in text] batch_query = self.__call__( text=texts_query, return_token_type_ids=False, text_kwargs=output_kwargs["text_kwargs"], ) return batch_query def score_retrieval( self, query_embeddings: Union["torch.Tensor", list["torch.Tensor"]], passage_embeddings: Union["torch.Tensor", list["torch.Tensor"]], batch_size: int = 128, output_dtype: Optional["torch.dtype"] = None, output_device: Union["torch.device", str] = "cpu", ) -> "torch.Tensor": """ Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector query embeddings (`qs`) and passage embeddings (`ps`). For ColQwen2, a passage is the image of a document page. Because the embedding tensors are multi-vector and can thus have different shapes, they should be fed as: (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim) (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually obtained by padding the list of tensors. Args: query_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Query embeddings. passage_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Passage embeddings. batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores. output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor. If `None`, the dtype of the input embeddings is used. output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor. Returns: `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score tensor is saved on the "cpu" device. """ if len(query_embeddings) == 0: raise ValueError("No queries provided") if len(passage_embeddings) == 0: raise ValueError("No passages provided") if query_embeddings[0].device != passage_embeddings[0].device: raise ValueError("Queries and passages must be on the same device") if query_embeddings[0].dtype != passage_embeddings[0].dtype: raise ValueError("Queries and passages must have the same dtype") if output_dtype is None: output_dtype = query_embeddings[0].dtype scores: list[torch.Tensor] = [] for i in range(0, len(query_embeddings), batch_size): batch_scores: list[torch.Tensor] = [] batch_queries = torch.nn.utils.rnn.pad_sequence( query_embeddings[i : i + batch_size], batch_first=True, padding_value=0 ) for j in range(0, len(passage_embeddings), batch_size): batch_passages = torch.nn.utils.rnn.pad_sequence( passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0 ) batch_scores.append( torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2) ) scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device)) return torch.cat(scores, dim=0) @auto_docstring class ColModernVBertPreTrainedModel(ColPaliPreTrainedModel): config: ColModernVBertConfig @dataclass @auto_docstring( custom_intro=""" Base class for ColModernVBert embeddings output. """ ) class ColModernVBertForRetrievalOutput(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): The embeddings of the model. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True` and `pixel_values` are provided): Tuple of `torch.FloatTensor` (one for the output of the image modality projection + one for the output of each layer) of shape `(batch_size, num_channels, image_size, image_size)`. Hidden-states of the image encoder at the output of each layer plus the initial modality projection outputs. """ loss: torch.FloatTensor | None = None embeddings: torch.Tensor | None = None hidden_states: tuple[torch.FloatTensor] | None = None image_hidden_states: tuple[torch.FloatTensor] | None = None attentions: tuple[torch.FloatTensor] | None = None @auto_docstring( custom_intro=""" Following the ColPali approach, ColModernVBert leverages VLMs to construct efficient multi-vector embeddings directly from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity between these document embeddings and the corresponding query embeddings, using the late interaction method introduced in ColBERT. Using ColModernVBert removes the need for potentially complex and brittle layout recognition and OCR pipelines with a single model that can take into account both the textual and visual content (layout, charts, ...) of a document. ColModernVBert is trained on top of ModernVBert, and was introduced in the following paper: [*ModernVBERT: Towards Smaller Visual Document Retrievers*](https://arxiv.org/abs/2510.01149). ColModernVBert is part of the ColVision model family, which was introduced with ColPali in the following paper: [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://huggingface.co/papers/2407.01449). """ ) class ColModernVBertForRetrieval(ColPaliForRetrieval): def __init__(self, config: ColModernVBertConfig): super().__init__(config) self.vlm = AutoModel.from_config(config.vlm_config) self.post_init() @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, pixel_values: torch.FloatTensor | None = None, attention_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ColModernVBertForRetrievalOutput: vlm_output = self.vlm( input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, **kwargs, ) last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size) proj_dtype = self.embedding_proj_layer.weight.dtype embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype)) # (batch_size, sequence_length, dim) # L2 normalization embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) if attention_mask is not None: attention_mask = attention_mask.to(dtype=embeddings.dtype, device=embeddings.device) embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) return ColModernVBertForRetrievalOutput( embeddings=embeddings, hidden_states=vlm_output.hidden_states, attentions=vlm_output.attentions, image_hidden_states=vlm_output.image_hidden_states, ) __all__ = [ "ColModernVBertConfig", "ColModernVBertForRetrieval", "ColModernVBertPreTrainedModel", "ColModernVBertProcessor", ]