| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- # Copyright 2024 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch ColPali model"""
- from dataclasses import dataclass
- import torch
- from torch import nn
- from transformers import AutoModel
- from ... import initialization as init
- from ...cache_utils import Cache
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
- from .configuration_colpali import ColPaliConfig
- @auto_docstring
- class ColPaliPreTrainedModel(PreTrainedModel):
- config: ColPaliConfig
- base_model_prefix = "model"
- input_modalities = ("image", "text")
- _no_split_modules = []
- _supports_sdpa = True
- _supports_flash_attn = True
- _supports_flex_attn = True
- @torch.no_grad()
- def _init_weights(self, module):
- std = (
- self.config.initializer_range
- if hasattr(self.config, "initializer_range")
- else self.config.vlm_config.text_config.initializer_range
- )
- if isinstance(module, (nn.Linear, nn.Conv2d)):
- init.normal_(module.weight, mean=0.0, std=std)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif isinstance(module, nn.Embedding):
- init.normal_(module.weight, mean=0.0, std=std)
- # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
- if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
- init.zeros_(module.weight[module.padding_idx])
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for ColPali embeddings output.
- """
- )
- class ColPaliForRetrievalOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- The embeddings of the model.
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
- """
- loss: torch.FloatTensor | None = None
- embeddings: torch.Tensor | None = None
- past_key_values: Cache | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- image_hidden_states: torch.FloatTensor | None = None
- @auto_docstring(
- custom_intro="""
- The ColPali architecture leverages VLMs to construct efficient multi-vector embeddings directly
- from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity
- between these document embeddings and the corresponding query embeddings, using the late interaction method
- introduced in ColBERT.
- Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a
- single model that can take into account both the textual and visual content (layout, charts, etc.) of a document.
- ColPali is part of the ColVision model family, which was first introduced in the following paper:
- [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://huggingface.co/papers/2407.01449).
- """
- )
- class ColPaliForRetrieval(ColPaliPreTrainedModel):
- base_model_prefix = "vlm"
- def __init__(self, config: ColPaliConfig):
- super().__init__(config)
- self.config = config
- self.vocab_size = config.vlm_config.text_config.vocab_size
- self.vlm = AutoModel.from_config(config.vlm_config)
- self.embedding_dim = self.config.embedding_dim
- self.embedding_proj_layer = nn.Linear(
- self.config.vlm_config.text_config.hidden_size,
- self.embedding_dim,
- )
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> ColPaliForRetrievalOutput:
- if pixel_values is not None:
- pixel_values = pixel_values.to(dtype=self.dtype)
- output_hidden_states = kwargs.pop("output_hidden_states", None)
- if output_hidden_states is None:
- output_hidden_states = self.config.output_hidden_states
- vlm_output = self.vlm(
- input_ids=input_ids,
- attention_mask=attention_mask,
- pixel_values=pixel_values,
- output_hidden_states=True,
- **kwargs,
- )
- vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
- vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None
- last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
- proj_dtype = self.embedding_proj_layer.weight.dtype
- embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype)) # (batch_size, sequence_length, dim)
- # L2 normalization
- embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
- if attention_mask is not None:
- embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
- return ColPaliForRetrievalOutput(
- embeddings=embeddings,
- past_key_values=vlm_output.past_key_values,
- hidden_states=vlm_hidden_states,
- attentions=vlm_output.attentions,
- image_hidden_states=vlm_image_hidden_states,
- )
- __all__ = [
- "ColPaliForRetrieval",
- "ColPaliPreTrainedModel",
- ]
|