modular_instructblipvideo.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from huggingface_hub.dataclasses import strict
  16. from transformers.models.instructblip.configuration_instructblip import (
  17. InstructBlipConfig,
  18. InstructBlipQFormerConfig,
  19. InstructBlipVisionConfig,
  20. )
  21. from transformers.models.instructblip.modeling_instructblip import (
  22. BaseModelOutputWithVisionQformerOutputs,
  23. InstructBlipForConditionalGeneration,
  24. InstructBlipForConditionalGenerationModelOutput,
  25. InstructBlipModel,
  26. InstructBlipPreTrainedModel,
  27. InstructBlipQFormerModel,
  28. InstructBlipVisionModel,
  29. TransformersKwargs,
  30. )
  31. from ...modeling_outputs import BaseModelOutputWithPooling
  32. from ...processing_utils import Unpack
  33. from ...utils import auto_docstring, can_return_tuple
  34. @auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl")
  35. @strict
  36. class InstructBlipVideoVisionConfig(InstructBlipVisionConfig):
  37. r"""
  38. Example:
  39. ```python
  40. >>> from transformers import InstructBlipVideoVisionConfig, InstructBlipVideoVisionModel
  41. >>> # Initializing a InstructBlipVideoVisionConfig with Salesforce/instructblip-flan-t5-xl style configuration
  42. >>> configuration = InstructBlipVideoVisionConfig()
  43. >>> # Initializing a InstructBlipVideoVisionModel (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration
  44. >>> model = InstructBlipVideoVisionModel(configuration)
  45. >>> # Accessing the model configuration
  46. >>> configuration = model.config
  47. ```"""
  48. @auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl")
  49. @strict
  50. class InstructBlipVideoQFormerConfig(InstructBlipQFormerConfig):
  51. r"""
  52. cross_attention_frequency (`int`, *optional*, defaults to 2):
  53. The frequency of adding cross-attention to the Transformer layers.
  54. encoder_hidden_size (`int`, *optional*, defaults to 1408):
  55. The hidden size of the hidden states for cross-attention.
  56. Examples:
  57. ```python
  58. >>> from transformers import InstructBlipVideoQFormerConfig, InstructBlipVideoQFormerModel
  59. >>> # Initializing a InstructBlipVideo Salesforce/instructblip-flan-t5-xl style configuration
  60. >>> configuration = InstructBlipVideoQFormerConfig()
  61. >>> # Initializing a model (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration
  62. >>> model = InstructBlipVideoQFormerModel(configuration)
  63. >>> # Accessing the model configuration
  64. >>> configuration = model.config
  65. ```"""
  66. @auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl")
  67. @strict
  68. class InstructBlipVideoConfig(InstructBlipConfig):
  69. r"""
  70. qformer_config (`dict`, *optional*):
  71. Dictionary of configuration options used to initialize [`InstructBlipVideoQFormerConfig`].
  72. num_query_tokens (`int`, *optional*, defaults to 32):
  73. The number of query tokens passed through the Transformer.
  74. Example:
  75. ```python
  76. >>> from transformers import (
  77. ... InstructBlipVideoVisionConfig,
  78. ... InstructBlipVideoQFormerConfig,
  79. ... OPTConfig,
  80. ... InstructBlipVideoConfig,
  81. ... InstructBlipVideoForConditionalGeneration,
  82. ... )
  83. >>> # Initializing a InstructBlipVideoConfig with Salesforce/instructblip-flan-t5-xl style configuration
  84. >>> configuration = InstructBlipVideoConfig()
  85. >>> # Initializing a InstructBlipVideoForConditionalGeneration (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration
  86. >>> model = InstructBlipVideoForConditionalGeneration(configuration)
  87. >>> # Accessing the model configuration
  88. >>> configuration = model.config
  89. >>> # We can also initialize a InstructBlipVideoConfig from a InstructBlipVideoVisionConfig, InstructBlipVideoQFormerConfig and any PreTrainedConfig
  90. >>> # Initializing Instructblipvideo vision, Instructblipvideo Q-Former and language model configurations
  91. >>> vision_config = InstructBlipVideoVisionConfig()
  92. >>> qformer_config = InstructBlipVideoQFormerConfig()
  93. >>> text_config = OPTConfig()
  94. >>> config = InstructBlipVideoConfig(vision_config=vision_config, qformer_config=qformer_config, text_config=text_config)
  95. ```"""
  96. attribute_map = {"video_token_id": "video_token_index"}
  97. video_token_index: int | None = None
  98. image_token_index = AttributeError()
  99. class InstructBlipVideoPreTrainedModel(InstructBlipPreTrainedModel):
  100. input_modalities = ("video", "text")
  101. class InstructBlipVideoVisionModel(InstructBlipVisionModel):
  102. input_modalities = "video"
  103. class InstructBlipVideoQFormerModel(InstructBlipQFormerModel):
  104. pass
  105. class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForConditionalGenerationModelOutput):
  106. pass
  107. class InstructBlipVideoModel(InstructBlipModel):
  108. @can_return_tuple
  109. @auto_docstring
  110. def forward(
  111. self,
  112. pixel_values: torch.FloatTensor,
  113. qformer_input_ids: torch.FloatTensor,
  114. qformer_attention_mask: torch.LongTensor | None = None,
  115. input_ids: torch.FloatTensor | None = None,
  116. attention_mask: torch.LongTensor | None = None,
  117. decoder_input_ids: torch.LongTensor | None = None,
  118. decoder_attention_mask: torch.LongTensor | None = None,
  119. inputs_embeds: torch.Tensor | None = None,
  120. interpolate_pos_encoding: bool = False,
  121. use_cache: bool | None = None,
  122. **kwargs: Unpack[TransformersKwargs],
  123. ) -> tuple | InstructBlipVideoForConditionalGenerationModelOutput:
  124. # step 1: forward the images through the vision encoder,
  125. # we process in a batched way, later unbatch it back (video has frames=4 always)
  126. batch_size, frames, channel, height, width = pixel_values.shape
  127. pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
  128. vision_outputs = self.vision_model(
  129. pixel_values=pixel_values,
  130. interpolate_pos_encoding=interpolate_pos_encoding,
  131. **kwargs,
  132. )
  133. image_embeds = vision_outputs[0]
  134. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  135. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  136. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  137. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  138. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  139. if qformer_attention_mask is None:
  140. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  141. qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
  142. qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
  143. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  144. query_outputs = self.qformer(
  145. input_ids=qformer_input_ids,
  146. attention_mask=qformer_attention_mask,
  147. query_embeds=query_tokens,
  148. encoder_hidden_states=image_embeds,
  149. encoder_attention_mask=image_attention_mask,
  150. **kwargs,
  151. )
  152. query_output = query_outputs[0][:, : query_tokens.size(1), :]
  153. # step 3: use the language model, conditioned on the query outputs and the prompt
  154. language_model_inputs = self.language_projection(query_output)
  155. # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
  156. language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
  157. if inputs_embeds is None:
  158. inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
  159. special_image_mask = input_ids == self.config.video_token_id
  160. if attention_mask is None:
  161. attention_mask = torch.ones_like(input_ids)
  162. else:
  163. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  164. torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
  165. )
  166. special_image_mask = special_image_mask.all(-1)
  167. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  168. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  169. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  170. if self.config.use_decoder_only_language_model:
  171. outputs = self.language_model(
  172. inputs_embeds=inputs_embeds,
  173. attention_mask=attention_mask,
  174. use_cache=use_cache,
  175. **kwargs,
  176. )
  177. else:
  178. outputs = self.language_model(
  179. inputs_embeds=inputs_embeds,
  180. attention_mask=attention_mask,
  181. decoder_input_ids=decoder_input_ids,
  182. decoder_attention_mask=decoder_attention_mask,
  183. use_cache=use_cache,
  184. **kwargs,
  185. )
  186. return InstructBlipVideoForConditionalGenerationModelOutput(
  187. vision_outputs=vision_outputs,
  188. qformer_outputs=query_outputs,
  189. language_model_outputs=outputs,
  190. )
  191. class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
  192. @can_return_tuple
  193. @auto_docstring
  194. def get_video_features(
  195. self,
  196. pixel_values: torch.FloatTensor,
  197. qformer_input_ids: torch.LongTensor,
  198. qformer_attention_mask: torch.LongTensor | None = None,
  199. interpolate_pos_encoding: bool | None = False,
  200. **kwargs: Unpack[TransformersKwargs],
  201. ) -> tuple | BaseModelOutputWithVisionQformerOutputs:
  202. r"""
  203. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  204. The tensors corresponding to the input images.
  205. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length)):
  206. The sequence used as a prompt to be fed to the Q-Former module.
  207. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  208. Mask to avoid performing attention on padding token indices.
  209. """
  210. # step 1: forward the images through the vision encoder,
  211. # we process in a batched way, later unbatch it back (video has frames=4 always)
  212. batch_size, frames, channel, height, width = pixel_values.shape
  213. pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
  214. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  215. pixel_values=pixel_values,
  216. interpolate_pos_encoding=interpolate_pos_encoding,
  217. **kwargs,
  218. )
  219. vision_outputs = BaseModelOutputWithVisionQformerOutputs(
  220. last_hidden_state=vision_outputs.last_hidden_state,
  221. pooler_output=vision_outputs.pooler_output,
  222. hidden_states=vision_outputs.hidden_states,
  223. attentions=vision_outputs.attentions,
  224. vision_outputs=vision_outputs,
  225. qformer_outputs=None,
  226. )
  227. image_embeds = vision_outputs[0]
  228. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  229. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  230. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  231. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  232. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  233. if qformer_attention_mask is None:
  234. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  235. qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
  236. qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
  237. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  238. qformer_outputs = self.qformer(
  239. input_ids=qformer_input_ids,
  240. attention_mask=qformer_attention_mask,
  241. query_embeds=query_tokens,
  242. encoder_hidden_states=image_embeds,
  243. encoder_attention_mask=image_attention_mask,
  244. **kwargs,
  245. )
  246. vision_outputs.qformer_outputs = qformer_outputs
  247. query_output = qformer_outputs[0][:, : query_tokens.size(1), :]
  248. # step 3: use the language model, conditioned on the query outputs and the prompt
  249. video_features = self.language_projection(query_output)
  250. # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
  251. video_features = video_features.reshape(batch_size, self.config.num_query_tokens * frames, -1)
  252. vision_outputs.pooler_output = video_features
  253. return vision_outputs
  254. def get_image_features(**super_kwargs):
  255. raise AttributeError("No need to inherit as this architecture only supports videos.")
  256. def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
  257. """
  258. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
  259. """
  260. if input_ids is None:
  261. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  262. torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
  263. )
  264. special_image_mask = special_image_mask.all(-1)
  265. else:
  266. special_image_mask = input_ids == self.config.video_token_id
  267. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  268. return special_image_mask
  269. @can_return_tuple
  270. @auto_docstring
  271. def forward(
  272. self,
  273. pixel_values: torch.FloatTensor,
  274. qformer_input_ids: torch.FloatTensor,
  275. qformer_attention_mask: torch.LongTensor | None = None,
  276. input_ids: torch.FloatTensor | None = None,
  277. attention_mask: torch.LongTensor | None = None,
  278. decoder_input_ids: torch.LongTensor | None = None,
  279. decoder_attention_mask: torch.LongTensor | None = None,
  280. inputs_embeds: torch.FloatTensor | None = None,
  281. labels: torch.LongTensor | None = None,
  282. interpolate_pos_encoding: bool = False,
  283. use_cache: bool | None = None,
  284. **kwargs: Unpack[TransformersKwargs],
  285. ) -> tuple | InstructBlipVideoForConditionalGenerationModelOutput:
  286. r"""
  287. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length)):
  288. The sequence used as a prompt to be fed to the Q-Former module.
  289. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  290. Mask to avoid performing attention on padding token indices.
  291. Examples:
  292. ```python
  293. >>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
  294. >>> import torch
  295. >>> from huggingface_hub import hf_hub_download
  296. >>> import av
  297. >>> import numpy as np
  298. >>> def read_video_pyav(container, indices):
  299. ... '''
  300. ... Decode the video with PyAV decoder.
  301. ... Args:
  302. ... container (`av.container.input.InputContainer`): PyAV container.
  303. ... indices (`list[int]`): List of frame indices to decode.
  304. ... Returns:
  305. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  306. ... '''
  307. ... frames = []
  308. ... container.seek(0)
  309. ... start_index = indices[0]
  310. ... end_index = indices[-1]
  311. ... for i, frame in enumerate(container.decode(video=0)):
  312. ... if i > end_index:
  313. ... break
  314. ... if i >= start_index and i in indices:
  315. ... frames.append(frame)
  316. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  317. >>> model = InstructBlipVideoForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto")
  318. >>> processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
  319. >>> file_path = hf_hub_download(
  320. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  321. ... )
  322. >>> container = av.open(file_path)
  323. >>> # sample uniformly 4 frames from the videWhy is this video funny?o
  324. >>> total_frames = container.streams.video[0].frames
  325. >>> indices = np.arange(0, total_frames, total_frames / 4).astype(int)
  326. >>> clip = read_video_pyav(container, indices)
  327. >>> prompt = "What is happening in the video?"
  328. >>> inputs = processor(text=prompt, images=clip, return_tensors="pt").to(model.device)
  329. >>> outputs = model.generate(
  330. ... **inputs,
  331. ... do_sample=False,
  332. ... num_beams=5,
  333. ... max_length=256,
  334. ... repetition_penalty=1.5,
  335. ... length_penalty=1.0,
  336. ... )
  337. >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
  338. >>> print(generated_text)
  339. "A person is eating a bowl of pasta, and they are using a fork to eat it. The person is sitting at a table, and the plate of pasta is on the table in front"
  340. ```"""
  341. video_features: BaseModelOutputWithVisionQformerOutputs = self.get_video_features(
  342. pixel_values,
  343. qformer_input_ids=qformer_input_ids,
  344. qformer_attention_mask=qformer_attention_mask,
  345. interpolate_pos_encoding=interpolate_pos_encoding,
  346. **kwargs,
  347. )
  348. language_model_inputs = video_features.pooler_output
  349. qformer_outputs = video_features.qformer_outputs
  350. vision_outputs = video_features.vision_outputs
  351. if inputs_embeds is None:
  352. inputs_embeds = self.get_input_embeddings()(input_ids)
  353. if attention_mask is None:
  354. attention_mask = torch.ones_like(input_ids)
  355. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  356. special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
  357. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  358. if self.config.use_decoder_only_language_model:
  359. outputs = self.language_model(
  360. inputs_embeds=inputs_embeds,
  361. attention_mask=attention_mask,
  362. use_cache=use_cache,
  363. **kwargs,
  364. )
  365. logits = outputs[0]
  366. loss = None
  367. if labels is not None:
  368. loss = self.loss_function(
  369. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  370. )
  371. else:
  372. outputs = self.language_model(
  373. inputs_embeds=inputs_embeds,
  374. attention_mask=attention_mask,
  375. decoder_input_ids=decoder_input_ids,
  376. decoder_attention_mask=decoder_attention_mask,
  377. labels=labels,
  378. use_cache=use_cache,
  379. **kwargs,
  380. )
  381. loss = outputs.loss
  382. logits = outputs.logits
  383. return InstructBlipVideoForConditionalGenerationModelOutput(
  384. loss=loss,
  385. logits=logits,
  386. vision_outputs=vision_outputs,
  387. qformer_outputs=qformer_outputs,
  388. language_model_outputs=outputs,
  389. )
  390. @torch.no_grad()
  391. def generate(
  392. self,
  393. pixel_values: torch.FloatTensor,
  394. qformer_input_ids: torch.LongTensor | None = None,
  395. qformer_attention_mask: torch.LongTensor | None = None,
  396. input_ids: torch.LongTensor | None = None,
  397. attention_mask: torch.LongTensor | None = None,
  398. inputs_embeds: torch.FloatTensor | None = None,
  399. interpolate_pos_encoding: bool = False,
  400. **generate_kwargs,
  401. ) -> torch.LongTensor:
  402. r"""
  403. Overrides `generate` function to be able to use the model as a conditional generator.
  404. Args:
  405. pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or
  406. (batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed.
  407. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  408. The sequence used as a prompt to be fed to the Q-Former module.
  409. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  410. Mask to avoid performing attention on padding token indices.
  411. input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  412. The sequence used as a prompt for the generation.
  413. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  414. Mask to avoid performing attention on padding token indices.
  415. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  416. Embedded representation of the inputs. Should be float, not int tokens.
  417. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  418. Whether to interpolate the positional encoding of the image embeddings.
  419. Returns:
  420. captions (list): A list of strings of length batch_size * num_captions.
  421. """
  422. if hasattr(self, "hf_device_map"):
  423. # preprocess for `accelerate`
  424. self._preprocess_accelerate()
  425. batch_size = pixel_values.shape[0]
  426. video_features: BaseModelOutputWithVisionQformerOutputs = self.get_video_features(
  427. pixel_values,
  428. qformer_input_ids=qformer_input_ids,
  429. qformer_attention_mask=qformer_attention_mask,
  430. interpolate_pos_encoding=interpolate_pos_encoding,
  431. )
  432. language_model_inputs = video_features.pooler_output
  433. if inputs_embeds is None:
  434. if input_ids is None:
  435. video_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4
  436. start_tokens = video_tokens + [self.config.text_config.bos_token_id]
  437. input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
  438. input_ids = input_ids.repeat(batch_size, 1)
  439. inputs_embeds = self.get_input_embeddings()(input_ids)
  440. if attention_mask is None:
  441. attention_mask = torch.ones_like(input_ids)
  442. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  443. special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
  444. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  445. inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
  446. if not self.language_model.config.is_encoder_decoder:
  447. inputs["input_ids"] = input_ids
  448. outputs = self.language_model.generate(**inputs, **generate_kwargs)
  449. return outputs
  450. __all__ = [
  451. "InstructBlipVideoConfig",
  452. "InstructBlipVideoQFormerConfig",
  453. "InstructBlipVideoVisionConfig",
  454. "InstructBlipVideoVisionModel",
  455. "InstructBlipVideoPreTrainedModel",
  456. "InstructBlipVideoQFormerModel",
  457. "InstructBlipVideoModel",
  458. "InstructBlipVideoForConditionalGeneration",
  459. ]