modeling_mllama.py 71 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634
  1. # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Mllama model."""
  15. import math
  16. from collections.abc import Callable
  17. from typing import Optional
  18. import torch
  19. import torch.nn.functional as F
  20. from torch import nn
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache
  24. from ...generation import GenerationMixin
  25. from ...masking_utils import create_causal_mask
  26. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
  29. from ...modeling_rope_utils import (
  30. ROPE_INIT_FUNCTIONS,
  31. dynamic_rope_update,
  32. )
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  36. from ...utils.generic import (
  37. maybe_autocast,
  38. merge_with_config_defaults,
  39. )
  40. from ...utils.output_capturing import OutputRecorder, capture_outputs
  41. from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig
  42. logger = logging.get_logger(__name__)
  43. def _prepare_cross_attention_mask(
  44. cross_attention_mask: torch.Tensor,
  45. num_vision_tokens: int,
  46. dtype: str,
  47. ) -> tuple[torch.Tensor, torch.Tensor]:
  48. # reshape so it can be used by attn module
  49. batch_size, text_total_length, *_ = cross_attention_mask.shape
  50. cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3)
  51. cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
  52. cross_attention_mask = cross_attention_mask.unsqueeze(1)
  53. # invert the mask
  54. inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
  55. cross_attention_mask = inverted_cross_attn_mask.masked_fill(
  56. inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
  57. )
  58. # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
  59. # last dimension contains negative infinity values, otherwise it's 1
  60. negative_inf_value = torch.finfo(dtype).min
  61. full_text_row_masked_out_mask = (
  62. (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
  63. )
  64. cross_attention_mask *= full_text_row_masked_out_mask
  65. return cross_attention_mask, full_text_row_masked_out_mask
  66. def _prepare_aspect_ratio_attention_mask(
  67. aspect_ratio_mask: torch.Tensor,
  68. num_patches: int,
  69. target_length: int,
  70. dtype: torch.dtype,
  71. ) -> torch.Tensor:
  72. # Expand aspect ratio mask to target_length
  73. batch_size, max_num_tiles = aspect_ratio_mask.shape
  74. attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
  75. attention_mask = attention_mask.repeat(1, 1, target_length, 1)
  76. # Mask padding patches
  77. pad_patches = target_length - num_patches
  78. attention_mask[:, :, -pad_patches:] = 0
  79. # Invert the mask (0 -> 1, 1 -> 0)
  80. attention_mask = 1 - attention_mask
  81. # Reshape to 2D and create 4D attention mask
  82. # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
  83. attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1)
  84. attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min
  85. attention_mask = attention_mask.unsqueeze(1)
  86. return attention_mask
  87. class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
  88. def __init__(self, config: MllamaVisionConfig, is_gated: bool = True):
  89. super().__init__()
  90. self.max_num_tiles = config.max_num_tiles
  91. self.hidden_size = config.hidden_size
  92. self.max_aspect_ratio_id = config.max_aspect_ratio_id
  93. self.is_gated = is_gated
  94. self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size)
  95. if is_gated:
  96. self.gate = nn.Parameter(torch.zeros(1))
  97. def forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
  98. embeddings = self.embedding(aspect_ratio_ids)
  99. embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
  100. if self.is_gated:
  101. embeddings = embeddings * self.gate.tanh()
  102. hidden_state = hidden_state + embeddings
  103. return hidden_state
  104. class MllamaPrecomputedPositionEmbedding(nn.Module):
  105. def __init__(self, config: MllamaVisionConfig):
  106. super().__init__()
  107. self.max_num_tiles = config.max_num_tiles
  108. self.max_aspect_ratio_id = config.max_aspect_ratio_id
  109. self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
  110. self.hidden_size = config.hidden_size
  111. self.scale = config.hidden_size**-0.5
  112. self.gate = nn.Parameter(torch.zeros(1))
  113. # position embedding
  114. position_embedding = torch.randn(self.num_patches, self.hidden_size)
  115. self.embedding = nn.Parameter(self.scale * position_embedding)
  116. # tile position embedding
  117. self.tile_embedding = nn.Embedding(
  118. self.max_aspect_ratio_id + 1, self.max_num_tiles * self.num_patches * self.hidden_size
  119. )
  120. def forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
  121. # position embeddings
  122. gated_position_embedding = (1 - self.gate.tanh()) * self.embedding
  123. hidden_state = hidden_state + gated_position_embedding.view(1, 1, self.num_patches, self.hidden_size)
  124. # precomputed tile position embeddings
  125. tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
  126. batch_size = hidden_state.shape[0]
  127. tile_position_embedding = tile_position_embedding.reshape(
  128. batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
  129. )
  130. gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
  131. hidden_state = hidden_state + gated_tile_position_embedding
  132. return hidden_state
  133. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
  134. class MllamaVisionMLP(nn.Module):
  135. def __init__(self, config):
  136. super().__init__()
  137. self.config = config
  138. self.activation_fn = ACT2FN[config.hidden_act]
  139. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  140. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  141. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  142. hidden_states = self.fc1(hidden_states)
  143. hidden_states = self.activation_fn(hidden_states)
  144. hidden_states = self.fc2(hidden_states)
  145. return hidden_states
  146. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  147. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  148. """
  149. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  150. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  151. """
  152. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  153. if n_rep == 1:
  154. return hidden_states
  155. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  156. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  157. # Copied from transformers.models.llama.modeling_llama.eager_attention_forward
  158. def eager_attention_forward(
  159. module: nn.Module,
  160. query: torch.Tensor,
  161. key: torch.Tensor,
  162. value: torch.Tensor,
  163. attention_mask: torch.Tensor | None,
  164. scaling: float,
  165. dropout: float = 0.0,
  166. **kwargs: Unpack[TransformersKwargs],
  167. ):
  168. key_states = repeat_kv(key, module.num_key_value_groups)
  169. value_states = repeat_kv(value, module.num_key_value_groups)
  170. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  171. if attention_mask is not None:
  172. attn_weights = attn_weights + attention_mask
  173. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  174. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  175. attn_output = torch.matmul(attn_weights, value_states)
  176. attn_output = attn_output.transpose(1, 2).contiguous()
  177. return attn_output, attn_weights
  178. class MllamaVisionAttention(nn.Module):
  179. def __init__(self, config: MllamaVisionConfig):
  180. super().__init__()
  181. self.config = config
  182. self.embed_dim = config.hidden_size
  183. self.num_heads = config.attention_heads
  184. self.head_dim = config.hidden_size // config.attention_heads
  185. self.scaling = self.head_dim**-0.5
  186. self.num_key_value_groups = 1
  187. self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False)
  188. self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False)
  189. self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False)
  190. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=False)
  191. def forward(
  192. self,
  193. hidden_state: torch.Tensor,
  194. attention_mask: torch.Tensor | None = None,
  195. **kwargs,
  196. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  197. query = self.q_proj(hidden_state)
  198. key = self.k_proj(hidden_state)
  199. value = self.v_proj(hidden_state)
  200. batch_size, q_seq_len, _ = query.shape
  201. _, kv_seq_len, _ = key.shape
  202. query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  203. key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  204. value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  205. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  206. self.config._attn_implementation, eager_attention_forward
  207. )
  208. attn_output, attn_weights = attention_interface(
  209. self,
  210. query,
  211. key,
  212. value,
  213. attention_mask,
  214. dropout=0.0,
  215. scaling=self.scaling,
  216. **kwargs,
  217. )
  218. attn_output = attn_output.reshape(batch_size, q_seq_len, -1).contiguous()
  219. attn_output = self.o_proj(attn_output)
  220. return attn_output, attn_weights
  221. class MllamaVisionEncoderLayer(nn.Module):
  222. def __init__(self, config: MllamaVisionConfig, is_gated: bool = False):
  223. super().__init__()
  224. self.hidden_size = config.hidden_size
  225. self.num_attention_heads = config.attention_heads
  226. self.is_gated = is_gated
  227. self.intermediate_size = config.intermediate_size
  228. self.self_attn = MllamaVisionAttention(config)
  229. self.mlp = MllamaVisionMLP(config)
  230. self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
  231. self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
  232. if is_gated:
  233. self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4)
  234. self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4)
  235. def forward(
  236. self,
  237. hidden_state: torch.Tensor,
  238. attention_mask: torch.Tensor | None = None,
  239. ):
  240. # Self Attention
  241. residual = hidden_state
  242. hidden_state = self.input_layernorm(hidden_state)
  243. hidden_state, attn_weights = self.self_attn(hidden_state, attention_mask=attention_mask)
  244. if self.is_gated:
  245. hidden_state = self.gate_attn.tanh() * hidden_state
  246. hidden_state = residual + hidden_state
  247. # Feed forward
  248. residual = hidden_state
  249. hidden_state = self.post_attention_layernorm(hidden_state)
  250. hidden_state = self.mlp(hidden_state)
  251. if self.is_gated:
  252. hidden_state = self.gate_ffn.tanh() * hidden_state
  253. hidden_state = residual + hidden_state
  254. return hidden_state
  255. class MllamaVisionEncoder(nn.Module):
  256. """
  257. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  258. [`MllamaEncoderLayer`].
  259. Args:
  260. config: MllamaConfig
  261. """
  262. def __init__(self, config: MllamaVisionConfig, num_layers=32, is_gated=False):
  263. super().__init__()
  264. self.config = config
  265. self.layers = nn.ModuleList([MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)])
  266. self.gradient_checkpointing = False
  267. self.config = config
  268. def forward(
  269. self,
  270. hidden_states: torch.Tensor,
  271. attention_mask: torch.Tensor | None = None,
  272. ) -> BaseModelOutput:
  273. r"""
  274. Args:
  275. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  276. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  277. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  278. than the model's internal embedding lookup matrix.
  279. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  280. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  281. - 1 for tokens that are **not masked**,
  282. - 0 for tokens that are **masked**.
  283. [What are attention masks?](../glossary#attention-mask)
  284. """
  285. encoder_states = ()
  286. for encoder_layer in self.layers:
  287. hidden_states = encoder_layer(
  288. hidden_state=hidden_states,
  289. attention_mask=attention_mask,
  290. )
  291. encoder_states = encoder_states + (hidden_states,)
  292. return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states)
  293. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
  294. class MllamaTextRMSNorm(nn.Module):
  295. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  296. """
  297. MllamaTextRMSNorm is equivalent to T5LayerNorm
  298. """
  299. super().__init__()
  300. self.weight = nn.Parameter(torch.ones(hidden_size))
  301. self.variance_epsilon = eps
  302. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  303. input_dtype = hidden_states.dtype
  304. hidden_states = hidden_states.to(torch.float32)
  305. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  306. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  307. return self.weight * hidden_states.to(input_dtype)
  308. def extra_repr(self):
  309. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  310. class MllamaTextCrossAttention(nn.Module):
  311. """Multi-headed attention from 'Attention Is All You Need' paper"""
  312. def __init__(
  313. self,
  314. config: MllamaTextConfig | None = None,
  315. layer_idx: int | None = None,
  316. ):
  317. super().__init__()
  318. self.config = config
  319. self.num_heads = self.config.num_attention_heads
  320. self.num_key_value_heads = self.config.num_key_value_heads
  321. self.dropout = config.dropout
  322. self.hidden_size = config.hidden_size
  323. self.head_dim = config.hidden_size // self.num_heads
  324. self.layer_idx = layer_idx
  325. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  326. self.scaling = self.head_dim**-0.5
  327. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  328. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  329. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  330. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  331. self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
  332. self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
  333. def forward(
  334. self,
  335. hidden_states: torch.Tensor,
  336. cross_attention_states: torch.Tensor | None = None,
  337. past_key_values: Cache | None = None,
  338. attention_mask: torch.Tensor | None = None,
  339. use_cache: bool | None = None,
  340. **kwargs,
  341. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  342. """Input shape: Batch x Time x Channel"""
  343. bsz, q_len, _ = hidden_states.size()
  344. query_states = self.q_proj(hidden_states)
  345. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  346. query_states = self.q_norm(query_states)
  347. if cross_attention_states is not None:
  348. key_states = self.k_proj(cross_attention_states)
  349. value_states = self.v_proj(cross_attention_states)
  350. key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  351. value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  352. key_states = self.k_norm(key_states)
  353. if past_key_values is not None:
  354. # if we have a new image + new tokens, we only computed key_states on that new image
  355. # we still update the cross key states, past_image, new_image. And use it!
  356. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  357. elif past_key_values is not None and past_key_values.get_seq_length() > 0:
  358. key_states, value_states = (
  359. past_key_values.layers[self.layer_idx].keys,
  360. past_key_values.layers[self.layer_idx].values,
  361. )
  362. else:
  363. raise ValueError(
  364. "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
  365. )
  366. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  367. self.config._attn_implementation, eager_attention_forward
  368. )
  369. attn_output, attn_weights = attention_interface(
  370. self,
  371. query_states,
  372. key_states,
  373. value_states,
  374. attention_mask,
  375. dropout=0.0 if not self.training else self.dropout,
  376. scaling=self.scaling,
  377. **kwargs,
  378. )
  379. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  380. attn_output = self.o_proj(attn_output)
  381. return attn_output, attn_weights
  382. # Copied from transformers.models.llama.modeling_llama.rotate_half
  383. def rotate_half(x):
  384. """Rotates half the hidden dims of the input."""
  385. x1 = x[..., : x.shape[-1] // 2]
  386. x2 = x[..., x.shape[-1] // 2 :]
  387. return torch.cat((-x2, x1), dim=-1)
  388. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  389. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  390. """Applies Rotary Position Embedding to the query and key tensors.
  391. Args:
  392. q (`torch.Tensor`): The query tensor.
  393. k (`torch.Tensor`): The key tensor.
  394. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  395. sin (`torch.Tensor`): The sine part of the rotary embedding.
  396. unsqueeze_dim (`int`, *optional*, defaults to 1):
  397. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  398. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  399. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  400. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  401. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  402. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  403. Returns:
  404. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  405. """
  406. cos = cos.unsqueeze(unsqueeze_dim)
  407. sin = sin.unsqueeze(unsqueeze_dim)
  408. q_embed = (q * cos) + (rotate_half(q) * sin)
  409. k_embed = (k * cos) + (rotate_half(k) * sin)
  410. return q_embed, k_embed
  411. class MllamaTextSelfAttention(nn.Module):
  412. def __init__(self, config: MllamaTextConfig, layer_idx: int):
  413. super().__init__()
  414. self.config = config
  415. self.num_heads = config.num_attention_heads
  416. self.dropout = config.dropout
  417. self.hidden_size = config.hidden_size
  418. self.num_key_value_heads = config.num_key_value_heads
  419. self.head_dim = config.hidden_size // self.num_heads
  420. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  421. self.scaling = self.head_dim**-0.5
  422. self.layer_idx = layer_idx
  423. self.is_causal = True
  424. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  425. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  426. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  427. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  428. def forward(
  429. self,
  430. hidden_states: torch.Tensor,
  431. attention_mask: torch.Tensor,
  432. position_embeddings: torch.Tensor,
  433. past_key_values=None,
  434. **kwargs,
  435. ):
  436. bsz, q_len, _ = hidden_states.size()
  437. query_states = self.q_proj(hidden_states)
  438. key_states = self.k_proj(hidden_states)
  439. value_states = self.v_proj(hidden_states)
  440. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  441. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  442. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  443. cos, sin = position_embeddings
  444. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  445. if past_key_values is not None:
  446. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  447. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  448. self.config._attn_implementation, eager_attention_forward
  449. )
  450. attn_output, attn_weights = attention_interface(
  451. self,
  452. query_states,
  453. key_states,
  454. value_states,
  455. attention_mask,
  456. dropout=0.0 if not self.training else self.dropout,
  457. scaling=self.scaling,
  458. **kwargs,
  459. )
  460. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  461. attn_output = self.o_proj(attn_output)
  462. return attn_output, attn_weights
  463. # Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
  464. class MllamaTextMLP(nn.Module):
  465. def __init__(self, config):
  466. super().__init__()
  467. self.config = config
  468. self.hidden_size = config.hidden_size
  469. self.intermediate_size = config.intermediate_size
  470. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  471. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  472. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  473. # Ignore copy
  474. self.act_fn = ACT2FN[config.hidden_act]
  475. def forward(self, x):
  476. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  477. return down_proj
  478. # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer
  479. class MllamaSelfAttentionDecoderLayer(GradientCheckpointingLayer):
  480. def __init__(self, config: MllamaTextConfig, layer_idx: int):
  481. super().__init__()
  482. self.hidden_size = config.hidden_size
  483. self.self_attn = MllamaTextSelfAttention(config=config, layer_idx=layer_idx)
  484. self.mlp = MllamaTextMLP(config)
  485. self.input_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  486. self.post_attention_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  487. self.layer_idx = layer_idx
  488. def forward(
  489. self,
  490. hidden_states: torch.Tensor,
  491. cross_attention_states: torch.Tensor | None = None,
  492. cross_attention_mask: torch.Tensor | None = None,
  493. attention_mask: torch.Tensor | None = None,
  494. full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor] | None = None,
  495. position_ids: torch.LongTensor | None = None,
  496. past_key_values: Cache | None = None,
  497. use_cache: bool | None = False,
  498. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  499. **kwargs: Unpack[FlashAttentionKwargs],
  500. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  501. """
  502. Args:
  503. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  504. attention_mask (`torch.FloatTensor`, *optional*):
  505. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  506. query_sequence_length, key_sequence_length)` if default attention is used.
  507. use_cache (`bool`, *optional*):
  508. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  509. (see `past_key_values`).
  510. past_key_values (`Cache`, *optional*): cached past key and value projection states
  511. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  512. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  513. with `head_dim` being the embedding dimension of each attention head.
  514. kwargs (`dict`, *optional*):
  515. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  516. into the model
  517. """
  518. residual = hidden_states
  519. hidden_states = self.input_layernorm(hidden_states)
  520. # Self Attention
  521. hidden_states, self_attn_weights = self.self_attn(
  522. hidden_states=hidden_states,
  523. attention_mask=attention_mask,
  524. position_ids=position_ids,
  525. past_key_values=past_key_values,
  526. use_cache=use_cache,
  527. position_embeddings=position_embeddings,
  528. **kwargs,
  529. )
  530. hidden_states = residual + hidden_states
  531. # Fully Connected
  532. residual = hidden_states
  533. hidden_states = self.post_attention_layernorm(hidden_states)
  534. hidden_states = self.mlp(hidden_states)
  535. hidden_states = residual + hidden_states
  536. return hidden_states
  537. class MllamaCrossAttentionDecoderLayer(GradientCheckpointingLayer):
  538. """Cross-attention transformer block with tanh-gated attention and feedforward."""
  539. def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None:
  540. super().__init__()
  541. self.layer_idx = layer_idx
  542. self.cross_attn = MllamaTextCrossAttention(config, layer_idx=layer_idx)
  543. self.input_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  544. self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1))
  545. self.mlp = MllamaTextMLP(config)
  546. self.post_attention_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  547. self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1))
  548. def forward(
  549. self,
  550. hidden_states: torch.Tensor,
  551. cross_attention_states: torch.Tensor,
  552. cross_attention_mask: torch.Tensor,
  553. attention_mask: torch.Tensor,
  554. full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor],
  555. position_ids: torch.LongTensor | None = None,
  556. past_key_values: Cache | None = None,
  557. use_cache: bool | None = False,
  558. position_embeddings: torch.Tensor | None = None,
  559. **kwargs: Unpack[FlashAttentionKwargs],
  560. ) -> tuple[torch.Tensor]:
  561. residual = hidden_states
  562. hidden_states = self.input_layernorm(hidden_states)
  563. hidden_states, attn_weights = self.cross_attn(
  564. hidden_states=hidden_states,
  565. attention_mask=cross_attention_mask,
  566. cross_attention_states=cross_attention_states,
  567. past_key_values=past_key_values,
  568. **kwargs,
  569. )
  570. hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
  571. residual = hidden_states
  572. hidden_states = self.post_attention_layernorm(hidden_states)
  573. hidden_states = self.mlp(hidden_states)
  574. if full_text_row_masked_out_mask is not None:
  575. hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore
  576. hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
  577. return hidden_states
  578. # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with LlamaConfig->MllamaTextConfig,Llama->Mllama
  579. class MllamaRotaryEmbedding(nn.Module):
  580. inv_freq: torch.Tensor # fix linting for `register_buffer`
  581. def __init__(self, config: MllamaTextConfig, device=None):
  582. super().__init__()
  583. self.max_seq_len_cached = config.max_position_embeddings
  584. self.original_max_seq_len = config.max_position_embeddings
  585. self.config = config
  586. self.rope_type = self.config.rope_parameters["rope_type"]
  587. rope_init_fn: Callable = self.compute_default_rope_parameters
  588. if self.rope_type != "default":
  589. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  590. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  591. self.register_buffer("inv_freq", inv_freq, persistent=False)
  592. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  593. @staticmethod
  594. def compute_default_rope_parameters(
  595. config: MllamaTextConfig | None = None,
  596. device: Optional["torch.device"] = None,
  597. seq_len: int | None = None,
  598. ) -> tuple["torch.Tensor", float]:
  599. """
  600. Computes the inverse frequencies according to the original RoPE implementation
  601. Args:
  602. config ([`~transformers.PreTrainedConfig`]):
  603. The model configuration.
  604. device (`torch.device`):
  605. The device to use for initialization of the inverse frequencies.
  606. seq_len (`int`, *optional*):
  607. The current sequence length. Unused for this type of RoPE.
  608. Returns:
  609. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  610. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  611. """
  612. base = config.rope_parameters["rope_theta"]
  613. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  614. attention_factor = 1.0 # Unused in this type of RoPE
  615. # Compute the inverse frequencies
  616. inv_freq = 1.0 / (
  617. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  618. )
  619. return inv_freq, attention_factor
  620. # Ignore copy
  621. @torch.no_grad()
  622. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  623. def forward(self, x, position_ids):
  624. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  625. position_ids_expanded = position_ids[:, None, :].float()
  626. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  627. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  628. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  629. emb = torch.cat((freqs, freqs), dim=-1)
  630. cos = emb.cos() * self.attention_scaling
  631. sin = emb.sin() * self.attention_scaling
  632. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  633. @auto_docstring
  634. class MllamaPreTrainedModel(PreTrainedModel):
  635. config: MllamaConfig
  636. base_model_prefix = "model"
  637. input_modalities = ("image", "text")
  638. supports_gradient_checkpointing = True
  639. _no_split_modules = [
  640. "MllamaVisionEncoderLayer",
  641. "MllamaCrossAttentionDecoderLayer",
  642. "MllamaSelfAttentionDecoderLayer",
  643. ]
  644. _can_compile_fullgraph = False # static cache cannot have different shapes for each layer
  645. _supports_sdpa = True
  646. _supports_flash_attn = True
  647. _supports_flex_attn = True
  648. _supports_attention_backend = True
  649. _can_record_outputs = {
  650. "hidden_states": [MllamaSelfAttentionDecoderLayer, MllamaCrossAttentionDecoderLayer],
  651. "attentions": [
  652. OutputRecorder(MllamaTextSelfAttention, index=1, layer_name="self_attn"),
  653. OutputRecorder(MllamaTextSelfAttention, index=1, layer_name="cross_attn"),
  654. OutputRecorder(MllamaTextCrossAttention, index=1, layer_name="cross_attn"),
  655. ],
  656. }
  657. @torch.no_grad()
  658. def _init_weights(self, module):
  659. std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
  660. if isinstance(module, (nn.Linear, nn.Conv2d)):
  661. init.normal_(module.weight, mean=0.0, std=std)
  662. if module.bias is not None:
  663. init.zeros_(module.bias)
  664. elif isinstance(module, nn.Embedding):
  665. init.normal_(module.weight, mean=0.0, std=std)
  666. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  667. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  668. init.zeros_(module.weight[module.padding_idx])
  669. elif isinstance(module, nn.LayerNorm):
  670. init.ones_(module.weight)
  671. init.zeros_(module.bias)
  672. elif isinstance(module, MllamaTextRMSNorm):
  673. init.ones_(module.weight)
  674. elif isinstance(module, MllamaVisionModel):
  675. init.normal_(module.class_embedding, std=std)
  676. elif isinstance(module, MllamaPrecomputedPositionEmbedding):
  677. init.normal_(module.embedding, std=std)
  678. init.zeros_(module.gate)
  679. elif isinstance(module, MllamaVisionEncoderLayer) and module.is_gated:
  680. init.normal_(module.gate_attn, std=std)
  681. init.normal_(module.gate_ffn, std=std)
  682. elif isinstance(module, MllamaCrossAttentionDecoderLayer):
  683. init.zeros_(module.cross_attn_attn_gate)
  684. init.zeros_(module.cross_attn_mlp_gate)
  685. elif isinstance(module, MllamaPrecomputedAspectRatioEmbedding):
  686. if module.is_gated:
  687. init.zeros_(module.gate)
  688. elif isinstance(module, MllamaRotaryEmbedding):
  689. rope_fn = (
  690. ROPE_INIT_FUNCTIONS[module.rope_type]
  691. if module.rope_type != "default"
  692. else module.compute_default_rope_parameters
  693. )
  694. buffer_value, _ = rope_fn(module.config)
  695. init.copy_(module.inv_freq, buffer_value)
  696. init.copy_(module.original_inv_freq, buffer_value)
  697. @auto_docstring(
  698. custom_intro="""
  699. The Mllama Vision Model which consists of two vision encoders.
  700. """
  701. )
  702. class MllamaVisionModel(MllamaPreTrainedModel):
  703. config: MllamaVisionConfig
  704. base_model_prefix = "vision_model"
  705. input_modalities = ("image",)
  706. def __init__(self, config: MllamaVisionConfig):
  707. super().__init__(config)
  708. self.image_size = config.image_size
  709. self.patch_size = config.patch_size
  710. self.max_num_tiles = config.max_num_tiles
  711. self.hidden_size = config.hidden_size
  712. self.num_channels = config.num_channels
  713. self.intermediate_layers_indices = config.intermediate_layers_indices
  714. self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
  715. self.scale = config.hidden_size**-0.5
  716. self.patch_embedding = nn.Conv2d(
  717. in_channels=config.num_channels,
  718. out_channels=self.hidden_size,
  719. kernel_size=self.patch_size,
  720. stride=self.patch_size,
  721. padding="valid",
  722. bias=False,
  723. )
  724. self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
  725. self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config)
  726. self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)
  727. self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)
  728. # layer norms
  729. self.layernorm_pre = nn.LayerNorm(self.hidden_size)
  730. self.layernorm_post = nn.LayerNorm(self.hidden_size)
  731. # encoders
  732. self.transformer = MllamaVisionEncoder(config, config.num_hidden_layers, is_gated=False)
  733. self.global_transformer = MllamaVisionEncoder(config, config.num_global_layers, is_gated=True)
  734. self.post_init()
  735. def get_input_embeddings(self):
  736. """
  737. This function is used to fetch the first embedding layer to activate grads on inputs.
  738. """
  739. return self.patch_embedding
  740. def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
  741. batch_size, _, hidden_size = hidden_state.shape
  742. class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
  743. hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
  744. return hidden_state
  745. @merge_with_config_defaults
  746. @capture_outputs
  747. @auto_docstring
  748. def forward(
  749. self, pixel_values: torch.Tensor, aspect_ratio_ids: torch.Tensor, aspect_ratio_mask: torch.Tensor, **kwargs
  750. ) -> BaseModelOutput:
  751. r"""
  752. aspect_ratio_ids (`torch.Tensor` of shape `(batch_size, max_num_images)`, *optional*):
  753. Aspect ratio ids used to select the appropriate precomputed tile embeddings based on the aspect ratio of each input image.
  754. These ids correspond to indices in the model's list of supported aspect ratios, offset by 1.
  755. For example, if the model supports aspect ratios [[1, 1], [1, 2], [2, 1]]:
  756. - An image with aspect ratio [1, 1] would have ID 1
  757. - An image with aspect ratio [1, 2] would have ID 2
  758. - An image with aspect ratio [2, 1] would have ID 3
  759. The id 0 is reserved for padding (i.e., no image).
  760. If an image has aspect ratio [1, 2], that means it was split into 2 tiles horizontally, and its `aspect_ratio_id` would be 2.
  761. aspect_ratio_mask (`torch.Tensor` of shape `(batch_size, max_num_images, max_num_tiles)`, *optional*):
  762. Mask to avoid performing attention on padding tiles. Mask values selected in `[0, 1]`:
  763. - 1 for tiles that are **not masked**,
  764. - 0 for tiles that are **masked**.
  765. Example:
  766. ```python
  767. >>> from PIL import Image
  768. >>> import httpx
  769. >>> from io import BytesIO
  770. >>> from transformers import AutoProcessor, MllamaVisionModel
  771. >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
  772. >>> model = MllamaVisionModel.from_pretrained(checkpoint)
  773. >>> processor = AutoProcessor.from_pretrained(checkpoint)
  774. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  775. >>> with httpx.stream("GET", url) as response:
  776. ... image = Image.open(BytesIO(response.read()))
  777. >>> inputs = processor(images=image, return_tensors="pt")
  778. >>> output = model(**inputs)
  779. >>> print(output.last_hidden_state.shape)
  780. torch.Size([1, 1, 4, 1025, 7680])
  781. ```
  782. """
  783. batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape
  784. pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width)
  785. aspect_ratio_ids = aspect_ratio_ids.reshape(batch_size * num_concurrent_media, -1)
  786. # Patch embedding
  787. target_dtype = self.patch_embedding.weight.dtype
  788. target_device = self.patch_embedding.weight.device
  789. patch_embeds = self.patch_embedding(pixel_values.to(target_device, target_dtype))
  790. hidden_state = patch_embeds.flatten(2).transpose(1, 2)
  791. # Tile embeddings
  792. _, num_patches, dim = hidden_state.shape
  793. hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, -1, dim)
  794. hidden_state = self.pre_tile_positional_embedding(hidden_state, aspect_ratio_ids)
  795. # Add cls token
  796. hidden_state = hidden_state.reshape(batch_size * num_concurrent_media * num_tiles, num_patches, dim)
  797. hidden_state = self.apply_class_embedding(hidden_state)
  798. num_patches += 1
  799. # Position embeddings
  800. hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, num_patches, dim)
  801. hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
  802. hidden_state = self.layernorm_pre(hidden_state)
  803. # Compute the number of tokens to pad
  804. num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
  805. # Compute padding tuple for pad function
  806. padding = (0, 0, 0, num_padding_patches) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
  807. # Pad the tensor
  808. hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
  809. slice_index = -num_padding_patches if num_padding_patches > 0 else None
  810. # Prepare attention mask
  811. attention_mask = aspect_ratio_mask.reshape(batch_size * num_concurrent_media, -1)
  812. attention_mask = _prepare_aspect_ratio_attention_mask(
  813. aspect_ratio_mask=attention_mask,
  814. num_patches=self.num_patches,
  815. target_length=hidden_state.shape[2],
  816. dtype=self.dtype,
  817. )
  818. # Apply encoder
  819. hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
  820. output = self.transformer(
  821. hidden_state,
  822. attention_mask=attention_mask,
  823. )
  824. hidden_state = output.last_hidden_state
  825. hidden_state = self.layernorm_post(hidden_state)
  826. # Apply global encoder
  827. hidden_state = hidden_state.reshape(
  828. batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim
  829. )
  830. hidden_state = self.post_tile_positional_embedding(hidden_state, aspect_ratio_ids)
  831. hidden_state = hidden_state.reshape(
  832. batch_size * num_concurrent_media, num_tiles * (num_patches + num_padding_patches), dim
  833. )
  834. global_output = self.global_transformer(
  835. hidden_state,
  836. attention_mask=attention_mask,
  837. )
  838. hidden_state = global_output.last_hidden_state
  839. # Remove padding form hidden state
  840. hidden_state = hidden_state.reshape(
  841. batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim
  842. )
  843. hidden_state = hidden_state[:, :, :slice_index]
  844. hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim)
  845. # Collect intermediate layer outputs from encoder output
  846. all_intermediate_hidden_states = [output.hidden_states[i] for i in self.intermediate_layers_indices]
  847. intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1)
  848. # Remove padding from intermediate hidden states
  849. intermediate_hidden_states = intermediate_hidden_states.reshape(
  850. batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1
  851. )
  852. intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
  853. intermediate_hidden_states = intermediate_hidden_states.reshape(
  854. batch_size, num_concurrent_media, num_tiles, num_patches, -1
  855. )
  856. # Concatenate final hidden state and intermediate hidden states
  857. hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
  858. return BaseModelOutput(last_hidden_state=hidden_state)
  859. @auto_docstring(
  860. custom_intro="""
  861. The Mllama Text Model which consists of transformer with self and cross attention layers.
  862. """
  863. )
  864. class MllamaTextModel(MllamaPreTrainedModel):
  865. config: MllamaTextConfig
  866. base_model_prefix = "language_model.model"
  867. input_modalities = ("text",)
  868. def __init__(self, config: MllamaTextConfig):
  869. super().__init__(config)
  870. self.padding_idx = config.pad_token_id
  871. self.vocab_size = config.vocab_size
  872. self.embed_tokens = nn.Embedding(config.vocab_size + 8, config.hidden_size, self.padding_idx)
  873. self.cross_attention_layers = config.cross_attention_layers
  874. layers = []
  875. for layer_idx in range(config.num_hidden_layers):
  876. if layer_idx in self.cross_attention_layers:
  877. layers.append(MllamaCrossAttentionDecoderLayer(config, layer_idx))
  878. else:
  879. layers.append(MllamaSelfAttentionDecoderLayer(config, layer_idx))
  880. self.layers = nn.ModuleList(layers)
  881. self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  882. self.rotary_emb = MllamaRotaryEmbedding(config=config)
  883. self.gradient_checkpointing = False
  884. self.post_init()
  885. @merge_with_config_defaults
  886. @capture_outputs
  887. @can_return_tuple
  888. @auto_docstring
  889. def forward(
  890. self,
  891. input_ids: torch.LongTensor | None = None,
  892. attention_mask: torch.Tensor | None = None,
  893. position_ids: torch.LongTensor | None = None,
  894. cross_attention_states: torch.FloatTensor | None = None,
  895. cross_attention_mask: torch.Tensor | None = None,
  896. full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor] | None = None,
  897. past_key_values: Cache | None = None,
  898. inputs_embeds: torch.FloatTensor | None = None,
  899. use_cache: bool | None = None,
  900. **kwargs: Unpack[FlashAttentionKwargs],
  901. ) -> BaseModelOutputWithPast:
  902. r"""
  903. cross_attention_states (`torch.FloatTensor`, *optional*):
  904. Output of the vision model, used for cross-attention. This tensor contains the processed image features that
  905. the language model will attend to.
  906. cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*):
  907. Cross-attention mask to control the interaction between text tokens and image tiles.
  908. This 4D tensor defines which image tiles each text token should attend to.
  909. For each text token (in seq_length):
  910. - 1 indicates the token **should attend** to the corresponding image tile
  911. - 0 indicates the token **should not attend** to the corresponding image tile
  912. full_text_row_masked_out_mask (`tuple[torch.Tensor, torch.Tensor]`, *optional*):
  913. A tuple containing two tensors that mask out rows in the cross-attention mechanism:
  914. - The first tensor has shape `(batch_size, 1, seq_length, 1)` and contains values of 0 or 1.
  915. A value of 0 indicates that the corresponding text token's entire row in the cross-attention
  916. matrix should be masked out (all image tokens ignored).
  917. - The second tensor has the same shape and is used internally to apply the masking during
  918. the forward pass of cross-attention layers.
  919. This mask is derived from the cross_attention_mask and is used to handle cases where a text token
  920. should not attend to any image token.
  921. Example:
  922. ```python
  923. >>> from transformers import AutoProcessor, MllamaTextModel
  924. >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
  925. >>> model = MllamaTextModel.from_pretrained(checkpoint)
  926. >>> processor = AutoProcessor.from_pretrained(checkpoint)
  927. >>> text = "<|image|>If I had to write a haiku for this one"
  928. >>> inputs = processor(text=text, return_tensors="pt")
  929. >>> output = model(**inputs)
  930. >>> print(output.last_hidden_state.shape)
  931. torch.Size([1, 13, 4096])
  932. ```
  933. """
  934. use_cache = use_cache if use_cache is not None else self.config.use_cache
  935. if (input_ids is None) ^ (inputs_embeds is not None):
  936. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  937. if inputs_embeds is None:
  938. inputs_embeds = self.embed_tokens(input_ids)
  939. hidden_states = inputs_embeds
  940. if use_cache and past_key_values is None:
  941. past_key_values = DynamicCache(config=self.config)
  942. if position_ids is None:
  943. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  944. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  945. position_ids = position_ids.unsqueeze(0)
  946. causal_mask = create_causal_mask(
  947. config=self.config,
  948. inputs_embeds=inputs_embeds,
  949. attention_mask=attention_mask,
  950. past_key_values=past_key_values,
  951. position_ids=position_ids,
  952. )
  953. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  954. # decoder layers
  955. for idx, decoder_layer in enumerate(self.layers):
  956. # For text-only path we should skip cross attention layers.
  957. # Let's check if the layer is cross attention layer and if we have cross attention states
  958. # or cached cross attention states.
  959. is_cross_attention_layer = idx in self.cross_attention_layers
  960. is_cross_attention_cache_empty = past_key_values is None or (
  961. past_key_values is not None and past_key_values.get_seq_length(idx) == 0
  962. )
  963. if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty:
  964. continue
  965. hidden_states = decoder_layer(
  966. hidden_states,
  967. cross_attention_states=cross_attention_states,
  968. cross_attention_mask=cross_attention_mask,
  969. attention_mask=causal_mask,
  970. full_text_row_masked_out_mask=full_text_row_masked_out_mask,
  971. position_ids=position_ids,
  972. past_key_values=past_key_values,
  973. use_cache=use_cache,
  974. position_embeddings=position_embeddings,
  975. **kwargs,
  976. )
  977. hidden_states = self.norm(hidden_states)
  978. return BaseModelOutputWithPast(
  979. last_hidden_state=hidden_states,
  980. past_key_values=past_key_values,
  981. )
  982. @auto_docstring(
  983. custom_intro="""
  984. The Mllama Text Model with a language modeling head on top.
  985. """
  986. )
  987. class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
  988. config: MllamaTextConfig
  989. _can_compile_fullgraph = True # only the LLM without cross attn can do compile
  990. base_model_prefix = "language_model"
  991. def __init__(self, config):
  992. super().__init__(config.get_text_config())
  993. self.text_config = config.get_text_config()
  994. self.vocab_size = self.text_config.vocab_size
  995. self.model = MllamaTextModel._from_config(self.text_config)
  996. self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
  997. self.post_init()
  998. @can_return_tuple
  999. @auto_docstring
  1000. def forward(
  1001. self,
  1002. input_ids: torch.LongTensor | None = None,
  1003. attention_mask: torch.Tensor | None = None,
  1004. position_ids: torch.LongTensor | None = None,
  1005. cross_attention_states: torch.LongTensor | None = None,
  1006. cross_attention_mask: torch.LongTensor | None = None,
  1007. full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor] | None = None,
  1008. past_key_values: Cache | None = None,
  1009. inputs_embeds: torch.FloatTensor | None = None,
  1010. labels: torch.LongTensor | None = None,
  1011. use_cache: bool | None = None,
  1012. logits_to_keep: int | torch.Tensor = 0,
  1013. **kwargs: Unpack[TransformersKwargs],
  1014. ) -> tuple | CausalLMOutputWithPast:
  1015. r"""
  1016. cross_attention_states (`torch.FloatTensor`, *optional*):
  1017. Output of the vision model, used for cross-attention. This tensor contains the processed image features that
  1018. the language model will attend to.
  1019. cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*):
  1020. Cross-attention mask to control the interaction between text tokens and image tiles.
  1021. This 4D tensor defines which image tiles each text token should attend to.
  1022. For each text token (in seq_length):
  1023. - 1 indicates the token **should attend** to the corresponding image tile
  1024. - 0 indicates the token **should not attend** to the corresponding image tile
  1025. full_text_row_masked_out_mask (`tuple[torch.Tensor, torch.Tensor]`, *optional*):
  1026. A tuple containing two tensors that mask out rows in the cross-attention mechanism:
  1027. - The first tensor has shape `(batch_size, 1, seq_length, 1)` and contains values of 0 or 1.
  1028. A value of 0 indicates that the corresponding text token's entire row in the cross-attention
  1029. matrix should be masked out (all image tokens ignored).
  1030. - The second tensor has the same shape and is used internally to apply the masking during
  1031. the forward pass of cross-attention layers.
  1032. This mask is derived from the cross_attention_mask and is used to handle cases where a text token
  1033. should not attend to any image token.
  1034. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1035. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1036. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1037. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1038. Example:
  1039. ```python
  1040. >>> from transformers import AutoTokenizer, MllamaForCausalLM
  1041. >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
  1042. >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
  1043. >>> prompt = "If I had to write a haiku, it would be:"
  1044. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1045. >>> # Generate
  1046. >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
  1047. >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1048. >>> print(result)
  1049. If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
  1050. I love the idea of snowflakes gently falling, each one
  1051. ```
  1052. """
  1053. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1054. outputs = self.model(
  1055. input_ids=input_ids,
  1056. cross_attention_states=cross_attention_states,
  1057. attention_mask=attention_mask,
  1058. position_ids=position_ids,
  1059. cross_attention_mask=cross_attention_mask,
  1060. full_text_row_masked_out_mask=full_text_row_masked_out_mask,
  1061. past_key_values=past_key_values,
  1062. inputs_embeds=inputs_embeds,
  1063. use_cache=use_cache,
  1064. **kwargs,
  1065. )
  1066. hidden_states = outputs.last_hidden_state
  1067. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1068. logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
  1069. loss = None
  1070. if labels is not None:
  1071. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  1072. return CausalLMOutputWithPast(
  1073. loss=loss,
  1074. logits=logits,
  1075. past_key_values=outputs.past_key_values,
  1076. hidden_states=outputs.hidden_states,
  1077. attentions=outputs.attentions,
  1078. )
  1079. @auto_docstring(
  1080. custom_intro="""
  1081. The Mllama model which consists of a vision encoder and a language model without language modeling head.
  1082. """
  1083. )
  1084. class MllamaModel(MllamaPreTrainedModel):
  1085. def __init__(self, config: MllamaConfig):
  1086. super().__init__(config)
  1087. self.vocab_size = config.text_config.vocab_size
  1088. self.hidden_size = config.text_config.hidden_size
  1089. self.max_num_tiles = config.vision_config.max_num_tiles
  1090. self.vision_output_dim = config.vision_config.vision_output_dim
  1091. self.vision_model = MllamaVisionModel._from_config(config.vision_config)
  1092. self.language_model = MllamaTextModel._from_config(config.text_config)
  1093. self.multi_modal_projector = nn.Linear(
  1094. config.vision_config.vision_output_dim,
  1095. config.text_config.hidden_size,
  1096. bias=True,
  1097. )
  1098. self.post_init()
  1099. def get_input_embeddings(self):
  1100. return self.language_model.get_input_embeddings()
  1101. def set_input_embeddings(self, value):
  1102. self.language_model.set_input_embeddings(value)
  1103. @can_return_tuple
  1104. @auto_docstring
  1105. def forward(
  1106. self,
  1107. input_ids: torch.LongTensor | None = None,
  1108. pixel_values: torch.FloatTensor | None = None,
  1109. aspect_ratio_mask: torch.Tensor | None = None,
  1110. aspect_ratio_ids: torch.Tensor | None = None,
  1111. attention_mask: torch.Tensor | None = None,
  1112. cross_attention_mask: torch.Tensor | None = None,
  1113. cross_attention_states: torch.Tensor | None = None,
  1114. position_ids: torch.LongTensor | None = None,
  1115. past_key_values: Cache | None = None,
  1116. inputs_embeds: torch.FloatTensor | None = None,
  1117. use_cache: bool | None = None,
  1118. **kwargs: Unpack[FlashAttentionKwargs],
  1119. ) -> BaseModelOutputWithPast:
  1120. r"""
  1121. aspect_ratio_mask (`torch.Tensor` of shape `(batch_size, max_num_images, max_num_tiles)`, *optional*):
  1122. Mask to avoid performing attention on padding tiles. Mask values selected in `[0, 1]`:
  1123. - 1 for tiles that are **not masked**,
  1124. - 0 for tiles that are **masked**.
  1125. aspect_ratio_ids (`torch.Tensor` of shape `(batch_size, max_num_images)`, *optional*):
  1126. Aspect ratio ids used to select the appropriate precomputed tile embeddings based on the aspect ratio of each input image.
  1127. These ids correspond to indices in the model's list of supported aspect ratios, offset by 1.
  1128. For example, if the model supports aspect ratios [[1, 1], [1, 2], [2, 1]]:
  1129. - An image with aspect ratio [1, 1] would have ID 1
  1130. - An image with aspect ratio [1, 2] would have ID 2
  1131. - An image with aspect ratio [2, 1] would have ID 3
  1132. The id 0 is reserved for padding (i.e., no image).
  1133. If an image has aspect ratio [1, 2], that means it was split into 2 tiles horizontally, and its `aspect_ratio_id` would be 2.
  1134. cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*):
  1135. Cross-attention mask to control the interaction between text tokens and image tiles.
  1136. This 4D tensor defines which image tiles each text token should attend to.
  1137. For each text token (in seq_length):
  1138. - 1 indicates the token **should attend** to the corresponding image tile
  1139. - 0 indicates the token **should not attend** to the corresponding image tile
  1140. cross_attention_states (`torch.FloatTensor`, *optional*):
  1141. Output of the vision model, used for cross-attention. This tensor contains the processed image features that
  1142. the language model will attend to.
  1143. """
  1144. if (input_ids is None) ^ (inputs_embeds is not None):
  1145. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1146. if pixel_values is not None and cross_attention_states is not None:
  1147. raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")
  1148. if pixel_values is not None:
  1149. if aspect_ratio_ids is None:
  1150. raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
  1151. # get vision tokens from vision model
  1152. vision_outputs = self.vision_model(
  1153. pixel_values=pixel_values,
  1154. aspect_ratio_ids=aspect_ratio_ids,
  1155. aspect_ratio_mask=aspect_ratio_mask,
  1156. )
  1157. cross_attention_states = vision_outputs.last_hidden_state
  1158. cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
  1159. -1, cross_attention_states.shape[-2], self.hidden_size
  1160. )
  1161. if cross_attention_mask is not None:
  1162. cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
  1163. cross_attention_mask,
  1164. num_vision_tokens=self.vision_model.num_patches,
  1165. dtype=self.dtype,
  1166. )
  1167. else:
  1168. full_text_row_masked_out_mask = None
  1169. if cross_attention_mask is not None:
  1170. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1171. seq_len = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1172. device = input_ids.device if input_ids is not None else inputs_embeds.device
  1173. current_pos = torch.arange(seq_len, device=device) + past_seen_tokens
  1174. cross_attention_mask = cross_attention_mask[:, :, current_pos]
  1175. full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, current_pos]
  1176. outputs = self.language_model(
  1177. input_ids=input_ids,
  1178. attention_mask=attention_mask,
  1179. position_ids=position_ids,
  1180. cross_attention_states=cross_attention_states,
  1181. cross_attention_mask=cross_attention_mask,
  1182. full_text_row_masked_out_mask=full_text_row_masked_out_mask,
  1183. past_key_values=past_key_values,
  1184. use_cache=use_cache,
  1185. inputs_embeds=inputs_embeds,
  1186. **kwargs,
  1187. )
  1188. return BaseModelOutputWithPast(
  1189. last_hidden_state=outputs.last_hidden_state,
  1190. past_key_values=outputs.past_key_values,
  1191. hidden_states=outputs.hidden_states,
  1192. attentions=outputs.attentions,
  1193. )
  1194. @auto_docstring(
  1195. custom_intro="""
  1196. The Mllama model which consists of a vision encoder and a language model.
  1197. """,
  1198. )
  1199. class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
  1200. # _tied_weights_keys = {"lm_head.weight": "model.language_moddel.embed_tokens.weight"}
  1201. def __init__(self, config: MllamaConfig):
  1202. super().__init__(config)
  1203. self.model = MllamaModel(config)
  1204. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  1205. self.post_init()
  1206. def get_input_embeddings(self):
  1207. return self.model.get_input_embeddings()
  1208. def set_input_embeddings(self, value):
  1209. self.model.set_input_embeddings(value)
  1210. @can_return_tuple
  1211. @auto_docstring
  1212. def forward(
  1213. self,
  1214. input_ids: torch.LongTensor | None = None,
  1215. pixel_values: torch.FloatTensor | None = None,
  1216. aspect_ratio_mask: torch.Tensor | None = None,
  1217. aspect_ratio_ids: torch.Tensor | None = None,
  1218. attention_mask: torch.Tensor | None = None,
  1219. cross_attention_mask: torch.Tensor | None = None,
  1220. cross_attention_states: torch.Tensor | None = None,
  1221. position_ids: torch.LongTensor | None = None,
  1222. past_key_values: Cache | None = None,
  1223. inputs_embeds: torch.FloatTensor | None = None,
  1224. labels: torch.LongTensor | None = None,
  1225. use_cache: bool | None = None,
  1226. logits_to_keep: int | torch.Tensor = 0,
  1227. **kwargs: Unpack[TransformersKwargs],
  1228. ) -> tuple | CausalLMOutputWithPast:
  1229. r"""
  1230. aspect_ratio_mask (`torch.Tensor` of shape `(batch_size, max_num_images, max_num_tiles)`, *optional*):
  1231. Mask to avoid performing attention on padding tiles. Mask values selected in `[0, 1]`:
  1232. - 1 for tiles that are **not masked**,
  1233. - 0 for tiles that are **masked**.
  1234. aspect_ratio_ids (`torch.Tensor` of shape `(batch_size, max_num_images)`, *optional*):
  1235. Aspect ratio ids used to select the appropriate precomputed tile embeddings based on the aspect ratio of each input image.
  1236. These ids correspond to indices in the model's list of supported aspect ratios, offset by 1.
  1237. For example, if the model supports aspect ratios [[1, 1], [1, 2], [2, 1]]:
  1238. - An image with aspect ratio [1, 1] would have ID 1
  1239. - An image with aspect ratio [1, 2] would have ID 2
  1240. - An image with aspect ratio [2, 1] would have ID 3
  1241. The id 0 is reserved for padding (i.e., no image).
  1242. If an image has aspect ratio [1, 2], that means it was split into 2 tiles horizontally, and its `aspect_ratio_id` would be 2.
  1243. cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*):
  1244. Cross-attention mask to control the interaction between text tokens and image tiles.
  1245. This 4D tensor defines which image tiles each text token should attend to.
  1246. For each text token (in seq_length):
  1247. - 1 indicates the token **should attend** to the corresponding image tile
  1248. - 0 indicates the token **should not attend** to the corresponding image tile
  1249. cross_attention_states (`torch.FloatTensor`, *optional*):
  1250. Output of the vision model, used for cross-attention. This tensor contains the processed image features that
  1251. the language model will attend to.
  1252. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1253. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1254. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1255. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1256. Example:
  1257. ```python
  1258. >>> from PIL import Image
  1259. >>> import httpx
  1260. >>> from io import BytesIO
  1261. >>> from transformers import AutoProcessor, MllamaForConditionalGeneration
  1262. >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
  1263. >>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
  1264. >>> processor = AutoProcessor.from_pretrained(checkpoint)
  1265. >>> prompt = "<|image|>If I had to write a haiku for this one"
  1266. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  1267. >>> with httpx.stream("GET", url) as response:
  1268. ... image = Image.open(BytesIO(response.read()))
  1269. >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
  1270. >>> # Generate
  1271. >>> output = model.generate(**inputs, max_new_tokens=15)
  1272. >>> prompt_len = inputs.input_ids.shape[-1]
  1273. >>> generated_ids = output[:, prompt_len:]
  1274. >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
  1275. >>> print(generated_text)
  1276. [', it would be:.\\nA stop sign in Chinatown.\\n']
  1277. ```
  1278. """
  1279. outputs = self.model(
  1280. input_ids=input_ids,
  1281. pixel_values=pixel_values,
  1282. aspect_ratio_mask=aspect_ratio_mask,
  1283. aspect_ratio_ids=aspect_ratio_ids,
  1284. cross_attention_mask=cross_attention_mask,
  1285. cross_attention_states=cross_attention_states,
  1286. attention_mask=attention_mask,
  1287. position_ids=position_ids,
  1288. past_key_values=past_key_values,
  1289. inputs_embeds=inputs_embeds,
  1290. use_cache=use_cache,
  1291. **kwargs,
  1292. )
  1293. hidden_states = outputs.last_hidden_state
  1294. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1295. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1296. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1297. loss = None
  1298. if labels is not None:
  1299. loss = self.loss_function(logits, labels, self.config.text_config.vocab_size, **kwargs)
  1300. return CausalLMOutputWithPast(
  1301. loss=loss,
  1302. logits=logits,
  1303. past_key_values=outputs.past_key_values,
  1304. hidden_states=outputs.hidden_states,
  1305. attentions=outputs.attentions,
  1306. )
  1307. def prepare_inputs_for_generation(
  1308. self,
  1309. input_ids=None,
  1310. inputs_embeds=None,
  1311. attention_mask=None,
  1312. position_ids=None,
  1313. pixel_values=None,
  1314. aspect_ratio_ids=None,
  1315. aspect_ratio_mask=None,
  1316. cross_attention_mask=None,
  1317. past_key_values=None,
  1318. use_cache=False,
  1319. logits_to_keep=None,
  1320. is_first_iteration=False,
  1321. **kwargs,
  1322. ):
  1323. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  1324. model_inputs = super().prepare_inputs_for_generation(
  1325. input_ids,
  1326. past_key_values=past_key_values,
  1327. use_cache=use_cache,
  1328. inputs_embeds=inputs_embeds,
  1329. position_ids=position_ids,
  1330. attention_mask=attention_mask,
  1331. pixel_values=pixel_values,
  1332. aspect_ratio_ids=aspect_ratio_ids,
  1333. aspect_ratio_mask=aspect_ratio_mask,
  1334. cross_attention_mask=cross_attention_mask,
  1335. logits_to_keep=logits_to_keep,
  1336. is_first_iteration=is_first_iteration,
  1337. **kwargs,
  1338. )
  1339. # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios
  1340. # to compute image hidden states, otherwise they are cached within each cross attn layer
  1341. if not is_first_iteration and use_cache:
  1342. model_inputs["pixel_values"] = None
  1343. model_inputs["aspect_ratio_ids"] = None
  1344. model_inputs["aspect_ratio_mask"] = None
  1345. return model_inputs
  1346. def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
  1347. cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None)
  1348. model_kwargs = super()._update_model_kwargs_for_generation(
  1349. outputs=outputs,
  1350. model_kwargs=model_kwargs,
  1351. is_encoder_decoder=is_encoder_decoder,
  1352. **kwargs,
  1353. )
  1354. # add cross-attn mask for new token
  1355. if cross_attention_mask_prev is not None:
  1356. model_kwargs["cross_attention_mask"] = torch.cat(
  1357. [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1
  1358. )
  1359. return model_kwargs
  1360. __all__ = [
  1361. "MllamaForConditionalGeneration",
  1362. "MllamaForCausalLM",
  1363. "MllamaTextModel",
  1364. "MllamaVisionModel",
  1365. "MllamaPreTrainedModel",
  1366. "MllamaModel",
  1367. ]