modular_colmodernvbert.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. # Copyright 2026 Illuin Technology and contributors, and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from dataclasses import dataclass
  15. from typing import Optional, Union
  16. import torch
  17. from huggingface_hub.dataclasses import strict
  18. from ...configuration_utils import PreTrainedConfig
  19. from ...feature_extraction_utils import BatchFeature
  20. from ...image_utils import ImageInput, is_valid_image
  21. from ...processing_utils import Unpack
  22. from ...tokenization_utils_base import TextInput
  23. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
  24. from ...utils.generic import can_return_tuple
  25. from ...utils.import_utils import requires
  26. from ..auto import CONFIG_MAPPING
  27. from ..auto.modeling_auto import AutoModel
  28. from ..colpali.modeling_colpali import ColPaliForRetrieval, ColPaliPreTrainedModel
  29. from ..colqwen2.configuration_colqwen2 import ColQwen2Config
  30. from ..idefics3.processing_idefics3 import Idefics3Processor, Idefics3ProcessorKwargs
  31. logger = logging.get_logger(__name__)
  32. @auto_docstring(checkpoint="ModernVBERT/colmodernvbert-merged")
  33. @strict
  34. class ColModernVBertConfig(ColQwen2Config):
  35. r"""
  36. Example:
  37. ```python
  38. from transformers import ColModernVBertConfig, ColModernVBertForRetrieval
  39. config = ColModernVBertConfig()
  40. model = ColModernVBertForRetrieval(config)
  41. ```
  42. """
  43. model_type = "colmodernvbert"
  44. sub_configs = {"vlm_config": PreTrainedConfig}
  45. vlm_config: dict | PreTrainedConfig | None = None
  46. embedding_dim: int = 128
  47. initializer_range: float = 0.02
  48. def __post_init__(self, **kwargs):
  49. if self.vlm_config is None:
  50. self.vlm_config = CONFIG_MAPPING["modernvbert"]()
  51. logger.info(
  52. "`vlm_config` is `None`. Initializing `vlm_config` with the `ModernVBertConfig` with default values."
  53. )
  54. elif isinstance(self.vlm_config, dict):
  55. self.vlm_config = CONFIG_MAPPING[self.vlm_config["model_type"]](**self.vlm_config)
  56. if not hasattr(self.vlm_config, "vocab_size"):
  57. self.vlm_config.vocab_size = self.vlm_config.get_text_config().vocab_size
  58. super().__post_init__(**kwargs)
  59. class ColModernVBertProcessorKwargs(Idefics3ProcessorKwargs, total=False):
  60. _defaults = {
  61. "text_kwargs": {
  62. "padding": "longest",
  63. },
  64. "images_kwargs": {
  65. "return_row_col_info": True,
  66. "data_format": "channels_first",
  67. "do_convert_rgb": True,
  68. },
  69. "common_kwargs": {"return_tensors": "pt"},
  70. }
  71. @requires(backends=("torch",))
  72. @auto_docstring
  73. class ColModernVBertProcessor(Idefics3Processor):
  74. r"""
  75. Constructs a ColModernVBert processor which wraps a ModernVBertProcessor and special methods to process images and queries, as
  76. well as to compute the late-interaction retrieval score.
  77. [`ColModernVBertProcessor`] offers all the functionalities of [`ModernVBertProcessor`]. See the [`~ModernVBertProcessor.__call__`]
  78. for more information.
  79. Args:
  80. image_processor ([`Idefics3ImageProcessor`]): An instance of [`Idefics3ImageProcessor`]. The image processor is a required input.
  81. tokenizer (`PreTrainedTokenizerFast`, *optional*): An instance of [`PreTrainedTokenizerFast`]. This should correspond with the model's text model. The tokenizer is a required input.
  82. image_seq_len (`int`, *optional*, defaults to 64): The length of the image sequence i.e. the number of <image> tokens per image in the input.
  83. visual_prompt_prefix (`Optional`, *optional*): A prefix to be prepended to visual prompts.
  84. query_prefix (`Optional`, *optional*): A prefix to be prepended to query prompts.
  85. """
  86. def __init__(
  87. self,
  88. image_processor,
  89. tokenizer=None,
  90. chat_template=None,
  91. image_seq_len: int = 64,
  92. visual_prompt_prefix: str | None = None,
  93. query_prefix: str | None = None,
  94. **kwargs,
  95. ):
  96. r"""
  97. image_seq_len (`int`, *optional*, defaults to 64):
  98. The length of the image sequence i.e. the number of <image> tokens per image in the input.
  99. visual_prompt_prefix (`str`, *optional*):
  100. A string that gets tokenized and prepended to the image tokens.
  101. query_prefix (`str`, *optional*):
  102. A prefix to be used for the query.
  103. """
  104. chat_template = None # ColModernVBert does not use chat templates
  105. super().__init__(
  106. image_processor,
  107. tokenizer,
  108. chat_template=chat_template,
  109. image_seq_len=image_seq_len,
  110. **kwargs,
  111. )
  112. self.visual_prompt_prefix = visual_prompt_prefix or (
  113. f"<|begin_of_text|>User:{self.image_token}Describe the image.<end_of_utterance>\nAssistant:"
  114. )
  115. self.query_prefix = query_prefix or ""
  116. self.query_augmentation_token = self.end_of_utterance_token
  117. def process_images(
  118. self,
  119. images: ImageInput | None = None,
  120. **kwargs: Unpack[ColModernVBertProcessorKwargs],
  121. ) -> BatchFeature:
  122. """
  123. Prepare for the model one or several image(s). Handles input validation, RGB conversion,
  124. and prepends the `visual_prompt_prefix` to each image. Optionally computes labels from
  125. `token_type_ids` when a `suffix` is provided in `text_kwargs`.
  126. Args:
  127. images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
  128. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  129. tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
  130. number of channels, H and W are image height and width.
  131. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  132. If set, will return tensors of a particular framework. Acceptable values are:
  133. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  134. - `'np'`: Return NumPy `np.ndarray` objects.
  135. Returns:
  136. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  137. - **input_ids** -- List of token ids to be fed to a model.
  138. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  139. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  140. `None`).
  141. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  142. """
  143. output_kwargs = self._merge_kwargs(
  144. ColModernVBertProcessorKwargs,
  145. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  146. **kwargs,
  147. )
  148. suffix = output_kwargs["text_kwargs"].pop("suffix", None)
  149. return_token_type_ids = suffix is not None
  150. # Normalize input to a flat list of images
  151. if is_valid_image(images):
  152. images = [images]
  153. elif isinstance(images, list) and is_valid_image(images[0]):
  154. pass
  155. elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
  156. raise ValueError("images must be an image, list of images or list of list of images")
  157. # Ensure all images are in RGB format
  158. images = [image.convert("RGB") for image in images]
  159. # Pair each image with the visual prompt prefix for the VLM backbone
  160. batch_doc = self.__call__(
  161. text=[self.visual_prompt_prefix] * len(images),
  162. images=images,
  163. images_kwargs=output_kwargs["images_kwargs"],
  164. text_kwargs=output_kwargs["text_kwargs"],
  165. )
  166. # When suffix is provided, generate labels by masking non-suffix tokens
  167. if return_token_type_ids:
  168. labels = batch_doc["input_ids"].masked_fill(batch_doc["token_type_ids"] == 0, -100)
  169. batch_doc.update({"labels": labels})
  170. return batch_doc
  171. def process_queries(
  172. self,
  173. text: TextInput | list[TextInput],
  174. **kwargs: Unpack[ColModernVBertProcessorKwargs],
  175. ) -> BatchFeature:
  176. """
  177. Prepare for the model one or several text queries. Handles input validation, prepends the
  178. `query_prefix`, and appends query augmentation tokens (used to pad query embeddings for
  179. better late-interaction retrieval performance).
  180. Args:
  181. text (`str`, `list[str]`, `list[list[str]]`):
  182. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  183. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  184. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  185. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  186. If set, will return tensors of a particular framework. Acceptable values are:
  187. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  188. - `'np'`: Return NumPy `np.ndarray` objects.
  189. Returns:
  190. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  191. - **input_ids** -- List of token ids to be fed to a model.
  192. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  193. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  194. `None`).
  195. """
  196. output_kwargs = self._merge_kwargs(
  197. ColModernVBertProcessorKwargs,
  198. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  199. **kwargs,
  200. )
  201. suffix = output_kwargs["text_kwargs"].pop("suffix", None)
  202. if isinstance(text, str):
  203. text = [text]
  204. elif not (isinstance(text, list) and isinstance(text[0], str)):
  205. raise ValueError("Text must be a string or a list of strings")
  206. # Default suffix: repeat the augmentation token to pad query embeddings
  207. if suffix is None:
  208. suffix = self.query_augmentation_token * 10
  209. # Build final queries: prefix + original query + augmentation suffix
  210. texts_query: list[str] = [self.query_prefix + query + suffix for query in text]
  211. batch_query = self.__call__(
  212. text=texts_query,
  213. return_token_type_ids=False,
  214. text_kwargs=output_kwargs["text_kwargs"],
  215. )
  216. return batch_query
  217. def score_retrieval(
  218. self,
  219. query_embeddings: Union["torch.Tensor", list["torch.Tensor"]],
  220. passage_embeddings: Union["torch.Tensor", list["torch.Tensor"]],
  221. batch_size: int = 128,
  222. output_dtype: Optional["torch.dtype"] = None,
  223. output_device: Union["torch.device", str] = "cpu",
  224. ) -> "torch.Tensor":
  225. """
  226. Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
  227. query embeddings (`qs`) and passage embeddings (`ps`). For ColQwen2, a passage is the
  228. image of a document page.
  229. Because the embedding tensors are multi-vector and can thus have different shapes, they
  230. should be fed as:
  231. (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
  232. (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
  233. obtained by padding the list of tensors.
  234. Args:
  235. query_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Query embeddings.
  236. passage_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Passage embeddings.
  237. batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
  238. output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor.
  239. If `None`, the dtype of the input embeddings is used.
  240. output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor.
  241. Returns:
  242. `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
  243. tensor is saved on the "cpu" device.
  244. """
  245. if len(query_embeddings) == 0:
  246. raise ValueError("No queries provided")
  247. if len(passage_embeddings) == 0:
  248. raise ValueError("No passages provided")
  249. if query_embeddings[0].device != passage_embeddings[0].device:
  250. raise ValueError("Queries and passages must be on the same device")
  251. if query_embeddings[0].dtype != passage_embeddings[0].dtype:
  252. raise ValueError("Queries and passages must have the same dtype")
  253. if output_dtype is None:
  254. output_dtype = query_embeddings[0].dtype
  255. scores: list[torch.Tensor] = []
  256. for i in range(0, len(query_embeddings), batch_size):
  257. batch_scores: list[torch.Tensor] = []
  258. batch_queries = torch.nn.utils.rnn.pad_sequence(
  259. query_embeddings[i : i + batch_size], batch_first=True, padding_value=0
  260. )
  261. for j in range(0, len(passage_embeddings), batch_size):
  262. batch_passages = torch.nn.utils.rnn.pad_sequence(
  263. passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0
  264. )
  265. batch_scores.append(
  266. torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2)
  267. )
  268. scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device))
  269. return torch.cat(scores, dim=0)
  270. @auto_docstring
  271. class ColModernVBertPreTrainedModel(ColPaliPreTrainedModel):
  272. config: ColModernVBertConfig
  273. @dataclass
  274. @auto_docstring(
  275. custom_intro="""
  276. Base class for ColModernVBert embeddings output.
  277. """
  278. )
  279. class ColModernVBertForRetrievalOutput(ModelOutput):
  280. r"""
  281. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  282. Language modeling loss (for next-token prediction).
  283. embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  284. The embeddings of the model.
  285. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True` and `pixel_values` are provided):
  286. Tuple of `torch.FloatTensor` (one for the output of the image modality projection + one for the output of each layer) of shape
  287. `(batch_size, num_channels, image_size, image_size)`.
  288. Hidden-states of the image encoder at the output of each layer plus the initial modality projection outputs.
  289. """
  290. loss: torch.FloatTensor | None = None
  291. embeddings: torch.Tensor | None = None
  292. hidden_states: tuple[torch.FloatTensor] | None = None
  293. image_hidden_states: tuple[torch.FloatTensor] | None = None
  294. attentions: tuple[torch.FloatTensor] | None = None
  295. @auto_docstring(
  296. custom_intro="""
  297. Following the ColPali approach, ColModernVBert leverages VLMs to construct efficient multi-vector embeddings directly
  298. from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity
  299. between these document embeddings and the corresponding query embeddings, using the late interaction method
  300. introduced in ColBERT.
  301. Using ColModernVBert removes the need for potentially complex and brittle layout recognition and OCR pipelines with
  302. a single model that can take into account both the textual and visual content (layout, charts, ...) of a document.
  303. ColModernVBert is trained on top of ModernVBert, and was introduced in the following paper:
  304. [*ModernVBERT: Towards Smaller Visual Document Retrievers*](https://arxiv.org/abs/2510.01149).
  305. ColModernVBert is part of the ColVision model family, which was introduced with ColPali in the following paper:
  306. [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://huggingface.co/papers/2407.01449).
  307. """
  308. )
  309. class ColModernVBertForRetrieval(ColPaliForRetrieval):
  310. def __init__(self, config: ColModernVBertConfig):
  311. super().__init__(config)
  312. self.vlm = AutoModel.from_config(config.vlm_config)
  313. self.post_init()
  314. @can_return_tuple
  315. @auto_docstring
  316. def forward(
  317. self,
  318. input_ids: torch.LongTensor | None = None,
  319. pixel_values: torch.FloatTensor | None = None,
  320. attention_mask: torch.Tensor | None = None,
  321. **kwargs: Unpack[TransformersKwargs],
  322. ) -> ColModernVBertForRetrievalOutput:
  323. vlm_output = self.vlm(
  324. input_ids=input_ids,
  325. attention_mask=attention_mask,
  326. pixel_values=pixel_values,
  327. **kwargs,
  328. )
  329. last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
  330. proj_dtype = self.embedding_proj_layer.weight.dtype
  331. embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype)) # (batch_size, sequence_length, dim)
  332. # L2 normalization
  333. embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
  334. if attention_mask is not None:
  335. attention_mask = attention_mask.to(dtype=embeddings.dtype, device=embeddings.device)
  336. embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
  337. return ColModernVBertForRetrievalOutput(
  338. embeddings=embeddings,
  339. hidden_states=vlm_output.hidden_states,
  340. attentions=vlm_output.attentions,
  341. image_hidden_states=vlm_output.image_hidden_states,
  342. )
  343. __all__ = [
  344. "ColModernVBertConfig",
  345. "ColModernVBertForRetrieval",
  346. "ColModernVBertPreTrainedModel",
  347. "ColModernVBertProcessor",
  348. ]