| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626 |
- # Copyright 2026 Illuin Technology and contributors, and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from dataclasses import dataclass
- from typing import Literal
- import torch
- import torch.nn as nn
- from huggingface_hub.dataclasses import strict
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ... import initialization as init
- from ...configuration_utils import PreTrainedConfig
- from ...modeling_outputs import (
- BaseModelOutput,
- MaskedLMOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
- )
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, logging
- from ...utils.generic import can_return_tuple
- from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel
- from ..modernbert.modeling_modernbert import ModernBertPredictionHead
- from ..smolvlm.modeling_smolvlm import SmolVLMModel, SmolVLMPreTrainedModel
- logger = logging.get_logger(__name__)
- @auto_docstring(checkpoint="ModernVBERT/modernvbert")
- @strict
- class ModernVBertConfig(PreTrainedConfig):
- r"""
- pixel_shuffle_factor (`int | None`, *optional*, defaults to 4):
- Scale factor used by any pixel-shuffle / upsampling operations in the vision head.
- initializer_cutoff_factor (`float | None`, *optional*, defaults to 2.0):
- The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
- classifier_pooling (`Literal["cls", "mean"]`, *optional*, defaults to `"cls"`):
- The pooling strategy to use for classification tasks.
- classifier_bias (`bool | None`, *optional*, defaults to `False`):
- Whether to add a bias term to the classification head
- Example:
- ```python
- >>> from transformers import ModernVBertConfig
- >>> # Initializing configuration
- >>> configuration = ModernVBertConfig()
- >>> # Initializing a model from the configuration (model class is implemented in
- >>> # `modernvbert.modeling_modernvbert`)
- >>> from transformers import ModernVBertModel
- >>> model = ModernVBertModel(configuration)
- >>> # Accessing the model configuration
- >>> cfg = model.config
- ```"""
- model_type = "modernvbert"
- sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
- text_config: PreTrainedConfig | dict | None = None
- vision_config: PreTrainedConfig | dict | None = None
- image_token_id: int = 50407
- pixel_shuffle_factor: int = 4
- initializer_range: float = 0.02
- initializer_cutoff_factor: float = 2.0
- classifier_pooling: Literal["cls", "mean"] = "cls"
- classifier_dropout: float | int = 0.0
- classifier_bias: bool = False
- tie_word_embeddings: bool = False
- def __post_init__(self, **kwargs):
- if self.text_config is None:
- self.text_config = CONFIG_MAPPING["modernbert"]()
- elif isinstance(self.text_config, dict):
- self.text_config = CONFIG_MAPPING["modernbert"](**self.text_config)
- if self.vision_config is None:
- self.vision_config = CONFIG_MAPPING["siglip_vision_model"]()
- elif isinstance(self.vision_config, dict):
- self.vision_config = CONFIG_MAPPING["siglip_vision_model"](**self.vision_config)
- super().__post_init__(**kwargs)
- @dataclass
- class ModernVBertBaseModelOutput(BaseModelOutput):
- """
- Base class for ModernVBERT model's outputs.
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
- hidden_size)` is output.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
- Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
- sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder
- """
- last_hidden_state: torch.FloatTensor = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- image_hidden_states: tuple[torch.FloatTensor] | None = None
- @dataclass
- class ModernVBertMaskedLMOutput(MaskedLMOutput):
- """
- Base class for ModernVBERT model's outputs with masked language modeling loss.
- Args:
- loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
- Masked language modeling (MLM) loss.
- logits (`torch.FloatTensor`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
- Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
- sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- image_hidden_states: torch.FloatTensor | None = None
- class ModernVBertConnector(nn.Module):
- """
- Connector module for ModernVBERT. It performs a pixel shuffle operation followed by a linear projection to match the text model's hidden size.
- Based on https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
- """
- def __init__(self, config):
- super().__init__()
- self.pixel_shuffle_factor = config.pixel_shuffle_factor
- self.modality_projection = nn.Linear(
- config.vision_config.hidden_size * (config.pixel_shuffle_factor**2),
- config.text_config.hidden_size,
- bias=False,
- )
- def pixel_shuffle(self, image_hidden_states, pixel_shuffle_factor):
- batch_size, seq_length, embed_dim = image_hidden_states.size()
- height = width = int(seq_length**0.5)
- image_hidden_states = image_hidden_states.view(batch_size, height, width, embed_dim)
- image_hidden_states = image_hidden_states.view(
- batch_size, height, int(width / pixel_shuffle_factor), embed_dim * pixel_shuffle_factor
- )
- image_hidden_states = image_hidden_states.permute(0, 2, 1, 3)
- image_hidden_states = image_hidden_states.reshape(
- batch_size,
- int(width / pixel_shuffle_factor),
- int(height / pixel_shuffle_factor),
- embed_dim * (pixel_shuffle_factor**2),
- )
- image_hidden_states = image_hidden_states.permute(0, 2, 1, 3)
- return image_hidden_states.reshape(
- batch_size, int(seq_length / (pixel_shuffle_factor**2)), embed_dim * (pixel_shuffle_factor**2)
- )
- def forward(self, image_hidden_states):
- image_hidden_states = self.pixel_shuffle(image_hidden_states, self.pixel_shuffle_factor)
- return self.modality_projection(image_hidden_states)
- @auto_docstring
- class ModernVBertPreTrainedModel(SmolVLMPreTrainedModel):
- config_class = ModernVBertConfig
- _no_split_modules = []
- @torch.no_grad()
- def _init_weights(self, module):
- PreTrainedModel._init_weights(self, module)
- def init_weight(module: nn.Module, std: float):
- cutoff_factor = getattr(self.config, "initializer_cutoff_factor", 2.0)
- init.trunc_normal_(
- module.weight,
- mean=0.0,
- std=std,
- a=-cutoff_factor * std,
- b=cutoff_factor * std,
- )
- if isinstance(module, (nn.Linear, nn.Conv2d)):
- if module.bias is not None:
- init.zeros_(module.bias)
- if isinstance(module, ModernVBertConnector):
- out_std = self.config.initializer_range / math.sqrt(2.0 * self.config.text_config.num_hidden_layers)
- init_weight(module.modality_projection, out_std)
- elif isinstance(module, ModernVBertForMaskedLM):
- out_std = self.config.initializer_range / math.sqrt(2.0 * self.config.text_config.num_hidden_layers)
- init_weight(module.lm_head, out_std)
- elif isinstance(
- module,
- (
- ModernVBertForSequenceClassification,
- ModernVBertForTokenClassification,
- ),
- ):
- final_out_std = self.config.initializer_range / math.sqrt(self.config.text_config.hidden_size)
- init_weight(module.classifier, final_out_std)
- @auto_docstring(
- custom_intro="""
- ModernVBertModel is a model that combines a vision encoder (SigLIP) and a text encoder (ModernBert).
- ModernVBert is the base model of the visual retriver ColModernVBert, and was introduced in the following paper:
- [*ModernVBERT: Towards Smaller Visual Document Retrievers*](https://arxiv.org/abs/2510.01149).
- """
- )
- class ModernVBertModel(SmolVLMModel):
- def __init__(self, config: ModernVBertConfig):
- super().__init__(config)
- # init components
- self.connector = ModernVBertConnector(config)
- self.text_model = AutoModel.from_config(config.text_config)
- self.vision_model = AutoModel.from_config(config.vision_config)
- self.image_seq_len = int(
- ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
- / (config.pixel_shuffle_factor**2)
- )
- # initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring(
- custom_intro="""
- Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
- the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
- max_num_images is the maximum number of images among the batch_size samples in the batch.
- Padding images are not needed beyond padding the pixel_values at the entrance of the model.
- For efficiency, we only pass through the vision_model's forward the real images by
- discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
- image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
- """,
- checkpoint="ModernVBERT/modernvbert",
- )
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- pixel_attention_mask: torch.BoolTensor | None = None,
- image_hidden_states: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | ModernVBertBaseModelOutput:
- r"""
- pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
- Mask to avoid performing attention on padding pixel indices.
- image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
- The hidden states of the image encoder after modality projection.
- """
- if inputs_embeds is None:
- inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
- # Images processing
- if pixel_values is not None:
- image_hidden_states = self.get_image_features(
- pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask
- ).pooler_output
- # Merge image and text embeddings
- if image_hidden_states is not None:
- image_hidden_states = image_hidden_states.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
- inputs_embeds = self.inputs_merger(
- input_ids=input_ids, inputs_embeds=inputs_embeds, image_hidden_states=image_hidden_states
- )
- # Language model pass
- outputs = self.text_model(
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- position_ids=position_ids,
- **kwargs,
- )
- return ModernVBertBaseModelOutput(
- last_hidden_state=outputs.last_hidden_state,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- image_hidden_states=image_hidden_states,
- )
- class ModernVBertPredictionHead(ModernBertPredictionHead):
- pass
- @auto_docstring
- class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
- _tied_weights_keys = {"lm_head.weight": "model.text_model.embeddings.tok_embeddings.weight"}
- def __init__(self, config: ModernVBertConfig):
- super().__init__(config)
- self.vocab_size = config.text_config.vocab_size
- self.model = ModernVBertModel(config)
- self.projection_head = ModernVBertPredictionHead(config.text_config)
- self.lm_head = nn.Linear(config.text_config.hidden_size, self.vocab_size, bias=config.text_config.decoder_bias)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.lm_head
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
- @can_return_tuple
- @auto_docstring(
- custom_intro="""
- Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
- the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
- max_num_images is the maximum number of images among the batch_size samples in the batch.
- Padding images are not needed beyond padding the pixel_values at the entrance of the model.
- For efficiency, we only pass through the vision_model's forward the real images by
- discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
- image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
- """,
- checkpoint="ModernVBERT/modernvbert",
- )
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- pixel_attention_mask: torch.BoolTensor | None = None,
- image_hidden_states: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | ModernVBertMaskedLMOutput:
- r"""
- pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
- Mask to avoid performing attention on padding pixel indices.
- image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
- The hidden states of the image encoder after modality projection.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- text_config.]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
- ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., text_config.]`.
- """
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- pixel_values=pixel_values,
- pixel_attention_mask=pixel_attention_mask,
- image_hidden_states=image_hidden_states,
- **kwargs,
- )
- hidden_states = outputs[0]
- logits = self.lm_head(self.projection_head(hidden_states))
- loss = None
- if labels is not None:
- criterion = CrossEntropyLoss()
- loss = criterion(logits.view(-1, self.vocab_size), labels.view(-1))
- return ModernVBertMaskedLMOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- image_hidden_states=outputs.image_hidden_states,
- )
- @auto_docstring(
- custom_intro="""
- The ModernVBert Model with a sequence classification head on top that performs pooling.
- """
- )
- class ModernVBertForSequenceClassification(ModernVBertPreTrainedModel):
- def __init__(self, config: ModernVBertConfig):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.config = config
- self.model = ModernVBertModel(config)
- self.head = ModernVBertPredictionHead(config.text_config)
- self.drop = nn.Dropout(config.classifier_dropout)
- self.classifier = nn.Linear(config.text_config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring(
- custom_intro="""
- Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
- the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
- max_num_images is the maximum number of images among the batch_size samples in the batch.
- Padding images are not needed beyond padding the pixel_values at the entrance of the model.
- For efficiency, we only pass through the vision_model's forward the real images by
- discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
- image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
- """,
- checkpoint="ModernVBERT/modernvbert",
- )
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- pixel_attention_mask: torch.BoolTensor | None = None,
- image_hidden_states: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | SequenceClassifierOutput:
- r"""
- pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
- Mask to avoid performing attention on padding pixel indices.
- image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
- The hidden states of the image encoder after modality projection.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- text_config.]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
- ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., text_config.]`.
- """
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- pixel_values=pixel_values,
- pixel_attention_mask=pixel_attention_mask,
- image_hidden_states=image_hidden_states,
- **kwargs,
- )
- last_hidden_state = outputs[0]
- if self.config.classifier_pooling == "cls":
- last_hidden_state = last_hidden_state[:, 0]
- elif self.config.classifier_pooling == "mean":
- if inputs_embeds is not None:
- batch_size, seq_len = inputs_embeds.shape[:2]
- else:
- batch_size, seq_len = input_ids.shape[:2]
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
- last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
- dim=1, keepdim=True
- )
- pooled_output = self.head(last_hidden_state)
- pooled_output = self.drop(pooled_output)
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- The ModernVBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
- """
- )
- class ModernVBertForTokenClassification(ModernVBertPreTrainedModel):
- def __init__(self, config: ModernVBertConfig):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.model = ModernVBertModel(config)
- self.head = ModernVBertPredictionHead(config.text_config)
- self.drop = nn.Dropout(config.classifier_dropout)
- self.classifier = nn.Linear(config.text_config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring(
- custom_intro="""
- Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
- the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
- max_num_images is the maximum number of images among the batch_size samples in the batch.
- Padding images are not needed beyond padding the pixel_values at the entrance of the model.
- For efficiency, we only pass through the vision_model's forward the real images by
- discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
- image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
- """,
- checkpoint="ModernVBERT/modernvbert",
- )
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- pixel_attention_mask: torch.BoolTensor | None = None,
- image_hidden_states: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | TokenClassifierOutput:
- r"""
- pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
- Mask to avoid performing attention on padding pixel indices.
- image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
- The hidden states of the image encoder after modality projection.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- text_config.]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
- ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., text_config.]`.
- """
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- pixel_values=pixel_values,
- pixel_attention_mask=pixel_attention_mask,
- image_hidden_states=image_hidden_states,
- **kwargs,
- )
- last_hidden_state = outputs[0]
- last_hidden_state = self.head(last_hidden_state)
- last_hidden_state = self.drop(last_hidden_state)
- logits = self.classifier(last_hidden_state)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- __all__ = [
- "ModernVBertConfig",
- "ModernVBertPreTrainedModel",
- "ModernVBertModel",
- "ModernVBertForMaskedLM",
- "ModernVBertForSequenceClassification",
- "ModernVBertForTokenClassification",
- ]
|