| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302 |
- # Copyright 2022 The Salesforce Team Authors and The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch BLIP model."""
- from dataclasses import dataclass
- from typing import Any
- import torch
- from torch import nn
- from torch.nn.functional import normalize
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...generation import GenerationMixin
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPooling,
- BaseModelOutputWithPoolingAndCrossAttentions,
- )
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
- from ...utils.generic import merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
- from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
- logger = logging.get_logger(__name__)
- # Copied from transformers.models.clip.modeling_clip.contrastive_loss
- def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
- return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
- # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->blip
- def blip_loss(similarity: torch.Tensor) -> torch.Tensor:
- caption_loss = contrastive_loss(similarity)
- image_loss = contrastive_loss(similarity.t())
- return (caption_loss + image_loss) / 2.0
- @dataclass
- @auto_docstring(
- custom_intro="""
- Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
- last hidden states. This class also adds the loss term from the text decoder.
- """
- )
- class BlipForConditionalGenerationModelOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
- Language modeling loss from the text decoder.
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
- Prediction scores of the language modeling head of the text decoder model.
- image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*):
- The image embeddings obtained after applying the Vision Transformer model to the input image.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
- loss: tuple[torch.FloatTensor] | None = None
- logits: tuple[torch.FloatTensor] | None = None
- image_embeds: torch.FloatTensor | None = None
- last_hidden_state: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
- last hidden states. This class also adds the loss term from the text decoder.
- """
- )
- class BlipTextVisionModelOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss from the text decoder.
- image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
- The image embeddings obtained by applying the projection layer to the pooler_output.
- """
- loss: torch.FloatTensor | None = None
- image_embeds: torch.FloatTensor | None = None
- last_hidden_state: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
- last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity
- scores.
- """
- )
- class BlipImageTextMatchingModelOutput(ModelOutput):
- r"""
- itm_score (`torch.FloatTensor`):
- The image-text similarity scores.
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss from the text decoder.
- image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
- The image embeddings obtained by applying the projection layer to the pooler_output.
- vision_pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*):
- Last layer hidden-state of the vision of the vision-only branch of the model.
- question_embeds (`torch.FloatTensor`):
- The question embeddings obtained by the text projection layer.
- """
- itm_score: torch.FloatTensor | None = None
- loss: torch.FloatTensor | None = None
- image_embeds: torch.FloatTensor | None = None
- last_hidden_state: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- vision_pooler_output: torch.FloatTensor | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- question_embeds: tuple[torch.FloatTensor] | None = None
- @dataclass
- @auto_docstring
- class BlipOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
- Contrastive loss for image-text similarity.
- logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
- The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
- similarity scores.
- logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
- The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
- similarity scores.
- text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
- The text embeddings obtained by applying the projection layer to the pooled output of [`BlipTextModel`].
- image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
- The image embeddings obtained by applying the projection layer to the pooled output of [`BlipVisionModel`].
- text_model_output (`BaseModelOutputWithPooling`):
- The output of the [`BlipTextModel`].
- vision_model_output (`BaseModelOutputWithPooling`):
- The output of the [`BlipVisionModel`].
- """
- loss: torch.FloatTensor | None = None
- logits_per_image: torch.FloatTensor | None = None
- logits_per_text: torch.FloatTensor | None = None
- text_embeds: torch.FloatTensor | None = None
- image_embeds: torch.FloatTensor | None = None
- text_model_output: BaseModelOutputWithPooling = None
- vision_model_output: BaseModelOutputWithPooling = None
- def to_tuple(self) -> tuple[Any]:
- return tuple(
- self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
- for k in self.keys()
- )
- class BlipVisionEmbeddings(nn.Module):
- def __init__(self, config: BlipVisionConfig):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.image_size = config.image_size
- self.patch_size = config.patch_size
- self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
- self.patch_embedding = nn.Conv2d(
- in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
- )
- self.num_patches = (self.image_size // self.patch_size) ** 2
- self.num_positions = self.num_patches + 1
- self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
- """
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
- images. This method is also adapted to support torch.jit tracing.
- Adapted from:
- - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
- """
- num_patches = embeddings.shape[1] - 1
- num_positions = self.position_embedding.shape[1] - 1
- # always interpolate when tracing to ensure the exported model works for dynamic input shapes
- if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
- return self.position_embedding
- class_pos_embed = self.position_embedding[:, :1]
- patch_pos_embed = self.position_embedding[:, 1:]
- dim = embeddings.shape[-1]
- new_height = height // self.patch_size
- new_width = width // self.patch_size
- sqrt_num_positions = torch_int(num_positions**0.5)
- patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
- patch_pos_embed = nn.functional.interpolate(
- patch_pos_embed,
- size=(new_height, new_width),
- mode="bicubic",
- align_corners=False,
- )
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
- return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
- def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
- batch_size, _, height, width = pixel_values.shape
- target_dtype = self.patch_embedding.weight.dtype
- patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
- class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
- if interpolate_pos_encoding:
- position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
- else:
- position_embedding = self.position_embedding
- embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
- return embeddings
- # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Blip
- class BlipTextEmbeddings(nn.Module):
- def __init__(self, config: BlipTextConfig):
- super().__init__()
- embed_dim = config.hidden_size
- self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
- self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.register_buffer(
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
- )
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- ) -> torch.Tensor:
- seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
- max_position_embedding = self.position_embedding.weight.shape[0]
- if seq_length > max_position_embedding:
- raise ValueError(
- f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
- f"{seq_length} and max_position_embeddings: {max_position_embedding}"
- )
- if position_ids is None:
- position_ids = self.position_ids[:, :seq_length]
- if inputs_embeds is None:
- inputs_embeds = self.token_embedding(input_ids)
- position_embeddings = self.position_embedding(position_ids)
- embeddings = inputs_embeds + position_embeddings
- return embeddings
- class BlipAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {self.num_heads})."
- )
- self.scale = self.head_dim**-0.5
- self.dropout = nn.Dropout(config.attention_dropout)
- self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim)
- self.projection = nn.Linear(self.embed_dim, self.embed_dim)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
- def forward(
- self,
- hidden_states: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """Input shape: Batch x Time x Channel"""
- bsz, tgt_len, embed_dim = hidden_states.size()
- mixed_qkv = (
- self.qkv(hidden_states)
- .reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads)
- .permute(2, 0, 3, 1, 4)
- )
- query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
- attention_scores = attention_scores * self.scale
- # Normalize the attention scores to probabilities.
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs)
- context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
- new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
- context_layer = context_layer.reshape(new_context_layer_shape)
- output = self.projection(context_layer)
- return output, attention_probs
- # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Blip
- class BlipMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.activation_fn = ACT2FN[config.hidden_act]
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states = self.fc2(hidden_states)
- return hidden_states
- class BlipEncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: BlipConfig):
- super().__init__()
- self.embed_dim = config.hidden_size
- self.self_attn = BlipAttention(config)
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- self.mlp = BlipMLP(config)
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- @auto_docstring
- def forward(
- self,
- hidden_states: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.FloatTensor:
- residual = hidden_states
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- **kwargs,
- )
- hidden_states = hidden_states + residual
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = hidden_states + residual
- return hidden_states
- @auto_docstring
- class BlipPreTrainedModel(PreTrainedModel):
- config: BlipConfig
- base_model_prefix = "blip"
- input_modalities = ("image", "text")
- supports_gradient_checkpointing = True
- _no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"]
- _skip_keys_device_placement = ["past_key_values"]
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- super()._init_weights(module)
- std = self.config.initializer_range
- if isinstance(module, BlipVisionEmbeddings):
- if hasattr(self.config, "vision_config"):
- std = self.config.vision_config.initializer_range
- init.trunc_normal_(module.position_embedding, mean=0.0, std=std)
- init.trunc_normal_(module.class_embedding, mean=0.0, std=std)
- elif isinstance(module, BlipTextEmbeddings):
- init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
- class BlipEncoder(nn.Module):
- """
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
- [`BlipEncoderLayer`].
- Args:
- config (`BlipConfig`):
- The corresponding vision configuration for the `BlipEncoder`.
- """
- def __init__(self, config: BlipConfig):
- super().__init__()
- self.config = config
- self.layers = nn.ModuleList([BlipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- @auto_docstring
- def forward(
- self,
- inputs_embeds,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutput:
- hidden_states = inputs_embeds
- for encoder_layer in self.layers:
- hidden_states = encoder_layer(
- hidden_states,
- **kwargs,
- )
- return BaseModelOutput(last_hidden_state=hidden_states)
- class BlipVisionModel(BlipPreTrainedModel):
- main_input_name = "pixel_values"
- input_modalities = ("image",)
- config: BlipVisionConfig
- _can_record_outputs = {
- "hidden_states": BlipEncoderLayer,
- "attentions": BlipAttention,
- }
- def __init__(self, config: BlipVisionConfig):
- super().__init__(config)
- self.config = config
- embed_dim = config.hidden_size
- self.embeddings = BlipVisionEmbeddings(config)
- self.encoder = BlipEncoder(config)
- self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs(tie_last_hidden_states=False)
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- interpolate_pos_encoding: bool = False,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- if pixel_values is None:
- raise ValueError("You have to specify pixel_values")
- hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
- encoder_outputs: BaseModelOutput = self.encoder(
- inputs_embeds=hidden_states,
- **kwargs,
- )
- last_hidden_state = encoder_outputs.last_hidden_state
- last_hidden_state = self.post_layernorm(last_hidden_state)
- pooled_output = last_hidden_state[:, 0, :]
- pooled_output = self.post_layernorm(pooled_output)
- return BaseModelOutputWithPooling(
- last_hidden_state=last_hidden_state,
- pooler_output=pooled_output,
- )
- def get_input_embeddings(self):
- return self.embeddings
- @auto_docstring(
- custom_intro="""
- This model is going to be deprecated in future versions. Please use `BlipForConditionalGeneration`, `BlipForQuestionAnswering` or `BlipForImageTextRetrieval` depending on your usecase.
- """
- )
- class BlipModel(BlipPreTrainedModel):
- config: BlipConfig
- def __init__(self, config: BlipConfig):
- super().__init__(config)
- if not isinstance(config.text_config, BlipTextConfig):
- raise TypeError(
- "config.text_config is expected to be of type BlipTextConfig but is of type"
- f" {type(config.text_config)}."
- )
- if not isinstance(config.vision_config, BlipVisionConfig):
- raise TypeError(
- "config.vision_config is expected to be of type BlipVisionConfig but is of type"
- f" {type(config.vision_config)}."
- )
- text_config = config.text_config
- vision_config = config.vision_config
- self.projection_dim = config.projection_dim
- self.text_embed_dim = text_config.hidden_size
- self.vision_embed_dim = vision_config.hidden_size
- self.text_model = BlipTextModel(text_config)
- self.vision_model = BlipVisionModel(vision_config)
- self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
- self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
- self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
- logger.warning(
- "`BlipModel` is going to be deprecated in future release, please use `BlipForConditionalGeneration`, `BlipForQuestionAnswering` or `BlipForImageTextRetrieval` depending on your usecase."
- )
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.text_model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.text_model.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring
- def get_text_features(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- r"""
- Examples:
- ```python
- >>> from transformers import AutoProcessor, BlipModel
- >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
- >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
- >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
- >>> text_features = model.get_text_features(**inputs)
- ```"""
- text_outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.text_model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- return_dict=True,
- **kwargs,
- )
- pooled_output = text_outputs.pooler_output
- text_outputs.pooler_output = self.text_projection(pooled_output)
- return text_outputs
- @can_return_tuple
- @auto_docstring
- def get_image_features(
- self,
- pixel_values: torch.FloatTensor | None = None,
- interpolate_pos_encoding: bool = False,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- r"""
- Examples:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, BlipModel
- >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
- >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> inputs = processor(images=image, return_tensors="pt")
- >>> image_features = model.get_image_features(**inputs)
- ```"""
- vision_outputs: BaseModelOutputWithPooling = self.vision_model(
- pixel_values=pixel_values,
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=True,
- **kwargs,
- )
- pooled_output = vision_outputs.pooler_output
- vision_outputs.pooler_output = self.visual_projection(pooled_output)
- return vision_outputs
- @auto_docstring
- def get_multimodal_features(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- interpolate_pos_encoding: bool = False,
- ) -> torch.FloatTensor:
- r"""
- Returns:
- multimodal_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The multimodal embeddings
- obtained by applying the image embeddings to the text encoder using the cross-attention mechanism.
- Examples:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, BlipModel
- >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
- >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> texts = ["a photo of a cat", "a photo of a dog"]
- >>> inputs = processor(images=image, text=texts, padding=True, return_tensors="pt")
- >>> multimodal_features = model.get_multimodal_features(**inputs)
- ```"""
- vision_outputs = self.vision_model(
- pixel_values=pixel_values,
- interpolate_pos_encoding=interpolate_pos_encoding,
- )
- image_embeds = vision_outputs[0]
- image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
- text_outputs = self.text_model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- encoder_hidden_states=image_embeds,
- encoder_attention_mask=image_atts,
- )
- pooled_output = text_outputs[1] # pooled_output
- multimodal_features = self.text_projection(pooled_output)
- return multimodal_features
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- return_loss: bool | None = None,
- interpolate_pos_encoding: bool = False,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BlipOutput:
- r"""
- return_loss (`bool`, *optional*):
- Whether or not to return the contrastive loss.
- Examples:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, BlipModel
- >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
- >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> inputs = processor(
- ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
- ... )
- >>> outputs = model(**inputs)
- >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
- >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
- ```"""
- vision_outputs = self.vision_model(
- pixel_values=pixel_values,
- interpolate_pos_encoding=interpolate_pos_encoding,
- **kwargs,
- )
- text_outputs = self.text_model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- **kwargs,
- )
- image_embeds = vision_outputs.pooler_output
- image_embeds = self.visual_projection(image_embeds)
- text_embeds = text_outputs.pooler_output
- text_embeds = self.text_projection(text_embeds)
- # normalized features
- image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
- text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
- # cosine similarity as logits
- logit_scale = self.logit_scale.exp().to(device=text_embeds.device)
- image_embeds = image_embeds.to(device=text_embeds.device, dtype=text_embeds.dtype)
- logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
- logits_per_image = logits_per_text.t()
- loss = None
- if return_loss:
- loss = blip_loss(logits_per_text)
- return BlipOutput(
- loss=loss,
- logits_per_image=logits_per_image,
- logits_per_text=logits_per_text,
- text_embeds=text_embeds,
- image_embeds=image_embeds,
- text_model_output=text_outputs,
- vision_model_output=vision_outputs,
- )
- @auto_docstring(
- custom_intro="""
- BLIP Model for image captioning. The model consists of a vision encoder and a text decoder. One can optionally pass
- `input_ids` to the model, which serve as a text prompt, to make the text decoder continue the prompt. Otherwise,
- the decoder starts generating text from the [BOS] (beginning-of-sequence) token. will start generating the caption
- from the text input. If no text input is provided, the decoder will start with the [BOS] token only.
- """
- )
- class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
- config: BlipConfig
- main_input_name = "pixel_values"
- _tied_weights_keys = {
- "text_decoder.cls.predictions.decoder.bias": "text_decoder.cls.predictions.bias",
- "text_decoder.cls.predictions.decoder.weight": "text_decoder.bert.embeddings.word_embeddings.weight",
- } # TODO @arthurzucker check why we need this when for other models, their subPreTrainedModel handle it themselves.
- def __init__(self, config: BlipConfig):
- super().__init__(config)
- self.vision_model = BlipVisionModel(config.vision_config)
- self.text_decoder = BlipTextLMHeadModel(config.text_config)
- self.decoder_input_ids = config.text_config.bos_token_id
- self.decoder_pad_token_id = config.text_config.pad_token_id
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.text_decoder.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.text_decoder.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- labels: torch.LongTensor | None = None,
- interpolate_pos_encoding: bool = False,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BlipForConditionalGenerationModelOutput:
- r"""
- Examples:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, BlipForConditionalGeneration
- >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
- >>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> text = "A picture of"
- >>> inputs = processor(images=image, text=text, return_tensors="pt")
- >>> outputs = model(**inputs)
- ```"""
- vision_outputs = self.vision_model(
- pixel_values=pixel_values,
- interpolate_pos_encoding=interpolate_pos_encoding,
- **kwargs,
- )
- image_embeds = vision_outputs.last_hidden_state
- outputs = self.text_decoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- encoder_hidden_states=image_embeds,
- labels=labels,
- reduction="mean",
- logits_to_keep=logits_to_keep,
- **kwargs,
- )
- return BlipForConditionalGenerationModelOutput(
- loss=outputs.loss,
- logits=outputs.logits,
- image_embeds=image_embeds,
- last_hidden_state=vision_outputs.last_hidden_state,
- hidden_states=vision_outputs.hidden_states,
- attentions=vision_outputs.attentions,
- )
- @torch.no_grad()
- def generate(
- self,
- pixel_values: torch.FloatTensor,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- interpolate_pos_encoding: bool = False,
- **generate_kwargs,
- ) -> torch.LongTensor:
- r"""
- Overrides *generate* function to be able to use the model as a conditional generator
- Parameters:
- pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
- Input image to be processed
- input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
- The sequence used as a prompt for the generation.
- attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- Examples:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, BlipForConditionalGeneration
- >>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
- >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> inputs = processor(images=image, return_tensors="pt")
- >>> outputs = model.generate(**inputs)
- >>> print(processor.decode(outputs[0], skip_special_tokens=True))
- two cats sleeping on a couch
- ```
- """
- batch_size = pixel_values.shape[0]
- vision_outputs = self.vision_model(
- pixel_values=pixel_values,
- interpolate_pos_encoding=interpolate_pos_encoding,
- )
- image_embeds = vision_outputs[0]
- image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
- if isinstance(input_ids, list):
- input_ids = torch.LongTensor(input_ids)
- elif input_ids is None:
- input_ids = (
- torch.LongTensor([[self.decoder_input_ids, self.config.text_config.eos_token_id]])
- .repeat(batch_size, 1)
- .to(image_embeds.device)
- )
- input_ids[:, 0] = self.config.text_config.bos_token_id
- attention_mask = attention_mask[:, :-1] if attention_mask is not None else None
- outputs = self.text_decoder.generate(
- input_ids=input_ids[:, :-1],
- eos_token_id=self.config.text_config.sep_token_id,
- pad_token_id=self.config.text_config.pad_token_id,
- attention_mask=attention_mask,
- encoder_hidden_states=image_embeds,
- encoder_attention_mask=image_attention_mask,
- **generate_kwargs,
- )
- return outputs
- @auto_docstring(
- custom_intro="""
- BLIP Model for visual question answering. The model consists of a vision encoder, a text encoder as well as a text
- decoder. The vision encoder will encode the input image, the text encoder will encode the input question together
- with the encoding of the image, and the text decoder will output the answer to the question.
- """
- )
- class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
- config: BlipConfig
- _tied_weights_keys = {
- "text_decoder.cls.predictions.decoder.bias": "text_decoder.cls.predictions.bias",
- "text_decoder.cls.predictions.decoder.weight": "text_decoder.bert.embeddings.word_embeddings.weight",
- }
- def __init__(self, config: BlipConfig):
- super().__init__(config)
- self.vision_model = BlipVisionModel(config.vision_config)
- self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
- self.text_decoder = BlipTextLMHeadModel(config.text_config)
- self.decoder_pad_token_id = config.text_config.pad_token_id
- self.decoder_start_token_id = config.text_config.bos_token_id
- # Initialize weights and apply final processing
- self.post_init()
- def set_input_embeddings(self, value):
- self.text_encoder.set_input_embeddings(value)
- def get_input_embeddings(self):
- # This will return shared embeddings if they are shared else specific to encoder.
- return self.text_encoder.get_input_embeddings()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor,
- pixel_values: torch.FloatTensor,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.LongTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- labels: torch.LongTensor | None = None,
- interpolate_pos_encoding: bool = False,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BlipTextVisionModelOutput:
- r"""
- Examples:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, BlipForQuestionAnswering
- >>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
- >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> # training
- >>> text = "How many cats are in the picture?"
- >>> label = "2"
- >>> inputs = processor(images=image, text=text, return_tensors="pt")
- >>> labels = processor(text=label, return_tensors="pt").input_ids
- >>> inputs["labels"] = labels
- >>> outputs = model(**inputs)
- >>> loss = outputs.loss
- >>> loss.backward()
- >>> # inference
- >>> text = "How many cats are in the picture?"
- >>> inputs = processor(images=image, text=text, return_tensors="pt")
- >>> outputs = model.generate(**inputs)
- >>> print(processor.decode(outputs[0], skip_special_tokens=True))
- 2
- ```"""
- if labels is None and decoder_input_ids is None:
- raise ValueError(
- "Either `decoder_input_ids` or `labels` should be passed when calling `forward` with"
- " `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you"
- " are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`"
- )
- vision_outputs = self.vision_model(
- pixel_values=pixel_values,
- interpolate_pos_encoding=interpolate_pos_encoding,
- **kwargs,
- )
- image_embeds = vision_outputs.last_hidden_state
- image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
- question_embeds = self.text_encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- encoder_hidden_states=image_embeds,
- encoder_attention_mask=image_attention_mask,
- **kwargs,
- )
- if labels is not None and decoder_input_ids is None:
- # labels are already shifted right, see: https://github.com/huggingface/transformers/pull/23153
- decoder_input_ids = labels
- question_embeds = question_embeds[0]
- answer_output = self.text_decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- encoder_hidden_states=question_embeds,
- encoder_attention_mask=attention_mask,
- labels=labels,
- reduction="mean",
- **kwargs,
- )
- if labels is not None:
- decoder_loss = answer_output.loss.mean()
- else:
- decoder_loss = None
- return BlipTextVisionModelOutput(
- loss=decoder_loss,
- image_embeds=image_embeds,
- last_hidden_state=vision_outputs.last_hidden_state,
- hidden_states=vision_outputs.hidden_states,
- attentions=vision_outputs.attentions,
- )
- @torch.no_grad()
- def generate(
- self,
- input_ids: torch.LongTensor,
- pixel_values: torch.FloatTensor,
- attention_mask: torch.LongTensor | None = None,
- interpolate_pos_encoding: bool = False,
- **generate_kwargs,
- ) -> torch.LongTensor:
- r"""
- Overrides *generate* function to be able to use the model as a conditional generator
- Parameters:
- input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*):
- The sequence used as a prompt for the generation.
- pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
- Input image to be processed
- attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for
- tokens that are NOT MASKED, `0` for MASKED tokens.
- **generate_kwargs:
- Additional arguments passed to the *generate* function of the decoder
- Examples:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, BlipForQuestionAnswering
- >>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
- >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> text = "How many cats are in the picture?"
- >>> inputs = processor(images=image, text=text, return_tensors="pt")
- >>> outputs = model.generate(**inputs)
- >>> print(processor.decode(outputs[0], skip_special_tokens=True))
- 2
- ```
- """
- vision_outputs = self.vision_model(
- pixel_values=pixel_values,
- interpolate_pos_encoding=interpolate_pos_encoding,
- )
- image_embeds = vision_outputs[0]
- image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
- if isinstance(input_ids, list):
- input_ids = torch.LongTensor(input_ids)
- question_outputs = self.text_encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- encoder_hidden_states=image_embeds,
- encoder_attention_mask=image_attention_mask,
- return_dict=False,
- )
- question_embeds = question_outputs[0]
- question_attention_mask = torch.ones(
- question_embeds.size()[:-1], dtype=torch.long, device=question_embeds.device
- )
- bos_ids = torch.full(
- (question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device
- )
- outputs = self.text_decoder.generate(
- input_ids=bos_ids,
- eos_token_id=self.config.text_config.sep_token_id,
- pad_token_id=self.config.text_config.pad_token_id,
- encoder_hidden_states=question_embeds,
- encoder_attention_mask=question_attention_mask,
- **generate_kwargs,
- )
- return outputs
- @auto_docstring(
- custom_intro="""
- BLIP Model with a vision and text projector, and a classification head on top. The model is used in the context of
- image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to
- the image.
- """
- )
- class BlipForImageTextRetrieval(BlipPreTrainedModel):
- config: BlipConfig
- def __init__(self, config: BlipConfig):
- super().__init__(config)
- self.vision_model = BlipVisionModel(config.vision_config)
- self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
- # vision projection layer
- self.vision_proj = nn.Linear(config.vision_config.hidden_size, config.image_text_hidden_size)
- # text projection layer
- self.text_proj = nn.Linear(config.text_config.hidden_size, config.image_text_hidden_size)
- # image text matching head
- self.itm_head = nn.Linear(config.text_config.hidden_size, 2)
- self.decoder_pad_token_id = (
- config.text_config.pad_token_id
- if not hasattr(config, "decoder_pad_token_id")
- else config.decoder_pad_token_id
- )
- self.decoder_start_token_id = (
- config.text_config.bos_token_id
- if not hasattr(config, "decoder_start_token_id")
- else config.decoder_start_token_id
- )
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.text_encoder.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.text_encoder.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor,
- pixel_values: torch.FloatTensor,
- use_itm_head: bool | None = True,
- attention_mask: torch.LongTensor | None = None,
- interpolate_pos_encoding: bool = False,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BlipTextVisionModelOutput:
- r"""
- use_itm_head (`bool`, *optional*, defaults to `True`):
- Whether or not to use the image-text matching head.
- Examples:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, BlipForImageTextRetrieval
- >>> model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
- >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> text = "an image of a cat"
- >>> inputs = processor(images=image, text=text, return_tensors="pt")
- >>> outputs = model(**inputs)
- ```
- """
- vision_outputs = self.vision_model(
- pixel_values=pixel_values,
- interpolate_pos_encoding=interpolate_pos_encoding,
- **kwargs,
- )
- image_embeds = vision_outputs.last_hidden_state
- image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
- if use_itm_head:
- question_embeds = self.text_encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- encoder_hidden_states=image_embeds,
- encoder_attention_mask=image_atts,
- **kwargs,
- )
- question_embeds = question_embeds.last_hidden_state
- output = self.itm_head(question_embeds[:, 0, :])
- else:
- question_embeds = self.text_encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- **kwargs,
- )
- question_embeds = question_embeds.last_hidden_state
- image_feat = normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
- text_feat = normalize(self.text_proj(question_embeds[:, 0, :]), dim=-1)
- output = image_feat @ text_feat.t()
- return BlipImageTextMatchingModelOutput(
- itm_score=output,
- last_hidden_state=vision_outputs.last_hidden_state,
- hidden_states=vision_outputs.hidden_states,
- attentions=vision_outputs.attentions,
- question_embeds=question_embeds,
- )
- __all__ = [
- "BlipModel",
- "BlipPreTrainedModel",
- "BlipForConditionalGeneration",
- "BlipForQuestionAnswering",
- "BlipVisionModel",
- "BlipTextModel",
- "BlipForImageTextRetrieval",
- ]
|