| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396 |
- # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
- import math
- from collections.abc import Callable
- from dataclasses import dataclass
- import torch
- from torch import nn
- from ...activations import ACT2FN
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
- from ...processing_utils import Unpack
- from ...utils import (
- ModelOutput,
- TransformersKwargs,
- logging,
- )
- from .configuration_idefics import IdeficsVisionConfig
- logger = logging.get_logger(__name__)
- @dataclass
- class IdeficsVisionModelOutput(ModelOutput):
- """
- Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
- Args:
- image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
- The image embeddings obtained by applying the projection layer to the pooler_output.
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
- image_embeds: torch.FloatTensor | None = None
- last_hidden_state: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- # Adapted from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings
- class IdeficsVisionEmbeddings(nn.Module):
- def __init__(self, config: IdeficsVisionConfig):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.image_size = config.image_size
- self.patch_size = config.patch_size
- self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
- self.patch_embedding = nn.Conv2d(
- in_channels=config.num_channels,
- out_channels=self.embed_dim,
- kernel_size=self.patch_size,
- stride=self.patch_size,
- bias=False,
- )
- self.num_patches = (self.image_size // self.patch_size) ** 2
- self.num_positions = self.num_patches + 1
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
- self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
- # Heavily inspired from https://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/models/vit/modeling_vit.py#L82
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
- """
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
- resolution images.
- Source:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
- """
- num_patches = embeddings.shape[1] - 1
- pos_embed = self.position_embedding(self.position_ids)
- num_positions = pos_embed.shape[1] - 1
- if num_patches == num_positions and height == width:
- return pos_embed
- class_pos_embed = pos_embed[:, 0]
- patch_pos_embed = pos_embed[:, 1:]
- embed_dim = embeddings.shape[-1]
- num_h_patches = height // self.config.patch_size
- num_w_patches = width // self.config.patch_size
- # we add a small number to avoid floating point error in the interpolation
- # see discussion at https://github.com/facebookresearch/dino/issues/8
- num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1
- sqrt_num_positions = math.sqrt(num_positions)
- patch_pos_embed = patch_pos_embed.reshape(1, int(sqrt_num_positions), int(sqrt_num_positions), embed_dim)
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
- fp32_upcasting = patch_pos_embed.dtype == torch.bfloat16
- if fp32_upcasting:
- logger.warning_once(
- "Upcasting patch_pos_embed to fp32 for interpolation since `upsample_bicubic2d_out_frame` in nn.functional.interpolate "
- "is not implemented for 'torch.bfloat16' dtype. This will result in a slight overhead."
- )
- patch_pos_embed = patch_pos_embed.to(torch.float)
- patch_pos_embed = nn.functional.interpolate(
- patch_pos_embed,
- scale_factor=(num_h_patches / sqrt_num_positions, num_w_patches / sqrt_num_positions),
- mode="bicubic",
- align_corners=False,
- )
- if fp32_upcasting:
- patch_pos_embed = patch_pos_embed.to(torch.bfloat16)
- if int(num_h_patches) != patch_pos_embed.shape[-2] or int(num_w_patches) != patch_pos_embed.shape[-1]:
- raise ValueError(
- f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the "
- f"shape of position embedding ({patch_pos_embed.shape[-2], patch_pos_embed.shape[-1]})"
- )
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)
- return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
- def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
- batch_size, num_channels, height, width = pixel_values.shape
- if not interpolate_pos_encoding:
- if height != self.image_size or width != self.image_size:
- raise ValueError(
- f"Input image size ({height}*{width}) doesn't match model"
- f" ({self.image_size}*{self.image_size}). You should try to set `interpolate_pos_encoding=True`"
- )
- target_dtype = self.patch_embedding.weight.dtype
- patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
- class_embeds = self.class_embedding.expand(batch_size, 1, -1)
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
- # add positional encoding to each token
- if interpolate_pos_encoding:
- embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
- else:
- embeddings = embeddings + self.position_embedding(self.position_ids)
- return embeddings
- # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
- ):
- attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class IdeficsVisionAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: IdeficsVisionConfig):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {self.num_heads})."
- )
- self.scale = self.head_dim**-0.5
- self.dropout = config.attention_dropout
- self.is_causal = False
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- """Input shape: Batch x Time x Channel"""
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- queries = self.q_proj(hidden_states)
- keys = self.k_proj(hidden_states)
- values = self.v_proj(hidden_states)
- queries = queries.view(hidden_shape).transpose(1, 2)
- keys = keys.view(hidden_shape).transpose(1, 2)
- values = values.view(hidden_shape).transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- queries,
- keys,
- values,
- attention_mask,
- is_causal=self.is_causal,
- scaling=self.scale,
- dropout=0.0 if not self.training else self.dropout,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights
- # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision
- class IdeficsVisionMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.activation_fn = ACT2FN[config.hidden_act]
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states = self.fc2(hidden_states)
- return hidden_states
- # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->IdeficsVision
- class IdeficsVisionEncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: IdeficsVisionConfig):
- super().__init__()
- self.embed_dim = config.hidden_size
- self.self_attn = IdeficsVisionAttention(config)
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- self.mlp = IdeficsVisionMLP(config)
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.FloatTensor, torch.Tensor | None]:
- residual = hidden_states
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- **kwargs,
- )
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->IdeficsVision
- class IdeficsVisionEncoder(nn.Module):
- """
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
- [`IdeficsVisionEncoderLayer`].
- Args:
- config: IdeficsVisionConfig
- """
- def __init__(self, config: IdeficsVisionConfig):
- super().__init__()
- self.config = config
- self.layers = nn.ModuleList([IdeficsVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- def forward(
- self,
- inputs_embeds,
- attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutput:
- r"""
- Args:
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
- This is useful if you want more control over how to convert `input_ids` indices into associated vectors
- than the model's internal embedding lookup matrix.
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- """
- hidden_states = inputs_embeds
- for encoder_layer in self.layers:
- hidden_states = encoder_layer(
- hidden_states,
- attention_mask,
- **kwargs,
- )
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- )
- # Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer
- class IdeficsVisionTransformer(nn.Module):
- def __init__(self, config: IdeficsVisionConfig):
- super().__init__()
- self.config = config
- embed_dim = config.hidden_size
- self.embeddings = IdeficsVisionEmbeddings(config)
- self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- self.encoder = IdeficsVisionEncoder(config)
- self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- # Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- interpolate_pos_encoding: bool | None = False,
- **kwargs,
- ) -> tuple | BaseModelOutputWithPooling:
- r"""
- Returns:
- """
- if pixel_values is None:
- raise ValueError("You have to specify pixel_values")
- hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
- hidden_states = self.pre_layrnorm(hidden_states)
- encoder_outputs: BaseModelOutput = self.encoder(
- inputs_embeds=hidden_states,
- **kwargs,
- )
- last_hidden_state = encoder_outputs.last_hidden_state
- pooled_output = last_hidden_state[:, 0, :]
- pooled_output = self.post_layernorm(pooled_output)
- return BaseModelOutputWithPooling(
- last_hidden_state=last_hidden_state,
- pooler_output=pooled_output,
- )
|