| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/colmodernvbert/modular_colmodernvbert.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_colmodernvbert.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2026 Illuin Technology and contributors, and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from dataclasses import dataclass
- import torch
- from torch import nn
- from ... import initialization as init
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring
- from ...utils.generic import can_return_tuple
- from ..auto.modeling_auto import AutoModel
- from .configuration_colmodernvbert import ColModernVBertConfig
- @auto_docstring
- class ColModernVBertPreTrainedModel(PreTrainedModel):
- config: ColModernVBertConfig
- base_model_prefix = "model"
- input_modalities = ("image", "text")
- _no_split_modules = []
- _supports_sdpa = True
- _supports_flash_attn = True
- _supports_flex_attn = True
- @torch.no_grad()
- def _init_weights(self, module):
- std = (
- self.config.initializer_range
- if hasattr(self.config, "initializer_range")
- else self.config.vlm_config.text_config.initializer_range
- )
- if isinstance(module, (nn.Linear, nn.Conv2d)):
- init.normal_(module.weight, mean=0.0, std=std)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif isinstance(module, nn.Embedding):
- init.normal_(module.weight, mean=0.0, std=std)
- # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
- if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
- init.zeros_(module.weight[module.padding_idx])
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for ColModernVBert embeddings output.
- """
- )
- class ColModernVBertForRetrievalOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- The embeddings of the model.
- image_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True` and `pixel_values` are provided):
- Tuple of `torch.FloatTensor` (one for the output of the image modality projection + one for the output of each layer) of shape
- `(batch_size, num_channels, image_size, image_size)`.
- Hidden-states of the image encoder at the output of each layer plus the initial modality projection outputs.
- """
- loss: torch.FloatTensor | None = None
- embeddings: torch.Tensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- image_hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- @auto_docstring(
- custom_intro="""
- Following the ColPali approach, ColModernVBert leverages VLMs to construct efficient multi-vector embeddings directly
- from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity
- between these document embeddings and the corresponding query embeddings, using the late interaction method
- introduced in ColBERT.
- Using ColModernVBert removes the need for potentially complex and brittle layout recognition and OCR pipelines with
- a single model that can take into account both the textual and visual content (layout, charts, ...) of a document.
- ColModernVBert is trained on top of ModernVBert, and was introduced in the following paper:
- [*ModernVBERT: Towards Smaller Visual Document Retrievers*](https://arxiv.org/abs/2510.01149).
- ColModernVBert is part of the ColVision model family, which was introduced with ColPali in the following paper:
- [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://huggingface.co/papers/2407.01449).
- """
- )
- class ColModernVBertForRetrieval(ColModernVBertPreTrainedModel):
- base_model_prefix = "vlm"
- def __init__(self, config: ColModernVBertConfig):
- super().__init__(config)
- self.config = config
- self.vocab_size = config.vlm_config.text_config.vocab_size
- self.vlm = AutoModel.from_config(config.vlm_config)
- self.embedding_dim = self.config.embedding_dim
- self.embedding_proj_layer = nn.Linear(
- self.config.vlm_config.text_config.hidden_size,
- self.embedding_dim,
- )
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> ColModernVBertForRetrievalOutput:
- vlm_output = self.vlm(
- input_ids=input_ids,
- attention_mask=attention_mask,
- pixel_values=pixel_values,
- **kwargs,
- )
- last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
- proj_dtype = self.embedding_proj_layer.weight.dtype
- embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype)) # (batch_size, sequence_length, dim)
- # L2 normalization
- embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
- if attention_mask is not None:
- attention_mask = attention_mask.to(dtype=embeddings.dtype, device=embeddings.device)
- embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
- return ColModernVBertForRetrievalOutput(
- embeddings=embeddings,
- hidden_states=vlm_output.hidden_states,
- attentions=vlm_output.attentions,
- image_hidden_states=vlm_output.image_hidden_states,
- )
- __all__ = ["ColModernVBertForRetrieval", "ColModernVBertPreTrainedModel"]
|