| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546 |
- # Copyright 2024 HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- from huggingface_hub.dataclasses import strict
- from transformers.models.instructblip.configuration_instructblip import (
- InstructBlipConfig,
- InstructBlipQFormerConfig,
- InstructBlipVisionConfig,
- )
- from transformers.models.instructblip.modeling_instructblip import (
- BaseModelOutputWithVisionQformerOutputs,
- InstructBlipForConditionalGeneration,
- InstructBlipForConditionalGenerationModelOutput,
- InstructBlipModel,
- InstructBlipPreTrainedModel,
- InstructBlipQFormerModel,
- InstructBlipVisionModel,
- TransformersKwargs,
- )
- from ...modeling_outputs import BaseModelOutputWithPooling
- from ...processing_utils import Unpack
- from ...utils import auto_docstring, can_return_tuple
- @auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl")
- @strict
- class InstructBlipVideoVisionConfig(InstructBlipVisionConfig):
- r"""
- Example:
- ```python
- >>> from transformers import InstructBlipVideoVisionConfig, InstructBlipVideoVisionModel
- >>> # Initializing a InstructBlipVideoVisionConfig with Salesforce/instructblip-flan-t5-xl style configuration
- >>> configuration = InstructBlipVideoVisionConfig()
- >>> # Initializing a InstructBlipVideoVisionModel (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration
- >>> model = InstructBlipVideoVisionModel(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
- @auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl")
- @strict
- class InstructBlipVideoQFormerConfig(InstructBlipQFormerConfig):
- r"""
- cross_attention_frequency (`int`, *optional*, defaults to 2):
- The frequency of adding cross-attention to the Transformer layers.
- encoder_hidden_size (`int`, *optional*, defaults to 1408):
- The hidden size of the hidden states for cross-attention.
- Examples:
- ```python
- >>> from transformers import InstructBlipVideoQFormerConfig, InstructBlipVideoQFormerModel
- >>> # Initializing a InstructBlipVideo Salesforce/instructblip-flan-t5-xl style configuration
- >>> configuration = InstructBlipVideoQFormerConfig()
- >>> # Initializing a model (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration
- >>> model = InstructBlipVideoQFormerModel(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
- @auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl")
- @strict
- class InstructBlipVideoConfig(InstructBlipConfig):
- r"""
- qformer_config (`dict`, *optional*):
- Dictionary of configuration options used to initialize [`InstructBlipVideoQFormerConfig`].
- num_query_tokens (`int`, *optional*, defaults to 32):
- The number of query tokens passed through the Transformer.
- Example:
- ```python
- >>> from transformers import (
- ... InstructBlipVideoVisionConfig,
- ... InstructBlipVideoQFormerConfig,
- ... OPTConfig,
- ... InstructBlipVideoConfig,
- ... InstructBlipVideoForConditionalGeneration,
- ... )
- >>> # Initializing a InstructBlipVideoConfig with Salesforce/instructblip-flan-t5-xl style configuration
- >>> configuration = InstructBlipVideoConfig()
- >>> # Initializing a InstructBlipVideoForConditionalGeneration (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration
- >>> model = InstructBlipVideoForConditionalGeneration(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- >>> # We can also initialize a InstructBlipVideoConfig from a InstructBlipVideoVisionConfig, InstructBlipVideoQFormerConfig and any PreTrainedConfig
- >>> # Initializing Instructblipvideo vision, Instructblipvideo Q-Former and language model configurations
- >>> vision_config = InstructBlipVideoVisionConfig()
- >>> qformer_config = InstructBlipVideoQFormerConfig()
- >>> text_config = OPTConfig()
- >>> config = InstructBlipVideoConfig(vision_config=vision_config, qformer_config=qformer_config, text_config=text_config)
- ```"""
- attribute_map = {"video_token_id": "video_token_index"}
- video_token_index: int | None = None
- image_token_index = AttributeError()
- class InstructBlipVideoPreTrainedModel(InstructBlipPreTrainedModel):
- input_modalities = ("video", "text")
- class InstructBlipVideoVisionModel(InstructBlipVisionModel):
- input_modalities = "video"
- class InstructBlipVideoQFormerModel(InstructBlipQFormerModel):
- pass
- class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForConditionalGenerationModelOutput):
- pass
- class InstructBlipVideoModel(InstructBlipModel):
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- qformer_input_ids: torch.FloatTensor,
- qformer_attention_mask: torch.LongTensor | None = None,
- input_ids: torch.FloatTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.LongTensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- interpolate_pos_encoding: bool = False,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | InstructBlipVideoForConditionalGenerationModelOutput:
- # step 1: forward the images through the vision encoder,
- # we process in a batched way, later unbatch it back (video has frames=4 always)
- batch_size, frames, channel, height, width = pixel_values.shape
- pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
- vision_outputs = self.vision_model(
- pixel_values=pixel_values,
- interpolate_pos_encoding=interpolate_pos_encoding,
- **kwargs,
- )
- image_embeds = vision_outputs[0]
- # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
- image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
- # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
- query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
- query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
- if qformer_attention_mask is None:
- qformer_attention_mask = torch.ones_like(qformer_input_ids)
- qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
- qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
- qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
- query_outputs = self.qformer(
- input_ids=qformer_input_ids,
- attention_mask=qformer_attention_mask,
- query_embeds=query_tokens,
- encoder_hidden_states=image_embeds,
- encoder_attention_mask=image_attention_mask,
- **kwargs,
- )
- query_output = query_outputs[0][:, : query_tokens.size(1), :]
- # step 3: use the language model, conditioned on the query outputs and the prompt
- language_model_inputs = self.language_projection(query_output)
- # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
- language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
- if inputs_embeds is None:
- inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
- special_image_mask = input_ids == self.config.video_token_id
- if attention_mask is None:
- attention_mask = torch.ones_like(input_ids)
- else:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- special_image_mask = special_image_mask.all(-1)
- special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
- if self.config.use_decoder_only_language_model:
- outputs = self.language_model(
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- use_cache=use_cache,
- **kwargs,
- )
- else:
- outputs = self.language_model(
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- use_cache=use_cache,
- **kwargs,
- )
- return InstructBlipVideoForConditionalGenerationModelOutput(
- vision_outputs=vision_outputs,
- qformer_outputs=query_outputs,
- language_model_outputs=outputs,
- )
- class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
- @can_return_tuple
- @auto_docstring
- def get_video_features(
- self,
- pixel_values: torch.FloatTensor,
- qformer_input_ids: torch.LongTensor,
- qformer_attention_mask: torch.LongTensor | None = None,
- interpolate_pos_encoding: bool | None = False,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithVisionQformerOutputs:
- r"""
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
- The tensors corresponding to the input images.
- qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length)):
- The sequence used as a prompt to be fed to the Q-Former module.
- qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
- Mask to avoid performing attention on padding token indices.
- """
- # step 1: forward the images through the vision encoder,
- # we process in a batched way, later unbatch it back (video has frames=4 always)
- batch_size, frames, channel, height, width = pixel_values.shape
- pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
- vision_outputs: BaseModelOutputWithPooling = self.vision_model(
- pixel_values=pixel_values,
- interpolate_pos_encoding=interpolate_pos_encoding,
- **kwargs,
- )
- vision_outputs = BaseModelOutputWithVisionQformerOutputs(
- last_hidden_state=vision_outputs.last_hidden_state,
- pooler_output=vision_outputs.pooler_output,
- hidden_states=vision_outputs.hidden_states,
- attentions=vision_outputs.attentions,
- vision_outputs=vision_outputs,
- qformer_outputs=None,
- )
- image_embeds = vision_outputs[0]
- # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
- image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
- # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
- query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
- query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
- if qformer_attention_mask is None:
- qformer_attention_mask = torch.ones_like(qformer_input_ids)
- qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
- qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
- qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
- qformer_outputs = self.qformer(
- input_ids=qformer_input_ids,
- attention_mask=qformer_attention_mask,
- query_embeds=query_tokens,
- encoder_hidden_states=image_embeds,
- encoder_attention_mask=image_attention_mask,
- **kwargs,
- )
- vision_outputs.qformer_outputs = qformer_outputs
- query_output = qformer_outputs[0][:, : query_tokens.size(1), :]
- # step 3: use the language model, conditioned on the query outputs and the prompt
- video_features = self.language_projection(query_output)
- # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
- video_features = video_features.reshape(batch_size, self.config.num_query_tokens * frames, -1)
- vision_outputs.pooler_output = video_features
- return vision_outputs
- def get_image_features(**super_kwargs):
- raise AttributeError("No need to inherit as this architecture only supports videos.")
- def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
- """
- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
- """
- if input_ids is None:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- special_image_mask = special_image_mask.all(-1)
- else:
- special_image_mask = input_ids == self.config.video_token_id
- special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- return special_image_mask
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- qformer_input_ids: torch.FloatTensor,
- qformer_attention_mask: torch.LongTensor | None = None,
- input_ids: torch.FloatTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- interpolate_pos_encoding: bool = False,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | InstructBlipVideoForConditionalGenerationModelOutput:
- r"""
- qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length)):
- The sequence used as a prompt to be fed to the Q-Former module.
- qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
- Mask to avoid performing attention on padding token indices.
- Examples:
- ```python
- >>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
- >>> import torch
- >>> from huggingface_hub import hf_hub_download
- >>> import av
- >>> import numpy as np
- >>> def read_video_pyav(container, indices):
- ... '''
- ... Decode the video with PyAV decoder.
- ... Args:
- ... container (`av.container.input.InputContainer`): PyAV container.
- ... indices (`list[int]`): List of frame indices to decode.
- ... Returns:
- ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
- ... '''
- ... frames = []
- ... container.seek(0)
- ... start_index = indices[0]
- ... end_index = indices[-1]
- ... for i, frame in enumerate(container.decode(video=0)):
- ... if i > end_index:
- ... break
- ... if i >= start_index and i in indices:
- ... frames.append(frame)
- ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
- >>> model = InstructBlipVideoForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto")
- >>> processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
- >>> file_path = hf_hub_download(
- ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
- ... )
- >>> container = av.open(file_path)
- >>> # sample uniformly 4 frames from the videWhy is this video funny?o
- >>> total_frames = container.streams.video[0].frames
- >>> indices = np.arange(0, total_frames, total_frames / 4).astype(int)
- >>> clip = read_video_pyav(container, indices)
- >>> prompt = "What is happening in the video?"
- >>> inputs = processor(text=prompt, images=clip, return_tensors="pt").to(model.device)
- >>> outputs = model.generate(
- ... **inputs,
- ... do_sample=False,
- ... num_beams=5,
- ... max_length=256,
- ... repetition_penalty=1.5,
- ... length_penalty=1.0,
- ... )
- >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
- >>> print(generated_text)
- "A person is eating a bowl of pasta, and they are using a fork to eat it. The person is sitting at a table, and the plate of pasta is on the table in front"
- ```"""
- video_features: BaseModelOutputWithVisionQformerOutputs = self.get_video_features(
- pixel_values,
- qformer_input_ids=qformer_input_ids,
- qformer_attention_mask=qformer_attention_mask,
- interpolate_pos_encoding=interpolate_pos_encoding,
- **kwargs,
- )
- language_model_inputs = video_features.pooler_output
- qformer_outputs = video_features.qformer_outputs
- vision_outputs = video_features.vision_outputs
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
- if attention_mask is None:
- attention_mask = torch.ones_like(input_ids)
- language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
- special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
- if self.config.use_decoder_only_language_model:
- outputs = self.language_model(
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- use_cache=use_cache,
- **kwargs,
- )
- logits = outputs[0]
- loss = None
- if labels is not None:
- loss = self.loss_function(
- logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
- )
- else:
- outputs = self.language_model(
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- labels=labels,
- use_cache=use_cache,
- **kwargs,
- )
- loss = outputs.loss
- logits = outputs.logits
- return InstructBlipVideoForConditionalGenerationModelOutput(
- loss=loss,
- logits=logits,
- vision_outputs=vision_outputs,
- qformer_outputs=qformer_outputs,
- language_model_outputs=outputs,
- )
- @torch.no_grad()
- def generate(
- self,
- pixel_values: torch.FloatTensor,
- qformer_input_ids: torch.LongTensor | None = None,
- qformer_attention_mask: torch.LongTensor | None = None,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- interpolate_pos_encoding: bool = False,
- **generate_kwargs,
- ) -> torch.LongTensor:
- r"""
- Overrides `generate` function to be able to use the model as a conditional generator.
- Args:
- pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or
- (batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed.
- qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
- The sequence used as a prompt to be fed to the Q-Former module.
- qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
- Mask to avoid performing attention on padding token indices.
- input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
- The sequence used as a prompt for the generation.
- attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
- Mask to avoid performing attention on padding token indices.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Embedded representation of the inputs. Should be float, not int tokens.
- interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
- Whether to interpolate the positional encoding of the image embeddings.
- Returns:
- captions (list): A list of strings of length batch_size * num_captions.
- """
- if hasattr(self, "hf_device_map"):
- # preprocess for `accelerate`
- self._preprocess_accelerate()
- batch_size = pixel_values.shape[0]
- video_features: BaseModelOutputWithVisionQformerOutputs = self.get_video_features(
- pixel_values,
- qformer_input_ids=qformer_input_ids,
- qformer_attention_mask=qformer_attention_mask,
- interpolate_pos_encoding=interpolate_pos_encoding,
- )
- language_model_inputs = video_features.pooler_output
- if inputs_embeds is None:
- if input_ids is None:
- video_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4
- start_tokens = video_tokens + [self.config.text_config.bos_token_id]
- input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
- input_ids = input_ids.repeat(batch_size, 1)
- inputs_embeds = self.get_input_embeddings()(input_ids)
- if attention_mask is None:
- attention_mask = torch.ones_like(input_ids)
- language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
- special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
- inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
- if not self.language_model.config.is_encoder_decoder:
- inputs["input_ids"] = input_ids
- outputs = self.language_model.generate(**inputs, **generate_kwargs)
- return outputs
- __all__ = [
- "InstructBlipVideoConfig",
- "InstructBlipVideoQFormerConfig",
- "InstructBlipVideoVisionConfig",
- "InstructBlipVideoVisionModel",
- "InstructBlipVideoPreTrainedModel",
- "InstructBlipVideoQFormerModel",
- "InstructBlipVideoModel",
- "InstructBlipVideoForConditionalGeneration",
- ]
|