modeling_instructblip.py 62 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475
  1. # Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch InstructBLIP model."""
  15. import math
  16. from collections.abc import Callable
  17. from dataclasses import dataclass
  18. from typing import Any
  19. import torch
  20. from torch import nn
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...generation import GenerationMixin
  24. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutput,
  28. BaseModelOutputWithPastAndCrossAttentions,
  29. BaseModelOutputWithPooling,
  30. BaseModelOutputWithPoolingAndCrossAttentions,
  31. CausalLMOutputWithPast,
  32. Seq2SeqLMOutput,
  33. )
  34. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  35. from ...processing_utils import Unpack
  36. from ...pytorch_utils import apply_chunking_to_forward
  37. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
  38. from ...utils.generic import merge_with_config_defaults
  39. from ...utils.output_capturing import OutputRecorder, capture_outputs
  40. from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
  41. from .configuration_instructblip import InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig
  42. logger = logging.get_logger(__name__)
  43. @dataclass
  44. @auto_docstring
  45. class BaseModelOutputWithVisionQformerOutputs(BaseModelOutputWithPooling):
  46. r"""
  47. vision_outputs (`BaseModelOutputWithPooling`):
  48. Outputs of the vision encoder.
  49. qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
  50. Outputs of the Q-Former (Querying Transformer).
  51. """
  52. vision_outputs: BaseModelOutputWithPooling | None = None
  53. qformer_outputs: BaseModelOutputWithPoolingAndCrossAttentions | None = None
  54. @dataclass
  55. @auto_docstring(
  56. custom_intro="""
  57. Class defining the outputs of [`InstructBlipForConditionalGeneration`].
  58. """
  59. )
  60. # Copied from transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGenerationModelOutput with Blip2->InstructBlip
  61. class InstructBlipForConditionalGenerationModelOutput(ModelOutput):
  62. r"""
  63. loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  64. Language modeling loss from the language model.
  65. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  66. Prediction scores of the language modeling head of the language model.
  67. vision_outputs (`BaseModelOutputWithPooling`):
  68. Outputs of the vision encoder.
  69. qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
  70. Outputs of the Q-Former (Querying Transformer).
  71. language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
  72. Outputs of the language model.
  73. """
  74. loss: tuple[torch.FloatTensor] | None = None
  75. logits: tuple[torch.FloatTensor] | None = None
  76. vision_outputs: BaseModelOutputWithPooling | None = None
  77. qformer_outputs: BaseModelOutputWithPoolingAndCrossAttentions | None = None
  78. language_model_outputs: CausalLMOutputWithPast | Seq2SeqLMOutput | None = None
  79. def to_tuple(self) -> tuple[Any]:
  80. return tuple(
  81. self[k]
  82. if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
  83. else getattr(self, k).to_tuple()
  84. for k in self.keys()
  85. )
  86. # Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->InstructBlip
  87. class InstructBlipVisionEmbeddings(nn.Module):
  88. def __init__(self, config: InstructBlipVisionConfig):
  89. super().__init__()
  90. self.config = config
  91. self.embed_dim = config.hidden_size
  92. self.image_size = config.image_size
  93. self.patch_size = config.patch_size
  94. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  95. self.patch_embedding = nn.Conv2d(
  96. in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
  97. )
  98. self.num_patches = (self.image_size // self.patch_size) ** 2
  99. self.num_positions = self.num_patches + 1
  100. self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
  101. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  102. """
  103. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  104. images. This method is also adapted to support torch.jit tracing.
  105. Adapted from:
  106. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  107. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  108. """
  109. num_patches = embeddings.shape[1] - 1
  110. num_positions = self.position_embedding.shape[1] - 1
  111. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  112. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  113. return self.position_embedding
  114. class_pos_embed = self.position_embedding[:, :1]
  115. patch_pos_embed = self.position_embedding[:, 1:]
  116. dim = embeddings.shape[-1]
  117. new_height = height // self.patch_size
  118. new_width = width // self.patch_size
  119. sqrt_num_positions = torch_int(num_positions**0.5)
  120. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  121. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  122. patch_pos_embed = nn.functional.interpolate(
  123. patch_pos_embed,
  124. size=(new_height, new_width),
  125. mode="bicubic",
  126. align_corners=False,
  127. )
  128. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  129. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  130. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  131. batch_size, _, height, width = pixel_values.shape
  132. target_dtype = self.patch_embedding.weight.dtype
  133. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  134. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  135. class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
  136. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  137. if interpolate_pos_encoding:
  138. position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
  139. else:
  140. position_embedding = self.position_embedding
  141. embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
  142. return embeddings
  143. # Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBLIP doesn't cast attn weights to fp32
  144. def eager_attention_forward(
  145. module: nn.Module,
  146. query: torch.Tensor,
  147. key: torch.Tensor,
  148. value: torch.Tensor,
  149. attention_mask: torch.Tensor | None,
  150. scaling: float,
  151. dropout: float = 0.0,
  152. **kwargs,
  153. ):
  154. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  155. if attention_mask is not None:
  156. attn_weights = attn_weights + attention_mask
  157. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  158. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  159. attn_output = torch.matmul(attn_weights, value)
  160. attn_output = attn_output.transpose(1, 2).contiguous()
  161. return attn_output, attn_weights
  162. # Copied from transformers.models.blip_2.modeling_blip_2.Blip2Attention with Blip2->InstructBlip
  163. class InstructBlipAttention(nn.Module):
  164. """Multi-headed attention from 'Attention Is All You Need' paper"""
  165. def __init__(self, config):
  166. super().__init__()
  167. self.config = config
  168. self.embed_dim = config.hidden_size
  169. self.num_heads = config.num_attention_heads
  170. self.head_dim = self.embed_dim // self.num_heads
  171. if self.head_dim * self.num_heads != self.embed_dim:
  172. raise ValueError(
  173. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  174. f" {self.num_heads})."
  175. )
  176. self.scale = self.head_dim**-0.5
  177. self.is_causal = False
  178. self.attention_dropout = config.attention_dropout
  179. # small tweak here compared to CLIP, no bias here
  180. self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
  181. if config.qkv_bias:
  182. q_bias = nn.Parameter(torch.zeros(self.embed_dim))
  183. v_bias = nn.Parameter(torch.zeros(self.embed_dim))
  184. else:
  185. q_bias = None
  186. v_bias = None
  187. if q_bias is not None:
  188. qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
  189. self.qkv.bias = nn.Parameter(qkv_bias)
  190. self.projection = nn.Linear(self.embed_dim, self.embed_dim)
  191. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  192. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  193. def forward(
  194. self,
  195. hidden_states: torch.Tensor,
  196. **kwargs,
  197. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  198. """Input shape: Batch x Time x Channel"""
  199. bsz, tgt_len, embed_dim = hidden_states.size()
  200. mixed_qkv = self.qkv(hidden_states)
  201. mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
  202. 2, 0, 3, 1, 4
  203. )
  204. query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
  205. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  206. self.config._attn_implementation, eager_attention_forward
  207. )
  208. attn_output, attn_weights = attention_interface(
  209. self,
  210. query_states,
  211. key_states,
  212. value_states,
  213. attention_mask=None,
  214. dropout=0.0 if not self.training else self.attention_dropout,
  215. scaling=self.scale,
  216. **kwargs,
  217. )
  218. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  219. attn_output = self.projection(attn_output)
  220. return attn_output, attn_weights
  221. # Copied from transformers.models.blip.modeling_blip.BlipMLP
  222. class InstructBlipMLP(nn.Module):
  223. def __init__(self, config):
  224. super().__init__()
  225. self.config = config
  226. self.activation_fn = ACT2FN[config.hidden_act]
  227. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  228. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  229. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  230. hidden_states = self.fc1(hidden_states)
  231. hidden_states = self.activation_fn(hidden_states)
  232. hidden_states = self.fc2(hidden_states)
  233. return hidden_states
  234. # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip
  235. class InstructBlipEncoderLayer(GradientCheckpointingLayer):
  236. def __init__(self, config: InstructBlipConfig):
  237. super().__init__()
  238. self.embed_dim = config.hidden_size
  239. self.self_attn = InstructBlipAttention(config)
  240. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  241. self.mlp = InstructBlipMLP(config)
  242. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  243. @auto_docstring
  244. def forward(
  245. self,
  246. hidden_states: torch.Tensor,
  247. **kwargs: Unpack[TransformersKwargs],
  248. ) -> torch.FloatTensor:
  249. residual = hidden_states
  250. hidden_states = self.layer_norm1(hidden_states)
  251. hidden_states, _ = self.self_attn(
  252. hidden_states=hidden_states,
  253. **kwargs,
  254. )
  255. hidden_states = hidden_states + residual
  256. residual = hidden_states
  257. hidden_states = self.layer_norm2(hidden_states)
  258. hidden_states = self.mlp(hidden_states)
  259. hidden_states = hidden_states + residual
  260. return hidden_states
  261. @auto_docstring
  262. class InstructBlipPreTrainedModel(PreTrainedModel):
  263. config: InstructBlipConfig
  264. base_model_prefix = "blip"
  265. input_modalities = ("image", "text")
  266. supports_gradient_checkpointing = True
  267. _supports_attention_backend = True
  268. _supports_flash_attn = True
  269. _supports_sdpa = True
  270. _supports_flex_attn = True
  271. _can_compile_fullgraph = True
  272. _no_split_modules = [
  273. "InstructBlipQFormerEmbeddings",
  274. "InstructBlipAttention",
  275. "InstructBlipQFormerMultiHeadAttention",
  276. "InstructBlipQFormerSelfOutput",
  277. ]
  278. @torch.no_grad()
  279. def _init_weights(self, module):
  280. """Initialize the weights"""
  281. super()._init_weights(module)
  282. factor = self.config.initializer_range
  283. if isinstance(module, InstructBlipVisionEmbeddings):
  284. init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
  285. init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
  286. elif isinstance(module, (InstructBlipForConditionalGeneration, InstructBlipModel)):
  287. init.zeros_(module.query_tokens)
  288. elif isinstance(module, InstructBlipQFormerEmbeddings):
  289. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  290. # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlip
  291. class InstructBlipEncoder(nn.Module):
  292. """
  293. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  294. [`InstructBlipEncoderLayer`].
  295. Args:
  296. config (`InstructBlipConfig`):
  297. The corresponding vision configuration for the `InstructBlipEncoder`.
  298. """
  299. def __init__(self, config: InstructBlipConfig):
  300. super().__init__()
  301. self.config = config
  302. self.layers = nn.ModuleList([InstructBlipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  303. self.gradient_checkpointing = False
  304. @auto_docstring
  305. def forward(
  306. self,
  307. inputs_embeds,
  308. **kwargs: Unpack[TransformersKwargs],
  309. ) -> tuple | BaseModelOutput:
  310. hidden_states = inputs_embeds
  311. for encoder_layer in self.layers:
  312. hidden_states = encoder_layer(
  313. hidden_states,
  314. **kwargs,
  315. )
  316. return BaseModelOutput(last_hidden_state=hidden_states)
  317. class InstructBlipVisionModel(InstructBlipPreTrainedModel):
  318. main_input_name = "pixel_values"
  319. input_modalities = ("image",)
  320. config: InstructBlipVisionConfig
  321. _can_record_outputs = {
  322. "hidden_states": InstructBlipEncoderLayer,
  323. "attentions": InstructBlipAttention,
  324. }
  325. def __init__(self, config: InstructBlipVisionConfig):
  326. super().__init__(config)
  327. self.config = config
  328. embed_dim = config.hidden_size
  329. self.embeddings = InstructBlipVisionEmbeddings(config)
  330. self.encoder = InstructBlipEncoder(config)
  331. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  332. self.post_init()
  333. @merge_with_config_defaults
  334. @capture_outputs(tie_last_hidden_states=False)
  335. @auto_docstring
  336. def forward(
  337. self,
  338. pixel_values: torch.FloatTensor | None = None,
  339. interpolate_pos_encoding: bool = False,
  340. **kwargs: Unpack[TransformersKwargs],
  341. ) -> tuple | BaseModelOutputWithPooling:
  342. if pixel_values is None:
  343. raise ValueError("You have to specify pixel_values")
  344. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  345. encoder_outputs: BaseModelOutput = self.encoder(
  346. inputs_embeds=hidden_states,
  347. **kwargs,
  348. )
  349. last_hidden_state = encoder_outputs.last_hidden_state
  350. last_hidden_state = self.post_layernorm(last_hidden_state)
  351. pooled_output = last_hidden_state[:, 0, :]
  352. pooled_output = self.post_layernorm(pooled_output)
  353. return BaseModelOutputWithPooling(
  354. last_hidden_state=last_hidden_state,
  355. pooler_output=pooled_output,
  356. )
  357. def get_input_embeddings(self):
  358. return self.embeddings
  359. class InstructBlipQFormerMultiHeadAttention(nn.Module):
  360. def __init__(self, config, is_cross_attention=False):
  361. super().__init__()
  362. self.config = config
  363. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  364. raise ValueError(
  365. "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
  366. % (config.hidden_size, config.num_attention_heads)
  367. )
  368. self.num_attention_heads = config.num_attention_heads
  369. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  370. self.all_head_size = self.num_attention_heads * self.attention_head_size
  371. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  372. if is_cross_attention:
  373. self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
  374. self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
  375. else:
  376. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  377. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  378. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  379. self.save_attention = False
  380. def save_attn_gradients(self, attn_gradients):
  381. self.attn_gradients = attn_gradients
  382. def get_attn_gradients(self):
  383. return self.attn_gradients
  384. def save_attention_map(self, attention_map):
  385. self.attention_map = attention_map
  386. def get_attention_map(self):
  387. return self.attention_map
  388. def transpose_for_scores(self, x):
  389. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  390. x = x.view(*new_x_shape)
  391. return x.permute(0, 2, 1, 3)
  392. def forward(
  393. self,
  394. hidden_states,
  395. attention_mask=None,
  396. encoder_hidden_states=None,
  397. encoder_attention_mask=None,
  398. **kwargs: Unpack[TransformersKwargs],
  399. ):
  400. # If this is instantiated as a cross-attention module, the keys
  401. # and values come from an encoder; the attention mask needs to be
  402. # such that the encoder's padding tokens are not attended to.
  403. is_cross_attention = encoder_hidden_states is not None
  404. if is_cross_attention:
  405. key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
  406. value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
  407. attention_mask = encoder_attention_mask
  408. else:
  409. key_layer = self.transpose_for_scores(self.key(hidden_states))
  410. value_layer = self.transpose_for_scores(self.value(hidden_states))
  411. mixed_query_layer = self.query(hidden_states)
  412. query_layer = self.transpose_for_scores(mixed_query_layer)
  413. # Take the dot product between "query" and "key" to get the raw attention scores.
  414. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  415. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  416. attention_scores_dtype = attention_scores.dtype
  417. if attention_mask is not None:
  418. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  419. attention_scores = attention_scores + attention_mask
  420. # Normalize the attention scores to probabilities.
  421. attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype)
  422. if is_cross_attention and self.save_attention:
  423. self.save_attention_map(attention_probs)
  424. attention_probs.register_hook(self.save_attn_gradients)
  425. # This is actually dropping out entire tokens to attend to, which might
  426. # seem a bit unusual, but is taken from the original Transformer paper.
  427. attention_probs_dropped = self.dropout(attention_probs)
  428. context_layer = torch.matmul(attention_probs_dropped, value_layer)
  429. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  430. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  431. context_layer = context_layer.view(*new_context_layer_shape)
  432. return context_layer, attention_probs
  433. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->InstructBlipQFormer
  434. class InstructBlipQFormerSelfOutput(nn.Module):
  435. def __init__(self, config):
  436. super().__init__()
  437. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  438. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  439. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  440. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  441. hidden_states = self.dense(hidden_states)
  442. hidden_states = self.dropout(hidden_states)
  443. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  444. return hidden_states
  445. # Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerAttention with Blip2->InstructBlip
  446. class InstructBlipQFormerAttention(nn.Module):
  447. def __init__(self, config, is_cross_attention=False):
  448. super().__init__()
  449. self.attention = InstructBlipQFormerMultiHeadAttention(config, is_cross_attention)
  450. self.output = InstructBlipQFormerSelfOutput(config)
  451. def forward(
  452. self,
  453. hidden_states: torch.Tensor,
  454. attention_mask: torch.FloatTensor | None = None,
  455. encoder_hidden_states: torch.FloatTensor | None = None,
  456. encoder_attention_mask: torch.FloatTensor | None = None,
  457. **kwargs: Unpack[TransformersKwargs],
  458. ) -> torch.Tensor:
  459. attn_output, _ = self.attention(
  460. hidden_states=hidden_states,
  461. attention_mask=attention_mask,
  462. encoder_hidden_states=encoder_hidden_states,
  463. encoder_attention_mask=encoder_attention_mask,
  464. **kwargs,
  465. )
  466. attention_output = self.output(attn_output, hidden_states)
  467. return attention_output
  468. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->InstructBlipQFormer
  469. class InstructBlipQFormerIntermediate(nn.Module):
  470. def __init__(self, config):
  471. super().__init__()
  472. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  473. if isinstance(config.hidden_act, str):
  474. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  475. else:
  476. self.intermediate_act_fn = config.hidden_act
  477. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  478. hidden_states = self.dense(hidden_states)
  479. hidden_states = self.intermediate_act_fn(hidden_states)
  480. return hidden_states
  481. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->InstructBlipQFormer
  482. class InstructBlipQFormerOutput(nn.Module):
  483. def __init__(self, config):
  484. super().__init__()
  485. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  486. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  487. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  488. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  489. hidden_states = self.dense(hidden_states)
  490. hidden_states = self.dropout(hidden_states)
  491. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  492. return hidden_states
  493. class InstructBlipQFormerLayer(GradientCheckpointingLayer):
  494. def __init__(self, config, layer_idx):
  495. super().__init__()
  496. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  497. self.seq_len_dim = 1
  498. self.attention = InstructBlipQFormerAttention(config)
  499. self.layer_idx = layer_idx
  500. if layer_idx % config.cross_attention_frequency == 0:
  501. self.crossattention = InstructBlipQFormerAttention(config, is_cross_attention=True)
  502. self.has_cross_attention = True
  503. else:
  504. self.has_cross_attention = False
  505. self.intermediate = InstructBlipQFormerIntermediate(config)
  506. self.output = InstructBlipQFormerOutput(config)
  507. self.intermediate_query = InstructBlipQFormerIntermediate(config)
  508. self.output_query = InstructBlipQFormerOutput(config)
  509. def forward(
  510. self,
  511. hidden_states,
  512. attention_mask=None,
  513. encoder_hidden_states=None,
  514. encoder_attention_mask=None,
  515. query_length=0,
  516. **kwargs: Unpack[TransformersKwargs],
  517. ):
  518. attention_output = self.attention(
  519. hidden_states,
  520. attention_mask=attention_mask,
  521. **kwargs,
  522. )
  523. if query_length > 0:
  524. query_attention_output = attention_output[:, :query_length, :]
  525. if self.has_cross_attention:
  526. if encoder_hidden_states is None:
  527. raise ValueError("encoder_hidden_states must be given for cross-attention layers")
  528. query_attention_output = self.crossattention(
  529. query_attention_output,
  530. attention_mask=attention_mask,
  531. encoder_hidden_states=encoder_hidden_states,
  532. encoder_attention_mask=encoder_attention_mask,
  533. **kwargs,
  534. )
  535. layer_output = apply_chunking_to_forward(
  536. self.feed_forward_chunk_query,
  537. self.chunk_size_feed_forward,
  538. self.seq_len_dim,
  539. query_attention_output,
  540. )
  541. if attention_output.shape[1] > query_length:
  542. layer_output_text = apply_chunking_to_forward(
  543. self.feed_forward_chunk,
  544. self.chunk_size_feed_forward,
  545. self.seq_len_dim,
  546. attention_output[:, query_length:, :],
  547. ).to(layer_output.device)
  548. layer_output = torch.cat([layer_output, layer_output_text], dim=1)
  549. else:
  550. layer_output = apply_chunking_to_forward(
  551. self.feed_forward_chunk,
  552. self.chunk_size_feed_forward,
  553. self.seq_len_dim,
  554. attention_output,
  555. )
  556. return layer_output
  557. def feed_forward_chunk(self, attention_output):
  558. intermediate_output = self.intermediate(attention_output)
  559. layer_output = self.output(intermediate_output, attention_output)
  560. return layer_output
  561. def feed_forward_chunk_query(self, attention_output):
  562. intermediate_output = self.intermediate_query(attention_output)
  563. layer_output = self.output_query(intermediate_output, attention_output)
  564. return layer_output
  565. # Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerEncoder with Blip2->InstructBlip
  566. class InstructBlipQFormerEncoder(nn.Module):
  567. def __init__(self, config):
  568. super().__init__()
  569. self.config = config
  570. self.layer = nn.ModuleList(
  571. [InstructBlipQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  572. )
  573. self.gradient_checkpointing = False
  574. @can_return_tuple
  575. def forward(
  576. self,
  577. hidden_states,
  578. attention_mask=None,
  579. encoder_hidden_states=None,
  580. encoder_attention_mask=None,
  581. query_length=0,
  582. **kwargs: Unpack[TransformersKwargs],
  583. ):
  584. for i in range(self.config.num_hidden_layers):
  585. layer_module = self.layer[i]
  586. hidden_states = layer_module(
  587. hidden_states,
  588. attention_mask,
  589. encoder_hidden_states, # as a positional argument for gradient checkpointing
  590. encoder_attention_mask=encoder_attention_mask,
  591. query_length=query_length,
  592. **kwargs,
  593. )
  594. return BaseModelOutputWithPastAndCrossAttentions(
  595. last_hidden_state=hidden_states,
  596. )
  597. class InstructBlipQFormerEmbeddings(nn.Module):
  598. """Construct the embeddings from word and position embeddings."""
  599. def __init__(self, config):
  600. super().__init__()
  601. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  602. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  603. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  604. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  605. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  606. self.register_buffer(
  607. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  608. )
  609. self.config = config
  610. def forward(
  611. self,
  612. input_ids=None,
  613. position_ids=None,
  614. query_embeds=None,
  615. past_key_values_length=0,
  616. ):
  617. if input_ids is not None:
  618. seq_length = input_ids.size()[1]
  619. else:
  620. seq_length = 0
  621. if position_ids is None:
  622. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
  623. if input_ids is not None:
  624. embeddings = self.word_embeddings(input_ids)
  625. position_embeddings = self.position_embeddings(position_ids.to(embeddings.device))
  626. embeddings = embeddings + position_embeddings
  627. if query_embeds is not None:
  628. embeddings = torch.cat((query_embeds, embeddings), dim=1)
  629. else:
  630. embeddings = query_embeds
  631. embeddings = embeddings.to(self.layernorm.weight.dtype)
  632. embeddings = self.layernorm(embeddings)
  633. embeddings = self.dropout(embeddings)
  634. return embeddings
  635. class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
  636. """
  637. Querying Transformer (Q-Former), used in InstructBLIP. Slightly modified from BLIP-2 as it also takes the
  638. instruction as input.
  639. """
  640. _supports_attention_backend = False # adds position on attn weights before last matmul
  641. _supports_flash_attn = False
  642. _supports_sdpa = False
  643. _supports_flex_attn = False
  644. _can_record_outputs = {
  645. "hidden_states": InstructBlipQFormerLayer,
  646. "attentions": [
  647. OutputRecorder(InstructBlipQFormerMultiHeadAttention, index=1, layer_name=".attention"),
  648. ],
  649. "cross_attentions": [
  650. OutputRecorder(InstructBlipQFormerMultiHeadAttention, index=1, layer_name=".crossattention"),
  651. ],
  652. }
  653. def __init__(self, config: InstructBlipQFormerConfig):
  654. super().__init__(config)
  655. self.config = config
  656. self.embeddings = InstructBlipQFormerEmbeddings(config)
  657. self.encoder = InstructBlipQFormerEncoder(config)
  658. self.post_init()
  659. def get_input_embeddings(self):
  660. return self.embeddings.word_embeddings
  661. def set_input_embeddings(self, value):
  662. self.embeddings.word_embeddings = value
  663. def get_extended_attention_mask(
  664. self,
  665. attention_mask: torch.Tensor,
  666. input_shape: tuple[int],
  667. device: torch.device,
  668. has_query: bool = False,
  669. ) -> torch.Tensor:
  670. """
  671. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
  672. Arguments:
  673. attention_mask (`torch.Tensor`):
  674. Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
  675. input_shape (`tuple[int]`):
  676. The shape of the input to the model.
  677. device: (`torch.device`):
  678. The device of the input to the model.
  679. Returns:
  680. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
  681. """
  682. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  683. # ourselves in which case we just need to make it broadcastable to all heads.
  684. if attention_mask.dim() == 3:
  685. extended_attention_mask = attention_mask[:, None, :, :]
  686. elif attention_mask.dim() == 2:
  687. # Provided a padding mask of dimensions [batch_size, seq_length]
  688. # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
  689. extended_attention_mask = attention_mask[:, None, None, :]
  690. else:
  691. raise ValueError(
  692. f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})",
  693. )
  694. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  695. # masked positions, this operation will create a tensor which is 0.0 for
  696. # positions we want to attend and -10000.0 for masked positions.
  697. # Since we are adding it to the raw scores before the softmax, this is
  698. # effectively the same as removing these entirely.
  699. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
  700. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  701. return extended_attention_mask
  702. @merge_with_config_defaults
  703. @capture_outputs
  704. @auto_docstring
  705. def forward(
  706. self,
  707. input_ids: torch.LongTensor,
  708. attention_mask: torch.FloatTensor | None = None,
  709. position_ids: torch.LongTensor | None = None,
  710. query_embeds: torch.Tensor | None = None,
  711. encoder_hidden_states: torch.FloatTensor | None = None,
  712. encoder_attention_mask: torch.FloatTensor | None = None,
  713. **kwargs: Unpack[TransformersKwargs],
  714. ) -> tuple[torch.FloatTensor] | BaseModelOutputWithPoolingAndCrossAttentions:
  715. r"""
  716. query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  717. Hidden states to be used in the attention computation. If cross-attention,
  718. will be used for the query (i.e., key and value will use the encoder_hidden_states).
  719. """
  720. if input_ids is None and query_embeds is None:
  721. raise ValueError("You have to specify query_embeds when input_ids is None")
  722. query_length = query_embeds.shape[1] if query_embeds is not None else 0
  723. embedding_output = self.embeddings(
  724. input_ids=input_ids,
  725. position_ids=position_ids,
  726. query_embeds=query_embeds,
  727. )
  728. input_shape = embedding_output.size()[:-1]
  729. batch_size, seq_length = input_shape
  730. device = embedding_output.device
  731. if attention_mask is None:
  732. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  733. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  734. # ourselves in which case we just need to make it broadcastable to all heads.
  735. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
  736. # If a 2D or 3D attention mask is provided for the cross-attention
  737. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  738. if encoder_hidden_states is not None:
  739. if isinstance(encoder_hidden_states, list):
  740. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
  741. else:
  742. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  743. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  744. if isinstance(encoder_attention_mask, list):
  745. encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
  746. elif encoder_attention_mask is None:
  747. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  748. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  749. else:
  750. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  751. else:
  752. encoder_extended_attention_mask = None
  753. encoder_outputs: BaseModelOutput = self.encoder(
  754. embedding_output,
  755. attention_mask=extended_attention_mask,
  756. encoder_hidden_states=encoder_hidden_states,
  757. encoder_attention_mask=encoder_extended_attention_mask,
  758. query_length=query_length,
  759. **kwargs,
  760. )
  761. sequence_output = encoder_outputs.last_hidden_state
  762. pooled_output = sequence_output[:, 0, :]
  763. return BaseModelOutputWithPoolingAndCrossAttentions(
  764. last_hidden_state=sequence_output,
  765. pooler_output=pooled_output,
  766. )
  767. @auto_docstring(
  768. custom_intro="""
  769. InstructBLIP base Model consisting of language model, qformer and vision encoder.
  770. """
  771. )
  772. class InstructBlipModel(InstructBlipPreTrainedModel):
  773. main_input_name = "pixel_values"
  774. _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
  775. def __init__(self, config: InstructBlipConfig):
  776. super().__init__(config)
  777. self.vision_model = InstructBlipVisionModel(config.vision_config)
  778. self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
  779. self.qformer = InstructBlipQFormerModel(config.qformer_config)
  780. self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
  781. self.language_model = AutoModel.from_config(config.text_config)
  782. # Initialize weights and apply final processing
  783. self.post_init()
  784. def get_input_embeddings(self):
  785. return self.language_model.get_input_embeddings()
  786. def set_input_embeddings(self, value):
  787. self.language_model.set_input_embeddings(value)
  788. def _preprocess_accelerate(self):
  789. r"""
  790. Some pre-processing hacks to make the model `accelerate` compatible. Check
  791. https://github.com/huggingface/transformers/pull/21707 for more details.
  792. """
  793. hf_device_map = self.hf_device_map
  794. if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
  795. # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
  796. logger.warning(
  797. "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
  798. " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
  799. " Please pass a `device_map` that contains `language_model` to remove this warning."
  800. " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
  801. " more details on creating a `device_map` for large models.",
  802. )
  803. if hasattr(self.language_model, "_hf_hook"):
  804. self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
  805. def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
  806. """
  807. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
  808. """
  809. if input_ids is None:
  810. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  811. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  812. )
  813. special_image_mask = special_image_mask.all(-1)
  814. else:
  815. special_image_mask = input_ids == self.config.image_token_id
  816. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  817. return special_image_mask
  818. @can_return_tuple
  819. @auto_docstring
  820. def forward(
  821. self,
  822. pixel_values: torch.FloatTensor,
  823. qformer_input_ids: torch.FloatTensor,
  824. qformer_attention_mask: torch.LongTensor | None = None,
  825. input_ids: torch.FloatTensor | None = None,
  826. attention_mask: torch.LongTensor | None = None,
  827. decoder_input_ids: torch.LongTensor | None = None,
  828. decoder_attention_mask: torch.LongTensor | None = None,
  829. inputs_embeds: torch.Tensor | None = None,
  830. interpolate_pos_encoding: bool = False,
  831. **kwargs: Unpack[FlashAttentionKwargs],
  832. ) -> tuple | InstructBlipForConditionalGenerationModelOutput:
  833. r"""
  834. qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  835. Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
  836. to serve as text prompt, which the Q-Former model will encode.
  837. Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
  838. details.
  839. [What are input IDs?](../glossary#input-ids)
  840. qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  841. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  842. - 1 for tokens that are **not masked**,
  843. - 0 for tokens that are **masked**.
  844. [What are attention masks?](../glossary#attention-mask)
  845. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  846. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  847. be used by default.
  848. Only relevant in case an encoder-decoder language model (like T5) is used.
  849. """
  850. # step 1: forward the images through the vision encoder,
  851. # to get image embeddings of shape (batch_size, seq_len, hidden_size)
  852. vision_outputs = self.vision_model(
  853. pixel_values=pixel_values,
  854. interpolate_pos_encoding=interpolate_pos_encoding,
  855. **kwargs,
  856. )
  857. image_embeds = vision_outputs[0]
  858. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  859. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  860. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  861. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  862. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  863. if qformer_attention_mask is None:
  864. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  865. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  866. query_outputs = self.qformer(
  867. input_ids=qformer_input_ids,
  868. attention_mask=qformer_attention_mask,
  869. query_embeds=query_tokens,
  870. encoder_hidden_states=image_embeds,
  871. encoder_attention_mask=image_attention_mask,
  872. **kwargs,
  873. )
  874. query_output = query_outputs[0][:, : query_tokens.size(1), :]
  875. if inputs_embeds is None:
  876. inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
  877. if attention_mask is None:
  878. attention_mask = torch.ones_like(input_ids)
  879. # step 3: use the language model, conditioned on the query outputs and the prompt
  880. language_model_inputs = self.language_projection(query_output)
  881. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  882. special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
  883. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  884. if self.config.use_decoder_only_language_model:
  885. outputs = self.language_model(
  886. inputs_embeds=inputs_embeds,
  887. attention_mask=attention_mask,
  888. **kwargs,
  889. )
  890. else:
  891. outputs = self.language_model(
  892. inputs_embeds=inputs_embeds,
  893. attention_mask=attention_mask,
  894. decoder_input_ids=decoder_input_ids,
  895. decoder_attention_mask=decoder_attention_mask,
  896. **kwargs,
  897. )
  898. return InstructBlipForConditionalGenerationModelOutput(
  899. vision_outputs=vision_outputs,
  900. qformer_outputs=query_outputs,
  901. language_model_outputs=outputs,
  902. )
  903. @auto_docstring(
  904. custom_intro="""
  905. InstructBLIP Model for generating text given an image and an optional text prompt. The model consists of a vision
  906. encoder, Querying Transformer (Q-Former) and a language model.
  907. One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
  908. the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
  909. """
  910. )
  911. class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin):
  912. config: InstructBlipConfig
  913. main_input_name = "pixel_values"
  914. _can_compile_fullgraph = True
  915. _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
  916. def __init__(self, config: InstructBlipConfig):
  917. super().__init__(config)
  918. self.vision_model = InstructBlipVisionModel._from_config(config.vision_config)
  919. self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
  920. self.qformer = InstructBlipQFormerModel._from_config(config.qformer_config)
  921. self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
  922. if config.use_decoder_only_language_model:
  923. language_model = AutoModelForCausalLM.from_config(config.text_config)
  924. else:
  925. language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
  926. self.language_model = language_model
  927. # Initialize weights and apply final processing
  928. self.post_init()
  929. def get_input_embeddings(self):
  930. return self.language_model.get_input_embeddings()
  931. def set_input_embeddings(self, value):
  932. self.language_model.set_input_embeddings(value)
  933. def set_output_embeddings(self, new_embeddings):
  934. self.language_model.set_output_embeddings(new_embeddings)
  935. def get_output_embeddings(self) -> nn.Module:
  936. return self.language_model.get_output_embeddings()
  937. def get_encoder(self, modality=None):
  938. if modality is None:
  939. return self.language_model.get_encoder()
  940. else:
  941. return super().get_encoder(modality=modality)
  942. def get_decoder(self):
  943. return self.language_model.get_decoder()
  944. # Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._preprocess_accelerate
  945. def _preprocess_accelerate(self):
  946. r"""
  947. Some pre-processing hacks to make the model `accelerate` compatible. Check
  948. https://github.com/huggingface/transformers/pull/21707 for more details.
  949. """
  950. hf_device_map = self.hf_device_map
  951. if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
  952. # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
  953. logger.warning(
  954. "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
  955. " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
  956. " Please pass a `device_map` that contains `language_model` to remove this warning."
  957. " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
  958. " more details on creating a `device_map` for large models.",
  959. )
  960. if hasattr(self.language_model, "_hf_hook"):
  961. self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
  962. @can_return_tuple
  963. @auto_docstring
  964. def get_image_features(
  965. self,
  966. pixel_values: torch.FloatTensor,
  967. qformer_input_ids: torch.LongTensor,
  968. qformer_attention_mask: torch.LongTensor | None = None,
  969. interpolate_pos_encoding: bool | None = False,
  970. **kwargs: Unpack[TransformersKwargs],
  971. ) -> tuple | BaseModelOutputWithVisionQformerOutputs:
  972. r"""
  973. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  974. The tensors corresponding to the input images.
  975. qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  976. Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
  977. to serve as text prompt, which the Q-Former model will encode.
  978. Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
  979. details.
  980. [What are input IDs?](../glossary#input-ids)
  981. qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  982. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  983. - 1 for tokens that are **not masked**,
  984. - 0 for tokens that are **masked**.
  985. [What are attention masks?](../glossary#attention-mask)
  986. """
  987. # step 1: forward the images through the vision encoder,
  988. # to get image embeddings of shape (batch_size, seq_len, hidden_size)
  989. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  990. pixel_values=pixel_values,
  991. interpolate_pos_encoding=interpolate_pos_encoding,
  992. return_dict=True,
  993. **kwargs,
  994. )
  995. vision_outputs = BaseModelOutputWithVisionQformerOutputs(**vision_outputs, vision_outputs=vision_outputs)
  996. image_embeds = vision_outputs[0]
  997. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  998. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  999. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  1000. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  1001. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1002. if qformer_attention_mask is None:
  1003. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  1004. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  1005. qformer_outputs = self.qformer(
  1006. input_ids=qformer_input_ids,
  1007. attention_mask=qformer_attention_mask,
  1008. query_embeds=query_tokens,
  1009. encoder_hidden_states=image_embeds,
  1010. encoder_attention_mask=image_attention_mask,
  1011. return_dict=True,
  1012. **kwargs,
  1013. )
  1014. vision_outputs.qformer_outputs = qformer_outputs
  1015. query_output = qformer_outputs[0][:, : query_tokens.size(1), :]
  1016. # step 3: use the language model, conditioned on the query outputs and the prompt
  1017. image_features = self.language_projection(query_output)
  1018. vision_outputs.pooler_output = image_features
  1019. return vision_outputs
  1020. def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
  1021. """
  1022. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
  1023. """
  1024. if input_ids is None:
  1025. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  1026. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  1027. )
  1028. special_image_mask = special_image_mask.all(-1)
  1029. else:
  1030. special_image_mask = input_ids == self.config.image_token_id
  1031. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1032. return special_image_mask
  1033. @can_return_tuple
  1034. @auto_docstring
  1035. def forward(
  1036. self,
  1037. pixel_values: torch.FloatTensor,
  1038. qformer_input_ids: torch.FloatTensor,
  1039. qformer_attention_mask: torch.LongTensor | None = None,
  1040. input_ids: torch.FloatTensor | None = None,
  1041. attention_mask: torch.LongTensor | None = None,
  1042. decoder_input_ids: torch.LongTensor | None = None,
  1043. decoder_attention_mask: torch.LongTensor | None = None,
  1044. inputs_embeds: torch.FloatTensor | None = None,
  1045. labels: torch.LongTensor | None = None,
  1046. interpolate_pos_encoding: bool = False,
  1047. **kwargs: Unpack[TransformersKwargs],
  1048. ) -> tuple | InstructBlipForConditionalGenerationModelOutput:
  1049. r"""
  1050. qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1051. Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
  1052. to serve as text prompt, which the Q-Former model will encode.
  1053. Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
  1054. details.
  1055. [What are input IDs?](../glossary#input-ids)
  1056. qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1057. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1058. - 1 for tokens that are **not masked**,
  1059. - 0 for tokens that are **masked**.
  1060. [What are attention masks?](../glossary#attention-mask)
  1061. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1062. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1063. be used by default.
  1064. Only relevant in case an encoder-decoder language model (like T5) is used.
  1065. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1066. Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
  1067. 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  1068. config.vocab_size]`
  1069. Examples:
  1070. ```python
  1071. >>> from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
  1072. >>> import torch
  1073. >>> from PIL import Image
  1074. >>> import httpx
  1075. >>> from io import BytesIO
  1076. >>> model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b")
  1077. >>> processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
  1078. >>> device = "cuda" if torch.cuda.is_available() else "cpu"
  1079. >>> model.to(device) # doctest: +IGNORE_RESULT
  1080. >>> url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
  1081. >>> with httpx.stream("GET", url) as response:
  1082. ... image = Image.open(BytesIO(response.read())).convert("RGB")
  1083. >>> prompt = "What is unusual about this image?"
  1084. >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
  1085. >>> outputs = model.generate(
  1086. ... **inputs,
  1087. ... do_sample=False,
  1088. ... num_beams=5,
  1089. ... max_length=256,
  1090. ... min_length=1,
  1091. ... top_p=0.9,
  1092. ... repetition_penalty=1.5,
  1093. ... length_penalty=1.0,
  1094. ... temperature=1,
  1095. ... )
  1096. >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
  1097. >>> print(generated_text)
  1098. The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV, which is parked in the middle of a busy city street. This is an unconventional approach to ironing clothes, as it requires the man to balance himself and his ironing equipment on top of the vehicle while navigating through traffic. Additionally, the presence of taxis and other vehicles in the scene further emphasizes the unusual nature of this situation.
  1099. ```"""
  1100. image_features: BaseModelOutputWithVisionQformerOutputs = self.get_image_features(
  1101. pixel_values,
  1102. qformer_input_ids=qformer_input_ids,
  1103. qformer_attention_mask=qformer_attention_mask,
  1104. interpolate_pos_encoding=interpolate_pos_encoding,
  1105. return_dict=True,
  1106. )
  1107. language_model_inputs = image_features.pooler_output
  1108. qformer_outputs = image_features.qformer_outputs
  1109. vision_outputs = image_features.vision_outputs
  1110. if inputs_embeds is None:
  1111. inputs_embeds = self.get_input_embeddings()(input_ids)
  1112. if attention_mask is None:
  1113. attention_mask = torch.ones_like(input_ids)
  1114. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  1115. special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
  1116. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  1117. if self.config.use_decoder_only_language_model:
  1118. outputs = self.language_model(
  1119. inputs_embeds=inputs_embeds,
  1120. attention_mask=attention_mask,
  1121. **kwargs,
  1122. )
  1123. logits = outputs[0]
  1124. loss = None
  1125. if labels is not None:
  1126. loss = self.loss_function(
  1127. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  1128. )
  1129. else:
  1130. kwargs["return_dict"] = True
  1131. outputs = self.language_model(
  1132. inputs_embeds=inputs_embeds,
  1133. attention_mask=attention_mask,
  1134. decoder_input_ids=decoder_input_ids,
  1135. decoder_attention_mask=decoder_attention_mask,
  1136. labels=labels,
  1137. **kwargs,
  1138. )
  1139. loss = outputs.loss
  1140. logits = outputs.logits
  1141. return InstructBlipForConditionalGenerationModelOutput(
  1142. loss=loss,
  1143. logits=logits,
  1144. vision_outputs=vision_outputs,
  1145. qformer_outputs=qformer_outputs,
  1146. language_model_outputs=outputs,
  1147. )
  1148. @torch.no_grad()
  1149. def generate(
  1150. self,
  1151. pixel_values: torch.FloatTensor,
  1152. qformer_input_ids: torch.LongTensor | None = None,
  1153. qformer_attention_mask: torch.LongTensor | None = None,
  1154. input_ids: torch.LongTensor | None = None,
  1155. attention_mask: torch.LongTensor | None = None,
  1156. inputs_embeds: torch.FloatTensor | None = None,
  1157. interpolate_pos_encoding: bool = False,
  1158. **generate_kwargs,
  1159. ) -> torch.LongTensor:
  1160. """
  1161. Overrides `generate` function to be able to use the model as a conditional generator.
  1162. Args:
  1163. pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
  1164. Input images to be processed.
  1165. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1166. The sequence used as a prompt to be fed to the Q-Former module.
  1167. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1168. Mask to avoid performing attention on padding token indices.
  1169. input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1170. The sequence used as a prompt for the generation.
  1171. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1172. Mask to avoid performing attention on padding token indices.
  1173. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  1174. Embedded representation of the inputs. Should be float, not int tokens.
  1175. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  1176. Whether to interpolate the positional encoding of the image embeddings.
  1177. Returns:
  1178. captions (list): A list of strings of length batch_size * num_captions.
  1179. """
  1180. if hasattr(self, "hf_device_map"):
  1181. # preprocess for `accelerate`
  1182. self._preprocess_accelerate()
  1183. batch_size = pixel_values.shape[0]
  1184. image_features: BaseModelOutputWithVisionQformerOutputs = self.get_image_features(
  1185. pixel_values,
  1186. qformer_input_ids=qformer_input_ids,
  1187. qformer_attention_mask=qformer_attention_mask,
  1188. interpolate_pos_encoding=interpolate_pos_encoding,
  1189. return_dict=True,
  1190. )
  1191. language_model_inputs = image_features.pooler_output
  1192. if inputs_embeds is None:
  1193. if input_ids is None:
  1194. image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
  1195. start_tokens = image_tokens + [self.config.text_config.bos_token_id]
  1196. input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
  1197. input_ids = input_ids.repeat(batch_size, 1)
  1198. inputs_embeds = self.get_input_embeddings()(input_ids)
  1199. if attention_mask is None:
  1200. attention_mask = torch.ones_like(input_ids)
  1201. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  1202. special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
  1203. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  1204. inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
  1205. if not self.language_model.config.is_encoder_decoder:
  1206. inputs["input_ids"] = input_ids
  1207. outputs = self.language_model.generate(**inputs, **generate_kwargs)
  1208. return outputs
  1209. __all__ = [
  1210. "InstructBlipQFormerModel",
  1211. "InstructBlipPreTrainedModel",
  1212. "InstructBlipModel",
  1213. "InstructBlipForConditionalGeneration",
  1214. "InstructBlipVisionModel",
  1215. ]