| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907 |
- # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch Idefics3 model."""
- from collections.abc import Callable
- from dataclasses import dataclass
- import torch
- from torch import nn
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache
- from ...generation import GenerationMixin
- from ...masking_utils import create_bidirectional_mask
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
- from ...utils.generic import merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from ..auto import AutoModel
- from .configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
- logger = logging.get_logger(__name__)
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for Idefics3 model's outputs that may also contain a past key/values (to speed up sequential decoding).
- """
- )
- class Idefics3BaseModelOutputWithPast(ModelOutput):
- r"""
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
- hidden_size)` is output.
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
- `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
- input) to speed up sequential decoding.
- image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
- Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
- sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder
- """
- last_hidden_state: torch.FloatTensor | None = None
- past_key_values: Cache | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- image_hidden_states: tuple[torch.FloatTensor] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for Idefics causal language model (or autoregressive) outputs.
- """
- )
- class Idefics3CausalLMOutputWithPast(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
- Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
- sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor | None = None
- past_key_values: Cache | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- image_hidden_states: tuple[torch.FloatTensor] | None = None
- # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionEmbeddings with Idefics2->Idefics3
- class Idefics3VisionEmbeddings(nn.Module):
- """
- This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
- resolution.
- The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://huggingface.co/papers/2307.06304)
- which allows treating images in their native aspect ratio and without the need to resize them to the same
- fixed size. In particular, we start from the original pre-trained SigLIP model
- (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
- """
- def __init__(self, config: Idefics3VisionConfig):
- super().__init__()
- self.embed_dim = config.hidden_size
- self.image_size = config.image_size
- self.patch_size = config.patch_size
- self.patch_embedding = nn.Conv2d(
- in_channels=config.num_channels,
- out_channels=self.embed_dim,
- kernel_size=self.patch_size,
- stride=self.patch_size,
- padding="valid",
- )
- self.num_patches_per_side = self.image_size // self.patch_size
- self.num_patches = self.num_patches_per_side**2
- self.num_positions = self.num_patches
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
- def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
- batch_size, _, max_im_h, max_im_w = pixel_values.shape
- patch_embeds = self.patch_embedding(pixel_values)
- embeddings = patch_embeds.flatten(2).transpose(1, 2)
- max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
- boundaries = torch.arange(
- 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
- )
- position_ids = torch.full(
- size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
- )
- nb_patches_h = patch_attention_mask[:, :, 0].sum(dim=1) # (batch_size,)
- nb_patches_w = patch_attention_mask[:, 0, :].sum(dim=1) # (batch_size,)
- step_h = 1.0 / nb_patches_h # (batch_size,)
- step_w = 1.0 / nb_patches_w # (batch_size,)
- max_patches_h = patch_attention_mask.size(1)
- max_patches_w = patch_attention_mask.size(2)
- h_indices = torch.arange(max_patches_h, device=position_ids.device, dtype=torch.float32)
- w_indices = torch.arange(max_patches_w, device=position_ids.device, dtype=torch.float32)
- fractional_coords_h = h_indices[None, :] * step_h[:, None]
- fractional_coords_w = w_indices[None, :] * step_w[:, None]
- fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6))
- fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6))
- fractional_coords_h = fractional_coords_h.to(pixel_values.dtype)
- fractional_coords_w = fractional_coords_w.to(pixel_values.dtype)
- bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
- bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
- pos_ids = bucket_coords_h[:, :, None] * self.num_patches_per_side + bucket_coords_w[:, None, :]
- pos_ids = pos_ids.reshape(batch_size, -1)
- position_ids[patch_attention_mask.view(batch_size, -1)] = pos_ids[patch_attention_mask.view(batch_size, -1)]
- embeddings = embeddings + self.position_embedding(position_ids)
- return embeddings
- # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
- ):
- attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- # Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics3Vision
- class Idefics3VisionAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {self.num_heads})."
- )
- self.scale = self.head_dim**-0.5
- self.dropout = config.attention_dropout
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
- # Ignore copy
- self.is_causal = False
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- """Input shape: Batch x Time x Channel"""
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- queries,
- keys,
- values,
- attention_mask,
- is_causal=self.is_causal,
- scaling=self.scale,
- dropout=0.0 if not self.training else self.dropout,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights
- # Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics3Vision
- class Idefics3VisionMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.activation_fn = ACT2FN[config.hidden_act]
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states = self.fc2(hidden_states)
- return hidden_states
- class Idefics3SimpleMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- input_size = config.vision_config.hidden_size * (config.scale_factor**2)
- output_size = config.text_config.hidden_size
- self.proj = nn.Linear(input_size, output_size, bias=False)
- def forward(self, x):
- return self.proj(x)
- # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2EncoderLayer with Idefics2->Idefics3
- class Idefics3EncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: Idefics3VisionConfig):
- super().__init__()
- self.embed_dim = config.hidden_size
- self.self_attn = Idefics3VisionAttention(config)
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- self.mlp = Idefics3VisionMLP(config)
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- @auto_docstring
- # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.FloatTensor:
- residual = hidden_states
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- **kwargs,
- )
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics3
- class Idefics3Encoder(nn.Module):
- """
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
- [`Idefics3EncoderLayer`].
- Args:
- config: Idefics3Config
- """
- def __init__(self, config: Idefics3Config):
- super().__init__()
- self.config = config
- self.layers = nn.ModuleList([Idefics3EncoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- # Ignore copy
- @auto_docstring
- def forward(
- self,
- inputs_embeds,
- attention_mask: torch.Tensor | None = None,
- ) -> tuple | BaseModelOutput:
- hidden_states = inputs_embeds
- for encoder_layer in self.layers:
- layer_outputs = encoder_layer(
- hidden_states,
- attention_mask,
- )
- hidden_states = layer_outputs
- return BaseModelOutput(last_hidden_state=hidden_states)
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics3
- class Idefics3RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps: float = 1e-6) -> None:
- """
- Idefics3RMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
- class Idefics3Connector(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.scale_factor = config.scale_factor
- self.modality_projection = Idefics3SimpleMLP(config)
- def pixel_shuffle(self, x, scale_factor=2):
- bsz, seq, embed_dim = x.size()
- height = width = int(seq**0.5)
- x = x.view(bsz, height, width, embed_dim)
- x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
- x = x.permute(0, 2, 1, 3)
- x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
- x = x.permute(0, 2, 1, 3)
- x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
- return x
- def forward(self, image_hidden_states):
- image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
- image_hidden_states = self.modality_projection(image_hidden_states)
- return image_hidden_states
- @auto_docstring
- class Idefics3PreTrainedModel(PreTrainedModel):
- config: Idefics3Config
- base_model_prefix = "model"
- input_modalities = ("image", "text")
- supports_gradient_checkpointing = True
- _no_split_modules = ["Idefics3VisionAttention", "Idefics3DecoderLayer"]
- _skip_keys_device_placement = "past_key_values"
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _supports_attention_backend = True
- @auto_docstring(
- custom_intro="""
- The Idefics3 Vision Transformer Model outputting raw image embedding.
- """
- )
- class Idefics3VisionTransformer(Idefics3PreTrainedModel):
- config: Idefics3VisionConfig
- input_modalities = ("image",)
- _can_record_outputs = {
- "hidden_states": Idefics3EncoderLayer,
- "attentions": Idefics3VisionAttention,
- }
- def __init__(self, config: Idefics3VisionConfig):
- super().__init__(config)
- embed_dim = config.hidden_size
- self.embeddings = Idefics3VisionEmbeddings(config)
- self.encoder = Idefics3Encoder(config)
- self.patch_size = config.patch_size
- self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- self.post_init()
- # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.get_input_embeddings
- def get_input_embeddings(self):
- return self.embeddings
- # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.set_input_embeddings
- def set_input_embeddings(self, value):
- self.embeddings = value
- @merge_with_config_defaults
- @capture_outputs(tie_last_hidden_states=False)
- def forward(
- self,
- pixel_values,
- patch_attention_mask: torch.BoolTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutput:
- batch_size = pixel_values.size(0)
- if patch_attention_mask is None:
- patch_size = self.patch_size
- patch_attention_mask = torch.ones(
- (
- batch_size,
- pixel_values.size(2) // patch_size,
- pixel_values.size(3) // patch_size,
- )
- )
- patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
- hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
- patch_attention_mask = patch_attention_mask.view(batch_size, -1)
- # Create the correct attention mask based on the attention implementation
- patch_attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=hidden_states,
- attention_mask=patch_attention_mask,
- )
- encoder_outputs: BaseModelOutput = self.encoder(
- inputs_embeds=hidden_states,
- attention_mask=patch_attention_mask,
- )
- last_hidden_state = encoder_outputs.last_hidden_state
- last_hidden_state = self.post_layernorm(last_hidden_state)
- return BaseModelOutput(
- last_hidden_state=last_hidden_state,
- )
- @auto_docstring(
- custom_intro="""
- Idefics3 model consisting of a SIGLIP vision encoder and Llama3 language decoder
- """
- )
- class Idefics3Model(Idefics3PreTrainedModel):
- def __init__(self, config: Idefics3Config):
- super().__init__(config)
- self.padding_idx = self.config.text_config.pad_token_id
- self.vocab_size = self.config.text_config.vocab_size
- self.vision_model = Idefics3VisionTransformer._from_config(config.vision_config)
- self.connector = Idefics3Connector(config)
- self.text_model = AutoModel.from_config(config.text_config)
- self.image_seq_len = int(
- ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
- )
- self.image_token_id = self.config.image_token_id
- self.post_init()
- # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.get_input_embeddings
- def get_input_embeddings(self):
- return self.text_model.get_input_embeddings()
- # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.set_input_embeddings
- def set_input_embeddings(self, value):
- self.text_model.set_input_embeddings(value)
- def inputs_merger(
- self,
- input_ids: torch.LongTensor,
- inputs_embeds: torch.Tensor | None,
- image_hidden_states: torch.Tensor | None,
- ):
- """
- This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
- The merging happens as follows:
- - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
- - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space.
- We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
- - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
- - To fit the format of that sequence, `input_ids`, `inputs_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
- """
- if input_ids is None:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- special_image_mask = special_image_mask.all(-1)
- else:
- special_image_mask = input_ids == self.config.image_token_id
- special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
- return inputs_embeds
- @can_return_tuple
- @auto_docstring
- def get_image_features(
- self,
- pixel_values: torch.FloatTensor,
- pixel_attention_mask: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- r"""
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
- The tensors corresponding to the input images.
- pixel_attention_mask (`torch.LongTensor`, *optional*):
- The attention mask indicating padded regions in the image.
- """
- batch_size, num_images, num_channels, height, width = pixel_values.shape
- pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
- pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
- # Remove padding images - padding images are full 0.
- nb_values_per_image = pixel_values.shape[1:].numel()
- real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
- pixel_values = pixel_values[real_images_inds].contiguous()
- # Handle the vision attention mask
- if pixel_attention_mask is None:
- pixel_attention_mask = torch.ones(
- size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
- dtype=torch.bool,
- device=pixel_values.device,
- )
- else:
- # Remove padding images from the mask
- pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
- pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
- patch_size = self.config.vision_config.patch_size
- patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
- patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
- patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
- # Get sequence from the vision encoder
- image_outputs = self.vision_model(
- pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, return_dict=True, **kwargs
- )
- image_hidden_states = image_outputs.last_hidden_state
- # Modality projection & resampling
- image_features = self.connector(image_hidden_states)
- image_outputs.pooler_output = image_features
- return image_outputs
- @can_return_tuple
- @auto_docstring(
- custom_intro="""
- Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
- the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
- max_num_images is the maximum number of images among the batch_size samples in the batch.
- Padding images are not needed beyond padding the pixel_values at the entrance of the model.
- For efficiency, we only pass through the vision_model's forward the real images by
- discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
- image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
- """
- )
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- pixel_attention_mask: torch.BoolTensor | None = None,
- image_hidden_states: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple | Idefics3BaseModelOutputWithPast:
- r"""
- pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
- Mask to avoid performing attention on padding pixel indices.
- image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
- The hidden states of the image encoder after modality projection.
- """
- if self.training and self.text_model.gradient_checkpointing and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
- # retrieve input_ids and inputs_embeds
- if input_ids is not None:
- batch_size, seq_length = input_ids.shape
- elif inputs_embeds is not None:
- batch_size, seq_length, _ = inputs_embeds.shape
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- if inputs_embeds is None:
- inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)
- # START VISUAL INPUTS INTEGRATION
- if pixel_values is not None and image_hidden_states is not None:
- raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
- elif pixel_values is not None:
- image_hidden_states = self.get_image_features(
- pixel_values, pixel_attention_mask, return_dict=True
- ).pooler_output
- elif image_hidden_states is not None:
- image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
- if image_hidden_states is not None:
- # When we generate, we don't want to replace the potential image_token_id that we generated by images
- # that simply don't exist
- inputs_embeds = self.inputs_merger(
- input_ids=input_ids,
- inputs_embeds=inputs_embeds,
- image_hidden_states=image_hidden_states,
- )
- outputs = self.text_model(
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- **kwargs,
- )
- return Idefics3BaseModelOutputWithPast(
- last_hidden_state=outputs.last_hidden_state,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- image_hidden_states=image_hidden_states,
- )
- @auto_docstring(
- custom_intro="""
- The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top.
- """
- )
- class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"}
- # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3
- def __init__(self, config):
- super().__init__(config)
- self.model = Idefics3Model(config)
- self.image_token_id = self.config.image_token_id
- self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
- self.vocab_size = config.text_config.vocab_size
- # Initialize weights and apply final processing
- self.post_init()
- # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.get_input_embeddings
- def get_input_embeddings(self):
- return self.model.text_model.get_input_embeddings()
- # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.set_input_embeddings
- def set_input_embeddings(self, value):
- self.model.text_model.set_input_embeddings(value)
- @auto_docstring
- def get_image_features(
- self,
- pixel_values: torch.FloatTensor,
- pixel_attention_mask: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- r"""
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
- The tensors corresponding to the input images.
- pixel_attention_mask (`torch.LongTensor`, *optional*):
- The attention mask indicating padded regions in the image.
- """
- return self.model.get_image_features(
- pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask, **kwargs
- )
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- pixel_attention_mask: torch.BoolTensor | None = None,
- image_hidden_states: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Idefics3CausalLMOutputWithPast:
- r"""
- pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
- Mask to avoid performing attention on padding pixel indices.
- image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
- The hidden states of the image encoder after modality projection.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`).
- Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
- computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Example:
- ```python
- >>> import torch
- >>> from PIL import Image
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, AutoModelForImageTextToText
- >>> from transformers.image_utils import load_image
- >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
- >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
- >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
- >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
- >>> processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
- >>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3", dtype=torch.bfloat16, device_map="auto")
- >>> # Create inputs
- >>> messages = [
- ... {
- ... "role": "user",
- ... "content": [
- ... {"type": "image"},
- ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
- ... {"type": "image"},
- ... {"type": "text", "text": "What can we see in this image?"},
- ... ]
- ... },
- ... {
- ... "role": "user",
- ... "content": [
- ... {"type": "image"},
- ... {"type": "text", "text": "In which city is that bridge located?"},
- ... ]
- ... }
- ... ]
- >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
- >>> images = [[image1, image2], [image3]]
- >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
- >>> # Generate
- >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
- >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
- >>> print(generated_texts[0])
- Assistant: There are buildings, trees, lights, and water visible in this image.
- >>> print(generated_texts[1])
- Assistant: The bridge is in San Francisco.
- ```"""
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- pixel_values=pixel_values,
- pixel_attention_mask=pixel_attention_mask,
- image_hidden_states=image_hidden_states,
- use_cache=use_cache,
- return_dict=True,
- **kwargs,
- )
- hidden_states = outputs[0]
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss = self.loss_function(
- logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
- )
- return Idefics3CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- image_hidden_states=outputs.image_hidden_states,
- )
- # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.prepare_inputs_for_generation
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- pixel_values=None,
- pixel_attention_mask=None,
- image_hidden_states=None,
- logits_to_keep=None,
- is_first_iteration=False,
- use_cache=False,
- **kwargs,
- ):
- # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
- # precedence is moved to the model, we can remove this fn)
- model_inputs = super().prepare_inputs_for_generation(
- input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- pixel_values=pixel_values,
- pixel_attention_mask=pixel_attention_mask,
- image_hidden_states=image_hidden_states,
- logits_to_keep=logits_to_keep,
- is_first_iteration=is_first_iteration,
- use_cache=use_cache,
- **kwargs,
- )
- if image_hidden_states is not None or (use_cache and not is_first_iteration):
- model_inputs["pixel_values"] = None
- model_inputs["pixel_attention_mask"] = None
- return model_inputs
- __all__ = ["Idefics3ForConditionalGeneration", "Idefics3PreTrainedModel", "Idefics3Model", "Idefics3VisionTransformer"]
|