modeling_colpali.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright 2024 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ColPali model"""
  15. from dataclasses import dataclass
  16. import torch
  17. from torch import nn
  18. from transformers import AutoModel
  19. from ... import initialization as init
  20. from ...cache_utils import Cache
  21. from ...modeling_utils import PreTrainedModel
  22. from ...processing_utils import Unpack
  23. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
  24. from .configuration_colpali import ColPaliConfig
  25. @auto_docstring
  26. class ColPaliPreTrainedModel(PreTrainedModel):
  27. config: ColPaliConfig
  28. base_model_prefix = "model"
  29. input_modalities = ("image", "text")
  30. _no_split_modules = []
  31. _supports_sdpa = True
  32. _supports_flash_attn = True
  33. _supports_flex_attn = True
  34. @torch.no_grad()
  35. def _init_weights(self, module):
  36. std = (
  37. self.config.initializer_range
  38. if hasattr(self.config, "initializer_range")
  39. else self.config.vlm_config.text_config.initializer_range
  40. )
  41. if isinstance(module, (nn.Linear, nn.Conv2d)):
  42. init.normal_(module.weight, mean=0.0, std=std)
  43. if module.bias is not None:
  44. init.zeros_(module.bias)
  45. elif isinstance(module, nn.Embedding):
  46. init.normal_(module.weight, mean=0.0, std=std)
  47. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  48. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  49. init.zeros_(module.weight[module.padding_idx])
  50. @dataclass
  51. @auto_docstring(
  52. custom_intro="""
  53. Base class for ColPali embeddings output.
  54. """
  55. )
  56. class ColPaliForRetrievalOutput(ModelOutput):
  57. r"""
  58. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  59. Language modeling loss (for next-token prediction).
  60. embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  61. The embeddings of the model.
  62. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  63. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  64. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  65. `past_key_values` input) to speed up sequential decoding.
  66. image_hidden_states (`torch.FloatTensor`, *optional*):
  67. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  68. image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
  69. """
  70. loss: torch.FloatTensor | None = None
  71. embeddings: torch.Tensor | None = None
  72. past_key_values: Cache | None = None
  73. hidden_states: tuple[torch.FloatTensor] | None = None
  74. attentions: tuple[torch.FloatTensor] | None = None
  75. image_hidden_states: torch.FloatTensor | None = None
  76. @auto_docstring(
  77. custom_intro="""
  78. The ColPali architecture leverages VLMs to construct efficient multi-vector embeddings directly
  79. from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity
  80. between these document embeddings and the corresponding query embeddings, using the late interaction method
  81. introduced in ColBERT.
  82. Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a
  83. single model that can take into account both the textual and visual content (layout, charts, etc.) of a document.
  84. ColPali is part of the ColVision model family, which was first introduced in the following paper:
  85. [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://huggingface.co/papers/2407.01449).
  86. """
  87. )
  88. class ColPaliForRetrieval(ColPaliPreTrainedModel):
  89. base_model_prefix = "vlm"
  90. def __init__(self, config: ColPaliConfig):
  91. super().__init__(config)
  92. self.config = config
  93. self.vocab_size = config.vlm_config.text_config.vocab_size
  94. self.vlm = AutoModel.from_config(config.vlm_config)
  95. self.embedding_dim = self.config.embedding_dim
  96. self.embedding_proj_layer = nn.Linear(
  97. self.config.vlm_config.text_config.hidden_size,
  98. self.embedding_dim,
  99. )
  100. self.post_init()
  101. @can_return_tuple
  102. @auto_docstring
  103. def forward(
  104. self,
  105. input_ids: torch.LongTensor | None = None,
  106. pixel_values: torch.FloatTensor | None = None,
  107. attention_mask: torch.Tensor | None = None,
  108. **kwargs: Unpack[TransformersKwargs],
  109. ) -> ColPaliForRetrievalOutput:
  110. if pixel_values is not None:
  111. pixel_values = pixel_values.to(dtype=self.dtype)
  112. output_hidden_states = kwargs.pop("output_hidden_states", None)
  113. if output_hidden_states is None:
  114. output_hidden_states = self.config.output_hidden_states
  115. vlm_output = self.vlm(
  116. input_ids=input_ids,
  117. attention_mask=attention_mask,
  118. pixel_values=pixel_values,
  119. output_hidden_states=True,
  120. **kwargs,
  121. )
  122. vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
  123. vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None
  124. last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
  125. proj_dtype = self.embedding_proj_layer.weight.dtype
  126. embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype)) # (batch_size, sequence_length, dim)
  127. # L2 normalization
  128. embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
  129. if attention_mask is not None:
  130. embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
  131. return ColPaliForRetrievalOutput(
  132. embeddings=embeddings,
  133. past_key_values=vlm_output.past_key_values,
  134. hidden_states=vlm_hidden_states,
  135. attentions=vlm_output.attentions,
  136. image_hidden_states=vlm_image_hidden_states,
  137. )
  138. __all__ = [
  139. "ColPaliForRetrieval",
  140. "ColPaliPreTrainedModel",
  141. ]