| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/ovis2/modular_ovis2.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_ovis2.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from collections.abc import Callable
- from dataclasses import dataclass
- import torch
- from torch import nn
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache
- from ...generation import GenerationMixin
- from ...integrations import use_kernel_forward_from_hub
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
- from ...utils.generic import merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from ..auto import AutoModel
- from .configuration_ovis2 import Ovis2Config, Ovis2VisionConfig
- @dataclass
- @auto_docstring
- class BaseModelOutputWithVisualIndicatorFeatures(BaseModelOutputWithPooling):
- r"""
- visual_indicator_features (`torch.FloatTensor` of shape `(batch_size, visual_indicator_size)`):
- Visual indicator features extracted from the model, which can be used for auxiliary tasks or further processing.
- """
- visual_indicator_features: torch.FloatTensor | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for Llava outputs, with hidden states and attentions.
- """
- )
- class Ovis2ModelOutputWithPast(BaseModelOutputWithPast):
- r"""
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
- """
- image_hidden_states: torch.FloatTensor | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for Ovis2 causal language model (or autoregressive) outputs.
- """
- )
- class Ovis2CausalLMOutputWithPast(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor | None = None
- past_key_values: Cache | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- image_hidden_states: torch.FloatTensor | None = None
- @use_kernel_forward_from_hub("RMSNorm")
- class Ovis2RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps: float = 1e-6) -> None:
- """
- Ovis2RMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
- class Ovis2VisionMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
- self.act_fn = ACT2FN[config.hidden_act]
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
- class Ovis2VisionEmbeddings(nn.Module):
- def __init__(self, config: Ovis2VisionConfig):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.image_size = config.image_size
- self.patch_size = config.patch_size
- self.patch_embedding = nn.Conv2d(
- in_channels=config.num_channels,
- out_channels=self.embed_dim,
- kernel_size=self.patch_size,
- stride=self.patch_size,
- padding="valid",
- )
- self.num_patches = (self.image_size // self.patch_size) ** 2
- self.num_positions = self.num_patches
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
- self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
- self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
- def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
- target_dtype = self.patch_embedding.weight.dtype
- patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
- embeddings = patch_embeds.flatten(2).transpose(1, 2)
- embeddings = self.rms_norm(embeddings)
- embeddings = embeddings + self.position_embedding(self.position_ids)
- return embeddings
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
- ):
- attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class Ovis2VisionAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {self.num_heads})."
- )
- self.scale = self.head_dim**-0.5
- self.dropout = config.attention_dropout
- self.is_causal = False
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- """Input shape: Batch x Time x Channel"""
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- queries,
- keys,
- values,
- attention_mask,
- is_causal=self.is_causal,
- scaling=self.scale,
- dropout=0.0 if not self.training else self.dropout,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights
- class Ovis2MLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
- self.act_fn = ACT2FN[config.hidden_act]
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
- class Ovis2VisionEncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: Ovis2VisionConfig):
- super().__init__()
- self.attention = Ovis2VisionAttention(config)
- self.ffn = Ovis2MLP(config)
- self.rms_norm1 = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
- self.rms_norm2 = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- norm_hidden_states = self.rms_norm1(hidden_states)
- attn_output, _ = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask, **kwargs)
- hidden_states = hidden_states + attn_output
- norm_hidden_states = self.rms_norm2(hidden_states)
- mlp_output = self.ffn(norm_hidden_states)
- hidden_states = hidden_states + mlp_output
- return hidden_states
- class Ovis2VisionEncoder(nn.Module):
- """
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
- [`Ovis2VisionEncoderLayer`].
- Args:
- config: Ovis2VisionConfig
- """
- def __init__(self, config: Ovis2VisionConfig):
- super().__init__()
- self.config = config
- self.layers = nn.ModuleList([Ovis2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- # Ignore copy
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- inputs_embeds,
- attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutput:
- hidden_states = inputs_embeds
- for encoder_layer in self.layers:
- hidden_states = encoder_layer(hidden_states, attention_mask, **kwargs)
- return BaseModelOutput(last_hidden_state=hidden_states)
- class Ovis2VisionTransformer(nn.Module):
- def __init__(self, config: Ovis2VisionConfig):
- super().__init__()
- self.config = config
- self.embeddings = Ovis2VisionEmbeddings(config)
- self.encoder = Ovis2VisionEncoder(config)
- self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
- self.gradient_checkpointing = False
- @can_return_tuple
- def forward(
- self,
- pixel_values,
- attention_mask: torch.Tensor | None = None,
- **kwargs,
- ):
- hidden_states = self.embeddings(pixel_values)
- encoder_outputs: BaseModelOutput = self.encoder(
- inputs_embeds=hidden_states,
- attention_mask=attention_mask,
- **kwargs,
- )
- last_hidden_state = encoder_outputs.last_hidden_state
- last_hidden_state = self.rms_norm(last_hidden_state)
- return BaseModelOutput(last_hidden_state=last_hidden_state)
- class Ovis2VisualEmbeddingTable(nn.Embedding):
- def forward(self, visual_tokens: torch.Tensor) -> torch.Tensor:
- if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
- return super().forward(visual_tokens)
- return torch.matmul(visual_tokens, self.weight)
- class Ovis2PreTrainedModel(PreTrainedModel):
- config: Ovis2Config
- base_model_prefix = "model"
- input_modalities = ("image", "text")
- supports_gradient_checkpointing = True
- _no_split_modules = ["Ovis2VisionAttention"]
- _skip_keys_device_placement = "past_key_values"
- _supports_cache_class = True
- _supports_flash_attn = True
- _supports_flex_attn = True
- _supports_sdpa = True
- _can_compile_fullgraph = True
- _supports_attention_backend = True
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, Ovis2VisionEmbeddings):
- init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
- def hard_softmax(logits: torch.Tensor, dim: int):
- y_soft = logits.softmax(dim)
- # Straight through.
- index = y_soft.max(dim, keepdim=True)[1]
- y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
- ret = y_hard - y_soft.detach() + y_soft
- return ret
- class Ovis2VisionModel(Ovis2PreTrainedModel):
- config: Ovis2VisionConfig
- _can_record_outputs = {
- "hidden_states": Ovis2VisionEncoderLayer,
- "attentions": Ovis2VisionAttention,
- }
- def __init__(self, config: Ovis2VisionConfig):
- super().__init__(config)
- self.config = config
- self.transformer = Ovis2VisionTransformer(config)
- self.num_visual_indicator_tokens = config.num_visual_indicator_tokens
- self.vocab_size = config.vocab_size
- self.head_linear = nn.Linear(
- config.hidden_size * config.hidden_stride * config.hidden_stride,
- self.vocab_size - self.num_visual_indicator_tokens,
- bias=False,
- )
- self.head_norm = nn.LayerNorm(self.vocab_size - self.num_visual_indicator_tokens)
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
- ) -> tuple | BaseModelOutputWithVisualIndicatorFeatures:
- outputs = self.transformer(pixel_values, **kwargs)
- last_hidden_state = outputs[0]
- if self.config.hidden_stride > 1:
- num_images, seq_len, hidden_dim = last_hidden_state.shape
- hidden_stride = self.config.hidden_stride
- sqrt_l = int(math.sqrt(seq_len))
- if sqrt_l * sqrt_l != seq_len:
- raise ValueError("Token sequence length must be a perfect square")
- pad_size = (hidden_stride - (sqrt_l % hidden_stride)) % hidden_stride
- last_hidden_state = nn.functional.pad(last_hidden_state, (0, 0, 0, pad_size, 0, pad_size), "constant", 0)
- sqrt_l += pad_size
- last_hidden_state = last_hidden_state.reshape(
- num_images, sqrt_l // hidden_stride, hidden_stride, sqrt_l // hidden_stride, hidden_stride, hidden_dim
- )
- last_hidden_state = last_hidden_state.permute(0, 1, 3, 2, 4, 5)
- last_hidden_state = last_hidden_state.reshape(
- num_images, -1, hidden_stride * hidden_stride * hidden_dim
- ) # (n, (sqrt_l//hs)^2, hs^2*d)
- logits = self.head_linear(last_hidden_state)
- logits = self.head_norm(logits)
- if self.config.tokenize_function == "gumbel_argmax":
- prob_token = nn.functional.gumbel_softmax(logits, dim=-1, hard=True)
- elif self.config.tokenize_function == "st_argmax":
- prob_token = hard_softmax(logits, dim=-1)
- elif self.config.tokenize_function == "softmax":
- prob_token = nn.functional.softmax(logits, dim=-1)
- return BaseModelOutputWithVisualIndicatorFeatures(
- last_hidden_state=last_hidden_state,
- pooler_output=prob_token,
- )
- @auto_docstring(
- custom_intro="""
- The Ovis2 model which consists of a vision backbone and a language model, without a language modeling head.
- """
- )
- class Ovis2Model(Ovis2PreTrainedModel):
- def __init__(self, config: Ovis2Config):
- super().__init__(config)
- self.vision_tower = Ovis2VisionModel(config.vision_config)
- self.language_model = AutoModel.from_config(config.text_config)
- self.visual_embeddings_table = Ovis2VisualEmbeddingTable(config.vision_config.vocab_size, config.hidden_size)
- self.visual_vocab_size = config.vision_config.vocab_size
- self.vocab_size = config.vocab_size
- self.visual_indicator_token_ids = config.visual_indicator_token_ids
- self.post_init()
- def get_input_embeddings(self):
- return self.language_model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.language_model.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring(
- custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
- )
- def get_image_features(
- self,
- pixel_values: torch.FloatTensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithVisualIndicatorFeatures:
- image_outputs = self.vision_tower(pixel_values, return_dict=True, **kwargs)
- image_features = image_outputs.pooler_output
- batch_size, img_seq_len, _ = image_features.shape
- padding_tensor = torch.zeros(
- (batch_size, img_seq_len, self.vision_tower.num_visual_indicator_tokens),
- dtype=image_features.dtype,
- device=image_features.device,
- requires_grad=False,
- layout=image_features.layout,
- )
- image_features = torch.cat([image_features, padding_tensor], dim=2)
- image_features = self.visual_embeddings_table(image_features)
- visual_indicator = torch.arange(
- self.visual_vocab_size - self.vision_tower.num_visual_indicator_tokens,
- self.visual_vocab_size,
- dtype=torch.long,
- ).to(image_features.device)
- image_outputs.pooler_output = image_features
- image_outputs.visual_indicator_features = self.visual_embeddings_table(visual_indicator)
- return image_outputs
- def get_placeholder_mask(
- self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
- ):
- """
- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
- equal to the length of multimodal features. If the lengths are different, an error is raised.
- """
- if input_ids is None:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- special_image_mask = special_image_mask.all(-1)
- else:
- special_image_mask = input_ids == self.config.image_token_id
- n_image_tokens = special_image_mask.sum()
- n_image_features = image_features.shape[0] * image_features.shape[1]
- special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- torch_compilable_check(
- inputs_embeds[special_image_mask].numel() == image_features.numel(),
- f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
- )
- return special_image_mask
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs,
- ) -> tuple | Ovis2ModelOutputWithPast:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
- if pixel_values is not None:
- image_outputs = self.get_image_features(pixel_values=pixel_values, return_dict=True)
- image_features = image_outputs.pooler_output
- visual_indicator_features = image_outputs.visual_indicator_features
- special_image_mask = self.get_placeholder_mask(
- input_ids,
- inputs_embeds=inputs_embeds,
- image_features=image_features,
- )
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
- for i, visual_indicator_id in enumerate(self.visual_indicator_token_ids):
- if input_ids is None:
- mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(visual_indicator_id, dtype=torch.long, device=inputs_embeds.device)
- )
- mask = mask.all(-1)
- else:
- mask = (input_ids == visual_indicator_id).to(inputs_embeds.device)
- if mask.any():
- inputs_embeds[mask] = (
- visual_indicator_features[i]
- .expand_as(inputs_embeds[mask])
- .to(inputs_embeds.device, inputs_embeds.dtype)
- )
- outputs = self.language_model(
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- logits_to_keep=logits_to_keep,
- **kwargs,
- )
- return Ovis2ModelOutputWithPast(
- last_hidden_state=outputs.last_hidden_state,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
- )
- @auto_docstring
- class Ovis2ForConditionalGeneration(Ovis2PreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
- def __init__(self, config: Ovis2Config):
- super().__init__(config)
- self.model = Ovis2Model(config)
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- self.post_init()
- def get_input_embeddings(self):
- return self.model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.model.set_input_embeddings(value)
- def get_output_embeddings(self) -> nn.Module:
- return self.lm_head
- @auto_docstring
- def get_image_features(
- self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
- ) -> tuple | BaseModelOutputWithVisualIndicatorFeatures:
- return self.model.get_image_features(pixel_values=pixel_values, **kwargs)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs,
- ) -> tuple | Ovis2CausalLMOutputWithPast:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Example:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, Ovis2ForConditionalGeneration
- >>> model = Ovis2ForConditionalGeneration.from_pretrained("thisisiron/Ovis2-2B-hf")
- >>> processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-2B-hf")
- >>> prompt = "<|im_start|>user\n<image>\nDescribe the image.<|im_end|>\n<|im_start|>assistant\n"
- >>> url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
- >>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
- "user\n\nDescribe the image.\nassistant\nThe image features a brown dog standing on a wooden floor, looking up with"
- ```"""
- outputs = self.model(
- input_ids=input_ids,
- pixel_values=pixel_values,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = outputs[0]
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss = self.loss_function(
- logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
- )
- return Ovis2CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- image_hidden_states=outputs.image_hidden_states,
- )
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- inputs_embeds=None,
- pixel_values=None,
- attention_mask=None,
- logits_to_keep=None,
- is_first_iteration=False,
- **kwargs,
- ):
- # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
- model_inputs = super().prepare_inputs_for_generation(
- input_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- logits_to_keep=logits_to_keep,
- is_first_iteration=is_first_iteration,
- **kwargs,
- )
- if is_first_iteration or not kwargs.get("use_cache", True):
- # Pixel values are used only in the first iteration if available
- # In subsequent iterations, they are already merged with text and cached
- # NOTE: first iteration doesn't have to be prefill, it can be the first
- # iteration with a question and cached system prompt (continue generate from cache)
- model_inputs["pixel_values"] = pixel_values
- return model_inputs
- __all__ = ["Ovis2PreTrainedModel", "Ovis2Model", "Ovis2ForConditionalGeneration"]
|