modeling_smolvlm.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/smolvlm/modular_smolvlm.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_smolvlm.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 the HuggingFace Inc. team. All rights reserved.
  8. # Written by Orr Zohar
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. from dataclasses import dataclass
  23. import torch
  24. from torch import nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationConfig, GenerationMixin
  28. from ...masking_utils import create_bidirectional_mask
  29. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
  35. from ...utils.generic import merge_with_config_defaults
  36. from ...utils.output_capturing import capture_outputs
  37. from ..auto import AutoModel
  38. from .configuration_smolvlm import SmolVLMConfig, SmolVLMVisionConfig
  39. logger = logging.get_logger(__name__)
  40. @auto_docstring
  41. class SmolVLMPreTrainedModel(PreTrainedModel):
  42. config: SmolVLMConfig
  43. base_model_prefix = "model"
  44. input_modalities = ("image", "text")
  45. supports_gradient_checkpointing = True
  46. _no_split_modules = ["SmolVLMVisionAttention", "SmolVLMDecoderLayer"]
  47. _skip_keys_device_placement = "past_key_values"
  48. _supports_flash_attn = True
  49. _supports_sdpa = True
  50. _supports_flex_attn = True
  51. _supports_attention_backend = True
  52. class SmolVLMVisionEmbeddings(nn.Module):
  53. """
  54. This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
  55. resolution.
  56. The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://huggingface.co/papers/2307.06304)
  57. which allows treating images in their native aspect ratio and without the need to resize them to the same
  58. fixed size. In particular, we start from the original pre-trained SigLIP model
  59. (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
  60. """
  61. def __init__(self, config: SmolVLMVisionConfig):
  62. super().__init__()
  63. self.embed_dim = config.hidden_size
  64. self.image_size = config.image_size
  65. self.patch_size = config.patch_size
  66. self.patch_embedding = nn.Conv2d(
  67. in_channels=config.num_channels,
  68. out_channels=self.embed_dim,
  69. kernel_size=self.patch_size,
  70. stride=self.patch_size,
  71. padding="valid",
  72. )
  73. self.num_patches_per_side = self.image_size // self.patch_size
  74. self.num_patches = self.num_patches_per_side**2
  75. self.num_positions = self.num_patches
  76. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  77. def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
  78. batch_size, _, max_im_h, max_im_w = pixel_values.shape
  79. patch_embeds = self.patch_embedding(pixel_values)
  80. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  81. max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
  82. boundaries = torch.arange(
  83. 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
  84. )
  85. position_ids = torch.full(
  86. size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
  87. )
  88. nb_patches_h = patch_attention_mask[:, :, 0].sum(dim=1) # (batch_size,)
  89. nb_patches_w = patch_attention_mask[:, 0, :].sum(dim=1) # (batch_size,)
  90. step_h = 1.0 / nb_patches_h # (batch_size,)
  91. step_w = 1.0 / nb_patches_w # (batch_size,)
  92. max_patches_h = patch_attention_mask.size(1)
  93. max_patches_w = patch_attention_mask.size(2)
  94. h_indices = torch.arange(max_patches_h, device=position_ids.device, dtype=torch.float32)
  95. w_indices = torch.arange(max_patches_w, device=position_ids.device, dtype=torch.float32)
  96. fractional_coords_h = h_indices[None, :] * step_h[:, None]
  97. fractional_coords_w = w_indices[None, :] * step_w[:, None]
  98. fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6))
  99. fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6))
  100. fractional_coords_h = fractional_coords_h.to(pixel_values.dtype)
  101. fractional_coords_w = fractional_coords_w.to(pixel_values.dtype)
  102. bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
  103. bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
  104. pos_ids = bucket_coords_h[:, :, None] * self.num_patches_per_side + bucket_coords_w[:, None, :]
  105. pos_ids = pos_ids.reshape(batch_size, -1)
  106. position_ids[patch_attention_mask.view(batch_size, -1)] = pos_ids[patch_attention_mask.view(batch_size, -1)]
  107. embeddings = embeddings + self.position_embedding(position_ids)
  108. return embeddings
  109. def eager_attention_forward(
  110. module: nn.Module,
  111. query: torch.Tensor,
  112. key: torch.Tensor,
  113. value: torch.Tensor,
  114. attention_mask: torch.Tensor | None,
  115. scaling: float,
  116. dropout: float = 0.0,
  117. **kwargs,
  118. ):
  119. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  120. if attention_mask is not None:
  121. attn_weights = attn_weights + attention_mask
  122. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  123. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  124. attn_output = torch.matmul(attn_weights, value)
  125. attn_output = attn_output.transpose(1, 2).contiguous()
  126. return attn_output, attn_weights
  127. class SmolVLMVisionAttention(nn.Module):
  128. """Multi-headed attention from 'Attention Is All You Need' paper"""
  129. def __init__(self, config):
  130. super().__init__()
  131. self.config = config
  132. self.embed_dim = config.hidden_size
  133. self.num_heads = config.num_attention_heads
  134. self.head_dim = self.embed_dim // self.num_heads
  135. if self.head_dim * self.num_heads != self.embed_dim:
  136. raise ValueError(
  137. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  138. f" {self.num_heads})."
  139. )
  140. self.scale = self.head_dim**-0.5
  141. self.dropout = config.attention_dropout
  142. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  143. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  144. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  145. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  146. # Ignore copy
  147. self.is_causal = False
  148. def forward(
  149. self,
  150. hidden_states: torch.Tensor,
  151. attention_mask: torch.Tensor | None = None,
  152. **kwargs,
  153. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  154. """Input shape: Batch x Time x Channel"""
  155. input_shape = hidden_states.shape[:-1]
  156. hidden_shape = (*input_shape, -1, self.head_dim)
  157. queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  158. keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  159. values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  160. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  161. self.config._attn_implementation, eager_attention_forward
  162. )
  163. attn_output, attn_weights = attention_interface(
  164. self,
  165. queries,
  166. keys,
  167. values,
  168. attention_mask,
  169. is_causal=self.is_causal,
  170. scaling=self.scale,
  171. dropout=0.0 if not self.training else self.dropout,
  172. )
  173. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  174. attn_output = self.out_proj(attn_output)
  175. return attn_output, attn_weights
  176. class SmolVLMVisionMLP(nn.Module):
  177. def __init__(self, config):
  178. super().__init__()
  179. self.config = config
  180. self.activation_fn = ACT2FN[config.hidden_act]
  181. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  182. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  183. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  184. hidden_states = self.fc1(hidden_states)
  185. hidden_states = self.activation_fn(hidden_states)
  186. hidden_states = self.fc2(hidden_states)
  187. return hidden_states
  188. class SmolVLMEncoderLayer(GradientCheckpointingLayer):
  189. def __init__(self, config: SmolVLMVisionConfig):
  190. super().__init__()
  191. self.embed_dim = config.hidden_size
  192. self.self_attn = SmolVLMVisionAttention(config)
  193. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  194. self.mlp = SmolVLMVisionMLP(config)
  195. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  196. @auto_docstring
  197. def forward(
  198. self,
  199. hidden_states: torch.Tensor,
  200. attention_mask: torch.Tensor,
  201. **kwargs: Unpack[TransformersKwargs],
  202. ) -> torch.FloatTensor:
  203. residual = hidden_states
  204. hidden_states = self.layer_norm1(hidden_states)
  205. hidden_states, _ = self.self_attn(
  206. hidden_states=hidden_states,
  207. attention_mask=attention_mask,
  208. **kwargs,
  209. )
  210. hidden_states = residual + hidden_states
  211. residual = hidden_states
  212. hidden_states = self.layer_norm2(hidden_states)
  213. hidden_states = self.mlp(hidden_states)
  214. hidden_states = residual + hidden_states
  215. return hidden_states
  216. class SmolVLMEncoder(nn.Module):
  217. """
  218. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  219. [`SmolVLMEncoderLayer`].
  220. Args:
  221. config: SmolVLMConfig
  222. """
  223. def __init__(self, config: SmolVLMConfig):
  224. super().__init__()
  225. self.config = config
  226. self.layers = nn.ModuleList([SmolVLMEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  227. self.gradient_checkpointing = False
  228. # Ignore copy
  229. @auto_docstring
  230. def forward(
  231. self,
  232. inputs_embeds,
  233. attention_mask: torch.Tensor | None = None,
  234. ) -> tuple | BaseModelOutput:
  235. hidden_states = inputs_embeds
  236. for encoder_layer in self.layers:
  237. layer_outputs = encoder_layer(
  238. hidden_states,
  239. attention_mask,
  240. )
  241. hidden_states = layer_outputs
  242. return BaseModelOutput(last_hidden_state=hidden_states)
  243. @auto_docstring(
  244. custom_intro="""
  245. The SmolVLM Vision Transformer Model outputting raw image embedding.
  246. """
  247. )
  248. class SmolVLMVisionTransformer(SmolVLMPreTrainedModel):
  249. config: SmolVLMVisionConfig
  250. input_modalities = ("image",)
  251. _can_record_outputs = {
  252. "hidden_states": SmolVLMEncoderLayer,
  253. "attentions": SmolVLMVisionAttention,
  254. }
  255. def __init__(self, config: SmolVLMVisionConfig):
  256. super().__init__(config)
  257. embed_dim = config.hidden_size
  258. self.embeddings = SmolVLMVisionEmbeddings(config)
  259. self.encoder = SmolVLMEncoder(config)
  260. self.patch_size = config.patch_size
  261. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  262. self.post_init()
  263. def get_input_embeddings(self):
  264. return self.embeddings
  265. def set_input_embeddings(self, value):
  266. self.embeddings = value
  267. @merge_with_config_defaults
  268. @capture_outputs(tie_last_hidden_states=False)
  269. def forward(
  270. self,
  271. pixel_values,
  272. patch_attention_mask: torch.BoolTensor | None = None,
  273. **kwargs: Unpack[TransformersKwargs],
  274. ) -> tuple | BaseModelOutput:
  275. batch_size = pixel_values.size(0)
  276. if patch_attention_mask is None:
  277. patch_size = self.patch_size
  278. patch_attention_mask = torch.ones(
  279. (
  280. batch_size,
  281. pixel_values.size(2) // patch_size,
  282. pixel_values.size(3) // patch_size,
  283. )
  284. )
  285. patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
  286. hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
  287. patch_attention_mask = patch_attention_mask.view(batch_size, -1)
  288. # Create the correct attention mask based on the attention implementation
  289. patch_attention_mask = create_bidirectional_mask(
  290. config=self.config,
  291. inputs_embeds=hidden_states,
  292. attention_mask=patch_attention_mask,
  293. )
  294. encoder_outputs: BaseModelOutput = self.encoder(
  295. inputs_embeds=hidden_states,
  296. attention_mask=patch_attention_mask,
  297. )
  298. last_hidden_state = encoder_outputs.last_hidden_state
  299. last_hidden_state = self.post_layernorm(last_hidden_state)
  300. return BaseModelOutput(
  301. last_hidden_state=last_hidden_state,
  302. )
  303. @dataclass
  304. @auto_docstring(
  305. custom_intro="""
  306. Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding).
  307. """
  308. )
  309. class SmolVLMBaseModelOutputWithPast(ModelOutput):
  310. r"""
  311. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  312. Sequence of hidden-states at the output of the last layer of the model.
  313. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  314. hidden_size)` is output.
  315. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  316. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  317. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  318. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  319. input) to speed up sequential decoding.
  320. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  321. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  322. sequence_length, hidden_size)`.
  323. image_hidden_states of the model produced by the vision encoder
  324. """
  325. last_hidden_state: torch.FloatTensor | None = None
  326. past_key_values: Cache | None = None
  327. hidden_states: tuple[torch.FloatTensor] | None = None
  328. attentions: tuple[torch.FloatTensor] | None = None
  329. image_hidden_states: tuple[torch.FloatTensor] | None = None
  330. class SmolVLMSimpleMLP(nn.Module):
  331. def __init__(self, config):
  332. super().__init__()
  333. input_size = config.vision_config.hidden_size * (config.scale_factor**2)
  334. output_size = config.text_config.hidden_size
  335. self.proj = nn.Linear(input_size, output_size, bias=False)
  336. def forward(self, x):
  337. return self.proj(x)
  338. class SmolVLMConnector(nn.Module):
  339. def __init__(self, config):
  340. super().__init__()
  341. self.scale_factor = config.scale_factor
  342. self.modality_projection = SmolVLMSimpleMLP(config)
  343. def pixel_shuffle(self, x, scale_factor=2):
  344. bsz, seq, embed_dim = x.size()
  345. height = width = int(seq**0.5)
  346. x = x.view(bsz, height, width, embed_dim)
  347. x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
  348. x = x.permute(0, 2, 1, 3)
  349. x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
  350. x = x.permute(0, 2, 1, 3)
  351. x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
  352. return x
  353. def forward(self, image_hidden_states):
  354. image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
  355. image_hidden_states = self.modality_projection(image_hidden_states)
  356. return image_hidden_states
  357. @auto_docstring(
  358. custom_intro="""
  359. SmolVLM model consisting of a SIGLIP vision encoder and Llama3 language decoder
  360. """
  361. )
  362. class SmolVLMModel(SmolVLMPreTrainedModel):
  363. """
  364. A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger
  365. in forward. Instead, we override inputs_merger here with custom logic.
  366. """
  367. def __init__(self, config: SmolVLMConfig):
  368. super().__init__(config)
  369. self.padding_idx = self.config.text_config.pad_token_id
  370. self.vocab_size = self.config.text_config.vocab_size
  371. self.vision_model = SmolVLMVisionTransformer._from_config(config.vision_config)
  372. self.connector = SmolVLMConnector(config)
  373. self.text_model = AutoModel.from_config(config.text_config)
  374. self.image_seq_len = int(
  375. ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
  376. )
  377. self.image_token_id = self.config.image_token_id
  378. self.post_init()
  379. def get_input_embeddings(self):
  380. return self.text_model.get_input_embeddings()
  381. def set_input_embeddings(self, value):
  382. self.text_model.set_input_embeddings(value)
  383. def inputs_merger(
  384. self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor
  385. ):
  386. """
  387. This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
  388. The merging happens as follows:
  389. - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
  390. - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space.
  391. We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
  392. - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
  393. - To fit the format of that sequence, `input_ids`, `inputs_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
  394. """
  395. _, patch_size, _ = image_hidden_states.shape
  396. if input_ids is None:
  397. image_mask = inputs_embeds == self.get_input_embeddings()(
  398. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  399. )
  400. image_mask = image_mask[..., 0] # slice off the hidden dim
  401. else:
  402. image_mask = input_ids == self.config.image_token_id
  403. num_image_tokens = image_mask.sum(dim=1)
  404. torch_compilable_check(
  405. torch.all(num_image_tokens % patch_size == 0),
  406. "At least one sample has <image> tokens not divisible by patch_size.",
  407. )
  408. blocks_per_sample = num_image_tokens // patch_size
  409. offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0)
  410. block_offset = offsets[:-1]
  411. row_cum = image_mask.cumsum(dim=-1)
  412. chunk_idx = (row_cum - 1) // patch_size
  413. local_idx = (row_cum - 1) % patch_size
  414. block_idx = block_offset.unsqueeze(1) + chunk_idx
  415. image_embeds = torch.zeros_like(inputs_embeds)
  416. image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :]
  417. merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
  418. return merged_embeds
  419. @can_return_tuple
  420. @auto_docstring(
  421. custom_intro="Encodes images into continuous embeddings that can be forwarded to the language model."
  422. )
  423. def get_image_features(
  424. self,
  425. pixel_values: torch.FloatTensor,
  426. pixel_attention_mask: torch.LongTensor | None = None,
  427. **kwargs: Unpack[TransformersKwargs],
  428. ) -> tuple | BaseModelOutputWithPooling:
  429. r"""
  430. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  431. The tensors corresponding to the input images.
  432. pixel_attention_mask (`torch.LongTensor`, *optional*):
  433. The attention mask indicating padded regions in the image.
  434. """
  435. batch_size, num_images, num_channels, height, width = pixel_values.shape
  436. pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
  437. pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
  438. # Remove padding images - padding images are full 0.
  439. nb_values_per_image = pixel_values.shape[1:].numel()
  440. real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
  441. # If no images, leave one empty image.
  442. real_images_inds[0] |= ~torch.any(real_images_inds)
  443. pixel_values = pixel_values[real_images_inds].contiguous()
  444. # Handle the vision attention mask
  445. if pixel_attention_mask is None:
  446. pixel_attention_mask = torch.ones(
  447. size=[pixel_values.shape[i] for i in (0, 2, 3)],
  448. dtype=torch.bool,
  449. device=pixel_values.device,
  450. )
  451. else:
  452. # Remove padding images from the mask
  453. pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
  454. pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
  455. patch_size = self.config.vision_config.patch_size
  456. patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
  457. patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
  458. patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
  459. # Get sequence from the vision encoder
  460. image_outputs = self.vision_model(
  461. pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, return_dict=True, **kwargs
  462. )
  463. image_hidden_states = image_outputs.last_hidden_state
  464. # Modality projection & resampling
  465. image_features = self.connector(image_hidden_states)
  466. image_outputs.pooler_output = image_features
  467. return image_outputs
  468. @can_return_tuple
  469. @auto_docstring(
  470. custom_intro="""
  471. Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
  472. the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
  473. max_num_images is the maximum number of images among the batch_size samples in the batch.
  474. Padding images are not needed beyond padding the pixel_values at the entrance of the model.
  475. For efficiency, we only pass through the vision_model's forward the real images by
  476. discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
  477. image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
  478. """
  479. )
  480. @can_return_tuple
  481. def forward(
  482. self,
  483. input_ids: torch.LongTensor | None = None,
  484. attention_mask: torch.Tensor | None = None,
  485. position_ids: torch.LongTensor | None = None,
  486. past_key_values: Cache | None = None,
  487. inputs_embeds: torch.FloatTensor | None = None,
  488. pixel_values: torch.FloatTensor | None = None,
  489. pixel_attention_mask: torch.BoolTensor | None = None,
  490. image_hidden_states: torch.FloatTensor | None = None,
  491. use_cache: bool | None = None,
  492. **kwargs: Unpack[FlashAttentionKwargs],
  493. ) -> tuple | SmolVLMBaseModelOutputWithPast:
  494. r"""
  495. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  496. Mask to avoid performing attention on padding pixel indices.
  497. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  498. The hidden states of the image encoder after modality projection.
  499. """
  500. if self.training and self.text_model.gradient_checkpointing and use_cache:
  501. logger.warning_once(
  502. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  503. )
  504. use_cache = False
  505. if input_ids is not None:
  506. batch_size, seq_length = input_ids.shape
  507. elif inputs_embeds is not None:
  508. batch_size, seq_length, _ = inputs_embeds.shape
  509. else:
  510. raise ValueError("You have to specify either input_ids or inputs_embeds")
  511. if use_cache and past_key_values is None:
  512. past_key_values = DynamicCache(config=self.config)
  513. if inputs_embeds is None:
  514. inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
  515. if pixel_values is not None and image_hidden_states is not None:
  516. raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
  517. if pixel_values is not None:
  518. image_hidden_states = self.get_image_features(
  519. pixel_values, pixel_attention_mask, return_dict=True
  520. ).pooler_output
  521. image_hidden_states = image_hidden_states.to(inputs_embeds.device)
  522. elif image_hidden_states is not None:
  523. image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=inputs_embeds.device)
  524. if image_hidden_states is not None:
  525. inputs_embeds = self.inputs_merger(
  526. input_ids=input_ids,
  527. inputs_embeds=inputs_embeds,
  528. image_hidden_states=image_hidden_states,
  529. )
  530. outputs = self.text_model(
  531. inputs_embeds=inputs_embeds,
  532. attention_mask=attention_mask,
  533. position_ids=position_ids,
  534. past_key_values=past_key_values,
  535. use_cache=use_cache,
  536. **kwargs,
  537. )
  538. return SmolVLMBaseModelOutputWithPast(
  539. last_hidden_state=outputs.last_hidden_state,
  540. past_key_values=outputs.past_key_values,
  541. hidden_states=outputs.hidden_states,
  542. attentions=outputs.attentions,
  543. image_hidden_states=image_hidden_states,
  544. )
  545. @dataclass
  546. @auto_docstring(
  547. custom_intro="""
  548. Base class for Idefics causal language model (or autoregressive) outputs.
  549. """
  550. )
  551. class SmolVLMCausalLMOutputWithPast(ModelOutput):
  552. r"""
  553. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  554. Language modeling loss (for next-token prediction).
  555. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  556. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  557. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  558. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  559. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  560. `past_key_values` input) to speed up sequential decoding.
  561. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  562. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  563. sequence_length, hidden_size)`.
  564. image_hidden_states of the model produced by the vision encoder
  565. """
  566. loss: torch.FloatTensor | None = None
  567. logits: torch.FloatTensor | None = None
  568. past_key_values: Cache | None = None
  569. hidden_states: tuple[torch.FloatTensor] | None = None
  570. attentions: tuple[torch.FloatTensor] | None = None
  571. image_hidden_states: tuple[torch.FloatTensor] | None = None
  572. @auto_docstring(
  573. custom_intro="""
  574. The SmolVLM Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top.
  575. """
  576. )
  577. class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
  578. _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"}
  579. def __init__(self, config):
  580. super().__init__(config)
  581. self.model = SmolVLMModel(config)
  582. self.image_token_id = self.config.image_token_id
  583. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  584. self.vocab_size = config.text_config.vocab_size
  585. self.model.text_model.generation_config = GenerationConfig.from_model_config(config)
  586. # Initialize weights and apply final processing
  587. self.post_init()
  588. def get_input_embeddings(self):
  589. return self.model.text_model.get_input_embeddings()
  590. def set_input_embeddings(self, value):
  591. self.model.text_model.set_input_embeddings(value)
  592. @auto_docstring
  593. def get_image_features(
  594. self,
  595. pixel_values: torch.FloatTensor,
  596. pixel_attention_mask: torch.LongTensor | None = None,
  597. **kwargs: Unpack[TransformersKwargs],
  598. ) -> tuple | BaseModelOutputWithPooling:
  599. r"""
  600. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  601. The tensors corresponding to the input images.
  602. pixel_attention_mask (`torch.LongTensor`, *optional*):
  603. The attention mask indicating padded regions in the image.
  604. """
  605. return self.model.get_image_features(
  606. pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask, **kwargs
  607. )
  608. @can_return_tuple
  609. @auto_docstring
  610. def forward(
  611. self,
  612. input_ids: torch.LongTensor | None = None,
  613. attention_mask: torch.Tensor | None = None,
  614. position_ids: torch.LongTensor | None = None,
  615. past_key_values: Cache | None = None,
  616. inputs_embeds: torch.FloatTensor | None = None,
  617. pixel_values: torch.FloatTensor | None = None,
  618. pixel_attention_mask: torch.BoolTensor | None = None,
  619. image_hidden_states: torch.FloatTensor | None = None,
  620. labels: torch.LongTensor | None = None,
  621. use_cache: bool | None = None,
  622. logits_to_keep: int | torch.Tensor = 0,
  623. **kwargs: Unpack[TransformersKwargs],
  624. ) -> tuple | SmolVLMCausalLMOutputWithPast:
  625. r"""
  626. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  627. Mask to avoid performing attention on padding pixel indices.
  628. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  629. The hidden states of the image encoder after modality projection.
  630. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  631. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  632. config.vocab_size]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
  633. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  634. Example:
  635. ```python
  636. >>> import httpx
  637. >>> from io import BytesIO
  638. >>> import torch
  639. >>> from PIL import Image
  640. >>> from io import BytesIO
  641. >>> from transformers import AutoProcessor, AutoModelForImageTextToText
  642. >>> from transformers.image_utils import load_image
  643. >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
  644. >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
  645. >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
  646. >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
  647. >>> processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
  648. >>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", dtype=torch.bfloat16, device_map="auto")
  649. >>> # Create inputs
  650. >>> messages = [
  651. ... {
  652. ... "role": "user",
  653. ... "content": [
  654. ... {"type": "video", "path": path/to/video},
  655. ... {"type": "text", "text": "What is happening in this video?"},
  656. ... ]
  657. ... }
  658. ... ]
  659. >>> inputs = processor.apply_chat_template([messages], add_generation_prompt=True)
  660. >>> # Generate
  661. >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
  662. >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
  663. >>> print(generated_texts)
  664. ```"""
  665. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  666. outputs = self.model(
  667. input_ids=input_ids,
  668. attention_mask=attention_mask,
  669. position_ids=position_ids,
  670. past_key_values=past_key_values,
  671. inputs_embeds=inputs_embeds,
  672. pixel_values=pixel_values,
  673. pixel_attention_mask=pixel_attention_mask,
  674. image_hidden_states=image_hidden_states,
  675. use_cache=use_cache,
  676. return_dict=True,
  677. **kwargs,
  678. )
  679. hidden_states = outputs[0]
  680. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  681. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  682. logits = self.lm_head(hidden_states[:, slice_indices, :])
  683. loss = None
  684. if labels is not None:
  685. loss = self.loss_function(
  686. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  687. )
  688. return SmolVLMCausalLMOutputWithPast(
  689. loss=loss,
  690. logits=logits,
  691. past_key_values=outputs.past_key_values,
  692. hidden_states=outputs.hidden_states,
  693. attentions=outputs.attentions,
  694. image_hidden_states=outputs.image_hidden_states,
  695. )
  696. def prepare_inputs_for_generation(
  697. self,
  698. input_ids,
  699. past_key_values=None,
  700. attention_mask=None,
  701. inputs_embeds=None,
  702. pixel_values=None,
  703. pixel_attention_mask=None,
  704. image_hidden_states=None,
  705. logits_to_keep=None,
  706. is_first_iteration=False,
  707. use_cache=False,
  708. **kwargs,
  709. ):
  710. # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
  711. # precedence is moved to the model, we can remove this fn)
  712. model_inputs = super().prepare_inputs_for_generation(
  713. input_ids,
  714. past_key_values=past_key_values,
  715. attention_mask=attention_mask,
  716. inputs_embeds=inputs_embeds,
  717. pixel_values=pixel_values,
  718. pixel_attention_mask=pixel_attention_mask,
  719. image_hidden_states=image_hidden_states,
  720. logits_to_keep=logits_to_keep,
  721. is_first_iteration=is_first_iteration,
  722. use_cache=use_cache,
  723. **kwargs,
  724. )
  725. if image_hidden_states is not None or (use_cache and not is_first_iteration):
  726. model_inputs["pixel_values"] = None
  727. model_inputs["pixel_attention_mask"] = None
  728. return model_inputs
  729. __all__ = ["SmolVLMForConditionalGeneration", "SmolVLMPreTrainedModel", "SmolVLMModel", "SmolVLMVisionTransformer"]