modeling_instructblipvideo.py 63 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/instructblipvideo/modular_instructblipvideo.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_instructblipvideo.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. from dataclasses import dataclass
  23. from typing import Any
  24. import torch
  25. from torch import nn
  26. from ... import initialization as init
  27. from ...activations import ACT2FN
  28. from ...generation import GenerationMixin
  29. from ...modeling_layers import GradientCheckpointingLayer
  30. from ...modeling_outputs import (
  31. BaseModelOutput,
  32. BaseModelOutputWithPastAndCrossAttentions,
  33. BaseModelOutputWithPooling,
  34. BaseModelOutputWithPoolingAndCrossAttentions,
  35. CausalLMOutputWithPast,
  36. Seq2SeqLMOutput,
  37. )
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...pytorch_utils import apply_chunking_to_forward
  41. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
  42. from ...utils.generic import merge_with_config_defaults
  43. from ...utils.output_capturing import OutputRecorder, capture_outputs
  44. from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
  45. from .configuration_instructblipvideo import (
  46. InstructBlipVideoConfig,
  47. InstructBlipVideoQFormerConfig,
  48. InstructBlipVideoVisionConfig,
  49. )
  50. logger = logging.get_logger(__name__)
  51. class InstructBlipVideoVisionEmbeddings(nn.Module):
  52. def __init__(self, config: InstructBlipVideoVisionConfig):
  53. super().__init__()
  54. self.config = config
  55. self.embed_dim = config.hidden_size
  56. self.image_size = config.image_size
  57. self.patch_size = config.patch_size
  58. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  59. self.patch_embedding = nn.Conv2d(
  60. in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
  61. )
  62. self.num_patches = (self.image_size // self.patch_size) ** 2
  63. self.num_positions = self.num_patches + 1
  64. self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
  65. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  66. """
  67. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  68. images. This method is also adapted to support torch.jit tracing.
  69. Adapted from:
  70. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  71. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  72. """
  73. num_patches = embeddings.shape[1] - 1
  74. num_positions = self.position_embedding.shape[1] - 1
  75. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  76. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  77. return self.position_embedding
  78. class_pos_embed = self.position_embedding[:, :1]
  79. patch_pos_embed = self.position_embedding[:, 1:]
  80. dim = embeddings.shape[-1]
  81. new_height = height // self.patch_size
  82. new_width = width // self.patch_size
  83. sqrt_num_positions = torch_int(num_positions**0.5)
  84. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  85. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  86. patch_pos_embed = nn.functional.interpolate(
  87. patch_pos_embed,
  88. size=(new_height, new_width),
  89. mode="bicubic",
  90. align_corners=False,
  91. )
  92. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  93. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  94. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  95. batch_size, _, height, width = pixel_values.shape
  96. target_dtype = self.patch_embedding.weight.dtype
  97. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  98. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  99. class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
  100. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  101. if interpolate_pos_encoding:
  102. position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
  103. else:
  104. position_embedding = self.position_embedding
  105. embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
  106. return embeddings
  107. class InstructBlipVideoQFormerEmbeddings(nn.Module):
  108. """Construct the embeddings from word and position embeddings."""
  109. def __init__(self, config):
  110. super().__init__()
  111. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  112. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  113. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  114. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  115. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  116. self.register_buffer(
  117. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  118. )
  119. self.config = config
  120. def forward(
  121. self,
  122. input_ids=None,
  123. position_ids=None,
  124. query_embeds=None,
  125. past_key_values_length=0,
  126. ):
  127. if input_ids is not None:
  128. seq_length = input_ids.size()[1]
  129. else:
  130. seq_length = 0
  131. if position_ids is None:
  132. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
  133. if input_ids is not None:
  134. embeddings = self.word_embeddings(input_ids)
  135. position_embeddings = self.position_embeddings(position_ids.to(embeddings.device))
  136. embeddings = embeddings + position_embeddings
  137. if query_embeds is not None:
  138. embeddings = torch.cat((query_embeds, embeddings), dim=1)
  139. else:
  140. embeddings = query_embeds
  141. embeddings = embeddings.to(self.layernorm.weight.dtype)
  142. embeddings = self.layernorm(embeddings)
  143. embeddings = self.dropout(embeddings)
  144. return embeddings
  145. @auto_docstring
  146. class InstructBlipVideoPreTrainedModel(PreTrainedModel):
  147. config: InstructBlipVideoConfig
  148. base_model_prefix = "blip"
  149. input_modalities = ("video", "text")
  150. supports_gradient_checkpointing = True
  151. _supports_attention_backend = True
  152. _supports_flash_attn = True
  153. _supports_sdpa = True
  154. _supports_flex_attn = True
  155. _can_compile_fullgraph = True
  156. _no_split_modules = [
  157. "InstructBlipVideoQFormerEmbeddings",
  158. "InstructBlipVideoAttention",
  159. "InstructBlipVideoQFormerMultiHeadAttention",
  160. "InstructBlipVideoQFormerSelfOutput",
  161. ]
  162. @torch.no_grad()
  163. def _init_weights(self, module):
  164. """Initialize the weights"""
  165. super()._init_weights(module)
  166. factor = self.config.initializer_range
  167. if isinstance(module, InstructBlipVideoVisionEmbeddings):
  168. init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
  169. init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
  170. elif isinstance(module, (InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel)):
  171. init.zeros_(module.query_tokens)
  172. elif isinstance(module, InstructBlipVideoQFormerEmbeddings):
  173. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  174. # Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBlipVideo doesn't cast attn weights to fp32
  175. def eager_attention_forward(
  176. module: nn.Module,
  177. query: torch.Tensor,
  178. key: torch.Tensor,
  179. value: torch.Tensor,
  180. attention_mask: torch.Tensor | None,
  181. scaling: float,
  182. dropout: float = 0.0,
  183. **kwargs,
  184. ):
  185. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  186. if attention_mask is not None:
  187. attn_weights = attn_weights + attention_mask
  188. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  189. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  190. attn_output = torch.matmul(attn_weights, value)
  191. attn_output = attn_output.transpose(1, 2).contiguous()
  192. return attn_output, attn_weights
  193. class InstructBlipVideoAttention(nn.Module):
  194. """Multi-headed attention from 'Attention Is All You Need' paper"""
  195. def __init__(self, config):
  196. super().__init__()
  197. self.config = config
  198. self.embed_dim = config.hidden_size
  199. self.num_heads = config.num_attention_heads
  200. self.head_dim = self.embed_dim // self.num_heads
  201. if self.head_dim * self.num_heads != self.embed_dim:
  202. raise ValueError(
  203. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  204. f" {self.num_heads})."
  205. )
  206. self.scale = self.head_dim**-0.5
  207. self.is_causal = False
  208. self.attention_dropout = config.attention_dropout
  209. # small tweak here compared to CLIP, no bias here
  210. self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
  211. if config.qkv_bias:
  212. q_bias = nn.Parameter(torch.zeros(self.embed_dim))
  213. v_bias = nn.Parameter(torch.zeros(self.embed_dim))
  214. else:
  215. q_bias = None
  216. v_bias = None
  217. if q_bias is not None:
  218. qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
  219. self.qkv.bias = nn.Parameter(qkv_bias)
  220. self.projection = nn.Linear(self.embed_dim, self.embed_dim)
  221. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  222. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  223. def forward(
  224. self,
  225. hidden_states: torch.Tensor,
  226. **kwargs,
  227. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  228. """Input shape: Batch x Time x Channel"""
  229. bsz, tgt_len, embed_dim = hidden_states.size()
  230. mixed_qkv = self.qkv(hidden_states)
  231. mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
  232. 2, 0, 3, 1, 4
  233. )
  234. query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
  235. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  236. self.config._attn_implementation, eager_attention_forward
  237. )
  238. attn_output, attn_weights = attention_interface(
  239. self,
  240. query_states,
  241. key_states,
  242. value_states,
  243. attention_mask=None,
  244. dropout=0.0 if not self.training else self.attention_dropout,
  245. scaling=self.scale,
  246. **kwargs,
  247. )
  248. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  249. attn_output = self.projection(attn_output)
  250. return attn_output, attn_weights
  251. class InstructBlipVideoMLP(nn.Module):
  252. def __init__(self, config):
  253. super().__init__()
  254. self.config = config
  255. self.activation_fn = ACT2FN[config.hidden_act]
  256. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  257. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  258. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  259. hidden_states = self.fc1(hidden_states)
  260. hidden_states = self.activation_fn(hidden_states)
  261. hidden_states = self.fc2(hidden_states)
  262. return hidden_states
  263. class InstructBlipVideoEncoderLayer(GradientCheckpointingLayer):
  264. def __init__(self, config: InstructBlipVideoConfig):
  265. super().__init__()
  266. self.embed_dim = config.hidden_size
  267. self.self_attn = InstructBlipVideoAttention(config)
  268. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  269. self.mlp = InstructBlipVideoMLP(config)
  270. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  271. @auto_docstring
  272. def forward(
  273. self,
  274. hidden_states: torch.Tensor,
  275. **kwargs: Unpack[TransformersKwargs],
  276. ) -> torch.FloatTensor:
  277. residual = hidden_states
  278. hidden_states = self.layer_norm1(hidden_states)
  279. hidden_states, _ = self.self_attn(
  280. hidden_states=hidden_states,
  281. **kwargs,
  282. )
  283. hidden_states = hidden_states + residual
  284. residual = hidden_states
  285. hidden_states = self.layer_norm2(hidden_states)
  286. hidden_states = self.mlp(hidden_states)
  287. hidden_states = hidden_states + residual
  288. return hidden_states
  289. class InstructBlipVideoEncoder(nn.Module):
  290. """
  291. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  292. [`InstructBlipVideoEncoderLayer`].
  293. Args:
  294. config (`InstructBlipVideoConfig`):
  295. The corresponding vision configuration for the `InstructBlipVideoEncoder`.
  296. """
  297. def __init__(self, config: InstructBlipVideoConfig):
  298. super().__init__()
  299. self.config = config
  300. self.layers = nn.ModuleList([InstructBlipVideoEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  301. self.gradient_checkpointing = False
  302. @auto_docstring
  303. def forward(
  304. self,
  305. inputs_embeds,
  306. **kwargs: Unpack[TransformersKwargs],
  307. ) -> tuple | BaseModelOutput:
  308. hidden_states = inputs_embeds
  309. for encoder_layer in self.layers:
  310. hidden_states = encoder_layer(
  311. hidden_states,
  312. **kwargs,
  313. )
  314. return BaseModelOutput(last_hidden_state=hidden_states)
  315. class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
  316. main_input_name = "pixel_values"
  317. input_modalities = "video"
  318. config: InstructBlipVideoVisionConfig
  319. _can_record_outputs = {
  320. "hidden_states": InstructBlipVideoEncoderLayer,
  321. "attentions": InstructBlipVideoAttention,
  322. }
  323. def __init__(self, config: InstructBlipVideoVisionConfig):
  324. super().__init__(config)
  325. self.config = config
  326. embed_dim = config.hidden_size
  327. self.embeddings = InstructBlipVideoVisionEmbeddings(config)
  328. self.encoder = InstructBlipVideoEncoder(config)
  329. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  330. self.post_init()
  331. @merge_with_config_defaults
  332. @capture_outputs(tie_last_hidden_states=False)
  333. @auto_docstring
  334. def forward(
  335. self,
  336. pixel_values: torch.FloatTensor | None = None,
  337. interpolate_pos_encoding: bool = False,
  338. **kwargs: Unpack[TransformersKwargs],
  339. ) -> tuple | BaseModelOutputWithPooling:
  340. if pixel_values is None:
  341. raise ValueError("You have to specify pixel_values")
  342. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  343. encoder_outputs: BaseModelOutput = self.encoder(
  344. inputs_embeds=hidden_states,
  345. **kwargs,
  346. )
  347. last_hidden_state = encoder_outputs.last_hidden_state
  348. last_hidden_state = self.post_layernorm(last_hidden_state)
  349. pooled_output = last_hidden_state[:, 0, :]
  350. pooled_output = self.post_layernorm(pooled_output)
  351. return BaseModelOutputWithPooling(
  352. last_hidden_state=last_hidden_state,
  353. pooler_output=pooled_output,
  354. )
  355. def get_input_embeddings(self):
  356. return self.embeddings
  357. class InstructBlipVideoQFormerMultiHeadAttention(nn.Module):
  358. def __init__(self, config, is_cross_attention=False):
  359. super().__init__()
  360. self.config = config
  361. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  362. raise ValueError(
  363. "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
  364. % (config.hidden_size, config.num_attention_heads)
  365. )
  366. self.num_attention_heads = config.num_attention_heads
  367. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  368. self.all_head_size = self.num_attention_heads * self.attention_head_size
  369. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  370. if is_cross_attention:
  371. self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
  372. self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
  373. else:
  374. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  375. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  376. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  377. self.save_attention = False
  378. def save_attn_gradients(self, attn_gradients):
  379. self.attn_gradients = attn_gradients
  380. def get_attn_gradients(self):
  381. return self.attn_gradients
  382. def save_attention_map(self, attention_map):
  383. self.attention_map = attention_map
  384. def get_attention_map(self):
  385. return self.attention_map
  386. def transpose_for_scores(self, x):
  387. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  388. x = x.view(*new_x_shape)
  389. return x.permute(0, 2, 1, 3)
  390. def forward(
  391. self,
  392. hidden_states,
  393. attention_mask=None,
  394. encoder_hidden_states=None,
  395. encoder_attention_mask=None,
  396. **kwargs: Unpack[TransformersKwargs],
  397. ):
  398. # If this is instantiated as a cross-attention module, the keys
  399. # and values come from an encoder; the attention mask needs to be
  400. # such that the encoder's padding tokens are not attended to.
  401. is_cross_attention = encoder_hidden_states is not None
  402. if is_cross_attention:
  403. key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
  404. value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
  405. attention_mask = encoder_attention_mask
  406. else:
  407. key_layer = self.transpose_for_scores(self.key(hidden_states))
  408. value_layer = self.transpose_for_scores(self.value(hidden_states))
  409. mixed_query_layer = self.query(hidden_states)
  410. query_layer = self.transpose_for_scores(mixed_query_layer)
  411. # Take the dot product between "query" and "key" to get the raw attention scores.
  412. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  413. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  414. attention_scores_dtype = attention_scores.dtype
  415. if attention_mask is not None:
  416. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  417. attention_scores = attention_scores + attention_mask
  418. # Normalize the attention scores to probabilities.
  419. attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype)
  420. if is_cross_attention and self.save_attention:
  421. self.save_attention_map(attention_probs)
  422. attention_probs.register_hook(self.save_attn_gradients)
  423. # This is actually dropping out entire tokens to attend to, which might
  424. # seem a bit unusual, but is taken from the original Transformer paper.
  425. attention_probs_dropped = self.dropout(attention_probs)
  426. context_layer = torch.matmul(attention_probs_dropped, value_layer)
  427. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  428. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  429. context_layer = context_layer.view(*new_context_layer_shape)
  430. return context_layer, attention_probs
  431. class InstructBlipVideoQFormerSelfOutput(nn.Module):
  432. def __init__(self, config):
  433. super().__init__()
  434. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  435. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  436. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  437. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  438. hidden_states = self.dense(hidden_states)
  439. hidden_states = self.dropout(hidden_states)
  440. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  441. return hidden_states
  442. class InstructBlipVideoQFormerAttention(nn.Module):
  443. def __init__(self, config, is_cross_attention=False):
  444. super().__init__()
  445. self.attention = InstructBlipVideoQFormerMultiHeadAttention(config, is_cross_attention)
  446. self.output = InstructBlipVideoQFormerSelfOutput(config)
  447. def forward(
  448. self,
  449. hidden_states: torch.Tensor,
  450. attention_mask: torch.FloatTensor | None = None,
  451. encoder_hidden_states: torch.FloatTensor | None = None,
  452. encoder_attention_mask: torch.FloatTensor | None = None,
  453. **kwargs: Unpack[TransformersKwargs],
  454. ) -> torch.Tensor:
  455. attn_output, _ = self.attention(
  456. hidden_states=hidden_states,
  457. attention_mask=attention_mask,
  458. encoder_hidden_states=encoder_hidden_states,
  459. encoder_attention_mask=encoder_attention_mask,
  460. **kwargs,
  461. )
  462. attention_output = self.output(attn_output, hidden_states)
  463. return attention_output
  464. class InstructBlipVideoQFormerIntermediate(nn.Module):
  465. def __init__(self, config):
  466. super().__init__()
  467. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  468. if isinstance(config.hidden_act, str):
  469. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  470. else:
  471. self.intermediate_act_fn = config.hidden_act
  472. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  473. hidden_states = self.dense(hidden_states)
  474. hidden_states = self.intermediate_act_fn(hidden_states)
  475. return hidden_states
  476. class InstructBlipVideoQFormerOutput(nn.Module):
  477. def __init__(self, config):
  478. super().__init__()
  479. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  480. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  481. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  482. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  483. hidden_states = self.dense(hidden_states)
  484. hidden_states = self.dropout(hidden_states)
  485. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  486. return hidden_states
  487. class InstructBlipVideoQFormerLayer(GradientCheckpointingLayer):
  488. def __init__(self, config, layer_idx):
  489. super().__init__()
  490. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  491. self.seq_len_dim = 1
  492. self.attention = InstructBlipVideoQFormerAttention(config)
  493. self.layer_idx = layer_idx
  494. if layer_idx % config.cross_attention_frequency == 0:
  495. self.crossattention = InstructBlipVideoQFormerAttention(config, is_cross_attention=True)
  496. self.has_cross_attention = True
  497. else:
  498. self.has_cross_attention = False
  499. self.intermediate = InstructBlipVideoQFormerIntermediate(config)
  500. self.output = InstructBlipVideoQFormerOutput(config)
  501. self.intermediate_query = InstructBlipVideoQFormerIntermediate(config)
  502. self.output_query = InstructBlipVideoQFormerOutput(config)
  503. def forward(
  504. self,
  505. hidden_states,
  506. attention_mask=None,
  507. encoder_hidden_states=None,
  508. encoder_attention_mask=None,
  509. query_length=0,
  510. **kwargs: Unpack[TransformersKwargs],
  511. ):
  512. attention_output = self.attention(
  513. hidden_states,
  514. attention_mask=attention_mask,
  515. **kwargs,
  516. )
  517. if query_length > 0:
  518. query_attention_output = attention_output[:, :query_length, :]
  519. if self.has_cross_attention:
  520. if encoder_hidden_states is None:
  521. raise ValueError("encoder_hidden_states must be given for cross-attention layers")
  522. query_attention_output = self.crossattention(
  523. query_attention_output,
  524. attention_mask=attention_mask,
  525. encoder_hidden_states=encoder_hidden_states,
  526. encoder_attention_mask=encoder_attention_mask,
  527. **kwargs,
  528. )
  529. layer_output = apply_chunking_to_forward(
  530. self.feed_forward_chunk_query,
  531. self.chunk_size_feed_forward,
  532. self.seq_len_dim,
  533. query_attention_output,
  534. )
  535. if attention_output.shape[1] > query_length:
  536. layer_output_text = apply_chunking_to_forward(
  537. self.feed_forward_chunk,
  538. self.chunk_size_feed_forward,
  539. self.seq_len_dim,
  540. attention_output[:, query_length:, :],
  541. ).to(layer_output.device)
  542. layer_output = torch.cat([layer_output, layer_output_text], dim=1)
  543. else:
  544. layer_output = apply_chunking_to_forward(
  545. self.feed_forward_chunk,
  546. self.chunk_size_feed_forward,
  547. self.seq_len_dim,
  548. attention_output,
  549. )
  550. return layer_output
  551. def feed_forward_chunk(self, attention_output):
  552. intermediate_output = self.intermediate(attention_output)
  553. layer_output = self.output(intermediate_output, attention_output)
  554. return layer_output
  555. def feed_forward_chunk_query(self, attention_output):
  556. intermediate_output = self.intermediate_query(attention_output)
  557. layer_output = self.output_query(intermediate_output, attention_output)
  558. return layer_output
  559. class InstructBlipVideoQFormerEncoder(nn.Module):
  560. def __init__(self, config):
  561. super().__init__()
  562. self.config = config
  563. self.layer = nn.ModuleList(
  564. [InstructBlipVideoQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  565. )
  566. self.gradient_checkpointing = False
  567. @can_return_tuple
  568. def forward(
  569. self,
  570. hidden_states,
  571. attention_mask=None,
  572. encoder_hidden_states=None,
  573. encoder_attention_mask=None,
  574. query_length=0,
  575. **kwargs: Unpack[TransformersKwargs],
  576. ):
  577. for i in range(self.config.num_hidden_layers):
  578. layer_module = self.layer[i]
  579. hidden_states = layer_module(
  580. hidden_states,
  581. attention_mask,
  582. encoder_hidden_states, # as a positional argument for gradient checkpointing
  583. encoder_attention_mask=encoder_attention_mask,
  584. query_length=query_length,
  585. **kwargs,
  586. )
  587. return BaseModelOutputWithPastAndCrossAttentions(
  588. last_hidden_state=hidden_states,
  589. )
  590. class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
  591. """
  592. Querying Transformer (Q-Former), used in InstructBlipVideo. Slightly modified from BLIP-2 as it also takes the
  593. instruction as input.
  594. """
  595. _supports_attention_backend = False # adds position on attn weights before last matmul
  596. _supports_flash_attn = False
  597. _supports_sdpa = False
  598. _supports_flex_attn = False
  599. _can_record_outputs = {
  600. "hidden_states": InstructBlipVideoQFormerLayer,
  601. "attentions": [
  602. OutputRecorder(InstructBlipVideoQFormerMultiHeadAttention, index=1, layer_name=".attention"),
  603. ],
  604. "cross_attentions": [
  605. OutputRecorder(InstructBlipVideoQFormerMultiHeadAttention, index=1, layer_name=".crossattention"),
  606. ],
  607. }
  608. def __init__(self, config: InstructBlipVideoQFormerConfig):
  609. super().__init__(config)
  610. self.config = config
  611. self.embeddings = InstructBlipVideoQFormerEmbeddings(config)
  612. self.encoder = InstructBlipVideoQFormerEncoder(config)
  613. self.post_init()
  614. def get_input_embeddings(self):
  615. return self.embeddings.word_embeddings
  616. def set_input_embeddings(self, value):
  617. self.embeddings.word_embeddings = value
  618. def get_extended_attention_mask(
  619. self,
  620. attention_mask: torch.Tensor,
  621. input_shape: tuple[int],
  622. device: torch.device,
  623. has_query: bool = False,
  624. ) -> torch.Tensor:
  625. """
  626. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
  627. Arguments:
  628. attention_mask (`torch.Tensor`):
  629. Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
  630. input_shape (`tuple[int]`):
  631. The shape of the input to the model.
  632. device: (`torch.device`):
  633. The device of the input to the model.
  634. Returns:
  635. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
  636. """
  637. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  638. # ourselves in which case we just need to make it broadcastable to all heads.
  639. if attention_mask.dim() == 3:
  640. extended_attention_mask = attention_mask[:, None, :, :]
  641. elif attention_mask.dim() == 2:
  642. # Provided a padding mask of dimensions [batch_size, seq_length]
  643. # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
  644. extended_attention_mask = attention_mask[:, None, None, :]
  645. else:
  646. raise ValueError(
  647. f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})",
  648. )
  649. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  650. # masked positions, this operation will create a tensor which is 0.0 for
  651. # positions we want to attend and -10000.0 for masked positions.
  652. # Since we are adding it to the raw scores before the softmax, this is
  653. # effectively the same as removing these entirely.
  654. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
  655. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  656. return extended_attention_mask
  657. @merge_with_config_defaults
  658. @capture_outputs
  659. @auto_docstring
  660. def forward(
  661. self,
  662. input_ids: torch.LongTensor,
  663. attention_mask: torch.FloatTensor | None = None,
  664. position_ids: torch.LongTensor | None = None,
  665. query_embeds: torch.Tensor | None = None,
  666. encoder_hidden_states: torch.FloatTensor | None = None,
  667. encoder_attention_mask: torch.FloatTensor | None = None,
  668. **kwargs: Unpack[TransformersKwargs],
  669. ) -> tuple[torch.FloatTensor] | BaseModelOutputWithPoolingAndCrossAttentions:
  670. r"""
  671. query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  672. Hidden states to be used in the attention computation. If cross-attention,
  673. will be used for the query (i.e., key and value will use the encoder_hidden_states).
  674. """
  675. if input_ids is None and query_embeds is None:
  676. raise ValueError("You have to specify query_embeds when input_ids is None")
  677. query_length = query_embeds.shape[1] if query_embeds is not None else 0
  678. embedding_output = self.embeddings(
  679. input_ids=input_ids,
  680. position_ids=position_ids,
  681. query_embeds=query_embeds,
  682. )
  683. input_shape = embedding_output.size()[:-1]
  684. batch_size, seq_length = input_shape
  685. device = embedding_output.device
  686. if attention_mask is None:
  687. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  688. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  689. # ourselves in which case we just need to make it broadcastable to all heads.
  690. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
  691. # If a 2D or 3D attention mask is provided for the cross-attention
  692. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  693. if encoder_hidden_states is not None:
  694. if isinstance(encoder_hidden_states, list):
  695. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
  696. else:
  697. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  698. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  699. if isinstance(encoder_attention_mask, list):
  700. encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
  701. elif encoder_attention_mask is None:
  702. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  703. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  704. else:
  705. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  706. else:
  707. encoder_extended_attention_mask = None
  708. encoder_outputs: BaseModelOutput = self.encoder(
  709. embedding_output,
  710. attention_mask=extended_attention_mask,
  711. encoder_hidden_states=encoder_hidden_states,
  712. encoder_attention_mask=encoder_extended_attention_mask,
  713. query_length=query_length,
  714. **kwargs,
  715. )
  716. sequence_output = encoder_outputs.last_hidden_state
  717. pooled_output = sequence_output[:, 0, :]
  718. return BaseModelOutputWithPoolingAndCrossAttentions(
  719. last_hidden_state=sequence_output,
  720. pooler_output=pooled_output,
  721. )
  722. @dataclass
  723. @auto_docstring(
  724. custom_intro="""
  725. Class defining the outputs of [`InstructBlipVideoForConditionalGeneration`].
  726. """
  727. )
  728. class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
  729. r"""
  730. loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  731. Language modeling loss from the language model.
  732. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  733. Prediction scores of the language modeling head of the language model.
  734. vision_outputs (`BaseModelOutputWithPooling`):
  735. Outputs of the vision encoder.
  736. qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
  737. Outputs of the Q-Former (Querying Transformer).
  738. language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
  739. Outputs of the language model.
  740. """
  741. loss: tuple[torch.FloatTensor] | None = None
  742. logits: tuple[torch.FloatTensor] | None = None
  743. vision_outputs: BaseModelOutputWithPooling | None = None
  744. qformer_outputs: BaseModelOutputWithPoolingAndCrossAttentions | None = None
  745. language_model_outputs: CausalLMOutputWithPast | Seq2SeqLMOutput | None = None
  746. def to_tuple(self) -> tuple[Any]:
  747. return tuple(
  748. self[k]
  749. if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
  750. else getattr(self, k).to_tuple()
  751. for k in self.keys()
  752. )
  753. @auto_docstring(
  754. custom_intro="""
  755. InstructBlipVideo base Model consisting of language model, qformer and vision encoder.
  756. """
  757. )
  758. class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
  759. main_input_name = "pixel_values"
  760. _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
  761. def __init__(self, config: InstructBlipVideoConfig):
  762. super().__init__(config)
  763. self.vision_model = InstructBlipVideoVisionModel(config.vision_config)
  764. self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
  765. self.qformer = InstructBlipVideoQFormerModel(config.qformer_config)
  766. self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
  767. self.language_model = AutoModel.from_config(config.text_config)
  768. # Initialize weights and apply final processing
  769. self.post_init()
  770. def get_input_embeddings(self):
  771. return self.language_model.get_input_embeddings()
  772. def set_input_embeddings(self, value):
  773. self.language_model.set_input_embeddings(value)
  774. def _preprocess_accelerate(self):
  775. r"""
  776. Some pre-processing hacks to make the model `accelerate` compatible. Check
  777. https://github.com/huggingface/transformers/pull/21707 for more details.
  778. """
  779. hf_device_map = self.hf_device_map
  780. if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
  781. # warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
  782. logger.warning(
  783. "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
  784. " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
  785. " Please pass a `device_map` that contains `language_model` to remove this warning."
  786. " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
  787. " more details on creating a `device_map` for large models.",
  788. )
  789. if hasattr(self.language_model, "_hf_hook"):
  790. self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
  791. def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
  792. """
  793. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
  794. """
  795. if input_ids is None:
  796. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  797. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  798. )
  799. special_image_mask = special_image_mask.all(-1)
  800. else:
  801. special_image_mask = input_ids == self.config.image_token_id
  802. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  803. return special_image_mask
  804. @can_return_tuple
  805. @auto_docstring
  806. def forward(
  807. self,
  808. pixel_values: torch.FloatTensor,
  809. qformer_input_ids: torch.FloatTensor,
  810. qformer_attention_mask: torch.LongTensor | None = None,
  811. input_ids: torch.FloatTensor | None = None,
  812. attention_mask: torch.LongTensor | None = None,
  813. decoder_input_ids: torch.LongTensor | None = None,
  814. decoder_attention_mask: torch.LongTensor | None = None,
  815. inputs_embeds: torch.Tensor | None = None,
  816. interpolate_pos_encoding: bool = False,
  817. use_cache: bool | None = None,
  818. **kwargs: Unpack[TransformersKwargs],
  819. ) -> tuple | InstructBlipVideoForConditionalGenerationModelOutput:
  820. r"""
  821. qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  822. Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
  823. to serve as text prompt, which the Q-Former model will encode.
  824. Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
  825. details.
  826. [What are input IDs?](../glossary#input-ids)
  827. qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  828. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  829. - 1 for tokens that are **not masked**,
  830. - 0 for tokens that are **masked**.
  831. [What are attention masks?](../glossary#attention-mask)
  832. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  833. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  834. be used by default.
  835. Only relevant in case an encoder-decoder language model (like T5) is used.
  836. """
  837. # step 1: forward the images through the vision encoder,
  838. # we process in a batched way, later unbatch it back (video has frames=4 always)
  839. batch_size, frames, channel, height, width = pixel_values.shape
  840. pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
  841. vision_outputs = self.vision_model(
  842. pixel_values=pixel_values,
  843. interpolate_pos_encoding=interpolate_pos_encoding,
  844. **kwargs,
  845. )
  846. image_embeds = vision_outputs[0]
  847. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  848. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  849. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  850. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  851. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  852. if qformer_attention_mask is None:
  853. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  854. qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
  855. qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
  856. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  857. query_outputs = self.qformer(
  858. input_ids=qformer_input_ids,
  859. attention_mask=qformer_attention_mask,
  860. query_embeds=query_tokens,
  861. encoder_hidden_states=image_embeds,
  862. encoder_attention_mask=image_attention_mask,
  863. **kwargs,
  864. )
  865. query_output = query_outputs[0][:, : query_tokens.size(1), :]
  866. # step 3: use the language model, conditioned on the query outputs and the prompt
  867. language_model_inputs = self.language_projection(query_output)
  868. # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
  869. language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
  870. if inputs_embeds is None:
  871. inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
  872. special_image_mask = input_ids == self.config.video_token_id
  873. if attention_mask is None:
  874. attention_mask = torch.ones_like(input_ids)
  875. else:
  876. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  877. torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
  878. )
  879. special_image_mask = special_image_mask.all(-1)
  880. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  881. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  882. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  883. if self.config.use_decoder_only_language_model:
  884. outputs = self.language_model(
  885. inputs_embeds=inputs_embeds,
  886. attention_mask=attention_mask,
  887. use_cache=use_cache,
  888. **kwargs,
  889. )
  890. else:
  891. outputs = self.language_model(
  892. inputs_embeds=inputs_embeds,
  893. attention_mask=attention_mask,
  894. decoder_input_ids=decoder_input_ids,
  895. decoder_attention_mask=decoder_attention_mask,
  896. use_cache=use_cache,
  897. **kwargs,
  898. )
  899. return InstructBlipVideoForConditionalGenerationModelOutput(
  900. vision_outputs=vision_outputs,
  901. qformer_outputs=query_outputs,
  902. language_model_outputs=outputs,
  903. )
  904. @dataclass
  905. @auto_docstring
  906. class BaseModelOutputWithVisionQformerOutputs(BaseModelOutputWithPooling):
  907. r"""
  908. vision_outputs (`BaseModelOutputWithPooling`):
  909. Outputs of the vision encoder.
  910. qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
  911. Outputs of the Q-Former (Querying Transformer).
  912. """
  913. vision_outputs: BaseModelOutputWithPooling | None = None
  914. qformer_outputs: BaseModelOutputWithPoolingAndCrossAttentions | None = None
  915. @auto_docstring(
  916. custom_intro="""
  917. InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision
  918. encoder, Querying Transformer (Q-Former) and a language model.
  919. One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
  920. the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
  921. """
  922. )
  923. class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin):
  924. config: InstructBlipVideoConfig
  925. main_input_name = "pixel_values"
  926. _can_compile_fullgraph = True
  927. _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
  928. def __init__(self, config: InstructBlipVideoConfig):
  929. super().__init__(config)
  930. self.vision_model = InstructBlipVideoVisionModel._from_config(config.vision_config)
  931. self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
  932. self.qformer = InstructBlipVideoQFormerModel._from_config(config.qformer_config)
  933. self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
  934. if config.use_decoder_only_language_model:
  935. language_model = AutoModelForCausalLM.from_config(config.text_config)
  936. else:
  937. language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
  938. self.language_model = language_model
  939. # Initialize weights and apply final processing
  940. self.post_init()
  941. def get_input_embeddings(self):
  942. return self.language_model.get_input_embeddings()
  943. def set_input_embeddings(self, value):
  944. self.language_model.set_input_embeddings(value)
  945. def set_output_embeddings(self, new_embeddings):
  946. self.language_model.set_output_embeddings(new_embeddings)
  947. def get_output_embeddings(self) -> nn.Module:
  948. return self.language_model.get_output_embeddings()
  949. def get_encoder(self, modality=None):
  950. if modality is None:
  951. return self.language_model.get_encoder()
  952. else:
  953. return super().get_encoder(modality=modality)
  954. def get_decoder(self):
  955. return self.language_model.get_decoder()
  956. def _preprocess_accelerate(self):
  957. r"""
  958. Some pre-processing hacks to make the model `accelerate` compatible. Check
  959. https://github.com/huggingface/transformers/pull/21707 for more details.
  960. """
  961. hf_device_map = self.hf_device_map
  962. if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
  963. # warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
  964. logger.warning(
  965. "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
  966. " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
  967. " Please pass a `device_map` that contains `language_model` to remove this warning."
  968. " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
  969. " more details on creating a `device_map` for large models.",
  970. )
  971. if hasattr(self.language_model, "_hf_hook"):
  972. self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
  973. def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
  974. """
  975. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
  976. """
  977. if input_ids is None:
  978. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  979. torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
  980. )
  981. special_image_mask = special_image_mask.all(-1)
  982. else:
  983. special_image_mask = input_ids == self.config.video_token_id
  984. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  985. return special_image_mask
  986. @can_return_tuple
  987. @auto_docstring
  988. def forward(
  989. self,
  990. pixel_values: torch.FloatTensor,
  991. qformer_input_ids: torch.FloatTensor,
  992. qformer_attention_mask: torch.LongTensor | None = None,
  993. input_ids: torch.FloatTensor | None = None,
  994. attention_mask: torch.LongTensor | None = None,
  995. decoder_input_ids: torch.LongTensor | None = None,
  996. decoder_attention_mask: torch.LongTensor | None = None,
  997. inputs_embeds: torch.FloatTensor | None = None,
  998. labels: torch.LongTensor | None = None,
  999. interpolate_pos_encoding: bool = False,
  1000. use_cache: bool | None = None,
  1001. **kwargs: Unpack[TransformersKwargs],
  1002. ) -> tuple | InstructBlipVideoForConditionalGenerationModelOutput:
  1003. r"""
  1004. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length)):
  1005. The sequence used as a prompt to be fed to the Q-Former module.
  1006. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1007. Mask to avoid performing attention on padding token indices.
  1008. Examples:
  1009. ```python
  1010. >>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
  1011. >>> import torch
  1012. >>> from huggingface_hub import hf_hub_download
  1013. >>> import av
  1014. >>> import numpy as np
  1015. >>> def read_video_pyav(container, indices):
  1016. ... '''
  1017. ... Decode the video with PyAV decoder.
  1018. ... Args:
  1019. ... container (`av.container.input.InputContainer`): PyAV container.
  1020. ... indices (`list[int]`): List of frame indices to decode.
  1021. ... Returns:
  1022. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  1023. ... '''
  1024. ... frames = []
  1025. ... container.seek(0)
  1026. ... start_index = indices[0]
  1027. ... end_index = indices[-1]
  1028. ... for i, frame in enumerate(container.decode(video=0)):
  1029. ... if i > end_index:
  1030. ... break
  1031. ... if i >= start_index and i in indices:
  1032. ... frames.append(frame)
  1033. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  1034. >>> model = InstructBlipVideoForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto")
  1035. >>> processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
  1036. >>> file_path = hf_hub_download(
  1037. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  1038. ... )
  1039. >>> container = av.open(file_path)
  1040. >>> # sample uniformly 4 frames from the videWhy is this video funny?o
  1041. >>> total_frames = container.streams.video[0].frames
  1042. >>> indices = np.arange(0, total_frames, total_frames / 4).astype(int)
  1043. >>> clip = read_video_pyav(container, indices)
  1044. >>> prompt = "What is happening in the video?"
  1045. >>> inputs = processor(text=prompt, images=clip, return_tensors="pt").to(model.device)
  1046. >>> outputs = model.generate(
  1047. ... **inputs,
  1048. ... do_sample=False,
  1049. ... num_beams=5,
  1050. ... max_length=256,
  1051. ... repetition_penalty=1.5,
  1052. ... length_penalty=1.0,
  1053. ... )
  1054. >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
  1055. >>> print(generated_text)
  1056. "A person is eating a bowl of pasta, and they are using a fork to eat it. The person is sitting at a table, and the plate of pasta is on the table in front"
  1057. ```"""
  1058. video_features: BaseModelOutputWithVisionQformerOutputs = self.get_video_features(
  1059. pixel_values,
  1060. qformer_input_ids=qformer_input_ids,
  1061. qformer_attention_mask=qformer_attention_mask,
  1062. interpolate_pos_encoding=interpolate_pos_encoding,
  1063. **kwargs,
  1064. )
  1065. language_model_inputs = video_features.pooler_output
  1066. qformer_outputs = video_features.qformer_outputs
  1067. vision_outputs = video_features.vision_outputs
  1068. if inputs_embeds is None:
  1069. inputs_embeds = self.get_input_embeddings()(input_ids)
  1070. if attention_mask is None:
  1071. attention_mask = torch.ones_like(input_ids)
  1072. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  1073. special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
  1074. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  1075. if self.config.use_decoder_only_language_model:
  1076. outputs = self.language_model(
  1077. inputs_embeds=inputs_embeds,
  1078. attention_mask=attention_mask,
  1079. use_cache=use_cache,
  1080. **kwargs,
  1081. )
  1082. logits = outputs[0]
  1083. loss = None
  1084. if labels is not None:
  1085. loss = self.loss_function(
  1086. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  1087. )
  1088. else:
  1089. outputs = self.language_model(
  1090. inputs_embeds=inputs_embeds,
  1091. attention_mask=attention_mask,
  1092. decoder_input_ids=decoder_input_ids,
  1093. decoder_attention_mask=decoder_attention_mask,
  1094. labels=labels,
  1095. use_cache=use_cache,
  1096. **kwargs,
  1097. )
  1098. loss = outputs.loss
  1099. logits = outputs.logits
  1100. return InstructBlipVideoForConditionalGenerationModelOutput(
  1101. loss=loss,
  1102. logits=logits,
  1103. vision_outputs=vision_outputs,
  1104. qformer_outputs=qformer_outputs,
  1105. language_model_outputs=outputs,
  1106. )
  1107. @torch.no_grad()
  1108. def generate(
  1109. self,
  1110. pixel_values: torch.FloatTensor,
  1111. qformer_input_ids: torch.LongTensor | None = None,
  1112. qformer_attention_mask: torch.LongTensor | None = None,
  1113. input_ids: torch.LongTensor | None = None,
  1114. attention_mask: torch.LongTensor | None = None,
  1115. inputs_embeds: torch.FloatTensor | None = None,
  1116. interpolate_pos_encoding: bool = False,
  1117. **generate_kwargs,
  1118. ) -> torch.LongTensor:
  1119. r"""
  1120. Overrides `generate` function to be able to use the model as a conditional generator.
  1121. Args:
  1122. pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or
  1123. (batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed.
  1124. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1125. The sequence used as a prompt to be fed to the Q-Former module.
  1126. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1127. Mask to avoid performing attention on padding token indices.
  1128. input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1129. The sequence used as a prompt for the generation.
  1130. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1131. Mask to avoid performing attention on padding token indices.
  1132. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  1133. Embedded representation of the inputs. Should be float, not int tokens.
  1134. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  1135. Whether to interpolate the positional encoding of the image embeddings.
  1136. Returns:
  1137. captions (list): A list of strings of length batch_size * num_captions.
  1138. """
  1139. if hasattr(self, "hf_device_map"):
  1140. # preprocess for `accelerate`
  1141. self._preprocess_accelerate()
  1142. batch_size = pixel_values.shape[0]
  1143. video_features: BaseModelOutputWithVisionQformerOutputs = self.get_video_features(
  1144. pixel_values,
  1145. qformer_input_ids=qformer_input_ids,
  1146. qformer_attention_mask=qformer_attention_mask,
  1147. interpolate_pos_encoding=interpolate_pos_encoding,
  1148. )
  1149. language_model_inputs = video_features.pooler_output
  1150. if inputs_embeds is None:
  1151. if input_ids is None:
  1152. video_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4
  1153. start_tokens = video_tokens + [self.config.text_config.bos_token_id]
  1154. input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
  1155. input_ids = input_ids.repeat(batch_size, 1)
  1156. inputs_embeds = self.get_input_embeddings()(input_ids)
  1157. if attention_mask is None:
  1158. attention_mask = torch.ones_like(input_ids)
  1159. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  1160. special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
  1161. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  1162. inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
  1163. if not self.language_model.config.is_encoder_decoder:
  1164. inputs["input_ids"] = input_ids
  1165. outputs = self.language_model.generate(**inputs, **generate_kwargs)
  1166. return outputs
  1167. @can_return_tuple
  1168. @auto_docstring
  1169. def get_video_features(
  1170. self,
  1171. pixel_values: torch.FloatTensor,
  1172. qformer_input_ids: torch.LongTensor,
  1173. qformer_attention_mask: torch.LongTensor | None = None,
  1174. interpolate_pos_encoding: bool | None = False,
  1175. **kwargs: Unpack[TransformersKwargs],
  1176. ) -> tuple | BaseModelOutputWithVisionQformerOutputs:
  1177. r"""
  1178. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  1179. The tensors corresponding to the input images.
  1180. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length)):
  1181. The sequence used as a prompt to be fed to the Q-Former module.
  1182. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1183. Mask to avoid performing attention on padding token indices.
  1184. """
  1185. # step 1: forward the images through the vision encoder,
  1186. # we process in a batched way, later unbatch it back (video has frames=4 always)
  1187. batch_size, frames, channel, height, width = pixel_values.shape
  1188. pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
  1189. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  1190. pixel_values=pixel_values,
  1191. interpolate_pos_encoding=interpolate_pos_encoding,
  1192. **kwargs,
  1193. )
  1194. vision_outputs = BaseModelOutputWithVisionQformerOutputs(
  1195. last_hidden_state=vision_outputs.last_hidden_state,
  1196. pooler_output=vision_outputs.pooler_output,
  1197. hidden_states=vision_outputs.hidden_states,
  1198. attentions=vision_outputs.attentions,
  1199. vision_outputs=vision_outputs,
  1200. qformer_outputs=None,
  1201. )
  1202. image_embeds = vision_outputs[0]
  1203. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  1204. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1205. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  1206. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  1207. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1208. if qformer_attention_mask is None:
  1209. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  1210. qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
  1211. qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
  1212. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  1213. qformer_outputs = self.qformer(
  1214. input_ids=qformer_input_ids,
  1215. attention_mask=qformer_attention_mask,
  1216. query_embeds=query_tokens,
  1217. encoder_hidden_states=image_embeds,
  1218. encoder_attention_mask=image_attention_mask,
  1219. **kwargs,
  1220. )
  1221. vision_outputs.qformer_outputs = qformer_outputs
  1222. query_output = qformer_outputs[0][:, : query_tokens.size(1), :]
  1223. # step 3: use the language model, conditioned on the query outputs and the prompt
  1224. video_features = self.language_projection(query_output)
  1225. # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
  1226. video_features = video_features.reshape(batch_size, self.config.num_query_tokens * frames, -1)
  1227. vision_outputs.pooler_output = video_features
  1228. return vision_outputs
  1229. __all__ = [
  1230. "InstructBlipVideoVisionModel",
  1231. "InstructBlipVideoPreTrainedModel",
  1232. "InstructBlipVideoQFormerModel",
  1233. "InstructBlipVideoModel",
  1234. "InstructBlipVideoForConditionalGeneration",
  1235. ]