modeling_idefics2.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144
  1. # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Idefics2 model."""
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache, DynamicCache
  22. from ...generation import GenerationMixin
  23. from ...masking_utils import create_bidirectional_mask
  24. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
  27. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  28. from ...processing_utils import Unpack
  29. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  30. from ...utils.generic import merge_with_config_defaults
  31. from ...utils.output_capturing import capture_outputs
  32. from ..auto import AutoModel
  33. from .configuration_idefics2 import Idefics2Config, Idefics2PerceiverConfig, Idefics2VisionConfig
  34. logger = logging.get_logger(__name__)
  35. @dataclass
  36. @auto_docstring(
  37. custom_intro="""
  38. Base class for Idefics2 model's outputs that may also contain a past key/values (to speed up sequential decoding).
  39. """
  40. )
  41. class Idefics2BaseModelOutputWithPast(ModelOutput):
  42. r"""
  43. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  44. Sequence of hidden-states at the output of the last layer of the model.
  45. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  46. hidden_size)` is output.
  47. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  48. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  49. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  50. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  51. input) to speed up sequential decoding.
  52. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  53. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  54. sequence_length, hidden_size)`.
  55. image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
  56. """
  57. last_hidden_state: torch.FloatTensor | None = None
  58. past_key_values: Cache | None = None
  59. hidden_states: tuple[torch.FloatTensor] | None = None
  60. attentions: tuple[torch.FloatTensor] | None = None
  61. image_hidden_states: tuple[torch.FloatTensor] | None = None
  62. @dataclass
  63. @auto_docstring(
  64. custom_intro="""
  65. Base class for Idefics2 causal language model (or autoregressive) outputs.
  66. """
  67. )
  68. # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Idefics2
  69. class Idefics2CausalLMOutputWithPast(ModelOutput):
  70. r"""
  71. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  72. Language modeling loss (for next-token prediction).
  73. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  74. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  75. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  76. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  77. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  78. `past_key_values` input) to speed up sequential decoding.
  79. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  80. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  81. sequence_length, hidden_size)`.
  82. image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
  83. """
  84. loss: torch.FloatTensor | None = None
  85. logits: torch.FloatTensor | None = None
  86. past_key_values: Cache | None = None
  87. hidden_states: tuple[torch.FloatTensor] | None = None
  88. attentions: tuple[torch.FloatTensor] | None = None
  89. image_hidden_states: tuple[torch.FloatTensor] | None = None
  90. class Idefics2VisionEmbeddings(nn.Module):
  91. """
  92. This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
  93. resolution.
  94. The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://huggingface.co/papers/2307.06304)
  95. which allows treating images in their native aspect ratio and without the need to resize them to the same
  96. fixed size. In particular, we start from the original pre-trained SigLIP model
  97. (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
  98. """
  99. def __init__(self, config: Idefics2VisionConfig):
  100. super().__init__()
  101. self.embed_dim = config.hidden_size
  102. self.image_size = config.image_size
  103. self.patch_size = config.patch_size
  104. self.patch_embedding = nn.Conv2d(
  105. in_channels=config.num_channels,
  106. out_channels=self.embed_dim,
  107. kernel_size=self.patch_size,
  108. stride=self.patch_size,
  109. padding="valid",
  110. )
  111. self.num_patches_per_side = self.image_size // self.patch_size
  112. self.num_patches = self.num_patches_per_side**2
  113. self.num_positions = self.num_patches
  114. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  115. def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
  116. batch_size, _, max_im_h, max_im_w = pixel_values.shape
  117. patch_embeds = self.patch_embedding(pixel_values)
  118. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  119. max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
  120. boundaries = torch.arange(
  121. 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
  122. )
  123. position_ids = torch.full(
  124. size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
  125. )
  126. nb_patches_h = patch_attention_mask[:, :, 0].sum(dim=1) # (batch_size,)
  127. nb_patches_w = patch_attention_mask[:, 0, :].sum(dim=1) # (batch_size,)
  128. step_h = 1.0 / nb_patches_h # (batch_size,)
  129. step_w = 1.0 / nb_patches_w # (batch_size,)
  130. max_patches_h = patch_attention_mask.size(1)
  131. max_patches_w = patch_attention_mask.size(2)
  132. h_indices = torch.arange(max_patches_h, device=position_ids.device, dtype=torch.float32)
  133. w_indices = torch.arange(max_patches_w, device=position_ids.device, dtype=torch.float32)
  134. fractional_coords_h = h_indices[None, :] * step_h[:, None]
  135. fractional_coords_w = w_indices[None, :] * step_w[:, None]
  136. fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6))
  137. fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6))
  138. fractional_coords_h = fractional_coords_h.to(pixel_values.dtype)
  139. fractional_coords_w = fractional_coords_w.to(pixel_values.dtype)
  140. bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
  141. bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
  142. pos_ids = bucket_coords_h[:, :, None] * self.num_patches_per_side + bucket_coords_w[:, None, :]
  143. pos_ids = pos_ids.reshape(batch_size, -1)
  144. position_ids[patch_attention_mask.view(batch_size, -1)] = pos_ids[patch_attention_mask.view(batch_size, -1)]
  145. embeddings = embeddings + self.position_embedding(position_ids)
  146. return embeddings
  147. def eager_attention_forward(
  148. module: nn.Module,
  149. query: torch.Tensor,
  150. key: torch.Tensor,
  151. value: torch.Tensor,
  152. attention_mask: torch.Tensor | None,
  153. scaling: float,
  154. dropout: float = 0.0,
  155. **kwargs: Unpack[TransformersKwargs],
  156. ):
  157. if hasattr(module, "num_key_value_groups"):
  158. key = repeat_kv(key, module.num_key_value_groups)
  159. value = repeat_kv(value, module.num_key_value_groups)
  160. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  161. if attention_mask is not None:
  162. attn_weights = attn_weights + attention_mask
  163. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  164. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  165. attn_output = torch.matmul(attn_weights, value)
  166. attn_output = attn_output.transpose(1, 2).contiguous()
  167. return attn_output, attn_weights
  168. # Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics2Vision
  169. class Idefics2VisionAttention(nn.Module):
  170. """Multi-headed attention from 'Attention Is All You Need' paper"""
  171. # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
  172. def __init__(self, config):
  173. super().__init__()
  174. self.config = config
  175. self.embed_dim = config.hidden_size
  176. self.num_heads = config.num_attention_heads
  177. self.head_dim = self.embed_dim // self.num_heads
  178. if self.head_dim * self.num_heads != self.embed_dim:
  179. raise ValueError(
  180. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  181. f" {self.num_heads})."
  182. )
  183. self.scale = self.head_dim**-0.5
  184. self.dropout = config.attention_dropout
  185. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  186. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  187. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  188. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  189. # Ignore copy
  190. self.is_causal = False
  191. def forward(
  192. self,
  193. hidden_states: torch.Tensor,
  194. attention_mask: torch.Tensor | None = None,
  195. **kwargs,
  196. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  197. """Input shape: Batch x Time x Channel"""
  198. input_shape = hidden_states.shape[:-1]
  199. hidden_shape = (*input_shape, -1, self.head_dim)
  200. queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  201. keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  202. values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  203. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  204. self.config._attn_implementation, eager_attention_forward
  205. )
  206. attn_output, attn_weights = attention_interface(
  207. self,
  208. queries,
  209. keys,
  210. values,
  211. attention_mask,
  212. is_causal=self.is_causal,
  213. scaling=self.scale,
  214. dropout=0.0 if not self.training else self.dropout,
  215. )
  216. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  217. attn_output = self.out_proj(attn_output)
  218. return attn_output, attn_weights
  219. # Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics2Vision
  220. class Idefics2VisionMLP(nn.Module):
  221. def __init__(self, config):
  222. super().__init__()
  223. self.config = config
  224. self.activation_fn = ACT2FN[config.hidden_act]
  225. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  226. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  227. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  228. hidden_states = self.fc1(hidden_states)
  229. hidden_states = self.activation_fn(hidden_states)
  230. hidden_states = self.fc2(hidden_states)
  231. return hidden_states
  232. class Idefics2MLP(nn.Module):
  233. def __init__(
  234. self,
  235. hidden_size: int,
  236. intermediate_size: int,
  237. output_size: int,
  238. hidden_act: str,
  239. ):
  240. super().__init__()
  241. self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
  242. self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
  243. self.down_proj = nn.Linear(intermediate_size, output_size, bias=False)
  244. self.act_fn = ACT2FN[hidden_act]
  245. def forward(self, x):
  246. return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  247. # Copied from transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead with Siglip->Idefics2
  248. class Idefics2MultiheadAttentionPoolingHead(nn.Module):
  249. """Multihead Attention Pooling."""
  250. def __init__(self, config: Idefics2VisionConfig):
  251. super().__init__()
  252. self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  253. self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
  254. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  255. # Ignore copy
  256. self.mlp = Idefics2MLP(
  257. hidden_size=config.hidden_size,
  258. intermediate_size=config.intermediate_size,
  259. hidden_act=config.hidden_act,
  260. output_size=config.hidden_size,
  261. )
  262. def forward(self, hidden_state):
  263. batch_size = hidden_state.shape[0]
  264. probe = self.probe.repeat(batch_size, 1, 1)
  265. hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  266. residual = hidden_state
  267. hidden_state = self.layernorm(hidden_state)
  268. hidden_state = residual + self.mlp(hidden_state)
  269. return hidden_state[:, 0]
  270. class Idefics2EncoderLayer(GradientCheckpointingLayer):
  271. def __init__(self, config: Idefics2VisionConfig):
  272. super().__init__()
  273. self.embed_dim = config.hidden_size
  274. self.self_attn = Idefics2VisionAttention(config)
  275. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  276. self.mlp = Idefics2VisionMLP(config)
  277. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  278. @auto_docstring
  279. # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
  280. def forward(
  281. self,
  282. hidden_states: torch.Tensor,
  283. attention_mask: torch.Tensor,
  284. **kwargs: Unpack[TransformersKwargs],
  285. ) -> torch.FloatTensor:
  286. residual = hidden_states
  287. hidden_states = self.layer_norm1(hidden_states)
  288. hidden_states, _ = self.self_attn(
  289. hidden_states=hidden_states,
  290. attention_mask=attention_mask,
  291. **kwargs,
  292. )
  293. hidden_states = residual + hidden_states
  294. residual = hidden_states
  295. hidden_states = self.layer_norm2(hidden_states)
  296. hidden_states = self.mlp(hidden_states)
  297. hidden_states = residual + hidden_states
  298. return hidden_states
  299. # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics2
  300. class Idefics2Encoder(nn.Module):
  301. """
  302. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  303. [`Idefics2EncoderLayer`].
  304. Args:
  305. config: Idefics2Config
  306. """
  307. def __init__(self, config: Idefics2Config):
  308. super().__init__()
  309. self.config = config
  310. self.layers = nn.ModuleList([Idefics2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
  311. self.gradient_checkpointing = False
  312. # Ignore copy
  313. @auto_docstring
  314. def forward(
  315. self,
  316. inputs_embeds,
  317. attention_mask: torch.Tensor | None = None,
  318. **kwargs: Unpack[TransformersKwargs],
  319. ) -> BaseModelOutput:
  320. hidden_states = inputs_embeds
  321. for encoder_layer in self.layers:
  322. hidden_states = encoder_layer(
  323. hidden_states,
  324. attention_mask,
  325. **kwargs,
  326. )
  327. return BaseModelOutput(last_hidden_state=hidden_states)
  328. @auto_docstring
  329. class Idefics2PreTrainedModel(PreTrainedModel):
  330. config: Idefics2Config
  331. base_model_prefix = "model"
  332. input_modalities = ("image", "text")
  333. supports_gradient_checkpointing = True
  334. _no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"]
  335. _skip_keys_device_placement = "past_key_values"
  336. _supports_flash_attn = True
  337. _supports_sdpa = True
  338. _supports_flex_attn = True
  339. _supports_attention_backend = True
  340. @torch.no_grad()
  341. def _init_weights(self, module):
  342. super()._init_weights(module)
  343. if isinstance(module, Idefics2MultiheadAttentionPoolingHead):
  344. init.normal_(module.probe)
  345. elif isinstance(module, Idefics2PerceiverResampler):
  346. init.ones_(module.latents)
  347. @auto_docstring(
  348. custom_intro="""
  349. Idefics2 vision encoder model that returnss raw image embeddings.
  350. """
  351. )
  352. class Idefics2VisionTransformer(Idefics2PreTrainedModel):
  353. config: Idefics2VisionConfig
  354. input_modalities = ("image",)
  355. _can_record_outputs = {
  356. "hidden_states": Idefics2EncoderLayer,
  357. "attentions": Idefics2VisionAttention,
  358. }
  359. def __init__(self, config: Idefics2VisionConfig):
  360. super().__init__(config)
  361. embed_dim = config.hidden_size
  362. self.config = config
  363. self.embeddings = Idefics2VisionEmbeddings(config)
  364. self.encoder = Idefics2Encoder(config)
  365. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  366. self.post_init()
  367. def get_input_embeddings(self):
  368. return self.embeddings
  369. def set_input_embeddings(self, value):
  370. self.embeddings = value
  371. @merge_with_config_defaults
  372. @capture_outputs(tie_last_hidden_states=False)
  373. @auto_docstring
  374. def forward(
  375. self,
  376. pixel_values,
  377. patch_attention_mask: torch.BoolTensor | None = None,
  378. **kwargs: Unpack[TransformersKwargs],
  379. ) -> tuple | BaseModelOutput:
  380. r"""
  381. patch_attention_mask (`torch.BoolTensor` of shape `(batch_size, num_patches_height, num_patches_width)`, *optional*):
  382. The attention mask for the patches.
  383. """
  384. batch_size = pixel_values.size(0)
  385. if patch_attention_mask is None:
  386. patch_size = self.config.patch_size
  387. patch_attention_mask = torch.ones(
  388. (
  389. batch_size,
  390. pixel_values.size(2) // patch_size,
  391. pixel_values.size(3) // patch_size,
  392. )
  393. )
  394. patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
  395. hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
  396. patch_attention_mask = patch_attention_mask.view(batch_size, -1)
  397. patch_attention_mask = create_bidirectional_mask(
  398. config=self.config,
  399. inputs_embeds=hidden_states,
  400. attention_mask=patch_attention_mask,
  401. )
  402. encoder_outputs: BaseModelOutput = self.encoder(
  403. inputs_embeds=hidden_states,
  404. attention_mask=patch_attention_mask,
  405. **kwargs,
  406. )
  407. last_hidden_state = encoder_outputs.last_hidden_state
  408. last_hidden_state = self.post_layernorm(last_hidden_state)
  409. return BaseModelOutput(last_hidden_state=last_hidden_state)
  410. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  411. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  412. """
  413. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  414. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  415. """
  416. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  417. if n_rep == 1:
  418. return hidden_states
  419. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  420. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  421. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics2
  422. class Idefics2RMSNorm(nn.Module):
  423. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  424. """
  425. Idefics2RMSNorm is equivalent to T5LayerNorm
  426. """
  427. super().__init__()
  428. self.weight = nn.Parameter(torch.ones(hidden_size))
  429. self.variance_epsilon = eps
  430. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  431. input_dtype = hidden_states.dtype
  432. hidden_states = hidden_states.to(torch.float32)
  433. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  434. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  435. return self.weight * hidden_states.to(input_dtype)
  436. def extra_repr(self):
  437. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  438. class Idefics2PerceiverAttention(nn.Module):
  439. def __init__(self, config, layer_idx: int | None = None) -> None:
  440. """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
  441. super().__init__()
  442. self.config = config
  443. self.layer_idx = None
  444. self.hidden_size = config.hidden_size
  445. self.num_heads = config.resampler_n_heads
  446. self.head_dim = config.resampler_head_dim
  447. self.num_key_value_heads = config.num_key_value_heads
  448. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  449. self.attention_dropout = config.attention_dropout
  450. self.scaling = self.head_dim**-0.5
  451. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  452. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  453. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  454. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  455. self.is_causal = False
  456. def forward(
  457. self,
  458. latents: torch.Tensor,
  459. context: torch.Tensor,
  460. attention_mask: torch.Tensor | None = None,
  461. position_ids: torch.LongTensor | None = None,
  462. past_key_values: Cache | None = None,
  463. **kwargs: Unpack[TransformersKwargs],
  464. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  465. """
  466. Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
  467. Args:
  468. latents (`torch.Tensor`): Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
  469. context (`torch.Tensor`): Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
  470. attention_mask (`torch.Tensor`, *optional*): Tensor of shape [bsz, 1, seq, n_latents] representing attention mask.
  471. position_ids (`torch.LongTensor`, *optional*): Tensor of shape [bsz, seq] representing position indices of each input token.
  472. past_key_values (`Cache`, *optional*): Tuple of tensors containing cached key and value states.
  473. output_attentions (`bool`, *optional*, defaults to `False`): Whether to return attention weights.
  474. use_cache (`bool`, *optional*, defaults to `False`): Whether to use past_key_values for caching.
  475. """
  476. bsz, q_len, _ = latents.size()
  477. kv_seq_len = q_len + context.size()[1]
  478. hidden_states = torch.concat([context, latents], dim=-2)
  479. queries = self.q_proj(latents)
  480. keys = self.k_proj(hidden_states)
  481. values = self.v_proj(hidden_states)
  482. queries = queries.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  483. keys = keys.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  484. values = values.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  485. past_key_values = getattr(self, "past_key_values", past_key_values)
  486. if past_key_values is not None:
  487. keys, values = past_key_values.update(keys, values, self.layer_idx)
  488. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  489. self.config._attn_implementation, eager_attention_forward
  490. )
  491. attn_output, attn_weights = attention_interface(
  492. self,
  493. queries,
  494. keys,
  495. values,
  496. attention_mask,
  497. is_causal=self.is_causal,
  498. scaling=self.scaling,
  499. dropout=0.0 if not self.training else self.attention_dropout,
  500. **kwargs,
  501. )
  502. attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
  503. attn_output = self.o_proj(attn_output)
  504. return attn_output, attn_weights
  505. class Idefics2PerceiverLayer(nn.Module):
  506. def __init__(self, config, layer_idx: int):
  507. super().__init__()
  508. self.hidden_size = config.hidden_size
  509. self.n_latents = config.resampler_n_latents
  510. self.depth = config.resampler_depth
  511. self.rms_norm_eps = config.rms_norm_eps
  512. self.input_latents_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
  513. self.input_context_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
  514. self.self_attn = Idefics2PerceiverAttention(config, layer_idx=layer_idx)
  515. self.post_attention_layernorm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
  516. self.mlp = Idefics2MLP(
  517. hidden_size=config.hidden_size,
  518. intermediate_size=config.hidden_size * 4,
  519. output_size=config.hidden_size,
  520. hidden_act=config.hidden_act,
  521. )
  522. def forward(
  523. self,
  524. latents: torch.Tensor,
  525. context: torch.Tensor,
  526. attention_mask: torch.Tensor | None = None,
  527. position_ids: torch.LongTensor | None = None,
  528. past_key_values: Cache | None = None,
  529. **kwargs: Unpack[TransformersKwargs],
  530. ) -> torch.FloatTensor:
  531. """
  532. Args:
  533. latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  534. context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  535. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  536. `(batch, sequence_length)` where padding elements are indicated by 0.
  537. output_attentions (`bool`, *optional*):
  538. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  539. returned tensors for more detail.
  540. use_cache (`bool`, *optional*):
  541. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  542. (see `past_key_values`).
  543. past_key_values (`Cache`, *optional*): cached past key and value projection states
  544. """
  545. residual = latents
  546. latents = self.input_latents_norm(latents)
  547. context = self.input_context_norm(context)
  548. latents, _ = self.self_attn(
  549. latents=latents,
  550. context=context,
  551. attention_mask=attention_mask,
  552. **kwargs,
  553. )
  554. latents = residual + latents
  555. residual = latents
  556. latents = self.post_attention_layernorm(latents)
  557. latents = self.mlp(latents)
  558. latents = residual + latents
  559. return latents
  560. @auto_docstring(
  561. custom_intro="""
  562. Idefics2 perceiver resampler model that performs `depth` blocks of cross-attention with a fixed
  563. """
  564. )
  565. class Idefics2PerceiverResampler(Idefics2PreTrainedModel):
  566. config: Idefics2PerceiverConfig
  567. input_modalities = ("image",)
  568. _supports_sdpa = True
  569. _supports_flash_attn = True
  570. _supports_flex_attn = True
  571. def __init__(self, config) -> None:
  572. super().__init__(config)
  573. self.hidden_size = config.hidden_size
  574. self.hidden_act = config.hidden_act
  575. self.n_latents = config.resampler_n_latents
  576. self.depth = config.resampler_depth
  577. self.rms_norm_eps = config.rms_norm_eps
  578. # Create Latents for Perceiver
  579. self.latents = nn.Parameter(torch.ones(self.n_latents, self.hidden_size))
  580. # Create Transformer Blocks
  581. self.layers = nn.ModuleList([Idefics2PerceiverLayer(config, idx) for idx in range(self.depth)])
  582. self.norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
  583. self.post_init()
  584. @auto_docstring
  585. def forward(
  586. self,
  587. context: torch.Tensor,
  588. attention_mask: torch.Tensor,
  589. **kwargs: Unpack[TransformersKwargs],
  590. ) -> torch.Tensor:
  591. r"""
  592. context (`torch.FloatTensor` of shape `(batch, seq_len, embed_dim)`):
  593. Input to the layer.
  594. """
  595. # seq embed -> bsz seq embed
  596. latents = self.latents.unsqueeze(0).expand((context.shape[0], *self.latents.size()))
  597. latent_attention_mask = torch.ones(
  598. (attention_mask.size(0), latents.size(1)), dtype=attention_mask.dtype, device=attention_mask.device
  599. )
  600. attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
  601. attention_mask = create_bidirectional_mask(
  602. config=self.config,
  603. inputs_embeds=latents,
  604. attention_mask=attention_mask,
  605. )
  606. compressed_context = latents
  607. for perceiver_layer in self.layers:
  608. compressed_context = perceiver_layer(
  609. compressed_context,
  610. context,
  611. attention_mask=attention_mask,
  612. position_ids=None,
  613. **kwargs,
  614. )
  615. compressed_context = self.norm(compressed_context)
  616. return compressed_context
  617. class Idefics2Connector(nn.Module):
  618. def __init__(self, config):
  619. super().__init__()
  620. self.modality_projection = Idefics2MLP(
  621. hidden_size=config.vision_config.hidden_size,
  622. intermediate_size=config.text_config.intermediate_size,
  623. output_size=config.text_config.hidden_size,
  624. hidden_act=config.text_config.hidden_act,
  625. )
  626. self.perceiver_resampler = Idefics2PerceiverResampler._from_config(config.perceiver_config)
  627. def forward(self, image_hidden_states, attention_mask):
  628. image_hidden_states = self.modality_projection(image_hidden_states)
  629. image_hidden_states = self.perceiver_resampler(context=image_hidden_states, attention_mask=attention_mask)
  630. return image_hidden_states
  631. @auto_docstring(
  632. custom_intro="""
  633. Idefics2 model consisting of a SIGLIP vision encoder and Mistral language decoder
  634. """
  635. )
  636. class Idefics2Model(Idefics2PreTrainedModel):
  637. def __init__(self, config: Idefics2Config):
  638. super().__init__(config)
  639. self.padding_idx = self.config.text_config.pad_token_id
  640. self.vocab_size = self.config.text_config.vocab_size
  641. self.vision_model = Idefics2VisionTransformer._from_config(config.vision_config)
  642. self.connector = Idefics2Connector(config)
  643. self.text_model = AutoModel.from_config(config.text_config)
  644. self.image_seq_len = config.perceiver_config.resampler_n_latents
  645. self.image_token_id = self.config.image_token_id
  646. self.post_init()
  647. def get_input_embeddings(self):
  648. return self.text_model.get_input_embeddings()
  649. def set_input_embeddings(self, value):
  650. self.text_model.set_input_embeddings(value)
  651. def inputs_merger(
  652. self,
  653. input_ids: torch.LongTensor,
  654. inputs_embeds: torch.Tensor | None,
  655. image_hidden_states: torch.Tensor | None,
  656. ):
  657. """
  658. This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
  659. The merging happens as follows:
  660. - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
  661. - We get the image hidden states for the image through the vision encoder (and potentially the perceiver), and that hidden state is then projected into the text embedding space.
  662. We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
  663. - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
  664. - To fit the format of that sequence, `input_ids`, `inputs_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
  665. """
  666. if input_ids is None:
  667. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  668. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  669. )
  670. special_image_mask = special_image_mask.all(-1)
  671. else:
  672. special_image_mask = input_ids == self.config.image_token_id
  673. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  674. image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
  675. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
  676. return inputs_embeds
  677. @can_return_tuple
  678. @auto_docstring
  679. def get_image_features(
  680. self,
  681. pixel_values: torch.FloatTensor,
  682. pixel_attention_mask: torch.LongTensor | None = None,
  683. **kwargs: Unpack[TransformersKwargs],
  684. ) -> tuple | BaseModelOutputWithPooling:
  685. r"""
  686. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  687. The tensors corresponding to the input images.
  688. pixel_attention_mask (`torch.LongTensor`, *optional*):
  689. The attention mask indicating padded regions in the image.
  690. """
  691. batch_size, num_images, num_channels, height, width = pixel_values.shape
  692. pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
  693. pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
  694. # Remove padding images - padding images are full 0.
  695. nb_values_per_image = pixel_values.shape[1:].numel()
  696. real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
  697. pixel_values = pixel_values[real_images_inds].contiguous()
  698. # Handle the vision attention mask
  699. if pixel_attention_mask is None:
  700. pixel_attention_mask = torch.ones(
  701. size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
  702. dtype=torch.bool,
  703. device=pixel_values.device,
  704. )
  705. else:
  706. # Remove padding images from the mask/pP p
  707. pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
  708. pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
  709. patch_size = self.config.vision_config.patch_size
  710. patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
  711. patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
  712. patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) == patch_size * patch_size).bool()
  713. # Get sequence from the vision encoder
  714. image_outputs = self.vision_model(
  715. pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, **kwargs
  716. )
  717. image_hidden_states = image_outputs.last_hidden_state
  718. # Modality projection & resampling
  719. image_features = self.connector(
  720. image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
  721. )
  722. image_outputs.pooler_output = image_features.view(-1, image_features.shape[-1])
  723. return image_outputs
  724. @can_return_tuple
  725. @auto_docstring(
  726. custom_intro="""
  727. Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
  728. the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
  729. max_num_images is the maximum number of images among the batch_size samples in the batch.
  730. Padding images are not needed beyond padding the pixel_values at the entrance of the model.
  731. For efficiency, we only pass through the vision_model's forward the real images by
  732. discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
  733. image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
  734. """
  735. )
  736. def forward(
  737. self,
  738. input_ids: torch.LongTensor | None = None,
  739. attention_mask: torch.Tensor | None = None,
  740. position_ids: torch.LongTensor | None = None,
  741. past_key_values: Cache | None = None,
  742. inputs_embeds: torch.FloatTensor | None = None,
  743. pixel_values: torch.FloatTensor | None = None,
  744. pixel_attention_mask: torch.BoolTensor | None = None,
  745. image_hidden_states: torch.FloatTensor | None = None,
  746. use_cache: bool | None = None,
  747. **kwargs: Unpack[FlashAttentionKwargs],
  748. ) -> tuple | Idefics2BaseModelOutputWithPast:
  749. r"""
  750. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  751. Mask to avoid performing attention on padding pixel indices.
  752. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  753. The hidden states of the image encoder after modality projection and perceiver resampling.
  754. """
  755. if self.training and self.text_model.gradient_checkpointing and use_cache:
  756. logger.warning_once(
  757. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  758. )
  759. use_cache = False
  760. # retrieve input_ids and inputs_embeds
  761. if input_ids is not None:
  762. batch_size, seq_length = input_ids.shape
  763. elif inputs_embeds is not None:
  764. batch_size, seq_length, _ = inputs_embeds.shape
  765. else:
  766. raise ValueError("You have to specify either input_ids or inputs_embeds")
  767. if use_cache and past_key_values is None:
  768. past_key_values = DynamicCache(config=self.config)
  769. if inputs_embeds is None:
  770. inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
  771. # START VISUAL INPUTS INTEGRATION
  772. if pixel_values is not None and image_hidden_states is not None:
  773. raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
  774. elif pixel_values is not None:
  775. image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask, **kwargs).pooler_output
  776. elif image_hidden_states is not None:
  777. image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
  778. if image_hidden_states is not None:
  779. # When we generate, we don't want to replace the potential image_token_id that we generated by images
  780. # that simply don't exist
  781. inputs_embeds = self.inputs_merger(
  782. input_ids=input_ids,
  783. inputs_embeds=inputs_embeds,
  784. image_hidden_states=image_hidden_states,
  785. )
  786. kwargs["return_dict"] = True
  787. outputs = self.text_model(
  788. inputs_embeds=inputs_embeds,
  789. attention_mask=attention_mask,
  790. position_ids=position_ids,
  791. past_key_values=past_key_values,
  792. use_cache=use_cache,
  793. **kwargs,
  794. )
  795. return Idefics2BaseModelOutputWithPast(
  796. last_hidden_state=outputs.last_hidden_state,
  797. past_key_values=outputs.past_key_values,
  798. hidden_states=outputs.hidden_states,
  799. attentions=outputs.attentions,
  800. image_hidden_states=image_hidden_states,
  801. )
  802. @auto_docstring(
  803. custom_intro="""
  804. The Idefics2 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top.
  805. """
  806. )
  807. class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin):
  808. _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"}
  809. def __init__(self, config):
  810. super().__init__(config)
  811. self.model = Idefics2Model(config)
  812. self.image_token_id = self.config.image_token_id
  813. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  814. self.vocab_size = config.text_config.vocab_size
  815. # Initialize weights and apply final processing
  816. self.post_init()
  817. def get_input_embeddings(self):
  818. return self.model.text_model.get_input_embeddings()
  819. def set_input_embeddings(self, value):
  820. self.model.text_model.set_input_embeddings(value)
  821. @auto_docstring
  822. def get_image_features(
  823. self,
  824. pixel_values: torch.FloatTensor,
  825. pixel_attention_mask: torch.LongTensor | None = None,
  826. **kwargs: Unpack[TransformersKwargs],
  827. ) -> tuple | BaseModelOutputWithPooling:
  828. r"""
  829. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  830. The tensors corresponding to the input images.
  831. pixel_attention_mask (`torch.LongTensor`, *optional*):
  832. The attention mask indicating padded regions in the image.
  833. """
  834. return self.model.get_image_features(
  835. pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask, **kwargs
  836. )
  837. @can_return_tuple
  838. @auto_docstring
  839. def forward(
  840. self,
  841. input_ids: torch.LongTensor | None = None,
  842. attention_mask: torch.Tensor | None = None,
  843. position_ids: torch.LongTensor | None = None,
  844. past_key_values: Cache | None = None,
  845. inputs_embeds: torch.FloatTensor | None = None,
  846. pixel_values: torch.FloatTensor | None = None,
  847. pixel_attention_mask: torch.BoolTensor | None = None,
  848. image_hidden_states: torch.FloatTensor | None = None,
  849. labels: torch.LongTensor | None = None,
  850. use_cache: bool | None = None,
  851. logits_to_keep: int | torch.Tensor = 0,
  852. **kwargs: Unpack[TransformersKwargs],
  853. ) -> tuple | Idefics2CausalLMOutputWithPast:
  854. r"""
  855. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  856. Mask to avoid performing attention on padding pixel indices.
  857. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  858. The hidden states of the image encoder after modality projection and perceiver resampling.
  859. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  860. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  861. config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics2ForConditionalGeneration`).
  862. Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
  863. computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  864. Example:
  865. ```python
  866. >>> import torch
  867. >>> from PIL import Image
  868. >>> from io import BytesIO
  869. >>> from transformers import AutoProcessor, AutoModelForImageTextToText
  870. >>> from transformers.image_utils import load_image
  871. >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
  872. >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
  873. >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
  874. >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
  875. >>> processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b-base")
  876. >>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceM4/idefics2-8b-base", device_map="auto")
  877. >>> BAD_WORDS_IDS = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
  878. >>> EOS_WORDS_IDS = [processor.tokenizer.eos_token_id]
  879. >>> # Create inputs
  880. >>> prompts = [
  881. ... "<image>In this image, we can see the city of New York, and more specifically the Statue of Liberty.<image>In this image,",
  882. ... "In which city is that bridge located?<image>",
  883. ... ]
  884. >>> images = [[image1, image2], [image3]]
  885. >>> inputs = processor(images=images, text=prompts, padding=True, return_tensors="pt").to("cuda")
  886. >>> # Generate
  887. >>> generated_ids = model.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_new_tokens=20)
  888. >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
  889. >>> print(generated_texts)
  890. ['In this image, we can see the city of New York, and more specifically the Statue of Liberty. In this image, we can see the city of New York, and more specifically the Statue of Liberty.\n\n', 'In which city is that bridge located?\n\nThe bridge is located in the city of Pittsburgh, Pennsylvania.\n\n\nThe bridge is']
  891. ```"""
  892. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  893. outputs = self.model(
  894. input_ids=input_ids,
  895. attention_mask=attention_mask,
  896. position_ids=position_ids,
  897. past_key_values=past_key_values,
  898. inputs_embeds=inputs_embeds,
  899. pixel_values=pixel_values,
  900. pixel_attention_mask=pixel_attention_mask,
  901. image_hidden_states=image_hidden_states,
  902. use_cache=use_cache,
  903. **kwargs,
  904. )
  905. hidden_states = outputs[0]
  906. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  907. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  908. logits = self.lm_head(hidden_states[:, slice_indices, :])
  909. loss = None
  910. if labels is not None:
  911. loss = self.loss_function(
  912. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  913. )
  914. return Idefics2CausalLMOutputWithPast(
  915. loss=loss,
  916. logits=logits,
  917. past_key_values=outputs.past_key_values,
  918. hidden_states=outputs.hidden_states,
  919. attentions=outputs.attentions,
  920. image_hidden_states=outputs.image_hidden_states,
  921. )
  922. def prepare_inputs_for_generation(
  923. self,
  924. input_ids,
  925. past_key_values=None,
  926. attention_mask=None,
  927. inputs_embeds=None,
  928. pixel_values=None,
  929. pixel_attention_mask=None,
  930. image_hidden_states=None,
  931. logits_to_keep=None,
  932. is_first_iteration=False,
  933. use_cache=False,
  934. **kwargs,
  935. ):
  936. # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
  937. # precedence is moved to the model, we can remove this fn)
  938. model_inputs = super().prepare_inputs_for_generation(
  939. input_ids,
  940. past_key_values=past_key_values,
  941. attention_mask=attention_mask,
  942. inputs_embeds=inputs_embeds,
  943. pixel_values=pixel_values,
  944. pixel_attention_mask=pixel_attention_mask,
  945. image_hidden_states=image_hidden_states,
  946. logits_to_keep=logits_to_keep,
  947. is_first_iteration=is_first_iteration,
  948. use_cache=use_cache,
  949. **kwargs,
  950. )
  951. if image_hidden_states is not None or (use_cache and not is_first_iteration):
  952. model_inputs["pixel_values"] = None
  953. model_inputs["pixel_attention_mask"] = None
  954. return model_inputs
  955. __all__ = ["Idefics2ForConditionalGeneration", "Idefics2PreTrainedModel", "Idefics2Model"]