modular_modernvbert.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. # Copyright 2026 Illuin Technology and contributors, and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from dataclasses import dataclass
  16. from typing import Literal
  17. import torch
  18. import torch.nn as nn
  19. from huggingface_hub.dataclasses import strict
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ... import initialization as init
  22. from ...configuration_utils import PreTrainedConfig
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. MaskedLMOutput,
  26. SequenceClassifierOutput,
  27. TokenClassifierOutput,
  28. )
  29. from ...modeling_utils import PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...utils import TransformersKwargs, auto_docstring, logging
  32. from ...utils.generic import can_return_tuple
  33. from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel
  34. from ..modernbert.modeling_modernbert import ModernBertPredictionHead
  35. from ..smolvlm.modeling_smolvlm import SmolVLMModel, SmolVLMPreTrainedModel
  36. logger = logging.get_logger(__name__)
  37. @auto_docstring(checkpoint="ModernVBERT/modernvbert")
  38. @strict
  39. class ModernVBertConfig(PreTrainedConfig):
  40. r"""
  41. pixel_shuffle_factor (`int | None`, *optional*, defaults to 4):
  42. Scale factor used by any pixel-shuffle / upsampling operations in the vision head.
  43. initializer_cutoff_factor (`float | None`, *optional*, defaults to 2.0):
  44. The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
  45. classifier_pooling (`Literal["cls", "mean"]`, *optional*, defaults to `"cls"`):
  46. The pooling strategy to use for classification tasks.
  47. classifier_bias (`bool | None`, *optional*, defaults to `False`):
  48. Whether to add a bias term to the classification head
  49. Example:
  50. ```python
  51. >>> from transformers import ModernVBertConfig
  52. >>> # Initializing configuration
  53. >>> configuration = ModernVBertConfig()
  54. >>> # Initializing a model from the configuration (model class is implemented in
  55. >>> # `modernvbert.modeling_modernvbert`)
  56. >>> from transformers import ModernVBertModel
  57. >>> model = ModernVBertModel(configuration)
  58. >>> # Accessing the model configuration
  59. >>> cfg = model.config
  60. ```"""
  61. model_type = "modernvbert"
  62. sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
  63. text_config: PreTrainedConfig | dict | None = None
  64. vision_config: PreTrainedConfig | dict | None = None
  65. image_token_id: int = 50407
  66. pixel_shuffle_factor: int = 4
  67. initializer_range: float = 0.02
  68. initializer_cutoff_factor: float = 2.0
  69. classifier_pooling: Literal["cls", "mean"] = "cls"
  70. classifier_dropout: float | int = 0.0
  71. classifier_bias: bool = False
  72. tie_word_embeddings: bool = False
  73. def __post_init__(self, **kwargs):
  74. if self.text_config is None:
  75. self.text_config = CONFIG_MAPPING["modernbert"]()
  76. elif isinstance(self.text_config, dict):
  77. self.text_config = CONFIG_MAPPING["modernbert"](**self.text_config)
  78. if self.vision_config is None:
  79. self.vision_config = CONFIG_MAPPING["siglip_vision_model"]()
  80. elif isinstance(self.vision_config, dict):
  81. self.vision_config = CONFIG_MAPPING["siglip_vision_model"](**self.vision_config)
  82. super().__post_init__(**kwargs)
  83. @dataclass
  84. class ModernVBertBaseModelOutput(BaseModelOutput):
  85. """
  86. Base class for ModernVBERT model's outputs.
  87. Args:
  88. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  89. Sequence of hidden-states at the output of the last layer of the model.
  90. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  91. hidden_size)` is output.
  92. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  93. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  94. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  95. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  96. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  97. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  98. sequence_length)`.
  99. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  100. heads.
  101. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  102. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  103. sequence_length, hidden_size)`.
  104. image_hidden_states of the model produced by the vision encoder
  105. """
  106. last_hidden_state: torch.FloatTensor = None
  107. hidden_states: tuple[torch.FloatTensor] | None = None
  108. attentions: tuple[torch.FloatTensor] | None = None
  109. image_hidden_states: tuple[torch.FloatTensor] | None = None
  110. @dataclass
  111. class ModernVBertMaskedLMOutput(MaskedLMOutput):
  112. """
  113. Base class for ModernVBERT model's outputs with masked language modeling loss.
  114. Args:
  115. loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
  116. Masked language modeling (MLM) loss.
  117. logits (`torch.FloatTensor`):
  118. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  119. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  120. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  121. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  122. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  123. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  124. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  125. sequence_length)`.
  126. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  127. heads.
  128. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  129. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  130. sequence_length, hidden_size)`.
  131. image_hidden_states of the model produced by the vision encoder
  132. """
  133. loss: torch.FloatTensor | None = None
  134. logits: torch.FloatTensor = None
  135. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  136. attentions: tuple[torch.FloatTensor, ...] | None = None
  137. image_hidden_states: torch.FloatTensor | None = None
  138. class ModernVBertConnector(nn.Module):
  139. """
  140. Connector module for ModernVBERT. It performs a pixel shuffle operation followed by a linear projection to match the text model's hidden size.
  141. Based on https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
  142. """
  143. def __init__(self, config):
  144. super().__init__()
  145. self.pixel_shuffle_factor = config.pixel_shuffle_factor
  146. self.modality_projection = nn.Linear(
  147. config.vision_config.hidden_size * (config.pixel_shuffle_factor**2),
  148. config.text_config.hidden_size,
  149. bias=False,
  150. )
  151. def pixel_shuffle(self, image_hidden_states, pixel_shuffle_factor):
  152. batch_size, seq_length, embed_dim = image_hidden_states.size()
  153. height = width = int(seq_length**0.5)
  154. image_hidden_states = image_hidden_states.view(batch_size, height, width, embed_dim)
  155. image_hidden_states = image_hidden_states.view(
  156. batch_size, height, int(width / pixel_shuffle_factor), embed_dim * pixel_shuffle_factor
  157. )
  158. image_hidden_states = image_hidden_states.permute(0, 2, 1, 3)
  159. image_hidden_states = image_hidden_states.reshape(
  160. batch_size,
  161. int(width / pixel_shuffle_factor),
  162. int(height / pixel_shuffle_factor),
  163. embed_dim * (pixel_shuffle_factor**2),
  164. )
  165. image_hidden_states = image_hidden_states.permute(0, 2, 1, 3)
  166. return image_hidden_states.reshape(
  167. batch_size, int(seq_length / (pixel_shuffle_factor**2)), embed_dim * (pixel_shuffle_factor**2)
  168. )
  169. def forward(self, image_hidden_states):
  170. image_hidden_states = self.pixel_shuffle(image_hidden_states, self.pixel_shuffle_factor)
  171. return self.modality_projection(image_hidden_states)
  172. @auto_docstring
  173. class ModernVBertPreTrainedModel(SmolVLMPreTrainedModel):
  174. config_class = ModernVBertConfig
  175. _no_split_modules = []
  176. @torch.no_grad()
  177. def _init_weights(self, module):
  178. PreTrainedModel._init_weights(self, module)
  179. def init_weight(module: nn.Module, std: float):
  180. cutoff_factor = getattr(self.config, "initializer_cutoff_factor", 2.0)
  181. init.trunc_normal_(
  182. module.weight,
  183. mean=0.0,
  184. std=std,
  185. a=-cutoff_factor * std,
  186. b=cutoff_factor * std,
  187. )
  188. if isinstance(module, (nn.Linear, nn.Conv2d)):
  189. if module.bias is not None:
  190. init.zeros_(module.bias)
  191. if isinstance(module, ModernVBertConnector):
  192. out_std = self.config.initializer_range / math.sqrt(2.0 * self.config.text_config.num_hidden_layers)
  193. init_weight(module.modality_projection, out_std)
  194. elif isinstance(module, ModernVBertForMaskedLM):
  195. out_std = self.config.initializer_range / math.sqrt(2.0 * self.config.text_config.num_hidden_layers)
  196. init_weight(module.lm_head, out_std)
  197. elif isinstance(
  198. module,
  199. (
  200. ModernVBertForSequenceClassification,
  201. ModernVBertForTokenClassification,
  202. ),
  203. ):
  204. final_out_std = self.config.initializer_range / math.sqrt(self.config.text_config.hidden_size)
  205. init_weight(module.classifier, final_out_std)
  206. @auto_docstring(
  207. custom_intro="""
  208. ModernVBertModel is a model that combines a vision encoder (SigLIP) and a text encoder (ModernBert).
  209. ModernVBert is the base model of the visual retriver ColModernVBert, and was introduced in the following paper:
  210. [*ModernVBERT: Towards Smaller Visual Document Retrievers*](https://arxiv.org/abs/2510.01149).
  211. """
  212. )
  213. class ModernVBertModel(SmolVLMModel):
  214. def __init__(self, config: ModernVBertConfig):
  215. super().__init__(config)
  216. # init components
  217. self.connector = ModernVBertConnector(config)
  218. self.text_model = AutoModel.from_config(config.text_config)
  219. self.vision_model = AutoModel.from_config(config.vision_config)
  220. self.image_seq_len = int(
  221. ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
  222. / (config.pixel_shuffle_factor**2)
  223. )
  224. # initialize weights and apply final processing
  225. self.post_init()
  226. @can_return_tuple
  227. @auto_docstring(
  228. custom_intro="""
  229. Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
  230. the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
  231. max_num_images is the maximum number of images among the batch_size samples in the batch.
  232. Padding images are not needed beyond padding the pixel_values at the entrance of the model.
  233. For efficiency, we only pass through the vision_model's forward the real images by
  234. discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
  235. image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
  236. """,
  237. checkpoint="ModernVBERT/modernvbert",
  238. )
  239. def forward(
  240. self,
  241. input_ids: torch.LongTensor = None,
  242. attention_mask: torch.Tensor | None = None,
  243. position_ids: torch.LongTensor | None = None,
  244. inputs_embeds: torch.FloatTensor | None = None,
  245. pixel_values: torch.FloatTensor | None = None,
  246. pixel_attention_mask: torch.BoolTensor | None = None,
  247. image_hidden_states: torch.FloatTensor | None = None,
  248. **kwargs: Unpack[TransformersKwargs],
  249. ) -> tuple | ModernVBertBaseModelOutput:
  250. r"""
  251. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  252. Mask to avoid performing attention on padding pixel indices.
  253. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  254. The hidden states of the image encoder after modality projection.
  255. """
  256. if inputs_embeds is None:
  257. inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
  258. # Images processing
  259. if pixel_values is not None:
  260. image_hidden_states = self.get_image_features(
  261. pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask
  262. ).pooler_output
  263. # Merge image and text embeddings
  264. if image_hidden_states is not None:
  265. image_hidden_states = image_hidden_states.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
  266. inputs_embeds = self.inputs_merger(
  267. input_ids=input_ids, inputs_embeds=inputs_embeds, image_hidden_states=image_hidden_states
  268. )
  269. # Language model pass
  270. outputs = self.text_model(
  271. inputs_embeds=inputs_embeds,
  272. attention_mask=attention_mask,
  273. position_ids=position_ids,
  274. **kwargs,
  275. )
  276. return ModernVBertBaseModelOutput(
  277. last_hidden_state=outputs.last_hidden_state,
  278. hidden_states=outputs.hidden_states,
  279. attentions=outputs.attentions,
  280. image_hidden_states=image_hidden_states,
  281. )
  282. class ModernVBertPredictionHead(ModernBertPredictionHead):
  283. pass
  284. @auto_docstring
  285. class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
  286. _tied_weights_keys = {"lm_head.weight": "model.text_model.embeddings.tok_embeddings.weight"}
  287. def __init__(self, config: ModernVBertConfig):
  288. super().__init__(config)
  289. self.vocab_size = config.text_config.vocab_size
  290. self.model = ModernVBertModel(config)
  291. self.projection_head = ModernVBertPredictionHead(config.text_config)
  292. self.lm_head = nn.Linear(config.text_config.hidden_size, self.vocab_size, bias=config.text_config.decoder_bias)
  293. # Initialize weights and apply final processing
  294. self.post_init()
  295. def get_output_embeddings(self):
  296. return self.lm_head
  297. def set_output_embeddings(self, new_embeddings):
  298. self.lm_head = new_embeddings
  299. @can_return_tuple
  300. @auto_docstring(
  301. custom_intro="""
  302. Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
  303. the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
  304. max_num_images is the maximum number of images among the batch_size samples in the batch.
  305. Padding images are not needed beyond padding the pixel_values at the entrance of the model.
  306. For efficiency, we only pass through the vision_model's forward the real images by
  307. discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
  308. image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
  309. """,
  310. checkpoint="ModernVBERT/modernvbert",
  311. )
  312. def forward(
  313. self,
  314. input_ids: torch.LongTensor = None,
  315. attention_mask: torch.Tensor | None = None,
  316. position_ids: torch.LongTensor | None = None,
  317. inputs_embeds: torch.FloatTensor | None = None,
  318. pixel_values: torch.FloatTensor | None = None,
  319. pixel_attention_mask: torch.BoolTensor | None = None,
  320. image_hidden_states: torch.FloatTensor | None = None,
  321. labels: torch.LongTensor | None = None,
  322. **kwargs: Unpack[TransformersKwargs],
  323. ) -> tuple | ModernVBertMaskedLMOutput:
  324. r"""
  325. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  326. Mask to avoid performing attention on padding pixel indices.
  327. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  328. The hidden states of the image encoder after modality projection.
  329. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  330. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  331. text_config.]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
  332. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., text_config.]`.
  333. """
  334. outputs = self.model(
  335. input_ids=input_ids,
  336. attention_mask=attention_mask,
  337. position_ids=position_ids,
  338. inputs_embeds=inputs_embeds,
  339. pixel_values=pixel_values,
  340. pixel_attention_mask=pixel_attention_mask,
  341. image_hidden_states=image_hidden_states,
  342. **kwargs,
  343. )
  344. hidden_states = outputs[0]
  345. logits = self.lm_head(self.projection_head(hidden_states))
  346. loss = None
  347. if labels is not None:
  348. criterion = CrossEntropyLoss()
  349. loss = criterion(logits.view(-1, self.vocab_size), labels.view(-1))
  350. return ModernVBertMaskedLMOutput(
  351. loss=loss,
  352. logits=logits,
  353. hidden_states=outputs.hidden_states,
  354. attentions=outputs.attentions,
  355. image_hidden_states=outputs.image_hidden_states,
  356. )
  357. @auto_docstring(
  358. custom_intro="""
  359. The ModernVBert Model with a sequence classification head on top that performs pooling.
  360. """
  361. )
  362. class ModernVBertForSequenceClassification(ModernVBertPreTrainedModel):
  363. def __init__(self, config: ModernVBertConfig):
  364. super().__init__(config)
  365. self.num_labels = config.num_labels
  366. self.config = config
  367. self.model = ModernVBertModel(config)
  368. self.head = ModernVBertPredictionHead(config.text_config)
  369. self.drop = nn.Dropout(config.classifier_dropout)
  370. self.classifier = nn.Linear(config.text_config.hidden_size, config.num_labels)
  371. # Initialize weights and apply final processing
  372. self.post_init()
  373. @can_return_tuple
  374. @auto_docstring(
  375. custom_intro="""
  376. Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
  377. the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
  378. max_num_images is the maximum number of images among the batch_size samples in the batch.
  379. Padding images are not needed beyond padding the pixel_values at the entrance of the model.
  380. For efficiency, we only pass through the vision_model's forward the real images by
  381. discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
  382. image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
  383. """,
  384. checkpoint="ModernVBERT/modernvbert",
  385. )
  386. def forward(
  387. self,
  388. input_ids: torch.LongTensor = None,
  389. attention_mask: torch.Tensor | None = None,
  390. position_ids: torch.LongTensor | None = None,
  391. inputs_embeds: torch.FloatTensor | None = None,
  392. pixel_values: torch.FloatTensor | None = None,
  393. pixel_attention_mask: torch.BoolTensor | None = None,
  394. image_hidden_states: torch.FloatTensor | None = None,
  395. labels: torch.LongTensor | None = None,
  396. **kwargs: Unpack[TransformersKwargs],
  397. ) -> tuple | SequenceClassifierOutput:
  398. r"""
  399. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  400. Mask to avoid performing attention on padding pixel indices.
  401. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  402. The hidden states of the image encoder after modality projection.
  403. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  404. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  405. text_config.]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
  406. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., text_config.]`.
  407. """
  408. outputs = self.model(
  409. input_ids=input_ids,
  410. attention_mask=attention_mask,
  411. position_ids=position_ids,
  412. inputs_embeds=inputs_embeds,
  413. pixel_values=pixel_values,
  414. pixel_attention_mask=pixel_attention_mask,
  415. image_hidden_states=image_hidden_states,
  416. **kwargs,
  417. )
  418. last_hidden_state = outputs[0]
  419. if self.config.classifier_pooling == "cls":
  420. last_hidden_state = last_hidden_state[:, 0]
  421. elif self.config.classifier_pooling == "mean":
  422. if inputs_embeds is not None:
  423. batch_size, seq_len = inputs_embeds.shape[:2]
  424. else:
  425. batch_size, seq_len = input_ids.shape[:2]
  426. device = input_ids.device if input_ids is not None else inputs_embeds.device
  427. if attention_mask is None:
  428. attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
  429. last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
  430. dim=1, keepdim=True
  431. )
  432. pooled_output = self.head(last_hidden_state)
  433. pooled_output = self.drop(pooled_output)
  434. logits = self.classifier(pooled_output)
  435. loss = None
  436. if labels is not None:
  437. if self.config.problem_type is None:
  438. if self.num_labels == 1:
  439. self.config.problem_type = "regression"
  440. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  441. self.config.problem_type = "single_label_classification"
  442. else:
  443. self.config.problem_type = "multi_label_classification"
  444. if self.config.problem_type == "regression":
  445. loss_fct = MSELoss()
  446. if self.num_labels == 1:
  447. loss = loss_fct(logits.squeeze(), labels.squeeze())
  448. else:
  449. loss = loss_fct(logits, labels)
  450. elif self.config.problem_type == "single_label_classification":
  451. loss_fct = CrossEntropyLoss()
  452. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  453. elif self.config.problem_type == "multi_label_classification":
  454. loss_fct = BCEWithLogitsLoss()
  455. loss = loss_fct(logits, labels)
  456. return SequenceClassifierOutput(
  457. loss=loss,
  458. logits=logits,
  459. hidden_states=outputs.hidden_states,
  460. attentions=outputs.attentions,
  461. )
  462. @auto_docstring(
  463. custom_intro="""
  464. The ModernVBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
  465. """
  466. )
  467. class ModernVBertForTokenClassification(ModernVBertPreTrainedModel):
  468. def __init__(self, config: ModernVBertConfig):
  469. super().__init__(config)
  470. self.num_labels = config.num_labels
  471. self.model = ModernVBertModel(config)
  472. self.head = ModernVBertPredictionHead(config.text_config)
  473. self.drop = nn.Dropout(config.classifier_dropout)
  474. self.classifier = nn.Linear(config.text_config.hidden_size, config.num_labels)
  475. # Initialize weights and apply final processing
  476. self.post_init()
  477. @can_return_tuple
  478. @auto_docstring(
  479. custom_intro="""
  480. Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
  481. the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
  482. max_num_images is the maximum number of images among the batch_size samples in the batch.
  483. Padding images are not needed beyond padding the pixel_values at the entrance of the model.
  484. For efficiency, we only pass through the vision_model's forward the real images by
  485. discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
  486. image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
  487. """,
  488. checkpoint="ModernVBERT/modernvbert",
  489. )
  490. def forward(
  491. self,
  492. input_ids: torch.LongTensor = None,
  493. attention_mask: torch.Tensor | None = None,
  494. position_ids: torch.LongTensor | None = None,
  495. inputs_embeds: torch.FloatTensor | None = None,
  496. pixel_values: torch.FloatTensor | None = None,
  497. pixel_attention_mask: torch.BoolTensor | None = None,
  498. image_hidden_states: torch.FloatTensor | None = None,
  499. labels: torch.LongTensor | None = None,
  500. **kwargs: Unpack[TransformersKwargs],
  501. ) -> tuple | TokenClassifierOutput:
  502. r"""
  503. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  504. Mask to avoid performing attention on padding pixel indices.
  505. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  506. The hidden states of the image encoder after modality projection.
  507. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  508. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  509. text_config.]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
  510. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., text_config.]`.
  511. """
  512. outputs = self.model(
  513. input_ids=input_ids,
  514. attention_mask=attention_mask,
  515. position_ids=position_ids,
  516. inputs_embeds=inputs_embeds,
  517. pixel_values=pixel_values,
  518. pixel_attention_mask=pixel_attention_mask,
  519. image_hidden_states=image_hidden_states,
  520. **kwargs,
  521. )
  522. last_hidden_state = outputs[0]
  523. last_hidden_state = self.head(last_hidden_state)
  524. last_hidden_state = self.drop(last_hidden_state)
  525. logits = self.classifier(last_hidden_state)
  526. loss = None
  527. if labels is not None:
  528. loss_fct = CrossEntropyLoss()
  529. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  530. return TokenClassifierOutput(
  531. loss=loss,
  532. logits=logits,
  533. hidden_states=outputs.hidden_states,
  534. attentions=outputs.attentions,
  535. )
  536. __all__ = [
  537. "ModernVBertConfig",
  538. "ModernVBertPreTrainedModel",
  539. "ModernVBertModel",
  540. "ModernVBertForMaskedLM",
  541. "ModernVBertForSequenceClassification",
  542. "ModernVBertForTokenClassification",
  543. ]