modular_internvl.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639
  1. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections.abc
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. import torch
  18. import torch.nn as nn
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...utils import TransformersKwargs, auto_docstring, torch_int
  27. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  28. from ...utils.output_capturing import capture_outputs
  29. from ..clip.modeling_clip import CLIPMLP
  30. from ..janus.modeling_janus import JanusVisionAttention
  31. from ..llama.modeling_llama import LlamaRMSNorm
  32. from ..llava.modeling_llava import (
  33. LlavaCausalLMOutputWithPast,
  34. LlavaForConditionalGeneration,
  35. LlavaModel,
  36. LlavaModelOutputWithPast,
  37. LlavaPreTrainedModel,
  38. )
  39. from .configuration_internvl import InternVLConfig, InternVLVisionConfig
  40. def eager_attention_forward(
  41. module: nn.Module,
  42. query: torch.Tensor,
  43. key: torch.Tensor,
  44. value: torch.Tensor,
  45. attention_mask: torch.Tensor | None,
  46. scaling: float,
  47. dropout: float | int = 0.0,
  48. **kwargs,
  49. ):
  50. key_states = key
  51. value_states = value
  52. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  53. if attention_mask is not None:
  54. attn_weights = attn_weights + attention_mask
  55. # No upcasting of the attention weights to float32 in this implementation
  56. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  57. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  58. attn_output = torch.matmul(attn_weights, value_states)
  59. attn_output = attn_output.transpose(1, 2).contiguous()
  60. return attn_output, attn_weights
  61. class InternVLVisionRMSNorm(LlamaRMSNorm):
  62. pass
  63. class InternVLVisionAttention(JanusVisionAttention):
  64. def __init__(self, config: InternVLVisionConfig):
  65. super().__init__(config)
  66. del self.num_key_value_groups
  67. # Needed for flash attention
  68. self.is_causal = False
  69. qk_norm = config.use_qk_norm
  70. self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
  71. self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
  72. def forward(
  73. self,
  74. hidden_states: torch.Tensor,
  75. attention_mask: torch.Tensor | None = None,
  76. **kwargs: Unpack[TransformersKwargs],
  77. ):
  78. batch_size, seq_len, _ = hidden_states.size()
  79. query_states = self.q_proj(hidden_states)
  80. key_states = self.k_proj(hidden_states)
  81. value_states = self.v_proj(hidden_states)
  82. query_states = self.q_norm(query_states)
  83. key_states = self.k_norm(key_states)
  84. query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  85. key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  86. value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  87. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  88. self.config._attn_implementation, eager_attention_forward
  89. )
  90. attn_output, attn_weights = attention_interface(
  91. self,
  92. query_states,
  93. key_states,
  94. value_states,
  95. attention_mask,
  96. dropout=0.0 if not self.training else self.attention_dropout,
  97. scaling=self.scale,
  98. is_causal=False,
  99. **kwargs,
  100. )
  101. attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
  102. output = self.projection_layer(attn_output)
  103. output = self.projection_dropout(output)
  104. return output, attn_weights
  105. @dataclass
  106. @auto_docstring(
  107. custom_intro="""
  108. Class for outputs of [`InternVLVisionModel`].
  109. """
  110. )
  111. class InternVLVisionModelOutputWithPooling(BaseModelOutputWithPooling):
  112. r"""
  113. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  114. Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
  115. *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
  116. will be returned.
  117. """
  118. class InternVLVisionPatchEmbeddings(nn.Module):
  119. """
  120. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  121. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  122. Transformer.
  123. """
  124. def __init__(self, config):
  125. super().__init__()
  126. image_size, patch_size = config.image_size, config.patch_size
  127. num_channels, hidden_size = config.num_channels, config.hidden_size
  128. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  129. patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  130. self.image_size = image_size
  131. self.patch_size = patch_size
  132. self.num_channels = num_channels
  133. self.num_patches = num_patches
  134. self.patch_shape = patch_shape
  135. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  136. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  137. batch_size, num_channels, height, width = pixel_values.shape
  138. if num_channels != self.num_channels:
  139. raise ValueError(
  140. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  141. )
  142. embeddings = self.projection(pixel_values.to(self.projection.weight.dtype))
  143. embeddings = embeddings.flatten(2).transpose(1, 2)
  144. return embeddings
  145. # Based on timm implementation, which can be found here:
  146. # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  147. class InternVLVisionEmbeddings(nn.Module):
  148. """
  149. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
  150. """
  151. def __init__(self, config: InternVLVisionConfig) -> None:
  152. super().__init__()
  153. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  154. if config.use_mask_token:
  155. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  156. else:
  157. self.mask_token = None
  158. self.patch_embeddings = InternVLVisionPatchEmbeddings(config)
  159. self.patch_size = config.patch_size
  160. self.image_size = (
  161. config.image_size
  162. if isinstance(config.image_size, collections.abc.Iterable)
  163. else (config.image_size, config.image_size)
  164. )
  165. num_patches = self.patch_embeddings.num_patches
  166. if config.use_absolute_position_embeddings:
  167. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  168. else:
  169. self.position_embeddings = None
  170. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  171. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  172. """
  173. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  174. images. This method is also adapted to support torch.jit tracing.
  175. Adapted from:
  176. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  177. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  178. """
  179. num_patches = embeddings.shape[1] - 1
  180. num_positions = self.position_embeddings.shape[1] - 1
  181. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  182. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  183. return self.position_embeddings
  184. class_pos_embed = self.position_embeddings[:, :1]
  185. patch_pos_embed = self.position_embeddings[:, 1:]
  186. dim = embeddings.shape[-1]
  187. new_height = height // self.patch_size[0]
  188. new_width = width // self.patch_size[1]
  189. sqrt_num_positions = torch_int(num_positions**0.5)
  190. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  191. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  192. patch_pos_embed = nn.functional.interpolate(
  193. patch_pos_embed,
  194. size=(new_height, new_width),
  195. mode="bicubic",
  196. align_corners=False,
  197. )
  198. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  199. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  200. def forward(
  201. self,
  202. pixel_values: torch.Tensor,
  203. bool_masked_pos: torch.BoolTensor | None = None,
  204. ) -> torch.Tensor:
  205. _, _, height, width = pixel_values.shape
  206. embeddings = self.patch_embeddings(pixel_values)
  207. batch_size, seq_len, _ = embeddings.size()
  208. if bool_masked_pos is not None:
  209. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  210. # replace the masked visual tokens by mask_tokens
  211. w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  212. embeddings = embeddings * (1 - w) + mask_tokens * w
  213. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  214. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  215. if self.position_embeddings is not None:
  216. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  217. embeddings = self.dropout(embeddings)
  218. return embeddings
  219. class InternVLVisionMLP(CLIPMLP):
  220. pass
  221. NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm}
  222. class InternVLVisionLayer(GradientCheckpointingLayer):
  223. """This corresponds to the Block class in the timm implementation."""
  224. def __init__(self, config: InternVLVisionConfig) -> None:
  225. super().__init__()
  226. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  227. self.seq_len_dim = 1
  228. self.attention = InternVLVisionAttention(config)
  229. self.mlp = InternVLVisionMLP(config)
  230. # InternVL uses different layernorm implementations for different models
  231. self.layernorm_before = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps)
  232. self.layernorm_after = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps)
  233. init_values = config.layer_scale_init_value
  234. self.lambda_1 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
  235. self.lambda_2 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
  236. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  237. def forward(
  238. self,
  239. hidden_states: torch.Tensor,
  240. ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
  241. attention_output, _ = self.attention(
  242. self.layernorm_before(hidden_states), # in InternVLVision, layernorm is applied before self-attention
  243. )
  244. attention_output = self.lambda_1 * attention_output
  245. # first residual connection
  246. hidden_states = attention_output + hidden_states
  247. # in InternVLVision, layernorm is also applied after self-attention
  248. layer_output = self.layernorm_after(hidden_states)
  249. layer_output = self.mlp(layer_output)
  250. layer_output = self.dropout(layer_output)
  251. if self.lambda_2 is not None:
  252. layer_output = self.lambda_2 * layer_output
  253. # second residual connection
  254. layer_output = layer_output + hidden_states
  255. return layer_output
  256. class InternVLVisionEncoder(nn.Module):
  257. def __init__(self, config: InternVLVisionConfig) -> None:
  258. super().__init__()
  259. self.config = config
  260. self.layer = nn.ModuleList([InternVLVisionLayer(config) for i in range(config.num_hidden_layers)])
  261. self.gradient_checkpointing = False
  262. def forward(
  263. self,
  264. hidden_states: torch.Tensor,
  265. ) -> tuple | BaseModelOutput:
  266. for layer_module in self.layer:
  267. hidden_states = layer_module(hidden_states)
  268. return BaseModelOutput(
  269. last_hidden_state=hidden_states,
  270. )
  271. @auto_docstring
  272. class InternVLVisionPreTrainedModel(PreTrainedModel):
  273. config: InternVLVisionConfig
  274. base_model_prefix = "internvl_vision"
  275. main_input_name = "pixel_values"
  276. input_modalities = ("image", "video")
  277. supports_gradient_checkpointing = True
  278. _no_split_modules = ["InternVLVisionLayer"]
  279. _supports_sdpa = True
  280. _supports_flash_attn = True
  281. _supports_flex_attn = True
  282. _supports_attention_backend = True
  283. _can_record_outputs = {
  284. "hidden_states": InternVLVisionLayer,
  285. "attentions": InternVLVisionAttention,
  286. }
  287. @torch.no_grad()
  288. def _init_weights(self, module):
  289. """Initialize the weights"""
  290. super()._init_weights(module)
  291. if isinstance(module, InternVLVisionEmbeddings):
  292. init.zeros_(module.cls_token)
  293. if module.mask_token is not None:
  294. init.zeros_(module.mask_token)
  295. if module.position_embeddings is not None:
  296. init.zeros_(module.position_embeddings)
  297. elif isinstance(module, InternVLVisionLayer):
  298. init.constant_(module.lambda_1, self.config.layer_scale_init_value)
  299. init.constant_(module.lambda_2, self.config.layer_scale_init_value)
  300. @auto_docstring
  301. class InternVLVisionModel(InternVLVisionPreTrainedModel):
  302. def __init__(self, config: InternVLVisionConfig) -> None:
  303. super().__init__(config)
  304. self.config = config
  305. self.embeddings = InternVLVisionEmbeddings(config)
  306. self.encoder = InternVLVisionEncoder(config)
  307. self.layernorm = (
  308. nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  309. )
  310. # Initialize weights and apply final processing
  311. self.post_init()
  312. def get_input_embeddings(self):
  313. return self.embeddings.patch_embeddings
  314. @merge_with_config_defaults
  315. @capture_outputs(tie_last_hidden_states=False)
  316. @auto_docstring
  317. def forward(
  318. self, pixel_values: torch.Tensor, bool_masked_pos: torch.BoolTensor | None = None, **kwargs
  319. ) -> tuple | InternVLVisionModelOutputWithPooling:
  320. r"""
  321. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  322. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  323. """
  324. embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
  325. encoder_outputs = self.encoder(embedding_output)
  326. sequence_output = encoder_outputs[0]
  327. sequence_output = self.layernorm(sequence_output)
  328. return InternVLVisionModelOutputWithPooling(
  329. last_hidden_state=sequence_output,
  330. hidden_states=encoder_outputs.hidden_states,
  331. attentions=encoder_outputs.attentions,
  332. )
  333. class InternVLPreTrainedModel(LlavaPreTrainedModel):
  334. input_modalities = ("image", "text", "video")
  335. INTERNVL_INPUTS_DOCSTRING = None
  336. class InternVLMultiModalProjector(nn.Module):
  337. def __init__(self, config: InternVLConfig):
  338. super().__init__()
  339. self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2)
  340. self.linear_1 = nn.Linear(
  341. config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2, config.text_config.hidden_size
  342. )
  343. self.act = ACT2FN[config.projector_hidden_act]
  344. self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
  345. def forward(self, image_features):
  346. hidden_states = self.layer_norm(image_features)
  347. hidden_states = self.linear_1(hidden_states)
  348. hidden_states = self.act(hidden_states)
  349. hidden_states = self.linear_2(hidden_states)
  350. return hidden_states
  351. class InternVLModelOutputWithPast(LlavaModelOutputWithPast):
  352. pass
  353. class InternVLModel(LlavaModel):
  354. def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
  355. """Perform pixel shuffle downsampling on vision features.
  356. Args:
  357. vision_features (`torch.Tensor`):
  358. Input tensor of shape (batch_size, width, height, channels).
  359. scale_factor (`float`, *optional*, defaults to `0.5`):
  360. Factor by which to downsample. Default is 0.5, which halves the dimensions.
  361. Returns:
  362. vision_features (`torch.Tensor`):
  363. Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
  364. """
  365. batch_size, width, height, channels = vision_features.size()
  366. if height % scale_factor != 0 or width % scale_factor != 0:
  367. raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.")
  368. # Reshape to allow downsampling
  369. vision_features = vision_features.view(
  370. batch_size, width, int(height * scale_factor), int(channels / scale_factor)
  371. )
  372. # Permute dimensions to align downsampled axis correctly
  373. vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
  374. # Reshape to achieve final downsampled dimensions
  375. vision_features = vision_features.view(
  376. batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2))
  377. )
  378. # Swap height and width back for proper orientation
  379. vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
  380. return vision_features
  381. @merge_with_config_defaults
  382. @can_return_tuple
  383. @auto_docstring(
  384. custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
  385. )
  386. def get_image_features(
  387. self,
  388. pixel_values: torch.FloatTensor,
  389. vision_feature_layer: int | list[int] | list[int] | None = None,
  390. vision_feature_select_strategy: str | None = None,
  391. **kwargs: Unpack[TransformersKwargs],
  392. ) -> tuple | BaseModelOutputWithPooling:
  393. r"""
  394. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
  395. The tensors corresponding to the input images.
  396. vision_feature_layer (`int` or `list[int]`):
  397. Layer index or list of layer indices to extract features from.
  398. """
  399. pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
  400. downsample_ratio = self.config.downsample_ratio
  401. if vision_feature_layer != -1:
  402. kwargs["output_hidden_states"] = True
  403. vision_outputs = self.vision_tower(pixel_values=pixel_values, return_dict=True, **kwargs)
  404. if vision_feature_layer == -1:
  405. vision_features = vision_outputs.last_hidden_state
  406. else:
  407. vision_features = vision_outputs.hidden_states[vision_feature_layer]
  408. if vision_feature_select_strategy == "default":
  409. vision_features = vision_features[:, 1:, :]
  410. # Calculate dimensions based on vision features
  411. channels = vision_features.shape[1]
  412. feature_size = int(channels**0.5)
  413. batch_size = vision_features.shape[0]
  414. # Reshape tensor to spatial dimensions
  415. vision_features = vision_features.reshape(batch_size, feature_size, feature_size, -1)
  416. # Apply downsampling using pixel shuffle
  417. vision_features = self.pixel_shuffle(vision_features, scale_factor=downsample_ratio)
  418. # Reshape tensor to prepare for projection
  419. vision_features = vision_features.reshape(batch_size, -1, vision_features.shape[-1])
  420. # Project features through multi-modal projector
  421. vision_features = self.multi_modal_projector(vision_features)
  422. vision_outputs.pooler_output = vision_features
  423. return vision_outputs
  424. @can_return_tuple
  425. @auto_docstring
  426. def forward(
  427. self,
  428. input_ids: torch.LongTensor | None = None,
  429. pixel_values: torch.FloatTensor | None = None,
  430. attention_mask: torch.Tensor | None = None,
  431. position_ids: torch.LongTensor | None = None,
  432. past_key_values: Cache | None = None,
  433. inputs_embeds: torch.FloatTensor | None = None,
  434. vision_feature_layer: int | list[int] | list[int] | None = None,
  435. vision_feature_select_strategy: str | None = None,
  436. **kwargs: Unpack[TransformersKwargs],
  437. ) -> tuple | InternVLModelOutputWithPast:
  438. if (input_ids is None) ^ (inputs_embeds is not None):
  439. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  440. if inputs_embeds is None:
  441. inputs_embeds = self.get_input_embeddings()(input_ids)
  442. if pixel_values is not None:
  443. image_features = self.get_image_features(
  444. pixel_values=pixel_values,
  445. vision_feature_layer=vision_feature_layer,
  446. vision_feature_select_strategy=vision_feature_select_strategy,
  447. return_dict=True,
  448. ).pooler_output
  449. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  450. special_image_mask = self.get_placeholder_mask(
  451. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  452. )
  453. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  454. outputs = self.language_model(
  455. attention_mask=attention_mask,
  456. position_ids=position_ids,
  457. past_key_values=past_key_values,
  458. inputs_embeds=inputs_embeds,
  459. **kwargs,
  460. )
  461. return InternVLModelOutputWithPast(
  462. last_hidden_state=outputs.last_hidden_state,
  463. past_key_values=outputs.past_key_values,
  464. hidden_states=outputs.hidden_states,
  465. attentions=outputs.attentions,
  466. image_hidden_states=image_features if pixel_values is not None else None,
  467. )
  468. class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
  469. pass
  470. class InternVLForConditionalGeneration(LlavaForConditionalGeneration):
  471. def forward(**super_kwargs):
  472. r"""
  473. Example:
  474. ```python
  475. >>> import torch
  476. >>> from transformers import AutoProcessor, AutoModelForImageTextToText
  477. >>> torch_device = "cuda"
  478. >>> processor = AutoProcessor.from_pretrained("OpenGVLab/InternVL3-1B-hf")
  479. >>> model = AutoModelForImageTextToText.from_pretrained(
  480. ... "OpenGVLab/InternVL3-1B-hf", dtype=torch.bfloat16, device_map=torch_device
  481. ... )
  482. >>> messages = [
  483. ... {
  484. ... "role": "user",
  485. ... "content": [
  486. ... {
  487. ... "type": "image",
  488. ... "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
  489. ... },
  490. ... {
  491. ... "type": "image",
  492. ... "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
  493. ... },
  494. ... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
  495. ... ],
  496. ... },
  497. ... ]
  498. >>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device)
  499. >>> generate_ids = model.generate(**inputs, max_new_tokens=200)
  500. >>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
  501. The images depict the Statue of Liberty and the Golden Gate Bridge.
  502. ```"""
  503. super().forward(**super_kwargs)
  504. __all__ = [
  505. "InternVLVisionPreTrainedModel",
  506. "InternVLVisionModel",
  507. "InternVLPreTrainedModel",
  508. "InternVLModel",
  509. "InternVLForConditionalGeneration",
  510. ]