modeling_modernvbert.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/modernvbert/modular_modernvbert.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_modernvbert.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2026 Illuin Technology and contributors, and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from dataclasses import dataclass
  22. import torch
  23. import torch.nn as nn
  24. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...modeling_outputs import (
  28. BaseModelOutput,
  29. BaseModelOutputWithPooling,
  30. MaskedLMOutput,
  31. SequenceClassifierOutput,
  32. TokenClassifierOutput,
  33. )
  34. from ...modeling_utils import PreTrainedModel
  35. from ...processing_utils import Unpack
  36. from ...utils import TransformersKwargs, auto_docstring, torch_compilable_check
  37. from ...utils.generic import can_return_tuple
  38. from ..auto import AutoModel
  39. from .configuration_modernvbert import ModernVBertConfig
  40. @dataclass
  41. class ModernVBertBaseModelOutput(BaseModelOutput):
  42. """
  43. Base class for ModernVBERT model's outputs.
  44. Args:
  45. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  46. Sequence of hidden-states at the output of the last layer of the model.
  47. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  48. hidden_size)` is output.
  49. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  50. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  51. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  52. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  53. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  54. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  55. sequence_length)`.
  56. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  57. heads.
  58. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  59. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  60. sequence_length, hidden_size)`.
  61. image_hidden_states of the model produced by the vision encoder
  62. """
  63. last_hidden_state: torch.FloatTensor = None
  64. hidden_states: tuple[torch.FloatTensor] | None = None
  65. attentions: tuple[torch.FloatTensor] | None = None
  66. image_hidden_states: tuple[torch.FloatTensor] | None = None
  67. @dataclass
  68. class ModernVBertMaskedLMOutput(MaskedLMOutput):
  69. """
  70. Base class for ModernVBERT model's outputs with masked language modeling loss.
  71. Args:
  72. loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
  73. Masked language modeling (MLM) loss.
  74. logits (`torch.FloatTensor`):
  75. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  76. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  77. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  78. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  79. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  80. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  81. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  82. sequence_length)`.
  83. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  84. heads.
  85. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  86. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  87. sequence_length, hidden_size)`.
  88. image_hidden_states of the model produced by the vision encoder
  89. """
  90. loss: torch.FloatTensor | None = None
  91. logits: torch.FloatTensor = None
  92. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  93. attentions: tuple[torch.FloatTensor, ...] | None = None
  94. image_hidden_states: torch.FloatTensor | None = None
  95. class ModernVBertConnector(nn.Module):
  96. """
  97. Connector module for ModernVBERT. It performs a pixel shuffle operation followed by a linear projection to match the text model's hidden size.
  98. Based on https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
  99. """
  100. def __init__(self, config):
  101. super().__init__()
  102. self.pixel_shuffle_factor = config.pixel_shuffle_factor
  103. self.modality_projection = nn.Linear(
  104. config.vision_config.hidden_size * (config.pixel_shuffle_factor**2),
  105. config.text_config.hidden_size,
  106. bias=False,
  107. )
  108. def pixel_shuffle(self, image_hidden_states, pixel_shuffle_factor):
  109. batch_size, seq_length, embed_dim = image_hidden_states.size()
  110. height = width = int(seq_length**0.5)
  111. image_hidden_states = image_hidden_states.view(batch_size, height, width, embed_dim)
  112. image_hidden_states = image_hidden_states.view(
  113. batch_size, height, int(width / pixel_shuffle_factor), embed_dim * pixel_shuffle_factor
  114. )
  115. image_hidden_states = image_hidden_states.permute(0, 2, 1, 3)
  116. image_hidden_states = image_hidden_states.reshape(
  117. batch_size,
  118. int(width / pixel_shuffle_factor),
  119. int(height / pixel_shuffle_factor),
  120. embed_dim * (pixel_shuffle_factor**2),
  121. )
  122. image_hidden_states = image_hidden_states.permute(0, 2, 1, 3)
  123. return image_hidden_states.reshape(
  124. batch_size, int(seq_length / (pixel_shuffle_factor**2)), embed_dim * (pixel_shuffle_factor**2)
  125. )
  126. def forward(self, image_hidden_states):
  127. image_hidden_states = self.pixel_shuffle(image_hidden_states, self.pixel_shuffle_factor)
  128. return self.modality_projection(image_hidden_states)
  129. @auto_docstring
  130. class ModernVBertPreTrainedModel(PreTrainedModel):
  131. config: ModernVBertConfig
  132. base_model_prefix = "model"
  133. input_modalities = ("image", "text")
  134. supports_gradient_checkpointing = True
  135. _no_split_modules = []
  136. _skip_keys_device_placement = "past_key_values"
  137. _supports_flash_attn = True
  138. _supports_sdpa = True
  139. _supports_flex_attn = True
  140. _supports_attention_backend = True
  141. config_class = ModernVBertConfig
  142. @torch.no_grad()
  143. def _init_weights(self, module):
  144. super()._init_weights(module)
  145. def init_weight(module: nn.Module, std: float):
  146. cutoff_factor = getattr(self.config, "initializer_cutoff_factor", 2.0)
  147. init.trunc_normal_(
  148. module.weight,
  149. mean=0.0,
  150. std=std,
  151. a=-cutoff_factor * std,
  152. b=cutoff_factor * std,
  153. )
  154. if isinstance(module, (nn.Linear, nn.Conv2d)):
  155. if module.bias is not None:
  156. init.zeros_(module.bias)
  157. if isinstance(module, ModernVBertConnector):
  158. out_std = self.config.initializer_range / math.sqrt(2.0 * self.config.text_config.num_hidden_layers)
  159. init_weight(module.modality_projection, out_std)
  160. elif isinstance(module, ModernVBertForMaskedLM):
  161. out_std = self.config.initializer_range / math.sqrt(2.0 * self.config.text_config.num_hidden_layers)
  162. init_weight(module.lm_head, out_std)
  163. elif isinstance(
  164. module,
  165. (
  166. ModernVBertForSequenceClassification,
  167. ModernVBertForTokenClassification,
  168. ),
  169. ):
  170. final_out_std = self.config.initializer_range / math.sqrt(self.config.text_config.hidden_size)
  171. init_weight(module.classifier, final_out_std)
  172. @auto_docstring(
  173. custom_intro="""
  174. ModernVBertModel is a model that combines a vision encoder (SigLIP) and a text encoder (ModernBert).
  175. ModernVBert is the base model of the visual retriver ColModernVBert, and was introduced in the following paper:
  176. [*ModernVBERT: Towards Smaller Visual Document Retrievers*](https://arxiv.org/abs/2510.01149).
  177. """
  178. )
  179. class ModernVBertModel(ModernVBertPreTrainedModel):
  180. """
  181. A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger
  182. in forward. Instead, we override inputs_merger here with custom logic.
  183. """
  184. def __init__(self, config: ModernVBertConfig):
  185. super().__init__(config)
  186. self.padding_idx = self.config.text_config.pad_token_id
  187. self.vocab_size = self.config.text_config.vocab_size
  188. self.vision_model = AutoModel.from_config(config.vision_config)
  189. # init components
  190. self.connector = ModernVBertConnector(config)
  191. self.text_model = AutoModel.from_config(config.text_config)
  192. self.image_seq_len = int(
  193. ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
  194. / (config.pixel_shuffle_factor**2)
  195. )
  196. self.image_token_id = self.config.image_token_id
  197. self.post_init()
  198. def get_input_embeddings(self):
  199. return self.text_model.get_input_embeddings()
  200. def set_input_embeddings(self, value):
  201. self.text_model.set_input_embeddings(value)
  202. def inputs_merger(
  203. self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor
  204. ):
  205. """
  206. This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
  207. The merging happens as follows:
  208. - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
  209. - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space.
  210. We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
  211. - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
  212. - To fit the format of that sequence, `input_ids`, `inputs_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
  213. """
  214. _, patch_size, _ = image_hidden_states.shape
  215. if input_ids is None:
  216. image_mask = inputs_embeds == self.get_input_embeddings()(
  217. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  218. )
  219. image_mask = image_mask[..., 0] # slice off the hidden dim
  220. else:
  221. image_mask = input_ids == self.config.image_token_id
  222. num_image_tokens = image_mask.sum(dim=1)
  223. torch_compilable_check(
  224. torch.all(num_image_tokens % patch_size == 0),
  225. "At least one sample has <image> tokens not divisible by patch_size.",
  226. )
  227. blocks_per_sample = num_image_tokens // patch_size
  228. offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0)
  229. block_offset = offsets[:-1]
  230. row_cum = image_mask.cumsum(dim=-1)
  231. chunk_idx = (row_cum - 1) // patch_size
  232. local_idx = (row_cum - 1) % patch_size
  233. block_idx = block_offset.unsqueeze(1) + chunk_idx
  234. image_embeds = torch.zeros_like(inputs_embeds)
  235. image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :]
  236. merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
  237. return merged_embeds
  238. @can_return_tuple
  239. @auto_docstring(
  240. custom_intro="Encodes images into continuous embeddings that can be forwarded to the language model."
  241. )
  242. def get_image_features(
  243. self,
  244. pixel_values: torch.FloatTensor,
  245. pixel_attention_mask: torch.LongTensor | None = None,
  246. **kwargs: Unpack[TransformersKwargs],
  247. ) -> tuple | BaseModelOutputWithPooling:
  248. r"""
  249. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  250. The tensors corresponding to the input images.
  251. pixel_attention_mask (`torch.LongTensor`, *optional*):
  252. The attention mask indicating padded regions in the image.
  253. """
  254. batch_size, num_images, num_channels, height, width = pixel_values.shape
  255. pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
  256. pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
  257. # Remove padding images - padding images are full 0.
  258. nb_values_per_image = pixel_values.shape[1:].numel()
  259. real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
  260. # If no images, leave one empty image.
  261. real_images_inds[0] |= ~torch.any(real_images_inds)
  262. pixel_values = pixel_values[real_images_inds].contiguous()
  263. # Handle the vision attention mask
  264. if pixel_attention_mask is None:
  265. pixel_attention_mask = torch.ones(
  266. size=[pixel_values.shape[i] for i in (0, 2, 3)],
  267. dtype=torch.bool,
  268. device=pixel_values.device,
  269. )
  270. else:
  271. # Remove padding images from the mask
  272. pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
  273. pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
  274. patch_size = self.config.vision_config.patch_size
  275. patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
  276. patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
  277. patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
  278. # Get sequence from the vision encoder
  279. image_outputs = self.vision_model(
  280. pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, return_dict=True, **kwargs
  281. )
  282. image_hidden_states = image_outputs.last_hidden_state
  283. # Modality projection & resampling
  284. image_features = self.connector(image_hidden_states)
  285. image_outputs.pooler_output = image_features
  286. return image_outputs
  287. @can_return_tuple
  288. @auto_docstring(
  289. custom_intro="""
  290. Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
  291. the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
  292. max_num_images is the maximum number of images among the batch_size samples in the batch.
  293. Padding images are not needed beyond padding the pixel_values at the entrance of the model.
  294. For efficiency, we only pass through the vision_model's forward the real images by
  295. discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
  296. image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
  297. """,
  298. checkpoint="ModernVBERT/modernvbert",
  299. )
  300. def forward(
  301. self,
  302. input_ids: torch.LongTensor = None,
  303. attention_mask: torch.Tensor | None = None,
  304. position_ids: torch.LongTensor | None = None,
  305. inputs_embeds: torch.FloatTensor | None = None,
  306. pixel_values: torch.FloatTensor | None = None,
  307. pixel_attention_mask: torch.BoolTensor | None = None,
  308. image_hidden_states: torch.FloatTensor | None = None,
  309. **kwargs: Unpack[TransformersKwargs],
  310. ) -> tuple | ModernVBertBaseModelOutput:
  311. r"""
  312. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  313. Mask to avoid performing attention on padding pixel indices.
  314. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  315. The hidden states of the image encoder after modality projection.
  316. """
  317. if inputs_embeds is None:
  318. inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
  319. # Images processing
  320. if pixel_values is not None:
  321. image_hidden_states = self.get_image_features(
  322. pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask
  323. ).pooler_output
  324. # Merge image and text embeddings
  325. if image_hidden_states is not None:
  326. image_hidden_states = image_hidden_states.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
  327. inputs_embeds = self.inputs_merger(
  328. input_ids=input_ids, inputs_embeds=inputs_embeds, image_hidden_states=image_hidden_states
  329. )
  330. # Language model pass
  331. outputs = self.text_model(
  332. inputs_embeds=inputs_embeds,
  333. attention_mask=attention_mask,
  334. position_ids=position_ids,
  335. **kwargs,
  336. )
  337. return ModernVBertBaseModelOutput(
  338. last_hidden_state=outputs.last_hidden_state,
  339. hidden_states=outputs.hidden_states,
  340. attentions=outputs.attentions,
  341. image_hidden_states=image_hidden_states,
  342. )
  343. class ModernVBertPredictionHead(nn.Module):
  344. def __init__(self, config: ModernVBertConfig):
  345. super().__init__()
  346. self.config = config
  347. self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
  348. self.act = ACT2FN[config.classifier_activation]
  349. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  350. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  351. return self.norm(self.act(self.dense(hidden_states)))
  352. @auto_docstring
  353. class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
  354. _tied_weights_keys = {"lm_head.weight": "model.text_model.embeddings.tok_embeddings.weight"}
  355. def __init__(self, config: ModernVBertConfig):
  356. super().__init__(config)
  357. self.vocab_size = config.text_config.vocab_size
  358. self.model = ModernVBertModel(config)
  359. self.projection_head = ModernVBertPredictionHead(config.text_config)
  360. self.lm_head = nn.Linear(config.text_config.hidden_size, self.vocab_size, bias=config.text_config.decoder_bias)
  361. # Initialize weights and apply final processing
  362. self.post_init()
  363. def get_output_embeddings(self):
  364. return self.lm_head
  365. def set_output_embeddings(self, new_embeddings):
  366. self.lm_head = new_embeddings
  367. @can_return_tuple
  368. @auto_docstring(
  369. custom_intro="""
  370. Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
  371. the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
  372. max_num_images is the maximum number of images among the batch_size samples in the batch.
  373. Padding images are not needed beyond padding the pixel_values at the entrance of the model.
  374. For efficiency, we only pass through the vision_model's forward the real images by
  375. discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
  376. image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
  377. """,
  378. checkpoint="ModernVBERT/modernvbert",
  379. )
  380. def forward(
  381. self,
  382. input_ids: torch.LongTensor = None,
  383. attention_mask: torch.Tensor | None = None,
  384. position_ids: torch.LongTensor | None = None,
  385. inputs_embeds: torch.FloatTensor | None = None,
  386. pixel_values: torch.FloatTensor | None = None,
  387. pixel_attention_mask: torch.BoolTensor | None = None,
  388. image_hidden_states: torch.FloatTensor | None = None,
  389. labels: torch.LongTensor | None = None,
  390. **kwargs: Unpack[TransformersKwargs],
  391. ) -> tuple | ModernVBertMaskedLMOutput:
  392. r"""
  393. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  394. Mask to avoid performing attention on padding pixel indices.
  395. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  396. The hidden states of the image encoder after modality projection.
  397. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  398. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  399. text_config.]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
  400. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., text_config.]`.
  401. """
  402. outputs = self.model(
  403. input_ids=input_ids,
  404. attention_mask=attention_mask,
  405. position_ids=position_ids,
  406. inputs_embeds=inputs_embeds,
  407. pixel_values=pixel_values,
  408. pixel_attention_mask=pixel_attention_mask,
  409. image_hidden_states=image_hidden_states,
  410. **kwargs,
  411. )
  412. hidden_states = outputs[0]
  413. logits = self.lm_head(self.projection_head(hidden_states))
  414. loss = None
  415. if labels is not None:
  416. criterion = CrossEntropyLoss()
  417. loss = criterion(logits.view(-1, self.vocab_size), labels.view(-1))
  418. return ModernVBertMaskedLMOutput(
  419. loss=loss,
  420. logits=logits,
  421. hidden_states=outputs.hidden_states,
  422. attentions=outputs.attentions,
  423. image_hidden_states=outputs.image_hidden_states,
  424. )
  425. @auto_docstring(
  426. custom_intro="""
  427. The ModernVBert Model with a sequence classification head on top that performs pooling.
  428. """
  429. )
  430. class ModernVBertForSequenceClassification(ModernVBertPreTrainedModel):
  431. def __init__(self, config: ModernVBertConfig):
  432. super().__init__(config)
  433. self.num_labels = config.num_labels
  434. self.config = config
  435. self.model = ModernVBertModel(config)
  436. self.head = ModernVBertPredictionHead(config.text_config)
  437. self.drop = nn.Dropout(config.classifier_dropout)
  438. self.classifier = nn.Linear(config.text_config.hidden_size, config.num_labels)
  439. # Initialize weights and apply final processing
  440. self.post_init()
  441. @can_return_tuple
  442. @auto_docstring(
  443. custom_intro="""
  444. Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
  445. the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
  446. max_num_images is the maximum number of images among the batch_size samples in the batch.
  447. Padding images are not needed beyond padding the pixel_values at the entrance of the model.
  448. For efficiency, we only pass through the vision_model's forward the real images by
  449. discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
  450. image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
  451. """,
  452. checkpoint="ModernVBERT/modernvbert",
  453. )
  454. def forward(
  455. self,
  456. input_ids: torch.LongTensor = None,
  457. attention_mask: torch.Tensor | None = None,
  458. position_ids: torch.LongTensor | None = None,
  459. inputs_embeds: torch.FloatTensor | None = None,
  460. pixel_values: torch.FloatTensor | None = None,
  461. pixel_attention_mask: torch.BoolTensor | None = None,
  462. image_hidden_states: torch.FloatTensor | None = None,
  463. labels: torch.LongTensor | None = None,
  464. **kwargs: Unpack[TransformersKwargs],
  465. ) -> tuple | SequenceClassifierOutput:
  466. r"""
  467. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  468. Mask to avoid performing attention on padding pixel indices.
  469. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  470. The hidden states of the image encoder after modality projection.
  471. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  472. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  473. text_config.]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
  474. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., text_config.]`.
  475. """
  476. outputs = self.model(
  477. input_ids=input_ids,
  478. attention_mask=attention_mask,
  479. position_ids=position_ids,
  480. inputs_embeds=inputs_embeds,
  481. pixel_values=pixel_values,
  482. pixel_attention_mask=pixel_attention_mask,
  483. image_hidden_states=image_hidden_states,
  484. **kwargs,
  485. )
  486. last_hidden_state = outputs[0]
  487. if self.config.classifier_pooling == "cls":
  488. last_hidden_state = last_hidden_state[:, 0]
  489. elif self.config.classifier_pooling == "mean":
  490. if inputs_embeds is not None:
  491. batch_size, seq_len = inputs_embeds.shape[:2]
  492. else:
  493. batch_size, seq_len = input_ids.shape[:2]
  494. device = input_ids.device if input_ids is not None else inputs_embeds.device
  495. if attention_mask is None:
  496. attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
  497. last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
  498. dim=1, keepdim=True
  499. )
  500. pooled_output = self.head(last_hidden_state)
  501. pooled_output = self.drop(pooled_output)
  502. logits = self.classifier(pooled_output)
  503. loss = None
  504. if labels is not None:
  505. if self.config.problem_type is None:
  506. if self.num_labels == 1:
  507. self.config.problem_type = "regression"
  508. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  509. self.config.problem_type = "single_label_classification"
  510. else:
  511. self.config.problem_type = "multi_label_classification"
  512. if self.config.problem_type == "regression":
  513. loss_fct = MSELoss()
  514. if self.num_labels == 1:
  515. loss = loss_fct(logits.squeeze(), labels.squeeze())
  516. else:
  517. loss = loss_fct(logits, labels)
  518. elif self.config.problem_type == "single_label_classification":
  519. loss_fct = CrossEntropyLoss()
  520. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  521. elif self.config.problem_type == "multi_label_classification":
  522. loss_fct = BCEWithLogitsLoss()
  523. loss = loss_fct(logits, labels)
  524. return SequenceClassifierOutput(
  525. loss=loss,
  526. logits=logits,
  527. hidden_states=outputs.hidden_states,
  528. attentions=outputs.attentions,
  529. )
  530. @auto_docstring(
  531. custom_intro="""
  532. The ModernVBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
  533. """
  534. )
  535. class ModernVBertForTokenClassification(ModernVBertPreTrainedModel):
  536. def __init__(self, config: ModernVBertConfig):
  537. super().__init__(config)
  538. self.num_labels = config.num_labels
  539. self.model = ModernVBertModel(config)
  540. self.head = ModernVBertPredictionHead(config.text_config)
  541. self.drop = nn.Dropout(config.classifier_dropout)
  542. self.classifier = nn.Linear(config.text_config.hidden_size, config.num_labels)
  543. # Initialize weights and apply final processing
  544. self.post_init()
  545. @can_return_tuple
  546. @auto_docstring(
  547. custom_intro="""
  548. Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
  549. the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
  550. max_num_images is the maximum number of images among the batch_size samples in the batch.
  551. Padding images are not needed beyond padding the pixel_values at the entrance of the model.
  552. For efficiency, we only pass through the vision_model's forward the real images by
  553. discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
  554. image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
  555. """,
  556. checkpoint="ModernVBERT/modernvbert",
  557. )
  558. def forward(
  559. self,
  560. input_ids: torch.LongTensor = None,
  561. attention_mask: torch.Tensor | None = None,
  562. position_ids: torch.LongTensor | None = None,
  563. inputs_embeds: torch.FloatTensor | None = None,
  564. pixel_values: torch.FloatTensor | None = None,
  565. pixel_attention_mask: torch.BoolTensor | None = None,
  566. image_hidden_states: torch.FloatTensor | None = None,
  567. labels: torch.LongTensor | None = None,
  568. **kwargs: Unpack[TransformersKwargs],
  569. ) -> tuple | TokenClassifierOutput:
  570. r"""
  571. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  572. Mask to avoid performing attention on padding pixel indices.
  573. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  574. The hidden states of the image encoder after modality projection.
  575. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  576. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  577. text_config.]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
  578. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., text_config.]`.
  579. """
  580. outputs = self.model(
  581. input_ids=input_ids,
  582. attention_mask=attention_mask,
  583. position_ids=position_ids,
  584. inputs_embeds=inputs_embeds,
  585. pixel_values=pixel_values,
  586. pixel_attention_mask=pixel_attention_mask,
  587. image_hidden_states=image_hidden_states,
  588. **kwargs,
  589. )
  590. last_hidden_state = outputs[0]
  591. last_hidden_state = self.head(last_hidden_state)
  592. last_hidden_state = self.drop(last_hidden_state)
  593. logits = self.classifier(last_hidden_state)
  594. loss = None
  595. if labels is not None:
  596. loss_fct = CrossEntropyLoss()
  597. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  598. return TokenClassifierOutput(
  599. loss=loss,
  600. logits=logits,
  601. hidden_states=outputs.hidden_states,
  602. attentions=outputs.attentions,
  603. )
  604. __all__ = [
  605. "ModernVBertPreTrainedModel",
  606. "ModernVBertModel",
  607. "ModernVBertForMaskedLM",
  608. "ModernVBertForSequenceClassification",
  609. "ModernVBertForTokenClassification",
  610. ]