modeling_colmodernvbert.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/colmodernvbert/modular_colmodernvbert.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_colmodernvbert.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2026 Illuin Technology and contributors, and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from dataclasses import dataclass
  21. import torch
  22. from torch import nn
  23. from ... import initialization as init
  24. from ...modeling_utils import PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...utils import ModelOutput, TransformersKwargs, auto_docstring
  27. from ...utils.generic import can_return_tuple
  28. from ..auto.modeling_auto import AutoModel
  29. from .configuration_colmodernvbert import ColModernVBertConfig
  30. @auto_docstring
  31. class ColModernVBertPreTrainedModel(PreTrainedModel):
  32. config: ColModernVBertConfig
  33. base_model_prefix = "model"
  34. input_modalities = ("image", "text")
  35. _no_split_modules = []
  36. _supports_sdpa = True
  37. _supports_flash_attn = True
  38. _supports_flex_attn = True
  39. @torch.no_grad()
  40. def _init_weights(self, module):
  41. std = (
  42. self.config.initializer_range
  43. if hasattr(self.config, "initializer_range")
  44. else self.config.vlm_config.text_config.initializer_range
  45. )
  46. if isinstance(module, (nn.Linear, nn.Conv2d)):
  47. init.normal_(module.weight, mean=0.0, std=std)
  48. if module.bias is not None:
  49. init.zeros_(module.bias)
  50. elif isinstance(module, nn.Embedding):
  51. init.normal_(module.weight, mean=0.0, std=std)
  52. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  53. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  54. init.zeros_(module.weight[module.padding_idx])
  55. @dataclass
  56. @auto_docstring(
  57. custom_intro="""
  58. Base class for ColModernVBert embeddings output.
  59. """
  60. )
  61. class ColModernVBertForRetrievalOutput(ModelOutput):
  62. r"""
  63. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  64. Language modeling loss (for next-token prediction).
  65. embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  66. The embeddings of the model.
  67. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True` and `pixel_values` are provided):
  68. Tuple of `torch.FloatTensor` (one for the output of the image modality projection + one for the output of each layer) of shape
  69. `(batch_size, num_channels, image_size, image_size)`.
  70. Hidden-states of the image encoder at the output of each layer plus the initial modality projection outputs.
  71. """
  72. loss: torch.FloatTensor | None = None
  73. embeddings: torch.Tensor | None = None
  74. hidden_states: tuple[torch.FloatTensor] | None = None
  75. image_hidden_states: tuple[torch.FloatTensor] | None = None
  76. attentions: tuple[torch.FloatTensor] | None = None
  77. @auto_docstring(
  78. custom_intro="""
  79. Following the ColPali approach, ColModernVBert leverages VLMs to construct efficient multi-vector embeddings directly
  80. from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity
  81. between these document embeddings and the corresponding query embeddings, using the late interaction method
  82. introduced in ColBERT.
  83. Using ColModernVBert removes the need for potentially complex and brittle layout recognition and OCR pipelines with
  84. a single model that can take into account both the textual and visual content (layout, charts, ...) of a document.
  85. ColModernVBert is trained on top of ModernVBert, and was introduced in the following paper:
  86. [*ModernVBERT: Towards Smaller Visual Document Retrievers*](https://arxiv.org/abs/2510.01149).
  87. ColModernVBert is part of the ColVision model family, which was introduced with ColPali in the following paper:
  88. [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://huggingface.co/papers/2407.01449).
  89. """
  90. )
  91. class ColModernVBertForRetrieval(ColModernVBertPreTrainedModel):
  92. base_model_prefix = "vlm"
  93. def __init__(self, config: ColModernVBertConfig):
  94. super().__init__(config)
  95. self.config = config
  96. self.vocab_size = config.vlm_config.text_config.vocab_size
  97. self.vlm = AutoModel.from_config(config.vlm_config)
  98. self.embedding_dim = self.config.embedding_dim
  99. self.embedding_proj_layer = nn.Linear(
  100. self.config.vlm_config.text_config.hidden_size,
  101. self.embedding_dim,
  102. )
  103. self.post_init()
  104. @can_return_tuple
  105. @auto_docstring
  106. def forward(
  107. self,
  108. input_ids: torch.LongTensor | None = None,
  109. pixel_values: torch.FloatTensor | None = None,
  110. attention_mask: torch.Tensor | None = None,
  111. **kwargs: Unpack[TransformersKwargs],
  112. ) -> ColModernVBertForRetrievalOutput:
  113. vlm_output = self.vlm(
  114. input_ids=input_ids,
  115. attention_mask=attention_mask,
  116. pixel_values=pixel_values,
  117. **kwargs,
  118. )
  119. last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
  120. proj_dtype = self.embedding_proj_layer.weight.dtype
  121. embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype)) # (batch_size, sequence_length, dim)
  122. # L2 normalization
  123. embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
  124. if attention_mask is not None:
  125. attention_mask = attention_mask.to(dtype=embeddings.dtype, device=embeddings.device)
  126. embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
  127. return ColModernVBertForRetrievalOutput(
  128. embeddings=embeddings,
  129. hidden_states=vlm_output.hidden_states,
  130. attentions=vlm_output.attentions,
  131. image_hidden_states=vlm_output.image_hidden_states,
  132. )
  133. __all__ = ["ColModernVBertForRetrieval", "ColModernVBertPreTrainedModel"]