modeling_idefics3.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907
  1. # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Idefics3 model."""
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from ...activations import ACT2FN
  20. from ...cache_utils import Cache, DynamicCache
  21. from ...generation import GenerationMixin
  22. from ...masking_utils import create_bidirectional_mask
  23. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
  26. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  27. from ...processing_utils import Unpack
  28. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  29. from ...utils.generic import merge_with_config_defaults
  30. from ...utils.output_capturing import capture_outputs
  31. from ..auto import AutoModel
  32. from .configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
  33. logger = logging.get_logger(__name__)
  34. @dataclass
  35. @auto_docstring(
  36. custom_intro="""
  37. Base class for Idefics3 model's outputs that may also contain a past key/values (to speed up sequential decoding).
  38. """
  39. )
  40. class Idefics3BaseModelOutputWithPast(ModelOutput):
  41. r"""
  42. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  43. Sequence of hidden-states at the output of the last layer of the model.
  44. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  45. hidden_size)` is output.
  46. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  47. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  48. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  49. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  50. input) to speed up sequential decoding.
  51. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  52. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  53. sequence_length, hidden_size)`.
  54. image_hidden_states of the model produced by the vision encoder
  55. """
  56. last_hidden_state: torch.FloatTensor | None = None
  57. past_key_values: Cache | None = None
  58. hidden_states: tuple[torch.FloatTensor] | None = None
  59. attentions: tuple[torch.FloatTensor] | None = None
  60. image_hidden_states: tuple[torch.FloatTensor] | None = None
  61. @dataclass
  62. @auto_docstring(
  63. custom_intro="""
  64. Base class for Idefics causal language model (or autoregressive) outputs.
  65. """
  66. )
  67. class Idefics3CausalLMOutputWithPast(ModelOutput):
  68. r"""
  69. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  70. Language modeling loss (for next-token prediction).
  71. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  72. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  73. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  74. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  75. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  76. `past_key_values` input) to speed up sequential decoding.
  77. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  78. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  79. sequence_length, hidden_size)`.
  80. image_hidden_states of the model produced by the vision encoder
  81. """
  82. loss: torch.FloatTensor | None = None
  83. logits: torch.FloatTensor | None = None
  84. past_key_values: Cache | None = None
  85. hidden_states: tuple[torch.FloatTensor] | None = None
  86. attentions: tuple[torch.FloatTensor] | None = None
  87. image_hidden_states: tuple[torch.FloatTensor] | None = None
  88. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionEmbeddings with Idefics2->Idefics3
  89. class Idefics3VisionEmbeddings(nn.Module):
  90. """
  91. This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
  92. resolution.
  93. The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://huggingface.co/papers/2307.06304)
  94. which allows treating images in their native aspect ratio and without the need to resize them to the same
  95. fixed size. In particular, we start from the original pre-trained SigLIP model
  96. (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
  97. """
  98. def __init__(self, config: Idefics3VisionConfig):
  99. super().__init__()
  100. self.embed_dim = config.hidden_size
  101. self.image_size = config.image_size
  102. self.patch_size = config.patch_size
  103. self.patch_embedding = nn.Conv2d(
  104. in_channels=config.num_channels,
  105. out_channels=self.embed_dim,
  106. kernel_size=self.patch_size,
  107. stride=self.patch_size,
  108. padding="valid",
  109. )
  110. self.num_patches_per_side = self.image_size // self.patch_size
  111. self.num_patches = self.num_patches_per_side**2
  112. self.num_positions = self.num_patches
  113. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  114. def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
  115. batch_size, _, max_im_h, max_im_w = pixel_values.shape
  116. patch_embeds = self.patch_embedding(pixel_values)
  117. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  118. max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
  119. boundaries = torch.arange(
  120. 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
  121. )
  122. position_ids = torch.full(
  123. size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
  124. )
  125. nb_patches_h = patch_attention_mask[:, :, 0].sum(dim=1) # (batch_size,)
  126. nb_patches_w = patch_attention_mask[:, 0, :].sum(dim=1) # (batch_size,)
  127. step_h = 1.0 / nb_patches_h # (batch_size,)
  128. step_w = 1.0 / nb_patches_w # (batch_size,)
  129. max_patches_h = patch_attention_mask.size(1)
  130. max_patches_w = patch_attention_mask.size(2)
  131. h_indices = torch.arange(max_patches_h, device=position_ids.device, dtype=torch.float32)
  132. w_indices = torch.arange(max_patches_w, device=position_ids.device, dtype=torch.float32)
  133. fractional_coords_h = h_indices[None, :] * step_h[:, None]
  134. fractional_coords_w = w_indices[None, :] * step_w[:, None]
  135. fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6))
  136. fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6))
  137. fractional_coords_h = fractional_coords_h.to(pixel_values.dtype)
  138. fractional_coords_w = fractional_coords_w.to(pixel_values.dtype)
  139. bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
  140. bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
  141. pos_ids = bucket_coords_h[:, :, None] * self.num_patches_per_side + bucket_coords_w[:, None, :]
  142. pos_ids = pos_ids.reshape(batch_size, -1)
  143. position_ids[patch_attention_mask.view(batch_size, -1)] = pos_ids[patch_attention_mask.view(batch_size, -1)]
  144. embeddings = embeddings + self.position_embedding(position_ids)
  145. return embeddings
  146. # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
  147. def eager_attention_forward(
  148. module: nn.Module,
  149. query: torch.Tensor,
  150. key: torch.Tensor,
  151. value: torch.Tensor,
  152. attention_mask: torch.Tensor | None,
  153. scaling: float,
  154. dropout: float = 0.0,
  155. **kwargs,
  156. ):
  157. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  158. if attention_mask is not None:
  159. attn_weights = attn_weights + attention_mask
  160. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  161. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  162. attn_output = torch.matmul(attn_weights, value)
  163. attn_output = attn_output.transpose(1, 2).contiguous()
  164. return attn_output, attn_weights
  165. # Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics3Vision
  166. class Idefics3VisionAttention(nn.Module):
  167. """Multi-headed attention from 'Attention Is All You Need' paper"""
  168. # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
  169. def __init__(self, config):
  170. super().__init__()
  171. self.config = config
  172. self.embed_dim = config.hidden_size
  173. self.num_heads = config.num_attention_heads
  174. self.head_dim = self.embed_dim // self.num_heads
  175. if self.head_dim * self.num_heads != self.embed_dim:
  176. raise ValueError(
  177. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  178. f" {self.num_heads})."
  179. )
  180. self.scale = self.head_dim**-0.5
  181. self.dropout = config.attention_dropout
  182. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  183. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  184. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  185. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  186. # Ignore copy
  187. self.is_causal = False
  188. def forward(
  189. self,
  190. hidden_states: torch.Tensor,
  191. attention_mask: torch.Tensor | None = None,
  192. **kwargs,
  193. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  194. """Input shape: Batch x Time x Channel"""
  195. input_shape = hidden_states.shape[:-1]
  196. hidden_shape = (*input_shape, -1, self.head_dim)
  197. queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  198. keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  199. values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  200. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  201. self.config._attn_implementation, eager_attention_forward
  202. )
  203. attn_output, attn_weights = attention_interface(
  204. self,
  205. queries,
  206. keys,
  207. values,
  208. attention_mask,
  209. is_causal=self.is_causal,
  210. scaling=self.scale,
  211. dropout=0.0 if not self.training else self.dropout,
  212. )
  213. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  214. attn_output = self.out_proj(attn_output)
  215. return attn_output, attn_weights
  216. # Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics3Vision
  217. class Idefics3VisionMLP(nn.Module):
  218. def __init__(self, config):
  219. super().__init__()
  220. self.config = config
  221. self.activation_fn = ACT2FN[config.hidden_act]
  222. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  223. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  224. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  225. hidden_states = self.fc1(hidden_states)
  226. hidden_states = self.activation_fn(hidden_states)
  227. hidden_states = self.fc2(hidden_states)
  228. return hidden_states
  229. class Idefics3SimpleMLP(nn.Module):
  230. def __init__(self, config):
  231. super().__init__()
  232. input_size = config.vision_config.hidden_size * (config.scale_factor**2)
  233. output_size = config.text_config.hidden_size
  234. self.proj = nn.Linear(input_size, output_size, bias=False)
  235. def forward(self, x):
  236. return self.proj(x)
  237. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2EncoderLayer with Idefics2->Idefics3
  238. class Idefics3EncoderLayer(GradientCheckpointingLayer):
  239. def __init__(self, config: Idefics3VisionConfig):
  240. super().__init__()
  241. self.embed_dim = config.hidden_size
  242. self.self_attn = Idefics3VisionAttention(config)
  243. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  244. self.mlp = Idefics3VisionMLP(config)
  245. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  246. @auto_docstring
  247. # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
  248. def forward(
  249. self,
  250. hidden_states: torch.Tensor,
  251. attention_mask: torch.Tensor,
  252. **kwargs: Unpack[TransformersKwargs],
  253. ) -> torch.FloatTensor:
  254. residual = hidden_states
  255. hidden_states = self.layer_norm1(hidden_states)
  256. hidden_states, _ = self.self_attn(
  257. hidden_states=hidden_states,
  258. attention_mask=attention_mask,
  259. **kwargs,
  260. )
  261. hidden_states = residual + hidden_states
  262. residual = hidden_states
  263. hidden_states = self.layer_norm2(hidden_states)
  264. hidden_states = self.mlp(hidden_states)
  265. hidden_states = residual + hidden_states
  266. return hidden_states
  267. # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics3
  268. class Idefics3Encoder(nn.Module):
  269. """
  270. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  271. [`Idefics3EncoderLayer`].
  272. Args:
  273. config: Idefics3Config
  274. """
  275. def __init__(self, config: Idefics3Config):
  276. super().__init__()
  277. self.config = config
  278. self.layers = nn.ModuleList([Idefics3EncoderLayer(config) for _ in range(config.num_hidden_layers)])
  279. self.gradient_checkpointing = False
  280. # Ignore copy
  281. @auto_docstring
  282. def forward(
  283. self,
  284. inputs_embeds,
  285. attention_mask: torch.Tensor | None = None,
  286. ) -> tuple | BaseModelOutput:
  287. hidden_states = inputs_embeds
  288. for encoder_layer in self.layers:
  289. layer_outputs = encoder_layer(
  290. hidden_states,
  291. attention_mask,
  292. )
  293. hidden_states = layer_outputs
  294. return BaseModelOutput(last_hidden_state=hidden_states)
  295. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  296. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  297. """
  298. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  299. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  300. """
  301. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  302. if n_rep == 1:
  303. return hidden_states
  304. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  305. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  306. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics3
  307. class Idefics3RMSNorm(nn.Module):
  308. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  309. """
  310. Idefics3RMSNorm is equivalent to T5LayerNorm
  311. """
  312. super().__init__()
  313. self.weight = nn.Parameter(torch.ones(hidden_size))
  314. self.variance_epsilon = eps
  315. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  316. input_dtype = hidden_states.dtype
  317. hidden_states = hidden_states.to(torch.float32)
  318. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  319. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  320. return self.weight * hidden_states.to(input_dtype)
  321. def extra_repr(self):
  322. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  323. class Idefics3Connector(nn.Module):
  324. def __init__(self, config):
  325. super().__init__()
  326. self.scale_factor = config.scale_factor
  327. self.modality_projection = Idefics3SimpleMLP(config)
  328. def pixel_shuffle(self, x, scale_factor=2):
  329. bsz, seq, embed_dim = x.size()
  330. height = width = int(seq**0.5)
  331. x = x.view(bsz, height, width, embed_dim)
  332. x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
  333. x = x.permute(0, 2, 1, 3)
  334. x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
  335. x = x.permute(0, 2, 1, 3)
  336. x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
  337. return x
  338. def forward(self, image_hidden_states):
  339. image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
  340. image_hidden_states = self.modality_projection(image_hidden_states)
  341. return image_hidden_states
  342. @auto_docstring
  343. class Idefics3PreTrainedModel(PreTrainedModel):
  344. config: Idefics3Config
  345. base_model_prefix = "model"
  346. input_modalities = ("image", "text")
  347. supports_gradient_checkpointing = True
  348. _no_split_modules = ["Idefics3VisionAttention", "Idefics3DecoderLayer"]
  349. _skip_keys_device_placement = "past_key_values"
  350. _supports_flash_attn = True
  351. _supports_sdpa = True
  352. _supports_flex_attn = True
  353. _supports_attention_backend = True
  354. @auto_docstring(
  355. custom_intro="""
  356. The Idefics3 Vision Transformer Model outputting raw image embedding.
  357. """
  358. )
  359. class Idefics3VisionTransformer(Idefics3PreTrainedModel):
  360. config: Idefics3VisionConfig
  361. input_modalities = ("image",)
  362. _can_record_outputs = {
  363. "hidden_states": Idefics3EncoderLayer,
  364. "attentions": Idefics3VisionAttention,
  365. }
  366. def __init__(self, config: Idefics3VisionConfig):
  367. super().__init__(config)
  368. embed_dim = config.hidden_size
  369. self.embeddings = Idefics3VisionEmbeddings(config)
  370. self.encoder = Idefics3Encoder(config)
  371. self.patch_size = config.patch_size
  372. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  373. self.post_init()
  374. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.get_input_embeddings
  375. def get_input_embeddings(self):
  376. return self.embeddings
  377. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.set_input_embeddings
  378. def set_input_embeddings(self, value):
  379. self.embeddings = value
  380. @merge_with_config_defaults
  381. @capture_outputs(tie_last_hidden_states=False)
  382. def forward(
  383. self,
  384. pixel_values,
  385. patch_attention_mask: torch.BoolTensor | None = None,
  386. **kwargs: Unpack[TransformersKwargs],
  387. ) -> tuple | BaseModelOutput:
  388. batch_size = pixel_values.size(0)
  389. if patch_attention_mask is None:
  390. patch_size = self.patch_size
  391. patch_attention_mask = torch.ones(
  392. (
  393. batch_size,
  394. pixel_values.size(2) // patch_size,
  395. pixel_values.size(3) // patch_size,
  396. )
  397. )
  398. patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
  399. hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
  400. patch_attention_mask = patch_attention_mask.view(batch_size, -1)
  401. # Create the correct attention mask based on the attention implementation
  402. patch_attention_mask = create_bidirectional_mask(
  403. config=self.config,
  404. inputs_embeds=hidden_states,
  405. attention_mask=patch_attention_mask,
  406. )
  407. encoder_outputs: BaseModelOutput = self.encoder(
  408. inputs_embeds=hidden_states,
  409. attention_mask=patch_attention_mask,
  410. )
  411. last_hidden_state = encoder_outputs.last_hidden_state
  412. last_hidden_state = self.post_layernorm(last_hidden_state)
  413. return BaseModelOutput(
  414. last_hidden_state=last_hidden_state,
  415. )
  416. @auto_docstring(
  417. custom_intro="""
  418. Idefics3 model consisting of a SIGLIP vision encoder and Llama3 language decoder
  419. """
  420. )
  421. class Idefics3Model(Idefics3PreTrainedModel):
  422. def __init__(self, config: Idefics3Config):
  423. super().__init__(config)
  424. self.padding_idx = self.config.text_config.pad_token_id
  425. self.vocab_size = self.config.text_config.vocab_size
  426. self.vision_model = Idefics3VisionTransformer._from_config(config.vision_config)
  427. self.connector = Idefics3Connector(config)
  428. self.text_model = AutoModel.from_config(config.text_config)
  429. self.image_seq_len = int(
  430. ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
  431. )
  432. self.image_token_id = self.config.image_token_id
  433. self.post_init()
  434. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.get_input_embeddings
  435. def get_input_embeddings(self):
  436. return self.text_model.get_input_embeddings()
  437. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.set_input_embeddings
  438. def set_input_embeddings(self, value):
  439. self.text_model.set_input_embeddings(value)
  440. def inputs_merger(
  441. self,
  442. input_ids: torch.LongTensor,
  443. inputs_embeds: torch.Tensor | None,
  444. image_hidden_states: torch.Tensor | None,
  445. ):
  446. """
  447. This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
  448. The merging happens as follows:
  449. - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
  450. - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space.
  451. We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
  452. - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
  453. - To fit the format of that sequence, `input_ids`, `inputs_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
  454. """
  455. if input_ids is None:
  456. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  457. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  458. )
  459. special_image_mask = special_image_mask.all(-1)
  460. else:
  461. special_image_mask = input_ids == self.config.image_token_id
  462. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  463. image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
  464. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
  465. return inputs_embeds
  466. @can_return_tuple
  467. @auto_docstring
  468. def get_image_features(
  469. self,
  470. pixel_values: torch.FloatTensor,
  471. pixel_attention_mask: torch.LongTensor | None = None,
  472. **kwargs: Unpack[TransformersKwargs],
  473. ) -> tuple | BaseModelOutputWithPooling:
  474. r"""
  475. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  476. The tensors corresponding to the input images.
  477. pixel_attention_mask (`torch.LongTensor`, *optional*):
  478. The attention mask indicating padded regions in the image.
  479. """
  480. batch_size, num_images, num_channels, height, width = pixel_values.shape
  481. pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
  482. pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
  483. # Remove padding images - padding images are full 0.
  484. nb_values_per_image = pixel_values.shape[1:].numel()
  485. real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
  486. pixel_values = pixel_values[real_images_inds].contiguous()
  487. # Handle the vision attention mask
  488. if pixel_attention_mask is None:
  489. pixel_attention_mask = torch.ones(
  490. size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
  491. dtype=torch.bool,
  492. device=pixel_values.device,
  493. )
  494. else:
  495. # Remove padding images from the mask
  496. pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
  497. pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
  498. patch_size = self.config.vision_config.patch_size
  499. patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
  500. patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
  501. patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
  502. # Get sequence from the vision encoder
  503. image_outputs = self.vision_model(
  504. pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, return_dict=True, **kwargs
  505. )
  506. image_hidden_states = image_outputs.last_hidden_state
  507. # Modality projection & resampling
  508. image_features = self.connector(image_hidden_states)
  509. image_outputs.pooler_output = image_features
  510. return image_outputs
  511. @can_return_tuple
  512. @auto_docstring(
  513. custom_intro="""
  514. Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
  515. the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
  516. max_num_images is the maximum number of images among the batch_size samples in the batch.
  517. Padding images are not needed beyond padding the pixel_values at the entrance of the model.
  518. For efficiency, we only pass through the vision_model's forward the real images by
  519. discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
  520. image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
  521. """
  522. )
  523. def forward(
  524. self,
  525. input_ids: torch.LongTensor | None = None,
  526. attention_mask: torch.Tensor | None = None,
  527. position_ids: torch.LongTensor | None = None,
  528. past_key_values: Cache | None = None,
  529. inputs_embeds: torch.FloatTensor | None = None,
  530. pixel_values: torch.FloatTensor | None = None,
  531. pixel_attention_mask: torch.BoolTensor | None = None,
  532. image_hidden_states: torch.FloatTensor | None = None,
  533. use_cache: bool | None = None,
  534. **kwargs: Unpack[FlashAttentionKwargs],
  535. ) -> tuple | Idefics3BaseModelOutputWithPast:
  536. r"""
  537. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  538. Mask to avoid performing attention on padding pixel indices.
  539. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  540. The hidden states of the image encoder after modality projection.
  541. """
  542. if self.training and self.text_model.gradient_checkpointing and use_cache:
  543. logger.warning_once(
  544. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  545. )
  546. use_cache = False
  547. # retrieve input_ids and inputs_embeds
  548. if input_ids is not None:
  549. batch_size, seq_length = input_ids.shape
  550. elif inputs_embeds is not None:
  551. batch_size, seq_length, _ = inputs_embeds.shape
  552. else:
  553. raise ValueError("You have to specify either input_ids or inputs_embeds")
  554. if use_cache and past_key_values is None:
  555. past_key_values = DynamicCache(config=self.config)
  556. if inputs_embeds is None:
  557. inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)
  558. # START VISUAL INPUTS INTEGRATION
  559. if pixel_values is not None and image_hidden_states is not None:
  560. raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
  561. elif pixel_values is not None:
  562. image_hidden_states = self.get_image_features(
  563. pixel_values, pixel_attention_mask, return_dict=True
  564. ).pooler_output
  565. elif image_hidden_states is not None:
  566. image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
  567. if image_hidden_states is not None:
  568. # When we generate, we don't want to replace the potential image_token_id that we generated by images
  569. # that simply don't exist
  570. inputs_embeds = self.inputs_merger(
  571. input_ids=input_ids,
  572. inputs_embeds=inputs_embeds,
  573. image_hidden_states=image_hidden_states,
  574. )
  575. outputs = self.text_model(
  576. inputs_embeds=inputs_embeds,
  577. attention_mask=attention_mask,
  578. position_ids=position_ids,
  579. past_key_values=past_key_values,
  580. use_cache=use_cache,
  581. **kwargs,
  582. )
  583. return Idefics3BaseModelOutputWithPast(
  584. last_hidden_state=outputs.last_hidden_state,
  585. past_key_values=outputs.past_key_values,
  586. hidden_states=outputs.hidden_states,
  587. attentions=outputs.attentions,
  588. image_hidden_states=image_hidden_states,
  589. )
  590. @auto_docstring(
  591. custom_intro="""
  592. The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top.
  593. """
  594. )
  595. class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin):
  596. _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"}
  597. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3
  598. def __init__(self, config):
  599. super().__init__(config)
  600. self.model = Idefics3Model(config)
  601. self.image_token_id = self.config.image_token_id
  602. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  603. self.vocab_size = config.text_config.vocab_size
  604. # Initialize weights and apply final processing
  605. self.post_init()
  606. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.get_input_embeddings
  607. def get_input_embeddings(self):
  608. return self.model.text_model.get_input_embeddings()
  609. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.set_input_embeddings
  610. def set_input_embeddings(self, value):
  611. self.model.text_model.set_input_embeddings(value)
  612. @auto_docstring
  613. def get_image_features(
  614. self,
  615. pixel_values: torch.FloatTensor,
  616. pixel_attention_mask: torch.LongTensor | None = None,
  617. **kwargs: Unpack[TransformersKwargs],
  618. ) -> tuple | BaseModelOutputWithPooling:
  619. r"""
  620. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  621. The tensors corresponding to the input images.
  622. pixel_attention_mask (`torch.LongTensor`, *optional*):
  623. The attention mask indicating padded regions in the image.
  624. """
  625. return self.model.get_image_features(
  626. pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask, **kwargs
  627. )
  628. @can_return_tuple
  629. @auto_docstring
  630. def forward(
  631. self,
  632. input_ids: torch.LongTensor | None = None,
  633. attention_mask: torch.Tensor | None = None,
  634. position_ids: torch.LongTensor | None = None,
  635. past_key_values: Cache | None = None,
  636. inputs_embeds: torch.FloatTensor | None = None,
  637. pixel_values: torch.FloatTensor | None = None,
  638. pixel_attention_mask: torch.BoolTensor | None = None,
  639. image_hidden_states: torch.FloatTensor | None = None,
  640. labels: torch.LongTensor | None = None,
  641. use_cache: bool | None = None,
  642. logits_to_keep: int | torch.Tensor = 0,
  643. **kwargs: Unpack[TransformersKwargs],
  644. ) -> tuple | Idefics3CausalLMOutputWithPast:
  645. r"""
  646. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  647. Mask to avoid performing attention on padding pixel indices.
  648. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  649. The hidden states of the image encoder after modality projection.
  650. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  651. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  652. config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`).
  653. Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
  654. computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  655. Example:
  656. ```python
  657. >>> import torch
  658. >>> from PIL import Image
  659. >>> from io import BytesIO
  660. >>> from transformers import AutoProcessor, AutoModelForImageTextToText
  661. >>> from transformers.image_utils import load_image
  662. >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
  663. >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
  664. >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
  665. >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
  666. >>> processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
  667. >>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3", dtype=torch.bfloat16, device_map="auto")
  668. >>> # Create inputs
  669. >>> messages = [
  670. ... {
  671. ... "role": "user",
  672. ... "content": [
  673. ... {"type": "image"},
  674. ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
  675. ... {"type": "image"},
  676. ... {"type": "text", "text": "What can we see in this image?"},
  677. ... ]
  678. ... },
  679. ... {
  680. ... "role": "user",
  681. ... "content": [
  682. ... {"type": "image"},
  683. ... {"type": "text", "text": "In which city is that bridge located?"},
  684. ... ]
  685. ... }
  686. ... ]
  687. >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
  688. >>> images = [[image1, image2], [image3]]
  689. >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
  690. >>> # Generate
  691. >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
  692. >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
  693. >>> print(generated_texts[0])
  694. Assistant: There are buildings, trees, lights, and water visible in this image.
  695. >>> print(generated_texts[1])
  696. Assistant: The bridge is in San Francisco.
  697. ```"""
  698. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  699. outputs = self.model(
  700. input_ids=input_ids,
  701. attention_mask=attention_mask,
  702. position_ids=position_ids,
  703. past_key_values=past_key_values,
  704. inputs_embeds=inputs_embeds,
  705. pixel_values=pixel_values,
  706. pixel_attention_mask=pixel_attention_mask,
  707. image_hidden_states=image_hidden_states,
  708. use_cache=use_cache,
  709. return_dict=True,
  710. **kwargs,
  711. )
  712. hidden_states = outputs[0]
  713. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  714. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  715. logits = self.lm_head(hidden_states[:, slice_indices, :])
  716. loss = None
  717. if labels is not None:
  718. loss = self.loss_function(
  719. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  720. )
  721. return Idefics3CausalLMOutputWithPast(
  722. loss=loss,
  723. logits=logits,
  724. past_key_values=outputs.past_key_values,
  725. hidden_states=outputs.hidden_states,
  726. attentions=outputs.attentions,
  727. image_hidden_states=outputs.image_hidden_states,
  728. )
  729. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.prepare_inputs_for_generation
  730. def prepare_inputs_for_generation(
  731. self,
  732. input_ids,
  733. past_key_values=None,
  734. attention_mask=None,
  735. inputs_embeds=None,
  736. pixel_values=None,
  737. pixel_attention_mask=None,
  738. image_hidden_states=None,
  739. logits_to_keep=None,
  740. is_first_iteration=False,
  741. use_cache=False,
  742. **kwargs,
  743. ):
  744. # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
  745. # precedence is moved to the model, we can remove this fn)
  746. model_inputs = super().prepare_inputs_for_generation(
  747. input_ids,
  748. past_key_values=past_key_values,
  749. attention_mask=attention_mask,
  750. inputs_embeds=inputs_embeds,
  751. pixel_values=pixel_values,
  752. pixel_attention_mask=pixel_attention_mask,
  753. image_hidden_states=image_hidden_states,
  754. logits_to_keep=logits_to_keep,
  755. is_first_iteration=is_first_iteration,
  756. use_cache=use_cache,
  757. **kwargs,
  758. )
  759. if image_hidden_states is not None or (use_cache and not is_first_iteration):
  760. model_inputs["pixel_values"] = None
  761. model_inputs["pixel_attention_mask"] = None
  762. return model_inputs
  763. __all__ = ["Idefics3ForConditionalGeneration", "Idefics3PreTrainedModel", "Idefics3Model", "Idefics3VisionTransformer"]