modeling_blip.py 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302
  1. # Copyright 2022 The Salesforce Team Authors and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch BLIP model."""
  15. from dataclasses import dataclass
  16. from typing import Any
  17. import torch
  18. from torch import nn
  19. from torch.nn.functional import normalize
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...generation import GenerationMixin
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import (
  25. BaseModelOutput,
  26. BaseModelOutputWithPooling,
  27. BaseModelOutputWithPoolingAndCrossAttentions,
  28. )
  29. from ...modeling_utils import PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
  32. from ...utils.generic import merge_with_config_defaults
  33. from ...utils.output_capturing import capture_outputs
  34. from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
  35. from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
  36. logger = logging.get_logger(__name__)
  37. # Copied from transformers.models.clip.modeling_clip.contrastive_loss
  38. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  39. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
  40. # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->blip
  41. def blip_loss(similarity: torch.Tensor) -> torch.Tensor:
  42. caption_loss = contrastive_loss(similarity)
  43. image_loss = contrastive_loss(similarity.t())
  44. return (caption_loss + image_loss) / 2.0
  45. @dataclass
  46. @auto_docstring(
  47. custom_intro="""
  48. Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
  49. last hidden states. This class also adds the loss term from the text decoder.
  50. """
  51. )
  52. class BlipForConditionalGenerationModelOutput(ModelOutput):
  53. r"""
  54. loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  55. Language modeling loss from the text decoder.
  56. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
  57. Prediction scores of the language modeling head of the text decoder model.
  58. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*):
  59. The image embeddings obtained after applying the Vision Transformer model to the input image.
  60. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
  61. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  62. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  63. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  64. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
  65. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  66. sequence_length)`.
  67. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  68. heads.
  69. """
  70. loss: tuple[torch.FloatTensor] | None = None
  71. logits: tuple[torch.FloatTensor] | None = None
  72. image_embeds: torch.FloatTensor | None = None
  73. last_hidden_state: torch.FloatTensor | None = None
  74. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  75. attentions: tuple[torch.FloatTensor, ...] | None = None
  76. @dataclass
  77. @auto_docstring(
  78. custom_intro="""
  79. Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
  80. last hidden states. This class also adds the loss term from the text decoder.
  81. """
  82. )
  83. class BlipTextVisionModelOutput(ModelOutput):
  84. r"""
  85. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  86. Language modeling loss from the text decoder.
  87. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  88. The image embeddings obtained by applying the projection layer to the pooler_output.
  89. """
  90. loss: torch.FloatTensor | None = None
  91. image_embeds: torch.FloatTensor | None = None
  92. last_hidden_state: torch.FloatTensor | None = None
  93. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  94. attentions: tuple[torch.FloatTensor, ...] | None = None
  95. @dataclass
  96. @auto_docstring(
  97. custom_intro="""
  98. Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
  99. last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity
  100. scores.
  101. """
  102. )
  103. class BlipImageTextMatchingModelOutput(ModelOutput):
  104. r"""
  105. itm_score (`torch.FloatTensor`):
  106. The image-text similarity scores.
  107. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  108. Language modeling loss from the text decoder.
  109. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  110. The image embeddings obtained by applying the projection layer to the pooler_output.
  111. vision_pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*):
  112. Last layer hidden-state of the vision of the vision-only branch of the model.
  113. question_embeds (`torch.FloatTensor`):
  114. The question embeddings obtained by the text projection layer.
  115. """
  116. itm_score: torch.FloatTensor | None = None
  117. loss: torch.FloatTensor | None = None
  118. image_embeds: torch.FloatTensor | None = None
  119. last_hidden_state: torch.FloatTensor | None = None
  120. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  121. vision_pooler_output: torch.FloatTensor | None = None
  122. attentions: tuple[torch.FloatTensor, ...] | None = None
  123. question_embeds: tuple[torch.FloatTensor] | None = None
  124. @dataclass
  125. @auto_docstring
  126. class BlipOutput(ModelOutput):
  127. r"""
  128. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  129. Contrastive loss for image-text similarity.
  130. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  131. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  132. similarity scores.
  133. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  134. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  135. similarity scores.
  136. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  137. The text embeddings obtained by applying the projection layer to the pooled output of [`BlipTextModel`].
  138. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  139. The image embeddings obtained by applying the projection layer to the pooled output of [`BlipVisionModel`].
  140. text_model_output (`BaseModelOutputWithPooling`):
  141. The output of the [`BlipTextModel`].
  142. vision_model_output (`BaseModelOutputWithPooling`):
  143. The output of the [`BlipVisionModel`].
  144. """
  145. loss: torch.FloatTensor | None = None
  146. logits_per_image: torch.FloatTensor | None = None
  147. logits_per_text: torch.FloatTensor | None = None
  148. text_embeds: torch.FloatTensor | None = None
  149. image_embeds: torch.FloatTensor | None = None
  150. text_model_output: BaseModelOutputWithPooling = None
  151. vision_model_output: BaseModelOutputWithPooling = None
  152. def to_tuple(self) -> tuple[Any]:
  153. return tuple(
  154. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  155. for k in self.keys()
  156. )
  157. class BlipVisionEmbeddings(nn.Module):
  158. def __init__(self, config: BlipVisionConfig):
  159. super().__init__()
  160. self.config = config
  161. self.embed_dim = config.hidden_size
  162. self.image_size = config.image_size
  163. self.patch_size = config.patch_size
  164. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  165. self.patch_embedding = nn.Conv2d(
  166. in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
  167. )
  168. self.num_patches = (self.image_size // self.patch_size) ** 2
  169. self.num_positions = self.num_patches + 1
  170. self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
  171. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  172. """
  173. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  174. images. This method is also adapted to support torch.jit tracing.
  175. Adapted from:
  176. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  177. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  178. """
  179. num_patches = embeddings.shape[1] - 1
  180. num_positions = self.position_embedding.shape[1] - 1
  181. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  182. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  183. return self.position_embedding
  184. class_pos_embed = self.position_embedding[:, :1]
  185. patch_pos_embed = self.position_embedding[:, 1:]
  186. dim = embeddings.shape[-1]
  187. new_height = height // self.patch_size
  188. new_width = width // self.patch_size
  189. sqrt_num_positions = torch_int(num_positions**0.5)
  190. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  191. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  192. patch_pos_embed = nn.functional.interpolate(
  193. patch_pos_embed,
  194. size=(new_height, new_width),
  195. mode="bicubic",
  196. align_corners=False,
  197. )
  198. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  199. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  200. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  201. batch_size, _, height, width = pixel_values.shape
  202. target_dtype = self.patch_embedding.weight.dtype
  203. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  204. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  205. class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
  206. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  207. if interpolate_pos_encoding:
  208. position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
  209. else:
  210. position_embedding = self.position_embedding
  211. embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
  212. return embeddings
  213. # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Blip
  214. class BlipTextEmbeddings(nn.Module):
  215. def __init__(self, config: BlipTextConfig):
  216. super().__init__()
  217. embed_dim = config.hidden_size
  218. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  219. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  220. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  221. self.register_buffer(
  222. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  223. )
  224. def forward(
  225. self,
  226. input_ids: torch.LongTensor | None = None,
  227. position_ids: torch.LongTensor | None = None,
  228. inputs_embeds: torch.FloatTensor | None = None,
  229. ) -> torch.Tensor:
  230. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  231. max_position_embedding = self.position_embedding.weight.shape[0]
  232. if seq_length > max_position_embedding:
  233. raise ValueError(
  234. f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
  235. f"{seq_length} and max_position_embeddings: {max_position_embedding}"
  236. )
  237. if position_ids is None:
  238. position_ids = self.position_ids[:, :seq_length]
  239. if inputs_embeds is None:
  240. inputs_embeds = self.token_embedding(input_ids)
  241. position_embeddings = self.position_embedding(position_ids)
  242. embeddings = inputs_embeds + position_embeddings
  243. return embeddings
  244. class BlipAttention(nn.Module):
  245. """Multi-headed attention from 'Attention Is All You Need' paper"""
  246. def __init__(self, config):
  247. super().__init__()
  248. self.config = config
  249. self.embed_dim = config.hidden_size
  250. self.num_heads = config.num_attention_heads
  251. self.head_dim = self.embed_dim // self.num_heads
  252. if self.head_dim * self.num_heads != self.embed_dim:
  253. raise ValueError(
  254. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  255. f" {self.num_heads})."
  256. )
  257. self.scale = self.head_dim**-0.5
  258. self.dropout = nn.Dropout(config.attention_dropout)
  259. self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim)
  260. self.projection = nn.Linear(self.embed_dim, self.embed_dim)
  261. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  262. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  263. def forward(
  264. self,
  265. hidden_states: torch.Tensor,
  266. **kwargs: Unpack[TransformersKwargs],
  267. ) -> tuple[torch.Tensor, torch.Tensor]:
  268. """Input shape: Batch x Time x Channel"""
  269. bsz, tgt_len, embed_dim = hidden_states.size()
  270. mixed_qkv = (
  271. self.qkv(hidden_states)
  272. .reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads)
  273. .permute(2, 0, 3, 1, 4)
  274. )
  275. query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
  276. # Take the dot product between "query" and "key" to get the raw attention scores.
  277. attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
  278. attention_scores = attention_scores * self.scale
  279. # Normalize the attention scores to probabilities.
  280. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  281. # This is actually dropping out entire tokens to attend to, which might
  282. # seem a bit unusual, but is taken from the original Transformer paper.
  283. attention_probs = self.dropout(attention_probs)
  284. context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
  285. new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
  286. context_layer = context_layer.reshape(new_context_layer_shape)
  287. output = self.projection(context_layer)
  288. return output, attention_probs
  289. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Blip
  290. class BlipMLP(nn.Module):
  291. def __init__(self, config):
  292. super().__init__()
  293. self.config = config
  294. self.activation_fn = ACT2FN[config.hidden_act]
  295. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  296. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  297. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  298. hidden_states = self.fc1(hidden_states)
  299. hidden_states = self.activation_fn(hidden_states)
  300. hidden_states = self.fc2(hidden_states)
  301. return hidden_states
  302. class BlipEncoderLayer(GradientCheckpointingLayer):
  303. def __init__(self, config: BlipConfig):
  304. super().__init__()
  305. self.embed_dim = config.hidden_size
  306. self.self_attn = BlipAttention(config)
  307. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  308. self.mlp = BlipMLP(config)
  309. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  310. @auto_docstring
  311. def forward(
  312. self,
  313. hidden_states: torch.Tensor,
  314. **kwargs: Unpack[TransformersKwargs],
  315. ) -> torch.FloatTensor:
  316. residual = hidden_states
  317. hidden_states = self.layer_norm1(hidden_states)
  318. hidden_states, _ = self.self_attn(
  319. hidden_states=hidden_states,
  320. **kwargs,
  321. )
  322. hidden_states = hidden_states + residual
  323. residual = hidden_states
  324. hidden_states = self.layer_norm2(hidden_states)
  325. hidden_states = self.mlp(hidden_states)
  326. hidden_states = hidden_states + residual
  327. return hidden_states
  328. @auto_docstring
  329. class BlipPreTrainedModel(PreTrainedModel):
  330. config: BlipConfig
  331. base_model_prefix = "blip"
  332. input_modalities = ("image", "text")
  333. supports_gradient_checkpointing = True
  334. _no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"]
  335. _skip_keys_device_placement = ["past_key_values"]
  336. @torch.no_grad()
  337. def _init_weights(self, module):
  338. """Initialize the weights"""
  339. super()._init_weights(module)
  340. std = self.config.initializer_range
  341. if isinstance(module, BlipVisionEmbeddings):
  342. if hasattr(self.config, "vision_config"):
  343. std = self.config.vision_config.initializer_range
  344. init.trunc_normal_(module.position_embedding, mean=0.0, std=std)
  345. init.trunc_normal_(module.class_embedding, mean=0.0, std=std)
  346. elif isinstance(module, BlipTextEmbeddings):
  347. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  348. class BlipEncoder(nn.Module):
  349. """
  350. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  351. [`BlipEncoderLayer`].
  352. Args:
  353. config (`BlipConfig`):
  354. The corresponding vision configuration for the `BlipEncoder`.
  355. """
  356. def __init__(self, config: BlipConfig):
  357. super().__init__()
  358. self.config = config
  359. self.layers = nn.ModuleList([BlipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  360. self.gradient_checkpointing = False
  361. @auto_docstring
  362. def forward(
  363. self,
  364. inputs_embeds,
  365. **kwargs: Unpack[TransformersKwargs],
  366. ) -> tuple | BaseModelOutput:
  367. hidden_states = inputs_embeds
  368. for encoder_layer in self.layers:
  369. hidden_states = encoder_layer(
  370. hidden_states,
  371. **kwargs,
  372. )
  373. return BaseModelOutput(last_hidden_state=hidden_states)
  374. class BlipVisionModel(BlipPreTrainedModel):
  375. main_input_name = "pixel_values"
  376. input_modalities = ("image",)
  377. config: BlipVisionConfig
  378. _can_record_outputs = {
  379. "hidden_states": BlipEncoderLayer,
  380. "attentions": BlipAttention,
  381. }
  382. def __init__(self, config: BlipVisionConfig):
  383. super().__init__(config)
  384. self.config = config
  385. embed_dim = config.hidden_size
  386. self.embeddings = BlipVisionEmbeddings(config)
  387. self.encoder = BlipEncoder(config)
  388. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  389. self.post_init()
  390. @merge_with_config_defaults
  391. @capture_outputs(tie_last_hidden_states=False)
  392. @auto_docstring
  393. def forward(
  394. self,
  395. pixel_values: torch.FloatTensor | None = None,
  396. interpolate_pos_encoding: bool = False,
  397. **kwargs: Unpack[TransformersKwargs],
  398. ) -> tuple | BaseModelOutputWithPooling:
  399. if pixel_values is None:
  400. raise ValueError("You have to specify pixel_values")
  401. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  402. encoder_outputs: BaseModelOutput = self.encoder(
  403. inputs_embeds=hidden_states,
  404. **kwargs,
  405. )
  406. last_hidden_state = encoder_outputs.last_hidden_state
  407. last_hidden_state = self.post_layernorm(last_hidden_state)
  408. pooled_output = last_hidden_state[:, 0, :]
  409. pooled_output = self.post_layernorm(pooled_output)
  410. return BaseModelOutputWithPooling(
  411. last_hidden_state=last_hidden_state,
  412. pooler_output=pooled_output,
  413. )
  414. def get_input_embeddings(self):
  415. return self.embeddings
  416. @auto_docstring(
  417. custom_intro="""
  418. This model is going to be deprecated in future versions. Please use `BlipForConditionalGeneration`, `BlipForQuestionAnswering` or `BlipForImageTextRetrieval` depending on your usecase.
  419. """
  420. )
  421. class BlipModel(BlipPreTrainedModel):
  422. config: BlipConfig
  423. def __init__(self, config: BlipConfig):
  424. super().__init__(config)
  425. if not isinstance(config.text_config, BlipTextConfig):
  426. raise TypeError(
  427. "config.text_config is expected to be of type BlipTextConfig but is of type"
  428. f" {type(config.text_config)}."
  429. )
  430. if not isinstance(config.vision_config, BlipVisionConfig):
  431. raise TypeError(
  432. "config.vision_config is expected to be of type BlipVisionConfig but is of type"
  433. f" {type(config.vision_config)}."
  434. )
  435. text_config = config.text_config
  436. vision_config = config.vision_config
  437. self.projection_dim = config.projection_dim
  438. self.text_embed_dim = text_config.hidden_size
  439. self.vision_embed_dim = vision_config.hidden_size
  440. self.text_model = BlipTextModel(text_config)
  441. self.vision_model = BlipVisionModel(vision_config)
  442. self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
  443. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
  444. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  445. logger.warning(
  446. "`BlipModel` is going to be deprecated in future release, please use `BlipForConditionalGeneration`, `BlipForQuestionAnswering` or `BlipForImageTextRetrieval` depending on your usecase."
  447. )
  448. # Initialize weights and apply final processing
  449. self.post_init()
  450. def get_input_embeddings(self):
  451. return self.text_model.get_input_embeddings()
  452. def set_input_embeddings(self, value):
  453. self.text_model.set_input_embeddings(value)
  454. @can_return_tuple
  455. @auto_docstring
  456. def get_text_features(
  457. self,
  458. input_ids: torch.Tensor | None = None,
  459. attention_mask: torch.Tensor | None = None,
  460. position_ids: torch.Tensor | None = None,
  461. **kwargs: Unpack[TransformersKwargs],
  462. ) -> tuple | BaseModelOutputWithPooling:
  463. r"""
  464. Examples:
  465. ```python
  466. >>> from transformers import AutoProcessor, BlipModel
  467. >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
  468. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  469. >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  470. >>> text_features = model.get_text_features(**inputs)
  471. ```"""
  472. text_outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.text_model(
  473. input_ids=input_ids,
  474. attention_mask=attention_mask,
  475. position_ids=position_ids,
  476. return_dict=True,
  477. **kwargs,
  478. )
  479. pooled_output = text_outputs.pooler_output
  480. text_outputs.pooler_output = self.text_projection(pooled_output)
  481. return text_outputs
  482. @can_return_tuple
  483. @auto_docstring
  484. def get_image_features(
  485. self,
  486. pixel_values: torch.FloatTensor | None = None,
  487. interpolate_pos_encoding: bool = False,
  488. **kwargs: Unpack[TransformersKwargs],
  489. ) -> tuple | BaseModelOutputWithPooling:
  490. r"""
  491. Examples:
  492. ```python
  493. >>> from PIL import Image
  494. >>> import httpx
  495. >>> from io import BytesIO
  496. >>> from transformers import AutoProcessor, BlipModel
  497. >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
  498. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  499. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  500. >>> with httpx.stream("GET", url) as response:
  501. ... image = Image.open(BytesIO(response.read()))
  502. >>> inputs = processor(images=image, return_tensors="pt")
  503. >>> image_features = model.get_image_features(**inputs)
  504. ```"""
  505. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  506. pixel_values=pixel_values,
  507. interpolate_pos_encoding=interpolate_pos_encoding,
  508. return_dict=True,
  509. **kwargs,
  510. )
  511. pooled_output = vision_outputs.pooler_output
  512. vision_outputs.pooler_output = self.visual_projection(pooled_output)
  513. return vision_outputs
  514. @auto_docstring
  515. def get_multimodal_features(
  516. self,
  517. input_ids: torch.LongTensor | None = None,
  518. pixel_values: torch.FloatTensor | None = None,
  519. attention_mask: torch.Tensor | None = None,
  520. interpolate_pos_encoding: bool = False,
  521. ) -> torch.FloatTensor:
  522. r"""
  523. Returns:
  524. multimodal_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The multimodal embeddings
  525. obtained by applying the image embeddings to the text encoder using the cross-attention mechanism.
  526. Examples:
  527. ```python
  528. >>> from PIL import Image
  529. >>> import httpx
  530. >>> from io import BytesIO
  531. >>> from transformers import AutoProcessor, BlipModel
  532. >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
  533. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  534. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  535. >>> with httpx.stream("GET", url) as response:
  536. ... image = Image.open(BytesIO(response.read()))
  537. >>> texts = ["a photo of a cat", "a photo of a dog"]
  538. >>> inputs = processor(images=image, text=texts, padding=True, return_tensors="pt")
  539. >>> multimodal_features = model.get_multimodal_features(**inputs)
  540. ```"""
  541. vision_outputs = self.vision_model(
  542. pixel_values=pixel_values,
  543. interpolate_pos_encoding=interpolate_pos_encoding,
  544. )
  545. image_embeds = vision_outputs[0]
  546. image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
  547. text_outputs = self.text_model(
  548. input_ids=input_ids,
  549. attention_mask=attention_mask,
  550. encoder_hidden_states=image_embeds,
  551. encoder_attention_mask=image_atts,
  552. )
  553. pooled_output = text_outputs[1] # pooled_output
  554. multimodal_features = self.text_projection(pooled_output)
  555. return multimodal_features
  556. @can_return_tuple
  557. @auto_docstring
  558. def forward(
  559. self,
  560. input_ids: torch.LongTensor | None = None,
  561. pixel_values: torch.FloatTensor | None = None,
  562. attention_mask: torch.Tensor | None = None,
  563. position_ids: torch.LongTensor | None = None,
  564. return_loss: bool | None = None,
  565. interpolate_pos_encoding: bool = False,
  566. **kwargs: Unpack[TransformersKwargs],
  567. ) -> tuple | BlipOutput:
  568. r"""
  569. return_loss (`bool`, *optional*):
  570. Whether or not to return the contrastive loss.
  571. Examples:
  572. ```python
  573. >>> from PIL import Image
  574. >>> import httpx
  575. >>> from io import BytesIO
  576. >>> from transformers import AutoProcessor, BlipModel
  577. >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
  578. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  579. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  580. >>> with httpx.stream("GET", url) as response:
  581. ... image = Image.open(BytesIO(response.read()))
  582. >>> inputs = processor(
  583. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
  584. ... )
  585. >>> outputs = model(**inputs)
  586. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  587. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  588. ```"""
  589. vision_outputs = self.vision_model(
  590. pixel_values=pixel_values,
  591. interpolate_pos_encoding=interpolate_pos_encoding,
  592. **kwargs,
  593. )
  594. text_outputs = self.text_model(
  595. input_ids=input_ids,
  596. attention_mask=attention_mask,
  597. position_ids=position_ids,
  598. **kwargs,
  599. )
  600. image_embeds = vision_outputs.pooler_output
  601. image_embeds = self.visual_projection(image_embeds)
  602. text_embeds = text_outputs.pooler_output
  603. text_embeds = self.text_projection(text_embeds)
  604. # normalized features
  605. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  606. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  607. # cosine similarity as logits
  608. logit_scale = self.logit_scale.exp().to(device=text_embeds.device)
  609. image_embeds = image_embeds.to(device=text_embeds.device, dtype=text_embeds.dtype)
  610. logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
  611. logits_per_image = logits_per_text.t()
  612. loss = None
  613. if return_loss:
  614. loss = blip_loss(logits_per_text)
  615. return BlipOutput(
  616. loss=loss,
  617. logits_per_image=logits_per_image,
  618. logits_per_text=logits_per_text,
  619. text_embeds=text_embeds,
  620. image_embeds=image_embeds,
  621. text_model_output=text_outputs,
  622. vision_model_output=vision_outputs,
  623. )
  624. @auto_docstring(
  625. custom_intro="""
  626. BLIP Model for image captioning. The model consists of a vision encoder and a text decoder. One can optionally pass
  627. `input_ids` to the model, which serve as a text prompt, to make the text decoder continue the prompt. Otherwise,
  628. the decoder starts generating text from the [BOS] (beginning-of-sequence) token. will start generating the caption
  629. from the text input. If no text input is provided, the decoder will start with the [BOS] token only.
  630. """
  631. )
  632. class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
  633. config: BlipConfig
  634. main_input_name = "pixel_values"
  635. _tied_weights_keys = {
  636. "text_decoder.cls.predictions.decoder.bias": "text_decoder.cls.predictions.bias",
  637. "text_decoder.cls.predictions.decoder.weight": "text_decoder.bert.embeddings.word_embeddings.weight",
  638. } # TODO @arthurzucker check why we need this when for other models, their subPreTrainedModel handle it themselves.
  639. def __init__(self, config: BlipConfig):
  640. super().__init__(config)
  641. self.vision_model = BlipVisionModel(config.vision_config)
  642. self.text_decoder = BlipTextLMHeadModel(config.text_config)
  643. self.decoder_input_ids = config.text_config.bos_token_id
  644. self.decoder_pad_token_id = config.text_config.pad_token_id
  645. # Initialize weights and apply final processing
  646. self.post_init()
  647. def get_input_embeddings(self):
  648. return self.text_decoder.get_input_embeddings()
  649. def set_input_embeddings(self, value):
  650. self.text_decoder.set_input_embeddings(value)
  651. @can_return_tuple
  652. @auto_docstring
  653. def forward(
  654. self,
  655. pixel_values: torch.FloatTensor,
  656. input_ids: torch.LongTensor | None = None,
  657. attention_mask: torch.LongTensor | None = None,
  658. labels: torch.LongTensor | None = None,
  659. interpolate_pos_encoding: bool = False,
  660. logits_to_keep: int | torch.Tensor = 0,
  661. **kwargs: Unpack[TransformersKwargs],
  662. ) -> tuple | BlipForConditionalGenerationModelOutput:
  663. r"""
  664. Examples:
  665. ```python
  666. >>> from PIL import Image
  667. >>> import httpx
  668. >>> from io import BytesIO
  669. >>> from transformers import AutoProcessor, BlipForConditionalGeneration
  670. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  671. >>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
  672. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  673. >>> with httpx.stream("GET", url) as response:
  674. ... image = Image.open(BytesIO(response.read()))
  675. >>> text = "A picture of"
  676. >>> inputs = processor(images=image, text=text, return_tensors="pt")
  677. >>> outputs = model(**inputs)
  678. ```"""
  679. vision_outputs = self.vision_model(
  680. pixel_values=pixel_values,
  681. interpolate_pos_encoding=interpolate_pos_encoding,
  682. **kwargs,
  683. )
  684. image_embeds = vision_outputs.last_hidden_state
  685. outputs = self.text_decoder(
  686. input_ids=input_ids,
  687. attention_mask=attention_mask,
  688. encoder_hidden_states=image_embeds,
  689. labels=labels,
  690. reduction="mean",
  691. logits_to_keep=logits_to_keep,
  692. **kwargs,
  693. )
  694. return BlipForConditionalGenerationModelOutput(
  695. loss=outputs.loss,
  696. logits=outputs.logits,
  697. image_embeds=image_embeds,
  698. last_hidden_state=vision_outputs.last_hidden_state,
  699. hidden_states=vision_outputs.hidden_states,
  700. attentions=vision_outputs.attentions,
  701. )
  702. @torch.no_grad()
  703. def generate(
  704. self,
  705. pixel_values: torch.FloatTensor,
  706. input_ids: torch.LongTensor | None = None,
  707. attention_mask: torch.LongTensor | None = None,
  708. interpolate_pos_encoding: bool = False,
  709. **generate_kwargs,
  710. ) -> torch.LongTensor:
  711. r"""
  712. Overrides *generate* function to be able to use the model as a conditional generator
  713. Parameters:
  714. pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
  715. Input image to be processed
  716. input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
  717. The sequence used as a prompt for the generation.
  718. attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
  719. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  720. Examples:
  721. ```python
  722. >>> from PIL import Image
  723. >>> import httpx
  724. >>> from io import BytesIO
  725. >>> from transformers import AutoProcessor, BlipForConditionalGeneration
  726. >>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
  727. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  728. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  729. >>> with httpx.stream("GET", url) as response:
  730. ... image = Image.open(BytesIO(response.read()))
  731. >>> inputs = processor(images=image, return_tensors="pt")
  732. >>> outputs = model.generate(**inputs)
  733. >>> print(processor.decode(outputs[0], skip_special_tokens=True))
  734. two cats sleeping on a couch
  735. ```
  736. """
  737. batch_size = pixel_values.shape[0]
  738. vision_outputs = self.vision_model(
  739. pixel_values=pixel_values,
  740. interpolate_pos_encoding=interpolate_pos_encoding,
  741. )
  742. image_embeds = vision_outputs[0]
  743. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  744. if isinstance(input_ids, list):
  745. input_ids = torch.LongTensor(input_ids)
  746. elif input_ids is None:
  747. input_ids = (
  748. torch.LongTensor([[self.decoder_input_ids, self.config.text_config.eos_token_id]])
  749. .repeat(batch_size, 1)
  750. .to(image_embeds.device)
  751. )
  752. input_ids[:, 0] = self.config.text_config.bos_token_id
  753. attention_mask = attention_mask[:, :-1] if attention_mask is not None else None
  754. outputs = self.text_decoder.generate(
  755. input_ids=input_ids[:, :-1],
  756. eos_token_id=self.config.text_config.sep_token_id,
  757. pad_token_id=self.config.text_config.pad_token_id,
  758. attention_mask=attention_mask,
  759. encoder_hidden_states=image_embeds,
  760. encoder_attention_mask=image_attention_mask,
  761. **generate_kwargs,
  762. )
  763. return outputs
  764. @auto_docstring(
  765. custom_intro="""
  766. BLIP Model for visual question answering. The model consists of a vision encoder, a text encoder as well as a text
  767. decoder. The vision encoder will encode the input image, the text encoder will encode the input question together
  768. with the encoding of the image, and the text decoder will output the answer to the question.
  769. """
  770. )
  771. class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
  772. config: BlipConfig
  773. _tied_weights_keys = {
  774. "text_decoder.cls.predictions.decoder.bias": "text_decoder.cls.predictions.bias",
  775. "text_decoder.cls.predictions.decoder.weight": "text_decoder.bert.embeddings.word_embeddings.weight",
  776. }
  777. def __init__(self, config: BlipConfig):
  778. super().__init__(config)
  779. self.vision_model = BlipVisionModel(config.vision_config)
  780. self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
  781. self.text_decoder = BlipTextLMHeadModel(config.text_config)
  782. self.decoder_pad_token_id = config.text_config.pad_token_id
  783. self.decoder_start_token_id = config.text_config.bos_token_id
  784. # Initialize weights and apply final processing
  785. self.post_init()
  786. def set_input_embeddings(self, value):
  787. self.text_encoder.set_input_embeddings(value)
  788. def get_input_embeddings(self):
  789. # This will return shared embeddings if they are shared else specific to encoder.
  790. return self.text_encoder.get_input_embeddings()
  791. @can_return_tuple
  792. @auto_docstring
  793. def forward(
  794. self,
  795. input_ids: torch.LongTensor,
  796. pixel_values: torch.FloatTensor,
  797. decoder_input_ids: torch.LongTensor | None = None,
  798. decoder_attention_mask: torch.LongTensor | None = None,
  799. attention_mask: torch.LongTensor | None = None,
  800. labels: torch.LongTensor | None = None,
  801. interpolate_pos_encoding: bool = False,
  802. **kwargs: Unpack[TransformersKwargs],
  803. ) -> tuple | BlipTextVisionModelOutput:
  804. r"""
  805. Examples:
  806. ```python
  807. >>> from PIL import Image
  808. >>> import httpx
  809. >>> from io import BytesIO
  810. >>> from transformers import AutoProcessor, BlipForQuestionAnswering
  811. >>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
  812. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
  813. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  814. >>> with httpx.stream("GET", url) as response:
  815. ... image = Image.open(BytesIO(response.read()))
  816. >>> # training
  817. >>> text = "How many cats are in the picture?"
  818. >>> label = "2"
  819. >>> inputs = processor(images=image, text=text, return_tensors="pt")
  820. >>> labels = processor(text=label, return_tensors="pt").input_ids
  821. >>> inputs["labels"] = labels
  822. >>> outputs = model(**inputs)
  823. >>> loss = outputs.loss
  824. >>> loss.backward()
  825. >>> # inference
  826. >>> text = "How many cats are in the picture?"
  827. >>> inputs = processor(images=image, text=text, return_tensors="pt")
  828. >>> outputs = model.generate(**inputs)
  829. >>> print(processor.decode(outputs[0], skip_special_tokens=True))
  830. 2
  831. ```"""
  832. if labels is None and decoder_input_ids is None:
  833. raise ValueError(
  834. "Either `decoder_input_ids` or `labels` should be passed when calling `forward` with"
  835. " `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you"
  836. " are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`"
  837. )
  838. vision_outputs = self.vision_model(
  839. pixel_values=pixel_values,
  840. interpolate_pos_encoding=interpolate_pos_encoding,
  841. **kwargs,
  842. )
  843. image_embeds = vision_outputs.last_hidden_state
  844. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
  845. question_embeds = self.text_encoder(
  846. input_ids=input_ids,
  847. attention_mask=attention_mask,
  848. encoder_hidden_states=image_embeds,
  849. encoder_attention_mask=image_attention_mask,
  850. **kwargs,
  851. )
  852. if labels is not None and decoder_input_ids is None:
  853. # labels are already shifted right, see: https://github.com/huggingface/transformers/pull/23153
  854. decoder_input_ids = labels
  855. question_embeds = question_embeds[0]
  856. answer_output = self.text_decoder(
  857. input_ids=decoder_input_ids,
  858. attention_mask=decoder_attention_mask,
  859. encoder_hidden_states=question_embeds,
  860. encoder_attention_mask=attention_mask,
  861. labels=labels,
  862. reduction="mean",
  863. **kwargs,
  864. )
  865. if labels is not None:
  866. decoder_loss = answer_output.loss.mean()
  867. else:
  868. decoder_loss = None
  869. return BlipTextVisionModelOutput(
  870. loss=decoder_loss,
  871. image_embeds=image_embeds,
  872. last_hidden_state=vision_outputs.last_hidden_state,
  873. hidden_states=vision_outputs.hidden_states,
  874. attentions=vision_outputs.attentions,
  875. )
  876. @torch.no_grad()
  877. def generate(
  878. self,
  879. input_ids: torch.LongTensor,
  880. pixel_values: torch.FloatTensor,
  881. attention_mask: torch.LongTensor | None = None,
  882. interpolate_pos_encoding: bool = False,
  883. **generate_kwargs,
  884. ) -> torch.LongTensor:
  885. r"""
  886. Overrides *generate* function to be able to use the model as a conditional generator
  887. Parameters:
  888. input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*):
  889. The sequence used as a prompt for the generation.
  890. pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
  891. Input image to be processed
  892. attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
  893. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for
  894. tokens that are NOT MASKED, `0` for MASKED tokens.
  895. **generate_kwargs:
  896. Additional arguments passed to the *generate* function of the decoder
  897. Examples:
  898. ```python
  899. >>> from PIL import Image
  900. >>> import httpx
  901. >>> from io import BytesIO
  902. >>> from transformers import AutoProcessor, BlipForQuestionAnswering
  903. >>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
  904. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
  905. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  906. >>> with httpx.stream("GET", url) as response:
  907. ... image = Image.open(BytesIO(response.read()))
  908. >>> text = "How many cats are in the picture?"
  909. >>> inputs = processor(images=image, text=text, return_tensors="pt")
  910. >>> outputs = model.generate(**inputs)
  911. >>> print(processor.decode(outputs[0], skip_special_tokens=True))
  912. 2
  913. ```
  914. """
  915. vision_outputs = self.vision_model(
  916. pixel_values=pixel_values,
  917. interpolate_pos_encoding=interpolate_pos_encoding,
  918. )
  919. image_embeds = vision_outputs[0]
  920. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  921. if isinstance(input_ids, list):
  922. input_ids = torch.LongTensor(input_ids)
  923. question_outputs = self.text_encoder(
  924. input_ids=input_ids,
  925. attention_mask=attention_mask,
  926. encoder_hidden_states=image_embeds,
  927. encoder_attention_mask=image_attention_mask,
  928. return_dict=False,
  929. )
  930. question_embeds = question_outputs[0]
  931. question_attention_mask = torch.ones(
  932. question_embeds.size()[:-1], dtype=torch.long, device=question_embeds.device
  933. )
  934. bos_ids = torch.full(
  935. (question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device
  936. )
  937. outputs = self.text_decoder.generate(
  938. input_ids=bos_ids,
  939. eos_token_id=self.config.text_config.sep_token_id,
  940. pad_token_id=self.config.text_config.pad_token_id,
  941. encoder_hidden_states=question_embeds,
  942. encoder_attention_mask=question_attention_mask,
  943. **generate_kwargs,
  944. )
  945. return outputs
  946. @auto_docstring(
  947. custom_intro="""
  948. BLIP Model with a vision and text projector, and a classification head on top. The model is used in the context of
  949. image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to
  950. the image.
  951. """
  952. )
  953. class BlipForImageTextRetrieval(BlipPreTrainedModel):
  954. config: BlipConfig
  955. def __init__(self, config: BlipConfig):
  956. super().__init__(config)
  957. self.vision_model = BlipVisionModel(config.vision_config)
  958. self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
  959. # vision projection layer
  960. self.vision_proj = nn.Linear(config.vision_config.hidden_size, config.image_text_hidden_size)
  961. # text projection layer
  962. self.text_proj = nn.Linear(config.text_config.hidden_size, config.image_text_hidden_size)
  963. # image text matching head
  964. self.itm_head = nn.Linear(config.text_config.hidden_size, 2)
  965. self.decoder_pad_token_id = (
  966. config.text_config.pad_token_id
  967. if not hasattr(config, "decoder_pad_token_id")
  968. else config.decoder_pad_token_id
  969. )
  970. self.decoder_start_token_id = (
  971. config.text_config.bos_token_id
  972. if not hasattr(config, "decoder_start_token_id")
  973. else config.decoder_start_token_id
  974. )
  975. # Initialize weights and apply final processing
  976. self.post_init()
  977. def get_input_embeddings(self):
  978. return self.text_encoder.get_input_embeddings()
  979. def set_input_embeddings(self, value):
  980. self.text_encoder.set_input_embeddings(value)
  981. @can_return_tuple
  982. @auto_docstring
  983. def forward(
  984. self,
  985. input_ids: torch.LongTensor,
  986. pixel_values: torch.FloatTensor,
  987. use_itm_head: bool | None = True,
  988. attention_mask: torch.LongTensor | None = None,
  989. interpolate_pos_encoding: bool = False,
  990. **kwargs: Unpack[TransformersKwargs],
  991. ) -> tuple | BlipTextVisionModelOutput:
  992. r"""
  993. use_itm_head (`bool`, *optional*, defaults to `True`):
  994. Whether or not to use the image-text matching head.
  995. Examples:
  996. ```python
  997. >>> from PIL import Image
  998. >>> import httpx
  999. >>> from io import BytesIO
  1000. >>> from transformers import AutoProcessor, BlipForImageTextRetrieval
  1001. >>> model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
  1002. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
  1003. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1004. >>> with httpx.stream("GET", url) as response:
  1005. ... image = Image.open(BytesIO(response.read()))
  1006. >>> text = "an image of a cat"
  1007. >>> inputs = processor(images=image, text=text, return_tensors="pt")
  1008. >>> outputs = model(**inputs)
  1009. ```
  1010. """
  1011. vision_outputs = self.vision_model(
  1012. pixel_values=pixel_values,
  1013. interpolate_pos_encoding=interpolate_pos_encoding,
  1014. **kwargs,
  1015. )
  1016. image_embeds = vision_outputs.last_hidden_state
  1017. image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
  1018. if use_itm_head:
  1019. question_embeds = self.text_encoder(
  1020. input_ids=input_ids,
  1021. attention_mask=attention_mask,
  1022. encoder_hidden_states=image_embeds,
  1023. encoder_attention_mask=image_atts,
  1024. **kwargs,
  1025. )
  1026. question_embeds = question_embeds.last_hidden_state
  1027. output = self.itm_head(question_embeds[:, 0, :])
  1028. else:
  1029. question_embeds = self.text_encoder(
  1030. input_ids=input_ids,
  1031. attention_mask=attention_mask,
  1032. **kwargs,
  1033. )
  1034. question_embeds = question_embeds.last_hidden_state
  1035. image_feat = normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
  1036. text_feat = normalize(self.text_proj(question_embeds[:, 0, :]), dim=-1)
  1037. output = image_feat @ text_feat.t()
  1038. return BlipImageTextMatchingModelOutput(
  1039. itm_score=output,
  1040. last_hidden_state=vision_outputs.last_hidden_state,
  1041. hidden_states=vision_outputs.hidden_states,
  1042. attentions=vision_outputs.attentions,
  1043. question_embeds=question_embeds,
  1044. )
  1045. __all__ = [
  1046. "BlipModel",
  1047. "BlipPreTrainedModel",
  1048. "BlipForConditionalGeneration",
  1049. "BlipForQuestionAnswering",
  1050. "BlipVisionModel",
  1051. "BlipTextModel",
  1052. "BlipForImageTextRetrieval",
  1053. ]