modeling_paligemma.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch PaliGemmamodel."""
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from ...cache_utils import Cache
  20. from ...configuration_utils import PreTrainedConfig
  21. from ...generation import GenerationMixin
  22. from ...masking_utils import create_masks_for_generate
  23. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  24. from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
  25. from ...modeling_utils import PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...utils import (
  28. ModelOutput,
  29. TransformersKwargs,
  30. auto_docstring,
  31. can_return_tuple,
  32. logging,
  33. torch_compilable_check,
  34. )
  35. from ...utils.deprecation import deprecate_kwarg
  36. from ..auto import AutoModel
  37. from .configuration_paligemma import PaliGemmaConfig
  38. logger = logging.get_logger(__name__)
  39. @dataclass
  40. @auto_docstring(
  41. custom_intro="""
  42. Base class for Paligemma outputs, with hidden states and attentions.
  43. """
  44. )
  45. class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
  46. r"""
  47. image_hidden_states (`torch.FloatTensor`, *optional*):
  48. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  49. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  50. """
  51. image_hidden_states: torch.FloatTensor | None = None
  52. @dataclass
  53. @auto_docstring(
  54. custom_intro="""
  55. Base class for PaliGemma causal language model (or autoregressive) outputs.
  56. """
  57. )
  58. class PaliGemmaCausalLMOutputWithPast(ModelOutput):
  59. r"""
  60. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  61. Language modeling loss (for next-token prediction).
  62. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
  63. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  64. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  65. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  66. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  67. `past_key_values` input) to speed up sequential decoding.
  68. image_hidden_states (`torch.FloatTensor`, *optional*):
  69. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  70. image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
  71. """
  72. loss: torch.FloatTensor | None = None
  73. logits: torch.FloatTensor | None = None
  74. past_key_values: Cache | None = None
  75. hidden_states: tuple[torch.FloatTensor] | None = None
  76. attentions: tuple[torch.FloatTensor] | None = None
  77. image_hidden_states: torch.FloatTensor | None = None
  78. class PaliGemmaMultiModalProjector(nn.Module):
  79. def __init__(self, config: PaliGemmaConfig):
  80. super().__init__()
  81. self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
  82. def forward(self, image_features):
  83. hidden_states = self.linear(image_features)
  84. return hidden_states
  85. def token_type_ids_mask_function(group_ids: torch.Tensor) -> Callable:
  86. """
  87. This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
  88. not start and end indices.
  89. Args:
  90. group_ids (`torch.Tensor`):
  91. A tensor of shape `(bs, len)` assigning each token to a vision group. Tokens with the same group
  92. come from the same input image. Text is denoted by `-1`.
  93. """
  94. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  95. seq_length = group_ids.shape[-1]
  96. # clamp indices because with static cache they can go beyond `group_ids.shape[-1]`
  97. q_idx_clamped = q_idx.clamp(max=seq_length - 1)
  98. kv_idx_clamped = kv_idx.clamp(max=seq_length - 1)
  99. # Unmask if the q and kv come from same group which is not -1 (i.e. non-text)
  100. q_group = group_ids[batch_idx, q_idx_clamped]
  101. kv_group = group_ids[batch_idx, kv_idx_clamped]
  102. q_group = torch.where(q_idx < seq_length, q_group, -1)
  103. kv_group = torch.where(kv_idx < seq_length, kv_group, -1)
  104. return (q_group == kv_group) & (q_group >= 0)
  105. return inner_mask
  106. @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
  107. def create_causal_mask_mapping(
  108. config: PreTrainedConfig,
  109. inputs_embeds: torch.Tensor,
  110. attention_mask: torch.Tensor | None,
  111. past_key_values: Cache | None,
  112. position_ids: torch.Tensor | None,
  113. token_type_ids: torch.Tensor | None = None,
  114. pixel_values: torch.FloatTensor | None = None,
  115. is_training: bool | None = False,
  116. is_first_iteration: bool | None = None,
  117. **kwargs,
  118. ) -> dict:
  119. """
  120. Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
  121. for all kinds of forward passes. Paligemma uses a bidirectional mask on the prompt tokens.
  122. Uses `pixel_values` as an optional input to disambiguate edge cases.
  123. """
  124. if is_training and token_type_ids is None:
  125. raise ValueError("`token_type_ids` is required as a model input when training")
  126. mask_kwargs = {
  127. "config": config.get_text_config(),
  128. "inputs_embeds": inputs_embeds,
  129. "attention_mask": attention_mask,
  130. "past_key_values": past_key_values,
  131. "position_ids": position_ids,
  132. }
  133. # Infer if prefill or decoding stage, if the flag isn't passed. This happens only when the mask is constructed
  134. # from `forward` call. If users run a `forward` call, we have no option to infer `is_first_iteration` because users may be
  135. # running generation with custom loop. Thus we need to infer it in a `non-perfect` way
  136. # NOTE: Determining prefill in that case requires checking data values, which is not compile-compatible.
  137. is_first_iteration = (
  138. is_first_iteration
  139. if is_first_iteration
  140. else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
  141. )
  142. if is_first_iteration or not kwargs.get("use_cache", True):
  143. if token_type_ids is not None:
  144. # The logic bellow was originally written for Gemma3, where `token_type_ids` is reversed. Let's reverse
  145. # it to then use exactly the same logic.
  146. token_type_ids = 1 - token_type_ids
  147. else:
  148. logger.warning_once(
  149. "It is a prefill stage but The `token_type_ids` is not provided. We recommend "
  150. "passing `token_type_ids` to the model to prevent bad attention masking."
  151. )
  152. # NOTE: this branch can't be reached when training because `token_type_ids` is required as a model input.
  153. token_type_ids = torch.ones_like(inputs_embeds)[:, :, 0]
  154. # Logic originally copied from Gemma3. It holds up for Paligemma as well because Paligemma assumes up to one image
  155. # per prompt AND we reverse `token_type_ids` above. Gemma3 uses a bidirectional mask for images, tagged through
  156. # `token_type_ids` 1s.
  157. if token_type_ids is not None and is_first_iteration:
  158. # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
  159. # undo the causal masking)
  160. # First find where a new image block starts: 1 if image and previous not image
  161. # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
  162. is_image = (token_type_ids == 1).to(inputs_embeds.device)
  163. is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
  164. new_image_start = is_image & ~is_previous_image
  165. group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
  166. group_ids = torch.where(is_image, group_ids, torch.full_like(token_type_ids, -1))
  167. mask_kwargs["or_mask_function"] = token_type_ids_mask_function(group_ids)
  168. return create_masks_for_generate(**mask_kwargs)
  169. @auto_docstring
  170. class PaliGemmaPreTrainedModel(PreTrainedModel):
  171. config: PaliGemmaConfig
  172. base_model_prefix = "model"
  173. input_modalities = ("image", "text")
  174. supports_gradient_checkpointing = True
  175. _no_split_modules = ["PaliGemmaMultiModalProjector"]
  176. _skip_keys_device_placement = "past_key_values"
  177. _can_compile_fullgraph = False
  178. _supports_flash_attn = True
  179. _supports_sdpa = True
  180. _supports_flex_attn = True
  181. _supports_attention_backend = True
  182. @auto_docstring(
  183. custom_intro="""
  184. The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
  185. """
  186. )
  187. class PaliGemmaModel(PaliGemmaPreTrainedModel):
  188. # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
  189. accepts_loss_kwargs = False
  190. def __init__(self, config: PaliGemmaConfig):
  191. super().__init__(config)
  192. self.vision_tower = AutoModel.from_config(config=config.vision_config)
  193. self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
  194. self.vocab_size = config.text_config.vocab_size
  195. language_model = AutoModel.from_config(config=config.text_config)
  196. self.language_model = language_model
  197. self.text_config_dtype = self.config.get_text_config().dtype or self.dtype
  198. self.post_init()
  199. # Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma
  200. def get_input_embeddings(self):
  201. return self.language_model.get_input_embeddings()
  202. # Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma
  203. def set_input_embeddings(self, value):
  204. self.language_model.set_input_embeddings(value)
  205. @can_return_tuple
  206. @auto_docstring(
  207. custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
  208. )
  209. def get_image_features(
  210. self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  211. ) -> tuple | BaseModelOutputWithPooling:
  212. image_outputs = self.vision_tower(pixel_values, **kwargs)
  213. selected_image_feature = image_outputs.last_hidden_state
  214. image_features = self.multi_modal_projector(selected_image_feature)
  215. image_outputs.pooler_output = image_features
  216. return image_outputs
  217. def get_placeholder_mask(
  218. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  219. ):
  220. """
  221. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  222. equal to the length of multimodal features. If the lengths are different, an error is raised.
  223. """
  224. if input_ids is None:
  225. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  226. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  227. )
  228. special_image_mask = special_image_mask.all(-1)
  229. else:
  230. special_image_mask = input_ids == self.config.image_token_id
  231. n_image_tokens = special_image_mask.sum()
  232. n_image_features = image_features.shape[0] * image_features.shape[1]
  233. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  234. torch_compilable_check(
  235. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  236. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
  237. )
  238. return special_image_mask
  239. @can_return_tuple
  240. @auto_docstring
  241. def forward(
  242. self,
  243. input_ids: torch.LongTensor | None = None,
  244. pixel_values: torch.FloatTensor | None = None,
  245. attention_mask: torch.Tensor | None = None,
  246. position_ids: torch.LongTensor | None = None,
  247. past_key_values: Cache | None = None,
  248. token_type_ids: torch.LongTensor | None = None,
  249. inputs_embeds: torch.FloatTensor | None = None,
  250. labels: torch.LongTensor | None = None,
  251. use_cache: bool | None = None,
  252. **kwargs: Unpack[FlashAttentionKwargs],
  253. ) -> tuple | PaligemmaModelOutputWithPast:
  254. r"""
  255. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  256. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  257. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  258. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
  259. Example:
  260. ```python
  261. >>> from PIL import Image
  262. >>> import httpx
  263. >>> from io import BytesIO
  264. >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
  265. >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
  266. >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
  267. >>> prompt = "Where is the cat standing?"
  268. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
  269. >>> with httpx.stream("GET", url) as response:
  270. ... image = Image.open(BytesIO(response.read()))
  271. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  272. >>> # Generate
  273. >>> generate_ids = model.generate(**inputs,)
  274. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  275. "Where is the cat standing?\nsnow"
  276. ```"""
  277. if (input_ids is None) ^ (inputs_embeds is not None):
  278. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  279. # Replace image id with PAD if the image token if OOV, to avoid index-errors
  280. if input_ids is not None and self.config.image_token_id >= self.vocab_size:
  281. special_image_mask = input_ids == self.config.image_token_id
  282. llm_input_ids = input_ids.clone()
  283. llm_input_ids[special_image_mask] = 0
  284. else:
  285. llm_input_ids = input_ids
  286. if inputs_embeds is None:
  287. inputs_embeds = self.get_input_embeddings()(llm_input_ids)
  288. if position_ids is None:
  289. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  290. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  291. position_ids = position_ids.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
  292. # Merge text and images
  293. if pixel_values is not None:
  294. image_features = self.get_image_features(pixel_values).pooler_output
  295. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  296. special_image_mask = self.get_placeholder_mask(
  297. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  298. )
  299. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  300. # It may already have been prepared by e.g. `generate`
  301. if not isinstance(causal_mask_mapping := attention_mask, dict):
  302. causal_mask_mapping = create_causal_mask_mapping(
  303. self.config,
  304. inputs_embeds,
  305. attention_mask,
  306. past_key_values,
  307. position_ids,
  308. token_type_ids,
  309. pixel_values,
  310. is_training=self.training,
  311. )
  312. outputs = self.language_model(
  313. attention_mask=causal_mask_mapping,
  314. position_ids=position_ids,
  315. past_key_values=past_key_values,
  316. inputs_embeds=inputs_embeds,
  317. use_cache=use_cache,
  318. **kwargs,
  319. )
  320. return PaligemmaModelOutputWithPast(
  321. last_hidden_state=outputs.last_hidden_state,
  322. past_key_values=outputs.past_key_values,
  323. hidden_states=outputs.hidden_states,
  324. attentions=outputs.attentions,
  325. image_hidden_states=image_features if pixel_values is not None else None,
  326. )
  327. @auto_docstring(
  328. custom_intro="""
  329. The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
  330. """
  331. )
  332. class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
  333. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  334. def __init__(self, config: PaliGemmaConfig):
  335. super().__init__(config)
  336. self.model = PaliGemmaModel(config)
  337. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  338. self.post_init()
  339. def get_input_embeddings(self):
  340. return self.model.get_input_embeddings()
  341. def set_input_embeddings(self, value):
  342. self.model.set_input_embeddings(value)
  343. @auto_docstring
  344. def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]):
  345. return self.model.get_image_features(pixel_values, **kwargs)
  346. @can_return_tuple
  347. @auto_docstring
  348. def forward(
  349. self,
  350. input_ids: torch.LongTensor | None = None,
  351. pixel_values: torch.FloatTensor | None = None,
  352. attention_mask: torch.Tensor | None = None,
  353. position_ids: torch.LongTensor | None = None,
  354. past_key_values: Cache | None = None,
  355. token_type_ids: torch.LongTensor | None = None,
  356. inputs_embeds: torch.FloatTensor | None = None,
  357. labels: torch.LongTensor | None = None,
  358. use_cache: bool | None = None,
  359. logits_to_keep: int | torch.Tensor = 0,
  360. **kwargs: Unpack[TransformersKwargs],
  361. ) -> tuple | PaliGemmaCausalLMOutputWithPast:
  362. r"""
  363. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  364. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  365. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  366. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
  367. Example:
  368. ```python
  369. >>> from PIL import Image
  370. >>> import httpx
  371. >>> from io import BytesIO
  372. >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
  373. >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
  374. >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
  375. >>> prompt = "Where is the cat standing?"
  376. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
  377. >>> with httpx.stream("GET", url) as response:
  378. ... image = Image.open(BytesIO(response.read()))
  379. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  380. >>> # Generate
  381. >>> generate_ids = model.generate(**inputs,)
  382. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  383. "Where is the cat standing?\nsnow"
  384. ```"""
  385. outputs = self.model(
  386. input_ids=input_ids,
  387. pixel_values=pixel_values,
  388. token_type_ids=token_type_ids,
  389. attention_mask=attention_mask,
  390. position_ids=position_ids,
  391. past_key_values=past_key_values,
  392. inputs_embeds=inputs_embeds,
  393. use_cache=use_cache,
  394. labels=labels,
  395. **kwargs,
  396. )
  397. hidden_states = outputs[0]
  398. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  399. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  400. logits = self.lm_head(hidden_states[:, slice_indices, :])
  401. loss = None
  402. if labels is not None:
  403. loss = self.loss_function(
  404. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  405. )
  406. return PaliGemmaCausalLMOutputWithPast(
  407. loss=loss,
  408. logits=logits,
  409. past_key_values=outputs.past_key_values,
  410. hidden_states=outputs.hidden_states,
  411. attentions=outputs.attentions,
  412. image_hidden_states=outputs.image_hidden_states,
  413. )
  414. def prepare_inputs_for_generation(
  415. self,
  416. input_ids,
  417. past_key_values=None,
  418. inputs_embeds=None,
  419. position_ids=None,
  420. pixel_values=None,
  421. attention_mask=None,
  422. token_type_ids=None,
  423. use_cache=True,
  424. logits_to_keep=None,
  425. labels=None,
  426. is_first_iteration=False,
  427. **kwargs,
  428. ):
  429. # Overwritten -- custom `position_ids` and `pixel_values` handling
  430. model_inputs = super().prepare_inputs_for_generation(
  431. input_ids,
  432. past_key_values=past_key_values,
  433. inputs_embeds=inputs_embeds,
  434. attention_mask=attention_mask,
  435. position_ids=position_ids,
  436. use_cache=use_cache,
  437. logits_to_keep=logits_to_keep,
  438. token_type_ids=token_type_ids,
  439. is_first_iteration=is_first_iteration,
  440. **kwargs,
  441. )
  442. # position_ids in Paligemma are 1-indexed
  443. if model_inputs.get("position_ids") is not None:
  444. # NOTE: we need this op out-of-place, otherwise it modifies the `model_kwargs` dict used in `generate` in-place!
  445. model_inputs["position_ids"] = model_inputs["position_ids"] + 1
  446. # Pixel values are used only in the first iteration if available
  447. # In subsequent iterations, they are already merged with text and cached
  448. # NOTE: first iteration doesn't have to be prefill, it can be the first
  449. # iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always
  450. if is_first_iteration or not use_cache:
  451. model_inputs["pixel_values"] = pixel_values
  452. return model_inputs
  453. @staticmethod
  454. @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
  455. def create_masks_for_generate(
  456. config: PreTrainedConfig,
  457. inputs_embeds: torch.Tensor,
  458. attention_mask: torch.Tensor | None,
  459. past_key_values: Cache | None,
  460. position_ids: torch.Tensor | None,
  461. token_type_ids: torch.Tensor | None = None,
  462. is_first_iteration: bool | None = False,
  463. **kwargs,
  464. ) -> dict:
  465. # Uses the overwritten `create_masks_for_generate` with `token_type_ids` masking
  466. return create_causal_mask_mapping(
  467. config,
  468. inputs_embeds,
  469. attention_mask,
  470. past_key_values,
  471. position_ids,
  472. token_type_ids,
  473. is_first_iteration=is_first_iteration,
  474. **{k: v for k, v in kwargs.items() if k != "pixel_values"},
  475. )
  476. __all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"]