| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/colqwen2/modular_colqwen2.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_colqwen2.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2025 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from dataclasses import dataclass
- from torch import nn
- from transformers import AutoModel
- from ... import initialization as init
- from ...cache_utils import Cache
- from ...modeling_utils import PreTrainedModel
- from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available
- from .configuration_colqwen2 import ColQwen2Config
- if is_torch_available():
- import torch
- @auto_docstring
- class ColQwen2PreTrainedModel(PreTrainedModel):
- config: ColQwen2Config
- base_model_prefix = "model"
- input_modalities = ("image", "text")
- _no_split_modules = []
- _supports_sdpa = True
- _supports_flash_attn = True
- _supports_flex_attn = True
- @torch.no_grad()
- def _init_weights(self, module):
- std = (
- self.config.initializer_range
- if hasattr(self.config, "initializer_range")
- else self.config.vlm_config.text_config.initializer_range
- )
- if isinstance(module, (nn.Linear, nn.Conv2d)):
- init.normal_(module.weight, mean=0.0, std=std)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif isinstance(module, nn.Embedding):
- init.normal_(module.weight, mean=0.0, std=std)
- # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
- if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
- init.zeros_(module.weight[module.padding_idx])
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for ColQwen2 embeddings output.
- """
- )
- class ColQwen2ForRetrievalOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- The embeddings of the model.
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- """
- loss: torch.FloatTensor | None = None
- embeddings: torch.Tensor | None = None
- past_key_values: Cache | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- @auto_docstring(
- custom_intro="""
- Following the ColPali approach, ColQwen2 leverages VLMs to construct efficient multi-vector embeddings directly
- from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity
- between these document embeddings and the corresponding query embeddings, using the late interaction method
- introduced in ColBERT.
- Using ColQwen2 removes the need for potentially complex and brittle layout recognition and OCR pipelines with
- a single model that can take into account both the textual and visual content (layout, charts, ...) of a document.
- ColQwen2 is part of the ColVision model family, which was introduced with ColPali in the following paper:
- [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://huggingface.co/papers/2407.01449).
- """
- )
- class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
- base_model_prefix = "vlm"
- def __init__(self, config: ColQwen2Config):
- super().__init__(config)
- self.config = config
- self.vocab_size = config.vlm_config.text_config.vocab_size
- self.vlm = AutoModel.from_config(config.vlm_config)
- self.embedding_dim = self.config.embedding_dim
- self.embedding_proj_layer = nn.Linear(
- self.config.vlm_config.text_config.hidden_size,
- self.embedding_dim,
- )
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- labels: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- pixel_values: torch.Tensor | None = None,
- image_grid_thw: torch.LongTensor | None = None,
- **kwargs,
- ) -> ColQwen2ForRetrievalOutput:
- r"""
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
- The temporal, height and width of feature shape of each image in LLM.
- """
- # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
- if pixel_values is not None and image_grid_thw is not None:
- # NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size)
- offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (batch_size,)
- arange = torch.arange(pixel_values.shape[1], device=offsets.device) # (max_len,)
- mask = arange.unsqueeze(0) < offsets.unsqueeze(1) # (batch_size, max_len)
- pixel_values = pixel_values[mask] # (total_valid_patches, channels, height, width)
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- # Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
- if inputs_embeds is None:
- inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
- if pixel_values is not None:
- image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True).pooler_output
- image_mask = (
- (input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
- )
- image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
- vlm_output = self.vlm(
- input_ids=None,
- position_ids=position_ids,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
- last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
- proj_dtype = self.embedding_proj_layer.weight.dtype
- embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype)) # (batch_size, sequence_length, dim)
- # L2 normalization
- embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
- if attention_mask is not None:
- embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
- return ColQwen2ForRetrievalOutput(
- embeddings=embeddings,
- past_key_values=vlm_output.past_key_values,
- hidden_states=vlm_hidden_states,
- attentions=vlm_output.attentions,
- )
- __all__ = ["ColQwen2ForRetrieval", "ColQwen2PreTrainedModel"]
|