modeling_internvl.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/internvl/modular_internvl.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_internvl.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import collections.abc
  21. from collections.abc import Callable
  22. from dataclasses import dataclass
  23. import torch
  24. import torch.nn as nn
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernel_forward_from_hub
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int
  35. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  36. from ...utils.output_capturing import capture_outputs
  37. from ..auto import AutoModel
  38. from .configuration_internvl import InternVLConfig, InternVLVisionConfig
  39. @use_kernel_forward_from_hub("RMSNorm")
  40. class InternVLVisionRMSNorm(nn.Module):
  41. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  42. """
  43. InternVLVisionRMSNorm is equivalent to T5LayerNorm
  44. """
  45. super().__init__()
  46. self.weight = nn.Parameter(torch.ones(hidden_size))
  47. self.variance_epsilon = eps
  48. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  49. input_dtype = hidden_states.dtype
  50. hidden_states = hidden_states.to(torch.float32)
  51. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  52. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  53. return self.weight * hidden_states.to(input_dtype)
  54. def extra_repr(self):
  55. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  56. def eager_attention_forward(
  57. module: nn.Module,
  58. query: torch.Tensor,
  59. key: torch.Tensor,
  60. value: torch.Tensor,
  61. attention_mask: torch.Tensor | None,
  62. scaling: float,
  63. dropout: float | int = 0.0,
  64. **kwargs,
  65. ):
  66. key_states = key
  67. value_states = value
  68. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  69. if attention_mask is not None:
  70. attn_weights = attn_weights + attention_mask
  71. # No upcasting of the attention weights to float32 in this implementation
  72. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  73. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  74. attn_output = torch.matmul(attn_weights, value_states)
  75. attn_output = attn_output.transpose(1, 2).contiguous()
  76. return attn_output, attn_weights
  77. class InternVLVisionAttention(nn.Module):
  78. """Attention Class for InternVL Vision Encoder"""
  79. def __init__(self, config: InternVLVisionConfig):
  80. super().__init__()
  81. self.config = config
  82. self.embed_dim = config.hidden_size
  83. self.num_heads = config.num_attention_heads
  84. self.head_dim = self.embed_dim // self.num_heads
  85. if self.head_dim * self.num_heads != self.embed_dim:
  86. raise ValueError(
  87. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  88. f" {self.num_heads})."
  89. )
  90. self.scale = self.head_dim**-0.5
  91. self.attention_dropout = config.attention_dropout
  92. proj_dropout = config.projection_dropout
  93. qk_norm = config.use_qk_norm
  94. # Needed for flash attention
  95. self.is_causal = False
  96. self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  97. self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  98. self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  99. self.projection_layer = nn.Linear(self.embed_dim, self.embed_dim)
  100. self.projection_dropout = nn.Dropout(proj_dropout) if proj_dropout > 0 else nn.Identity()
  101. self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
  102. self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
  103. def forward(
  104. self,
  105. hidden_states: torch.Tensor,
  106. attention_mask: torch.Tensor | None = None,
  107. **kwargs: Unpack[TransformersKwargs],
  108. ):
  109. batch_size, seq_len, _ = hidden_states.size()
  110. query_states = self.q_proj(hidden_states)
  111. key_states = self.k_proj(hidden_states)
  112. value_states = self.v_proj(hidden_states)
  113. query_states = self.q_norm(query_states)
  114. key_states = self.k_norm(key_states)
  115. query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  116. key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  117. value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  118. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  119. self.config._attn_implementation, eager_attention_forward
  120. )
  121. attn_output, attn_weights = attention_interface(
  122. self,
  123. query_states,
  124. key_states,
  125. value_states,
  126. attention_mask,
  127. dropout=0.0 if not self.training else self.attention_dropout,
  128. scaling=self.scale,
  129. is_causal=False,
  130. **kwargs,
  131. )
  132. attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
  133. output = self.projection_layer(attn_output)
  134. output = self.projection_dropout(output)
  135. return output, attn_weights
  136. @dataclass
  137. @auto_docstring(
  138. custom_intro="""
  139. Class for outputs of [`InternVLVisionModel`].
  140. """
  141. )
  142. class InternVLVisionModelOutputWithPooling(BaseModelOutputWithPooling):
  143. r"""
  144. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  145. Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
  146. *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
  147. will be returned.
  148. """
  149. class InternVLVisionPatchEmbeddings(nn.Module):
  150. """
  151. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  152. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  153. Transformer.
  154. """
  155. def __init__(self, config):
  156. super().__init__()
  157. image_size, patch_size = config.image_size, config.patch_size
  158. num_channels, hidden_size = config.num_channels, config.hidden_size
  159. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  160. patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  161. self.image_size = image_size
  162. self.patch_size = patch_size
  163. self.num_channels = num_channels
  164. self.num_patches = num_patches
  165. self.patch_shape = patch_shape
  166. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  167. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  168. batch_size, num_channels, height, width = pixel_values.shape
  169. if num_channels != self.num_channels:
  170. raise ValueError(
  171. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  172. )
  173. embeddings = self.projection(pixel_values.to(self.projection.weight.dtype))
  174. embeddings = embeddings.flatten(2).transpose(1, 2)
  175. return embeddings
  176. # Based on timm implementation, which can be found here:
  177. # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  178. class InternVLVisionEmbeddings(nn.Module):
  179. """
  180. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
  181. """
  182. def __init__(self, config: InternVLVisionConfig) -> None:
  183. super().__init__()
  184. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  185. if config.use_mask_token:
  186. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  187. else:
  188. self.mask_token = None
  189. self.patch_embeddings = InternVLVisionPatchEmbeddings(config)
  190. self.patch_size = config.patch_size
  191. self.image_size = (
  192. config.image_size
  193. if isinstance(config.image_size, collections.abc.Iterable)
  194. else (config.image_size, config.image_size)
  195. )
  196. num_patches = self.patch_embeddings.num_patches
  197. if config.use_absolute_position_embeddings:
  198. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  199. else:
  200. self.position_embeddings = None
  201. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  202. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  203. """
  204. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  205. images. This method is also adapted to support torch.jit tracing.
  206. Adapted from:
  207. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  208. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  209. """
  210. num_patches = embeddings.shape[1] - 1
  211. num_positions = self.position_embeddings.shape[1] - 1
  212. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  213. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  214. return self.position_embeddings
  215. class_pos_embed = self.position_embeddings[:, :1]
  216. patch_pos_embed = self.position_embeddings[:, 1:]
  217. dim = embeddings.shape[-1]
  218. new_height = height // self.patch_size[0]
  219. new_width = width // self.patch_size[1]
  220. sqrt_num_positions = torch_int(num_positions**0.5)
  221. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  222. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  223. patch_pos_embed = nn.functional.interpolate(
  224. patch_pos_embed,
  225. size=(new_height, new_width),
  226. mode="bicubic",
  227. align_corners=False,
  228. )
  229. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  230. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  231. def forward(
  232. self,
  233. pixel_values: torch.Tensor,
  234. bool_masked_pos: torch.BoolTensor | None = None,
  235. ) -> torch.Tensor:
  236. _, _, height, width = pixel_values.shape
  237. embeddings = self.patch_embeddings(pixel_values)
  238. batch_size, seq_len, _ = embeddings.size()
  239. if bool_masked_pos is not None:
  240. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  241. # replace the masked visual tokens by mask_tokens
  242. w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  243. embeddings = embeddings * (1 - w) + mask_tokens * w
  244. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  245. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  246. if self.position_embeddings is not None:
  247. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  248. embeddings = self.dropout(embeddings)
  249. return embeddings
  250. class InternVLVisionMLP(nn.Module):
  251. def __init__(self, config):
  252. super().__init__()
  253. self.config = config
  254. self.activation_fn = ACT2FN[config.hidden_act]
  255. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  256. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  257. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  258. hidden_states = self.fc1(hidden_states)
  259. hidden_states = self.activation_fn(hidden_states)
  260. hidden_states = self.fc2(hidden_states)
  261. return hidden_states
  262. NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm}
  263. class InternVLVisionLayer(GradientCheckpointingLayer):
  264. """This corresponds to the Block class in the timm implementation."""
  265. def __init__(self, config: InternVLVisionConfig) -> None:
  266. super().__init__()
  267. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  268. self.seq_len_dim = 1
  269. self.attention = InternVLVisionAttention(config)
  270. self.mlp = InternVLVisionMLP(config)
  271. # InternVL uses different layernorm implementations for different models
  272. self.layernorm_before = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps)
  273. self.layernorm_after = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps)
  274. init_values = config.layer_scale_init_value
  275. self.lambda_1 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
  276. self.lambda_2 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
  277. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  278. def forward(
  279. self,
  280. hidden_states: torch.Tensor,
  281. ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
  282. attention_output, _ = self.attention(
  283. self.layernorm_before(hidden_states), # in InternVLVision, layernorm is applied before self-attention
  284. )
  285. attention_output = self.lambda_1 * attention_output
  286. # first residual connection
  287. hidden_states = attention_output + hidden_states
  288. # in InternVLVision, layernorm is also applied after self-attention
  289. layer_output = self.layernorm_after(hidden_states)
  290. layer_output = self.mlp(layer_output)
  291. layer_output = self.dropout(layer_output)
  292. if self.lambda_2 is not None:
  293. layer_output = self.lambda_2 * layer_output
  294. # second residual connection
  295. layer_output = layer_output + hidden_states
  296. return layer_output
  297. class InternVLVisionEncoder(nn.Module):
  298. def __init__(self, config: InternVLVisionConfig) -> None:
  299. super().__init__()
  300. self.config = config
  301. self.layer = nn.ModuleList([InternVLVisionLayer(config) for i in range(config.num_hidden_layers)])
  302. self.gradient_checkpointing = False
  303. def forward(
  304. self,
  305. hidden_states: torch.Tensor,
  306. ) -> tuple | BaseModelOutput:
  307. for layer_module in self.layer:
  308. hidden_states = layer_module(hidden_states)
  309. return BaseModelOutput(
  310. last_hidden_state=hidden_states,
  311. )
  312. @auto_docstring
  313. class InternVLVisionPreTrainedModel(PreTrainedModel):
  314. config: InternVLVisionConfig
  315. base_model_prefix = "internvl_vision"
  316. main_input_name = "pixel_values"
  317. input_modalities = ("image", "video")
  318. supports_gradient_checkpointing = True
  319. _no_split_modules = ["InternVLVisionLayer"]
  320. _supports_sdpa = True
  321. _supports_flash_attn = True
  322. _supports_flex_attn = True
  323. _supports_attention_backend = True
  324. _can_record_outputs = {
  325. "hidden_states": InternVLVisionLayer,
  326. "attentions": InternVLVisionAttention,
  327. }
  328. @torch.no_grad()
  329. def _init_weights(self, module):
  330. """Initialize the weights"""
  331. super()._init_weights(module)
  332. if isinstance(module, InternVLVisionEmbeddings):
  333. init.zeros_(module.cls_token)
  334. if module.mask_token is not None:
  335. init.zeros_(module.mask_token)
  336. if module.position_embeddings is not None:
  337. init.zeros_(module.position_embeddings)
  338. elif isinstance(module, InternVLVisionLayer):
  339. init.constant_(module.lambda_1, self.config.layer_scale_init_value)
  340. init.constant_(module.lambda_2, self.config.layer_scale_init_value)
  341. @auto_docstring
  342. class InternVLVisionModel(InternVLVisionPreTrainedModel):
  343. def __init__(self, config: InternVLVisionConfig) -> None:
  344. super().__init__(config)
  345. self.config = config
  346. self.embeddings = InternVLVisionEmbeddings(config)
  347. self.encoder = InternVLVisionEncoder(config)
  348. self.layernorm = (
  349. nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  350. )
  351. # Initialize weights and apply final processing
  352. self.post_init()
  353. def get_input_embeddings(self):
  354. return self.embeddings.patch_embeddings
  355. @merge_with_config_defaults
  356. @capture_outputs(tie_last_hidden_states=False)
  357. @auto_docstring
  358. def forward(
  359. self, pixel_values: torch.Tensor, bool_masked_pos: torch.BoolTensor | None = None, **kwargs
  360. ) -> tuple | InternVLVisionModelOutputWithPooling:
  361. r"""
  362. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  363. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  364. """
  365. embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
  366. encoder_outputs = self.encoder(embedding_output)
  367. sequence_output = encoder_outputs[0]
  368. sequence_output = self.layernorm(sequence_output)
  369. return InternVLVisionModelOutputWithPooling(
  370. last_hidden_state=sequence_output,
  371. hidden_states=encoder_outputs.hidden_states,
  372. attentions=encoder_outputs.attentions,
  373. )
  374. @auto_docstring
  375. class InternVLPreTrainedModel(PreTrainedModel):
  376. config: InternVLConfig
  377. base_model_prefix = "model"
  378. input_modalities = ("image", "text", "video")
  379. supports_gradient_checkpointing = True
  380. _skip_keys_device_placement = "past_key_values"
  381. _supports_flash_attn = True
  382. _supports_sdpa = True
  383. _can_compile_fullgraph = True
  384. _supports_flex_attn = True
  385. _supports_attention_backend = True
  386. class InternVLMultiModalProjector(nn.Module):
  387. def __init__(self, config: InternVLConfig):
  388. super().__init__()
  389. self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2)
  390. self.linear_1 = nn.Linear(
  391. config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2, config.text_config.hidden_size
  392. )
  393. self.act = ACT2FN[config.projector_hidden_act]
  394. self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
  395. def forward(self, image_features):
  396. hidden_states = self.layer_norm(image_features)
  397. hidden_states = self.linear_1(hidden_states)
  398. hidden_states = self.act(hidden_states)
  399. hidden_states = self.linear_2(hidden_states)
  400. return hidden_states
  401. @dataclass
  402. @auto_docstring(
  403. custom_intro="""
  404. Base class for InternVL outputs, with hidden states and attentions.
  405. """
  406. )
  407. class InternVLModelOutputWithPast(BaseModelOutputWithPast):
  408. r"""
  409. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  410. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  411. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  412. `past_key_values` input) to speed up sequential decoding.
  413. image_hidden_states (`torch.FloatTensor`, *optional*):
  414. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  415. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  416. """
  417. image_hidden_states: torch.FloatTensor | None = None
  418. @auto_docstring(
  419. custom_intro="""
  420. The InternVL model which consists of a vision backbone and a language model, without a language modeling head.
  421. """
  422. )
  423. class InternVLModel(InternVLPreTrainedModel):
  424. def __init__(self, config: InternVLConfig):
  425. super().__init__(config)
  426. self.vision_tower = AutoModel.from_config(config.vision_config)
  427. self.multi_modal_projector = InternVLMultiModalProjector(config)
  428. self.language_model = AutoModel.from_config(config.text_config)
  429. self.post_init()
  430. def get_input_embeddings(self):
  431. return self.language_model.get_input_embeddings()
  432. def set_input_embeddings(self, value):
  433. self.language_model.set_input_embeddings(value)
  434. @merge_with_config_defaults
  435. @can_return_tuple
  436. @auto_docstring(
  437. custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
  438. )
  439. def get_image_features(
  440. self,
  441. pixel_values: torch.FloatTensor,
  442. vision_feature_layer: int | list[int] | list[int] | None = None,
  443. vision_feature_select_strategy: str | None = None,
  444. **kwargs: Unpack[TransformersKwargs],
  445. ) -> tuple | BaseModelOutputWithPooling:
  446. r"""
  447. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
  448. The tensors corresponding to the input images.
  449. vision_feature_layer (`int` or `list[int]`):
  450. Layer index or list of layer indices to extract features from.
  451. """
  452. pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
  453. downsample_ratio = self.config.downsample_ratio
  454. if vision_feature_layer != -1:
  455. kwargs["output_hidden_states"] = True
  456. vision_outputs = self.vision_tower(pixel_values=pixel_values, return_dict=True, **kwargs)
  457. if vision_feature_layer == -1:
  458. vision_features = vision_outputs.last_hidden_state
  459. else:
  460. vision_features = vision_outputs.hidden_states[vision_feature_layer]
  461. if vision_feature_select_strategy == "default":
  462. vision_features = vision_features[:, 1:, :]
  463. # Calculate dimensions based on vision features
  464. channels = vision_features.shape[1]
  465. feature_size = int(channels**0.5)
  466. batch_size = vision_features.shape[0]
  467. # Reshape tensor to spatial dimensions
  468. vision_features = vision_features.reshape(batch_size, feature_size, feature_size, -1)
  469. # Apply downsampling using pixel shuffle
  470. vision_features = self.pixel_shuffle(vision_features, scale_factor=downsample_ratio)
  471. # Reshape tensor to prepare for projection
  472. vision_features = vision_features.reshape(batch_size, -1, vision_features.shape[-1])
  473. # Project features through multi-modal projector
  474. vision_features = self.multi_modal_projector(vision_features)
  475. vision_outputs.pooler_output = vision_features
  476. return vision_outputs
  477. def get_placeholder_mask(
  478. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  479. ):
  480. """
  481. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  482. equal to the length of multimodal features. If the lengths are different, an error is raised.
  483. """
  484. if input_ids is None:
  485. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  486. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  487. )
  488. special_image_mask = special_image_mask.all(-1)
  489. else:
  490. special_image_mask = input_ids == self.config.image_token_id
  491. n_image_tokens = special_image_mask.sum()
  492. n_image_features = image_features.shape[0] * image_features.shape[1]
  493. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  494. torch_compilable_check(
  495. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  496. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
  497. )
  498. return special_image_mask
  499. @can_return_tuple
  500. @auto_docstring
  501. def forward(
  502. self,
  503. input_ids: torch.LongTensor | None = None,
  504. pixel_values: torch.FloatTensor | None = None,
  505. attention_mask: torch.Tensor | None = None,
  506. position_ids: torch.LongTensor | None = None,
  507. past_key_values: Cache | None = None,
  508. inputs_embeds: torch.FloatTensor | None = None,
  509. vision_feature_layer: int | list[int] | list[int] | None = None,
  510. vision_feature_select_strategy: str | None = None,
  511. **kwargs: Unpack[TransformersKwargs],
  512. ) -> tuple | InternVLModelOutputWithPast:
  513. if (input_ids is None) ^ (inputs_embeds is not None):
  514. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  515. if inputs_embeds is None:
  516. inputs_embeds = self.get_input_embeddings()(input_ids)
  517. if pixel_values is not None:
  518. image_features = self.get_image_features(
  519. pixel_values=pixel_values,
  520. vision_feature_layer=vision_feature_layer,
  521. vision_feature_select_strategy=vision_feature_select_strategy,
  522. return_dict=True,
  523. ).pooler_output
  524. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  525. special_image_mask = self.get_placeholder_mask(
  526. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  527. )
  528. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  529. outputs = self.language_model(
  530. attention_mask=attention_mask,
  531. position_ids=position_ids,
  532. past_key_values=past_key_values,
  533. inputs_embeds=inputs_embeds,
  534. **kwargs,
  535. )
  536. return InternVLModelOutputWithPast(
  537. last_hidden_state=outputs.last_hidden_state,
  538. past_key_values=outputs.past_key_values,
  539. hidden_states=outputs.hidden_states,
  540. attentions=outputs.attentions,
  541. image_hidden_states=image_features if pixel_values is not None else None,
  542. )
  543. def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
  544. """Perform pixel shuffle downsampling on vision features.
  545. Args:
  546. vision_features (`torch.Tensor`):
  547. Input tensor of shape (batch_size, width, height, channels).
  548. scale_factor (`float`, *optional*, defaults to `0.5`):
  549. Factor by which to downsample. Default is 0.5, which halves the dimensions.
  550. Returns:
  551. vision_features (`torch.Tensor`):
  552. Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
  553. """
  554. batch_size, width, height, channels = vision_features.size()
  555. if height % scale_factor != 0 or width % scale_factor != 0:
  556. raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.")
  557. # Reshape to allow downsampling
  558. vision_features = vision_features.view(
  559. batch_size, width, int(height * scale_factor), int(channels / scale_factor)
  560. )
  561. # Permute dimensions to align downsampled axis correctly
  562. vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
  563. # Reshape to achieve final downsampled dimensions
  564. vision_features = vision_features.view(
  565. batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2))
  566. )
  567. # Swap height and width back for proper orientation
  568. vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
  569. return vision_features
  570. @dataclass
  571. @auto_docstring(
  572. custom_intro="""
  573. Base class for InternVL causal language model (or autoregressive) outputs.
  574. """
  575. )
  576. class InternVLCausalLMOutputWithPast(ModelOutput):
  577. r"""
  578. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  579. Language modeling loss (for next-token prediction).
  580. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  581. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  582. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  583. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  584. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  585. `past_key_values` input) to speed up sequential decoding.
  586. image_hidden_states (`torch.FloatTensor`, *optional*):
  587. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  588. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  589. """
  590. loss: torch.FloatTensor | None = None
  591. logits: torch.FloatTensor | None = None
  592. past_key_values: Cache | None = None
  593. hidden_states: tuple[torch.FloatTensor] | None = None
  594. attentions: tuple[torch.FloatTensor] | None = None
  595. image_hidden_states: torch.FloatTensor | None = None
  596. @auto_docstring(
  597. custom_intro="""
  598. The INTERNVL model which consists of a vision backbone and a language model.
  599. """
  600. )
  601. class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin):
  602. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  603. def __init__(self, config: InternVLConfig):
  604. super().__init__(config)
  605. self.model = InternVLModel(config)
  606. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  607. self.post_init()
  608. def get_input_embeddings(self):
  609. return self.model.get_input_embeddings()
  610. def set_input_embeddings(self, value):
  611. self.model.set_input_embeddings(value)
  612. def get_output_embeddings(self) -> nn.Module:
  613. return self.lm_head
  614. @auto_docstring
  615. def get_image_features(
  616. self,
  617. pixel_values: torch.FloatTensor,
  618. vision_feature_layer: int | list[int] | list[int] | None = None,
  619. vision_feature_select_strategy: str | None = None,
  620. **kwargs: Unpack[TransformersKwargs],
  621. ) -> tuple | BaseModelOutputWithPooling:
  622. return self.model.get_image_features(
  623. pixel_values=pixel_values,
  624. vision_feature_layer=vision_feature_layer,
  625. vision_feature_select_strategy=vision_feature_select_strategy,
  626. **kwargs,
  627. )
  628. @can_return_tuple
  629. @auto_docstring
  630. def forward(
  631. self,
  632. input_ids: torch.LongTensor | None = None,
  633. pixel_values: torch.FloatTensor | None = None,
  634. attention_mask: torch.Tensor | None = None,
  635. position_ids: torch.LongTensor | None = None,
  636. past_key_values: Cache | None = None,
  637. inputs_embeds: torch.FloatTensor | None = None,
  638. vision_feature_layer: int | list[int] | list[int] | None = None,
  639. vision_feature_select_strategy: str | None = None,
  640. labels: torch.LongTensor | None = None,
  641. logits_to_keep: int | torch.Tensor = 0,
  642. image_sizes: torch.Tensor | None = None,
  643. **kwargs: Unpack[TransformersKwargs],
  644. ) -> tuple | InternVLCausalLMOutputWithPast:
  645. r"""
  646. Example:
  647. ```python
  648. >>> import torch
  649. >>> from transformers import AutoProcessor, AutoModelForImageTextToText
  650. >>> torch_device = "cuda"
  651. >>> processor = AutoProcessor.from_pretrained("OpenGVLab/InternVL3-1B-hf")
  652. >>> model = AutoModelForImageTextToText.from_pretrained(
  653. ... "OpenGVLab/InternVL3-1B-hf", dtype=torch.bfloat16, device_map=torch_device
  654. ... )
  655. >>> messages = [
  656. ... {
  657. ... "role": "user",
  658. ... "content": [
  659. ... {
  660. ... "type": "image",
  661. ... "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
  662. ... },
  663. ... {
  664. ... "type": "image",
  665. ... "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
  666. ... },
  667. ... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
  668. ... ],
  669. ... },
  670. ... ]
  671. >>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device)
  672. >>> generate_ids = model.generate(**inputs, max_new_tokens=200)
  673. >>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
  674. The images depict the Statue of Liberty and the Golden Gate Bridge.
  675. ```"""
  676. outputs = self.model(
  677. input_ids=input_ids,
  678. pixel_values=pixel_values,
  679. attention_mask=attention_mask,
  680. position_ids=position_ids,
  681. past_key_values=past_key_values,
  682. inputs_embeds=inputs_embeds,
  683. vision_feature_layer=vision_feature_layer,
  684. vision_feature_select_strategy=vision_feature_select_strategy,
  685. image_sizes=image_sizes,
  686. **kwargs,
  687. )
  688. hidden_states = outputs[0]
  689. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  690. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  691. logits = self.lm_head(hidden_states[:, slice_indices, :])
  692. loss = None
  693. if labels is not None:
  694. loss = self.loss_function(
  695. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  696. )
  697. return InternVLCausalLMOutputWithPast(
  698. loss=loss,
  699. logits=logits,
  700. past_key_values=outputs.past_key_values,
  701. hidden_states=outputs.hidden_states,
  702. attentions=outputs.attentions,
  703. image_hidden_states=outputs.image_hidden_states,
  704. )
  705. def prepare_inputs_for_generation(
  706. self,
  707. input_ids,
  708. past_key_values=None,
  709. inputs_embeds=None,
  710. pixel_values=None,
  711. attention_mask=None,
  712. logits_to_keep=None,
  713. is_first_iteration=False,
  714. **kwargs,
  715. ):
  716. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  717. model_inputs = super().prepare_inputs_for_generation(
  718. input_ids,
  719. past_key_values=past_key_values,
  720. inputs_embeds=inputs_embeds,
  721. attention_mask=attention_mask,
  722. logits_to_keep=logits_to_keep,
  723. is_first_iteration=is_first_iteration,
  724. **kwargs,
  725. )
  726. if is_first_iteration or not kwargs.get("use_cache", True):
  727. # Pixel values are used only in the first iteration if available
  728. # In subsequent iterations, they are already merged with text and cached
  729. # NOTE: first iteration doesn't have to be prefill, it can be the first
  730. # iteration with a question and cached system prompt (continue generate from cache)
  731. model_inputs["pixel_values"] = pixel_values
  732. return model_inputs
  733. __all__ = [
  734. "InternVLVisionPreTrainedModel",
  735. "InternVLVisionModel",
  736. "InternVLPreTrainedModel",
  737. "InternVLModel",
  738. "InternVLForConditionalGeneration",
  739. ]