modeling_ovis2.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/ovis2/modular_ovis2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_ovis2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. from dataclasses import dataclass
  23. import torch
  24. from torch import nn
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernel_forward_from_hub
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
  35. from ...utils.generic import merge_with_config_defaults
  36. from ...utils.output_capturing import capture_outputs
  37. from ..auto import AutoModel
  38. from .configuration_ovis2 import Ovis2Config, Ovis2VisionConfig
  39. @dataclass
  40. @auto_docstring
  41. class BaseModelOutputWithVisualIndicatorFeatures(BaseModelOutputWithPooling):
  42. r"""
  43. visual_indicator_features (`torch.FloatTensor` of shape `(batch_size, visual_indicator_size)`):
  44. Visual indicator features extracted from the model, which can be used for auxiliary tasks or further processing.
  45. """
  46. visual_indicator_features: torch.FloatTensor | None = None
  47. @dataclass
  48. @auto_docstring(
  49. custom_intro="""
  50. Base class for Llava outputs, with hidden states and attentions.
  51. """
  52. )
  53. class Ovis2ModelOutputWithPast(BaseModelOutputWithPast):
  54. r"""
  55. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  56. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  57. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  58. `past_key_values` input) to speed up sequential decoding.
  59. image_hidden_states (`torch.FloatTensor`, *optional*):
  60. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  61. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  62. """
  63. image_hidden_states: torch.FloatTensor | None = None
  64. @dataclass
  65. @auto_docstring(
  66. custom_intro="""
  67. Base class for Ovis2 causal language model (or autoregressive) outputs.
  68. """
  69. )
  70. class Ovis2CausalLMOutputWithPast(ModelOutput):
  71. r"""
  72. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  73. Language modeling loss (for next-token prediction).
  74. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  75. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  76. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  77. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  78. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  79. `past_key_values` input) to speed up sequential decoding.
  80. image_hidden_states (`torch.FloatTensor`, *optional*):
  81. A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
  82. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  83. """
  84. loss: torch.FloatTensor | None = None
  85. logits: torch.FloatTensor | None = None
  86. past_key_values: Cache | None = None
  87. hidden_states: tuple[torch.FloatTensor] | None = None
  88. attentions: tuple[torch.FloatTensor] | None = None
  89. image_hidden_states: torch.FloatTensor | None = None
  90. @use_kernel_forward_from_hub("RMSNorm")
  91. class Ovis2RMSNorm(nn.Module):
  92. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  93. """
  94. Ovis2RMSNorm is equivalent to T5LayerNorm
  95. """
  96. super().__init__()
  97. self.weight = nn.Parameter(torch.ones(hidden_size))
  98. self.variance_epsilon = eps
  99. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  100. input_dtype = hidden_states.dtype
  101. hidden_states = hidden_states.to(torch.float32)
  102. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  103. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  104. return self.weight * hidden_states.to(input_dtype)
  105. def extra_repr(self):
  106. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  107. class Ovis2VisionMLP(nn.Module):
  108. def __init__(self, config):
  109. super().__init__()
  110. self.config = config
  111. self.hidden_size = config.hidden_size
  112. self.intermediate_size = config.intermediate_size
  113. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  114. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  115. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  116. self.act_fn = ACT2FN[config.hidden_act]
  117. def forward(self, x):
  118. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  119. return down_proj
  120. class Ovis2VisionEmbeddings(nn.Module):
  121. def __init__(self, config: Ovis2VisionConfig):
  122. super().__init__()
  123. self.config = config
  124. self.embed_dim = config.hidden_size
  125. self.image_size = config.image_size
  126. self.patch_size = config.patch_size
  127. self.patch_embedding = nn.Conv2d(
  128. in_channels=config.num_channels,
  129. out_channels=self.embed_dim,
  130. kernel_size=self.patch_size,
  131. stride=self.patch_size,
  132. padding="valid",
  133. )
  134. self.num_patches = (self.image_size // self.patch_size) ** 2
  135. self.num_positions = self.num_patches
  136. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  137. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  138. self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
  139. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  140. target_dtype = self.patch_embedding.weight.dtype
  141. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
  142. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  143. embeddings = self.rms_norm(embeddings)
  144. embeddings = embeddings + self.position_embedding(self.position_ids)
  145. return embeddings
  146. def eager_attention_forward(
  147. module: nn.Module,
  148. query: torch.Tensor,
  149. key: torch.Tensor,
  150. value: torch.Tensor,
  151. attention_mask: torch.Tensor | None,
  152. scaling: float,
  153. dropout: float = 0.0,
  154. **kwargs,
  155. ):
  156. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  157. if attention_mask is not None:
  158. attn_weights = attn_weights + attention_mask
  159. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  160. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  161. attn_output = torch.matmul(attn_weights, value)
  162. attn_output = attn_output.transpose(1, 2).contiguous()
  163. return attn_output, attn_weights
  164. class Ovis2VisionAttention(nn.Module):
  165. """Multi-headed attention from 'Attention Is All You Need' paper"""
  166. def __init__(self, config):
  167. super().__init__()
  168. self.config = config
  169. self.embed_dim = config.hidden_size
  170. self.num_heads = config.num_attention_heads
  171. self.head_dim = self.embed_dim // self.num_heads
  172. if self.head_dim * self.num_heads != self.embed_dim:
  173. raise ValueError(
  174. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  175. f" {self.num_heads})."
  176. )
  177. self.scale = self.head_dim**-0.5
  178. self.dropout = config.attention_dropout
  179. self.is_causal = False
  180. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  181. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  182. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  183. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  184. def forward(
  185. self,
  186. hidden_states: torch.Tensor,
  187. attention_mask: torch.Tensor | None = None,
  188. **kwargs,
  189. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  190. """Input shape: Batch x Time x Channel"""
  191. input_shape = hidden_states.shape[:-1]
  192. hidden_shape = (*input_shape, -1, self.head_dim)
  193. queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  194. keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  195. values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  196. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  197. self.config._attn_implementation, eager_attention_forward
  198. )
  199. attn_output, attn_weights = attention_interface(
  200. self,
  201. queries,
  202. keys,
  203. values,
  204. attention_mask,
  205. is_causal=self.is_causal,
  206. scaling=self.scale,
  207. dropout=0.0 if not self.training else self.dropout,
  208. )
  209. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  210. attn_output = self.out_proj(attn_output)
  211. return attn_output, attn_weights
  212. class Ovis2MLP(nn.Module):
  213. def __init__(self, config):
  214. super().__init__()
  215. self.config = config
  216. self.hidden_size = config.hidden_size
  217. self.intermediate_size = config.intermediate_size
  218. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  219. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  220. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  221. self.act_fn = ACT2FN[config.hidden_act]
  222. def forward(self, x):
  223. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  224. return down_proj
  225. class Ovis2VisionEncoderLayer(GradientCheckpointingLayer):
  226. def __init__(self, config: Ovis2VisionConfig):
  227. super().__init__()
  228. self.attention = Ovis2VisionAttention(config)
  229. self.ffn = Ovis2MLP(config)
  230. self.rms_norm1 = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
  231. self.rms_norm2 = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
  232. def forward(
  233. self,
  234. hidden_states: torch.Tensor,
  235. attention_mask: torch.Tensor | None = None,
  236. **kwargs: Unpack[TransformersKwargs],
  237. ) -> torch.Tensor:
  238. norm_hidden_states = self.rms_norm1(hidden_states)
  239. attn_output, _ = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask, **kwargs)
  240. hidden_states = hidden_states + attn_output
  241. norm_hidden_states = self.rms_norm2(hidden_states)
  242. mlp_output = self.ffn(norm_hidden_states)
  243. hidden_states = hidden_states + mlp_output
  244. return hidden_states
  245. class Ovis2VisionEncoder(nn.Module):
  246. """
  247. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  248. [`Ovis2VisionEncoderLayer`].
  249. Args:
  250. config: Ovis2VisionConfig
  251. """
  252. def __init__(self, config: Ovis2VisionConfig):
  253. super().__init__()
  254. self.config = config
  255. self.layers = nn.ModuleList([Ovis2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  256. self.gradient_checkpointing = False
  257. # Ignore copy
  258. @can_return_tuple
  259. @auto_docstring
  260. def forward(
  261. self,
  262. inputs_embeds,
  263. attention_mask: torch.Tensor | None = None,
  264. **kwargs: Unpack[TransformersKwargs],
  265. ) -> BaseModelOutput:
  266. hidden_states = inputs_embeds
  267. for encoder_layer in self.layers:
  268. hidden_states = encoder_layer(hidden_states, attention_mask, **kwargs)
  269. return BaseModelOutput(last_hidden_state=hidden_states)
  270. class Ovis2VisionTransformer(nn.Module):
  271. def __init__(self, config: Ovis2VisionConfig):
  272. super().__init__()
  273. self.config = config
  274. self.embeddings = Ovis2VisionEmbeddings(config)
  275. self.encoder = Ovis2VisionEncoder(config)
  276. self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
  277. self.gradient_checkpointing = False
  278. @can_return_tuple
  279. def forward(
  280. self,
  281. pixel_values,
  282. attention_mask: torch.Tensor | None = None,
  283. **kwargs,
  284. ):
  285. hidden_states = self.embeddings(pixel_values)
  286. encoder_outputs: BaseModelOutput = self.encoder(
  287. inputs_embeds=hidden_states,
  288. attention_mask=attention_mask,
  289. **kwargs,
  290. )
  291. last_hidden_state = encoder_outputs.last_hidden_state
  292. last_hidden_state = self.rms_norm(last_hidden_state)
  293. return BaseModelOutput(last_hidden_state=last_hidden_state)
  294. class Ovis2VisualEmbeddingTable(nn.Embedding):
  295. def forward(self, visual_tokens: torch.Tensor) -> torch.Tensor:
  296. if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
  297. return super().forward(visual_tokens)
  298. return torch.matmul(visual_tokens, self.weight)
  299. class Ovis2PreTrainedModel(PreTrainedModel):
  300. config: Ovis2Config
  301. base_model_prefix = "model"
  302. input_modalities = ("image", "text")
  303. supports_gradient_checkpointing = True
  304. _no_split_modules = ["Ovis2VisionAttention"]
  305. _skip_keys_device_placement = "past_key_values"
  306. _supports_cache_class = True
  307. _supports_flash_attn = True
  308. _supports_flex_attn = True
  309. _supports_sdpa = True
  310. _can_compile_fullgraph = True
  311. _supports_attention_backend = True
  312. def _init_weights(self, module):
  313. super()._init_weights(module)
  314. if isinstance(module, Ovis2VisionEmbeddings):
  315. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  316. def hard_softmax(logits: torch.Tensor, dim: int):
  317. y_soft = logits.softmax(dim)
  318. # Straight through.
  319. index = y_soft.max(dim, keepdim=True)[1]
  320. y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
  321. ret = y_hard - y_soft.detach() + y_soft
  322. return ret
  323. class Ovis2VisionModel(Ovis2PreTrainedModel):
  324. config: Ovis2VisionConfig
  325. _can_record_outputs = {
  326. "hidden_states": Ovis2VisionEncoderLayer,
  327. "attentions": Ovis2VisionAttention,
  328. }
  329. def __init__(self, config: Ovis2VisionConfig):
  330. super().__init__(config)
  331. self.config = config
  332. self.transformer = Ovis2VisionTransformer(config)
  333. self.num_visual_indicator_tokens = config.num_visual_indicator_tokens
  334. self.vocab_size = config.vocab_size
  335. self.head_linear = nn.Linear(
  336. config.hidden_size * config.hidden_stride * config.hidden_stride,
  337. self.vocab_size - self.num_visual_indicator_tokens,
  338. bias=False,
  339. )
  340. self.head_norm = nn.LayerNorm(self.vocab_size - self.num_visual_indicator_tokens)
  341. self.post_init()
  342. @merge_with_config_defaults
  343. @capture_outputs
  344. def forward(
  345. self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  346. ) -> tuple | BaseModelOutputWithVisualIndicatorFeatures:
  347. outputs = self.transformer(pixel_values, **kwargs)
  348. last_hidden_state = outputs[0]
  349. if self.config.hidden_stride > 1:
  350. num_images, seq_len, hidden_dim = last_hidden_state.shape
  351. hidden_stride = self.config.hidden_stride
  352. sqrt_l = int(math.sqrt(seq_len))
  353. if sqrt_l * sqrt_l != seq_len:
  354. raise ValueError("Token sequence length must be a perfect square")
  355. pad_size = (hidden_stride - (sqrt_l % hidden_stride)) % hidden_stride
  356. last_hidden_state = nn.functional.pad(last_hidden_state, (0, 0, 0, pad_size, 0, pad_size), "constant", 0)
  357. sqrt_l += pad_size
  358. last_hidden_state = last_hidden_state.reshape(
  359. num_images, sqrt_l // hidden_stride, hidden_stride, sqrt_l // hidden_stride, hidden_stride, hidden_dim
  360. )
  361. last_hidden_state = last_hidden_state.permute(0, 1, 3, 2, 4, 5)
  362. last_hidden_state = last_hidden_state.reshape(
  363. num_images, -1, hidden_stride * hidden_stride * hidden_dim
  364. ) # (n, (sqrt_l//hs)^2, hs^2*d)
  365. logits = self.head_linear(last_hidden_state)
  366. logits = self.head_norm(logits)
  367. if self.config.tokenize_function == "gumbel_argmax":
  368. prob_token = nn.functional.gumbel_softmax(logits, dim=-1, hard=True)
  369. elif self.config.tokenize_function == "st_argmax":
  370. prob_token = hard_softmax(logits, dim=-1)
  371. elif self.config.tokenize_function == "softmax":
  372. prob_token = nn.functional.softmax(logits, dim=-1)
  373. return BaseModelOutputWithVisualIndicatorFeatures(
  374. last_hidden_state=last_hidden_state,
  375. pooler_output=prob_token,
  376. )
  377. @auto_docstring(
  378. custom_intro="""
  379. The Ovis2 model which consists of a vision backbone and a language model, without a language modeling head.
  380. """
  381. )
  382. class Ovis2Model(Ovis2PreTrainedModel):
  383. def __init__(self, config: Ovis2Config):
  384. super().__init__(config)
  385. self.vision_tower = Ovis2VisionModel(config.vision_config)
  386. self.language_model = AutoModel.from_config(config.text_config)
  387. self.visual_embeddings_table = Ovis2VisualEmbeddingTable(config.vision_config.vocab_size, config.hidden_size)
  388. self.visual_vocab_size = config.vision_config.vocab_size
  389. self.vocab_size = config.vocab_size
  390. self.visual_indicator_token_ids = config.visual_indicator_token_ids
  391. self.post_init()
  392. def get_input_embeddings(self):
  393. return self.language_model.get_input_embeddings()
  394. def set_input_embeddings(self, value):
  395. self.language_model.set_input_embeddings(value)
  396. @can_return_tuple
  397. @auto_docstring(
  398. custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
  399. )
  400. def get_image_features(
  401. self,
  402. pixel_values: torch.FloatTensor,
  403. **kwargs: Unpack[TransformersKwargs],
  404. ) -> tuple | BaseModelOutputWithVisualIndicatorFeatures:
  405. image_outputs = self.vision_tower(pixel_values, return_dict=True, **kwargs)
  406. image_features = image_outputs.pooler_output
  407. batch_size, img_seq_len, _ = image_features.shape
  408. padding_tensor = torch.zeros(
  409. (batch_size, img_seq_len, self.vision_tower.num_visual_indicator_tokens),
  410. dtype=image_features.dtype,
  411. device=image_features.device,
  412. requires_grad=False,
  413. layout=image_features.layout,
  414. )
  415. image_features = torch.cat([image_features, padding_tensor], dim=2)
  416. image_features = self.visual_embeddings_table(image_features)
  417. visual_indicator = torch.arange(
  418. self.visual_vocab_size - self.vision_tower.num_visual_indicator_tokens,
  419. self.visual_vocab_size,
  420. dtype=torch.long,
  421. ).to(image_features.device)
  422. image_outputs.pooler_output = image_features
  423. image_outputs.visual_indicator_features = self.visual_embeddings_table(visual_indicator)
  424. return image_outputs
  425. def get_placeholder_mask(
  426. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  427. ):
  428. """
  429. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  430. equal to the length of multimodal features. If the lengths are different, an error is raised.
  431. """
  432. if input_ids is None:
  433. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  434. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  435. )
  436. special_image_mask = special_image_mask.all(-1)
  437. else:
  438. special_image_mask = input_ids == self.config.image_token_id
  439. n_image_tokens = special_image_mask.sum()
  440. n_image_features = image_features.shape[0] * image_features.shape[1]
  441. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  442. torch_compilable_check(
  443. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  444. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
  445. )
  446. return special_image_mask
  447. @can_return_tuple
  448. @auto_docstring
  449. def forward(
  450. self,
  451. input_ids: torch.LongTensor | None = None,
  452. pixel_values: torch.FloatTensor | None = None,
  453. attention_mask: torch.Tensor | None = None,
  454. position_ids: torch.LongTensor | None = None,
  455. past_key_values: Cache | None = None,
  456. inputs_embeds: torch.FloatTensor | None = None,
  457. labels: torch.LongTensor | None = None,
  458. use_cache: bool | None = None,
  459. logits_to_keep: int | torch.Tensor = 0,
  460. **kwargs,
  461. ) -> tuple | Ovis2ModelOutputWithPast:
  462. if (input_ids is None) ^ (inputs_embeds is not None):
  463. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  464. if inputs_embeds is None:
  465. inputs_embeds = self.get_input_embeddings()(input_ids)
  466. if pixel_values is not None:
  467. image_outputs = self.get_image_features(pixel_values=pixel_values, return_dict=True)
  468. image_features = image_outputs.pooler_output
  469. visual_indicator_features = image_outputs.visual_indicator_features
  470. special_image_mask = self.get_placeholder_mask(
  471. input_ids,
  472. inputs_embeds=inputs_embeds,
  473. image_features=image_features,
  474. )
  475. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  476. for i, visual_indicator_id in enumerate(self.visual_indicator_token_ids):
  477. if input_ids is None:
  478. mask = inputs_embeds == self.get_input_embeddings()(
  479. torch.tensor(visual_indicator_id, dtype=torch.long, device=inputs_embeds.device)
  480. )
  481. mask = mask.all(-1)
  482. else:
  483. mask = (input_ids == visual_indicator_id).to(inputs_embeds.device)
  484. if mask.any():
  485. inputs_embeds[mask] = (
  486. visual_indicator_features[i]
  487. .expand_as(inputs_embeds[mask])
  488. .to(inputs_embeds.device, inputs_embeds.dtype)
  489. )
  490. outputs = self.language_model(
  491. attention_mask=attention_mask,
  492. position_ids=position_ids,
  493. past_key_values=past_key_values,
  494. inputs_embeds=inputs_embeds,
  495. use_cache=use_cache,
  496. logits_to_keep=logits_to_keep,
  497. **kwargs,
  498. )
  499. return Ovis2ModelOutputWithPast(
  500. last_hidden_state=outputs.last_hidden_state,
  501. past_key_values=outputs.past_key_values,
  502. hidden_states=outputs.hidden_states,
  503. attentions=outputs.attentions,
  504. image_hidden_states=image_features if pixel_values is not None else None,
  505. )
  506. @auto_docstring
  507. class Ovis2ForConditionalGeneration(Ovis2PreTrainedModel, GenerationMixin):
  508. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  509. def __init__(self, config: Ovis2Config):
  510. super().__init__(config)
  511. self.model = Ovis2Model(config)
  512. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  513. self.post_init()
  514. def get_input_embeddings(self):
  515. return self.model.get_input_embeddings()
  516. def set_input_embeddings(self, value):
  517. self.model.set_input_embeddings(value)
  518. def get_output_embeddings(self) -> nn.Module:
  519. return self.lm_head
  520. @auto_docstring
  521. def get_image_features(
  522. self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  523. ) -> tuple | BaseModelOutputWithVisualIndicatorFeatures:
  524. return self.model.get_image_features(pixel_values=pixel_values, **kwargs)
  525. @can_return_tuple
  526. @auto_docstring
  527. def forward(
  528. self,
  529. input_ids: torch.LongTensor | None = None,
  530. pixel_values: torch.FloatTensor | None = None,
  531. attention_mask: torch.Tensor | None = None,
  532. position_ids: torch.LongTensor | None = None,
  533. past_key_values: Cache | None = None,
  534. inputs_embeds: torch.FloatTensor | None = None,
  535. labels: torch.LongTensor | None = None,
  536. use_cache: bool | None = None,
  537. logits_to_keep: int | torch.Tensor = 0,
  538. **kwargs,
  539. ) -> tuple | Ovis2CausalLMOutputWithPast:
  540. r"""
  541. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  542. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  543. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  544. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  545. Example:
  546. ```python
  547. >>> from PIL import Image
  548. >>> import httpx
  549. >>> from io import BytesIO
  550. >>> from transformers import AutoProcessor, Ovis2ForConditionalGeneration
  551. >>> model = Ovis2ForConditionalGeneration.from_pretrained("thisisiron/Ovis2-2B-hf")
  552. >>> processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-2B-hf")
  553. >>> prompt = "<|im_start|>user\n<image>\nDescribe the image.<|im_end|>\n<|im_start|>assistant\n"
  554. >>> url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
  555. >>> with httpx.stream("GET", url) as response:
  556. ... image = Image.open(BytesIO(response.read()))
  557. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  558. >>> # Generate
  559. >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
  560. >>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
  561. "user\n\nDescribe the image.\nassistant\nThe image features a brown dog standing on a wooden floor, looking up with"
  562. ```"""
  563. outputs = self.model(
  564. input_ids=input_ids,
  565. pixel_values=pixel_values,
  566. attention_mask=attention_mask,
  567. position_ids=position_ids,
  568. past_key_values=past_key_values,
  569. inputs_embeds=inputs_embeds,
  570. use_cache=use_cache,
  571. **kwargs,
  572. )
  573. hidden_states = outputs[0]
  574. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  575. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  576. logits = self.lm_head(hidden_states[:, slice_indices, :])
  577. loss = None
  578. if labels is not None:
  579. loss = self.loss_function(
  580. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  581. )
  582. return Ovis2CausalLMOutputWithPast(
  583. loss=loss,
  584. logits=logits,
  585. past_key_values=outputs.past_key_values,
  586. hidden_states=outputs.hidden_states,
  587. attentions=outputs.attentions,
  588. image_hidden_states=outputs.image_hidden_states,
  589. )
  590. def prepare_inputs_for_generation(
  591. self,
  592. input_ids,
  593. past_key_values=None,
  594. inputs_embeds=None,
  595. pixel_values=None,
  596. attention_mask=None,
  597. logits_to_keep=None,
  598. is_first_iteration=False,
  599. **kwargs,
  600. ):
  601. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  602. model_inputs = super().prepare_inputs_for_generation(
  603. input_ids,
  604. past_key_values=past_key_values,
  605. inputs_embeds=inputs_embeds,
  606. attention_mask=attention_mask,
  607. logits_to_keep=logits_to_keep,
  608. is_first_iteration=is_first_iteration,
  609. **kwargs,
  610. )
  611. if is_first_iteration or not kwargs.get("use_cache", True):
  612. # Pixel values are used only in the first iteration if available
  613. # In subsequent iterations, they are already merged with text and cached
  614. # NOTE: first iteration doesn't have to be prefill, it can be the first
  615. # iteration with a question and cached system prompt (continue generate from cache)
  616. model_inputs["pixel_values"] = pixel_values
  617. return model_inputs
  618. __all__ = ["Ovis2PreTrainedModel", "Ovis2Model", "Ovis2ForConditionalGeneration"]