modeling_llama4.py 58 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418
  1. # Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. from collections.abc import Callable
  17. from dataclasses import dataclass
  18. from typing import Optional
  19. import torch
  20. import torch.nn as nn
  21. import torch.nn.functional as F
  22. from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig
  23. from ... import initialization as init
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...generation import GenerationMixin
  27. from ...integrations import use_kernel_forward_from_hub
  28. from ...masking_utils import create_causal_mask, create_chunked_causal_mask
  29. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import (
  32. BaseModelOutput,
  33. BaseModelOutputWithPast,
  34. BaseModelOutputWithPooling,
  35. CausalLMOutputWithPast,
  36. ModelOutput,
  37. )
  38. from ...modeling_rope_utils import (
  39. ROPE_INIT_FUNCTIONS,
  40. dynamic_rope_update,
  41. )
  42. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  43. from ...processing_utils import Unpack
  44. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
  45. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  46. from ...utils.output_capturing import capture_outputs
  47. from .configuration_llama4 import Llama4Config, Llama4TextConfig
  48. logger = logging.get_logger(__name__)
  49. class Llama4TextExperts(nn.Module):
  50. def __init__(self, config: Llama4TextConfig):
  51. super().__init__()
  52. self.num_experts = config.num_local_experts
  53. self.intermediate_size = config.intermediate_size
  54. self.hidden_size = config.hidden_size
  55. self.expert_dim = self.intermediate_size
  56. self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
  57. self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
  58. self.act_fn = ACT2FN[config.hidden_act]
  59. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  60. """
  61. This should really not be run on a single machine, as we are reaching compute bound:
  62. - the inputs are expected to be "sorted" per expert already.
  63. - the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape
  64. Args:
  65. hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
  66. selected_experts (torch.Tensor): (batch_size * token_num, top_k)
  67. routing_weights (torch.Tensor): (batch_size * token_num, top_k)
  68. Returns:
  69. torch.Tensor
  70. """
  71. hidden_states = hidden_states.view(self.gate_up_proj.shape[0], -1, self.hidden_size)
  72. gate_up = torch.bmm(hidden_states, self.gate_up_proj)
  73. gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
  74. next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
  75. next_states = next_states.view(-1, self.hidden_size)
  76. return next_states
  77. # Phi3MLP
  78. class Llama4TextMLP(nn.Module):
  79. def __init__(self, config, intermediate_size=None):
  80. super().__init__()
  81. if intermediate_size is None:
  82. intermediate_size = config.intermediate_size
  83. self.config = config
  84. self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
  85. self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
  86. self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
  87. self.activation_fn = ACT2FN[config.hidden_act]
  88. def forward(self, x):
  89. down_proj = self.activation_fn(self.gate_proj(x)) * self.up_proj(x)
  90. return self.down_proj(down_proj)
  91. class Llama4TextL2Norm(torch.nn.Module):
  92. def __init__(self, eps: float = 1e-6):
  93. super().__init__()
  94. self.eps = eps
  95. def _norm(self, x):
  96. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  97. def forward(self, x):
  98. return self._norm(x.float()).type_as(x)
  99. def extra_repr(self):
  100. return f"eps={self.eps}"
  101. class Llama4TextRMSNorm(nn.Module):
  102. def __init__(self, hidden_size, eps=1e-5):
  103. """
  104. Llama4RMSNorm is equivalent to T5LayerNorm
  105. """
  106. super().__init__()
  107. self.eps = eps
  108. self.weight = nn.Parameter(torch.ones(hidden_size))
  109. def _norm(self, x):
  110. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  111. def forward(self, x):
  112. output = self._norm(x.float()).type_as(x)
  113. return output * self.weight
  114. def extra_repr(self):
  115. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  116. class Llama4Router(nn.Linear):
  117. def __init__(self, config):
  118. super().__init__(config.hidden_size, config.num_local_experts, bias=False)
  119. self.num_experts = config.num_local_experts
  120. self.top_k = config.num_experts_per_tok
  121. def forward(self, hidden_states):
  122. router_logits = super().forward(hidden_states)
  123. router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
  124. router_scores = torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value)
  125. router_scores = torch.nn.functional.sigmoid(router_scores.float()).to(router_scores.dtype)
  126. return router_scores, router_logits
  127. @use_kernel_forward_from_hub("Llama4TextMoe")
  128. class Llama4TextMoe(nn.Module):
  129. def __init__(self, config):
  130. super().__init__()
  131. self.top_k = config.num_experts_per_tok
  132. self.hidden_dim = config.hidden_size
  133. self.num_experts = config.num_local_experts
  134. self.experts = Llama4TextExperts(config)
  135. self.router = Llama4Router(config)
  136. self.shared_expert = Llama4TextMLP(config)
  137. def forward(self, hidden_states):
  138. hidden_states = hidden_states.reshape(-1, self.hidden_dim)
  139. router_scores, router_logits = self.router(hidden_states)
  140. routed_in = hidden_states.repeat(router_scores.shape[1], 1)
  141. routed_in = routed_in * router_scores.transpose(0, 1).reshape(-1, 1)
  142. routed_out = self.experts(routed_in)
  143. out = self.shared_expert(hidden_states)
  144. out.add_(routed_out.reshape(router_scores.shape[1], -1, routed_out.shape[-1]).sum(dim=0))
  145. return out, router_logits
  146. # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Llama4Text
  147. class Llama4TextRotaryEmbedding(nn.Module):
  148. inv_freq: torch.Tensor # fix linting for `register_buffer`
  149. # Ignore copy
  150. def __init__(self, config: Llama4TextConfig, device=None):
  151. super().__init__()
  152. self.max_seq_len_cached = config.max_position_embeddings
  153. self.original_max_seq_len = config.max_position_embeddings
  154. self.config = config
  155. self.rope_type = self.config.rope_parameters["rope_type"]
  156. rope_init_fn: Callable = self.compute_default_rope_parameters
  157. if self.rope_type != "default":
  158. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  159. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  160. self.register_buffer("inv_freq", inv_freq, persistent=False)
  161. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  162. @staticmethod
  163. def compute_default_rope_parameters(
  164. config: Llama4TextConfig | None = None,
  165. device: Optional["torch.device"] = None,
  166. seq_len: int | None = None,
  167. ) -> tuple["torch.Tensor", float]:
  168. """
  169. Computes the inverse frequencies according to the original RoPE implementation
  170. Args:
  171. config ([`~transformers.PreTrainedConfig`]):
  172. The model configuration.
  173. device (`torch.device`):
  174. The device to use for initialization of the inverse frequencies.
  175. seq_len (`int`, *optional*):
  176. The current sequence length. Unused for this type of RoPE.
  177. Returns:
  178. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  179. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  180. """
  181. base = config.rope_parameters["rope_theta"]
  182. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  183. attention_factor = 1.0 # Unused in this type of RoPE
  184. # Compute the inverse frequencies
  185. inv_freq = 1.0 / (
  186. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  187. )
  188. return inv_freq, attention_factor
  189. # Ignore copy
  190. @torch.no_grad()
  191. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  192. def forward(self, x, position_ids):
  193. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  194. position_ids_expanded = position_ids[:, None, :].float()
  195. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  196. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  197. freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
  198. freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation
  199. freqs_cis = freqs_cis * self.attention_scaling
  200. return freqs_cis
  201. def apply_rotary_emb(
  202. xq: torch.Tensor,
  203. xk: torch.Tensor,
  204. freqs_cis: torch.Tensor,
  205. ) -> tuple[torch.Tensor, torch.Tensor]:
  206. xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
  207. xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
  208. xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
  209. xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
  210. return xq_out.type_as(xq), xk_out.type_as(xk)
  211. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  212. """
  213. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  214. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  215. """
  216. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  217. if n_rep == 1:
  218. return hidden_states
  219. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  220. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  221. # Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
  222. def eager_attention_forward(
  223. module: nn.Module,
  224. query: torch.Tensor,
  225. key: torch.Tensor,
  226. value: torch.Tensor,
  227. attention_mask: torch.Tensor | None,
  228. scaling: float,
  229. dropout: float = 0.0,
  230. **kwargs,
  231. ):
  232. key_states = repeat_kv(key, module.num_key_value_groups)
  233. value_states = repeat_kv(value, module.num_key_value_groups)
  234. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  235. if attention_mask is not None:
  236. attn_weights = attn_weights + attention_mask
  237. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  238. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  239. attn_output = torch.matmul(attn_weights, value_states)
  240. attn_output = attn_output.transpose(1, 2).contiguous()
  241. return attn_output, attn_weights
  242. # Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
  243. def vision_eager_attention_forward(
  244. module: nn.Module,
  245. query: torch.Tensor,
  246. key: torch.Tensor,
  247. value: torch.Tensor,
  248. attention_mask: torch.Tensor | None,
  249. scaling: float,
  250. dropout: float = 0.0,
  251. **kwargs,
  252. ):
  253. key_states = repeat_kv(key, module.num_key_value_groups)
  254. value_states = repeat_kv(value, module.num_key_value_groups)
  255. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.head_dim**-0.5
  256. if attention_mask is not None:
  257. attn_weights = attn_weights + attention_mask
  258. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  259. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  260. attn_output = torch.matmul(attn_weights, value_states)
  261. attn_output = attn_output.transpose(1, 2).contiguous()
  262. return attn_output, attn_weights
  263. class Llama4TextAttention(nn.Module):
  264. """Multi-headed attention from 'Attention Is All You Need' paper"""
  265. def __init__(self, config: Llama4TextConfig, layer_idx):
  266. super().__init__()
  267. self.config = config
  268. self.layer_idx = layer_idx
  269. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  270. self.num_attention_heads = config.num_attention_heads
  271. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  272. self.num_key_value_heads = config.num_key_value_heads
  273. self.scaling = self.head_dim**-0.5
  274. self.attn_scale = config.attn_scale
  275. self.floor_scale = config.floor_scale
  276. self.attn_temperature_tuning = config.attn_temperature_tuning
  277. self.attention_dropout = config.attention_dropout
  278. self.is_causal = True
  279. self.use_rope = config.no_rope_layers[layer_idx]
  280. self.q_proj = nn.Linear(
  281. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  282. )
  283. self.k_proj = nn.Linear(
  284. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  285. )
  286. self.v_proj = nn.Linear(
  287. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  288. )
  289. self.o_proj = nn.Linear(
  290. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  291. )
  292. if self.config.use_qk_norm and self.use_rope:
  293. self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps)
  294. def forward(
  295. self,
  296. hidden_states: torch.Tensor,
  297. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  298. attention_mask: torch.Tensor | None,
  299. past_key_values: Cache | None = None,
  300. **kwargs: Unpack[FlashAttentionKwargs],
  301. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  302. input_shape = hidden_states.shape[:-1]
  303. hidden_shape = (*input_shape, -1, self.head_dim)
  304. query_states = self.q_proj(hidden_states).view(hidden_shape)
  305. key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim)
  306. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  307. if self.use_rope: # the 16E model skips rope for long context on certain layers
  308. query_states, key_states = apply_rotary_emb(
  309. query_states, key_states, position_embeddings.to(query_states.device)
  310. )
  311. if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm
  312. query_states = self.qk_norm(query_states)
  313. key_states = self.qk_norm(key_states)
  314. # Use temperature tuning from https://huggingface.co/papers/2501.19399) to NoROPE layers
  315. if self.attn_temperature_tuning and not self.use_rope:
  316. past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0
  317. positions = torch.arange(hidden_states.shape[1], device=hidden_states.device) + past_seen_tokens
  318. attn_scales = (
  319. torch.log1p(torch.floor((positions.float() + 1.0) / self.floor_scale)) * self.attn_scale + 1.0
  320. )
  321. attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand((*input_shape, 1, 1)) # batch size > 1
  322. query_states = (query_states * attn_scales).to(query_states.dtype)
  323. query_states = query_states.transpose(1, 2)
  324. key_states = key_states.transpose(1, 2)
  325. if past_key_values is not None:
  326. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  327. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  328. self.config._attn_implementation, eager_attention_forward
  329. )
  330. attn_output, attn_weights = attention_interface(
  331. self,
  332. query_states,
  333. key_states,
  334. value_states,
  335. attention_mask,
  336. dropout=0.0 if not self.training else self.attention_dropout,
  337. scaling=self.scaling,
  338. **kwargs,
  339. )
  340. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  341. attn_output = self.o_proj(attn_output)
  342. return attn_output, attn_weights
  343. class Llama4TextDecoderLayer(GradientCheckpointingLayer):
  344. def __init__(self, config, layer_idx):
  345. super().__init__()
  346. self.hidden_size = config.hidden_size
  347. self.layer_idx = layer_idx
  348. self.self_attn = Llama4TextAttention(config, layer_idx)
  349. self.is_moe_layer = layer_idx in config.moe_layers
  350. if self.is_moe_layer: # the 128E model interleaves dense / sparse
  351. self.feed_forward = Llama4TextMoe(config)
  352. else:
  353. self.feed_forward = Llama4TextMLP(config, intermediate_size=config.intermediate_size_mlp)
  354. self.input_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  355. self.post_attention_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  356. def forward(
  357. self,
  358. hidden_states: torch.Tensor,
  359. attention_mask: torch.Tensor | None = None,
  360. position_ids: torch.LongTensor | None = None,
  361. past_key_values: Cache | None = None,
  362. use_cache: bool | None = False,
  363. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  364. **kwargs: Unpack[FlashAttentionKwargs],
  365. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  366. residual = hidden_states
  367. hidden_states = self.input_layernorm(hidden_states)
  368. # Self Attention
  369. attention_states, _ = self.self_attn(
  370. hidden_states=hidden_states,
  371. position_embeddings=position_embeddings,
  372. attention_mask=attention_mask,
  373. past_key_values=past_key_values,
  374. use_cache=use_cache,
  375. **kwargs,
  376. )
  377. hidden_states = residual + attention_states
  378. # Fully Connected
  379. residual = hidden_states
  380. hidden_states = self.post_attention_layernorm(hidden_states)
  381. hidden_states = self.feed_forward(hidden_states)
  382. if self.is_moe_layer:
  383. hidden_states, _ = hidden_states
  384. hidden_states = residual + hidden_states.view(residual.shape)
  385. return hidden_states
  386. @auto_docstring
  387. class Llama4PreTrainedModel(PreTrainedModel):
  388. config: Llama4Config
  389. input_modalities = ("image", "text")
  390. supports_gradient_checkpointing = True
  391. _skip_keys_device_placement = ["past_key_values"]
  392. _supports_flash_attn = False
  393. _supports_sdpa = True
  394. _supports_flex_attn = True
  395. _can_compile_fullgraph = True
  396. _supports_attention_backend = True
  397. @torch.no_grad()
  398. def _init_weights(self, module):
  399. super()._init_weights(module)
  400. std = (
  401. self.config.initializer_range
  402. if hasattr(self.config, "initializer_range")
  403. else self.config.text_config.initializer_range
  404. )
  405. if isinstance(module, Llama4TextExperts):
  406. init.normal_(module.gate_up_proj, mean=0.0, std=std)
  407. init.normal_(module.down_proj, mean=0.0, std=std)
  408. elif isinstance(module, Llama4VisionRotaryEmbedding):
  409. init.copy_(module.freqs_ci, module._compute_freqs_ci(module.config))
  410. elif isinstance(module, Llama4VisionModel):
  411. init.normal_(module.class_embedding, std=module.scale)
  412. init.normal_(module.positional_embedding_vlm, std=module.scale)
  413. @auto_docstring
  414. class Llama4TextModel(Llama4PreTrainedModel):
  415. _no_split_modules = ["Llama4TextDecoderLayer"]
  416. base_model_prefix = "model"
  417. input_modalities = ("text",)
  418. config: Llama4TextConfig
  419. _can_record_outputs = {
  420. "attentions": Llama4TextAttention,
  421. "hidden_states": Llama4TextDecoderLayer,
  422. "router_logits": Llama4TextMoe,
  423. }
  424. def __init__(self, config: Llama4TextConfig):
  425. super().__init__(config)
  426. self.padding_idx = config.pad_token_id
  427. self.vocab_size = config.vocab_size
  428. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  429. self.layers = nn.ModuleList(
  430. [Llama4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  431. )
  432. self.norm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  433. self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
  434. self.gradient_checkpointing = False
  435. # Initialize weights and apply final processing
  436. self.post_init()
  437. @can_return_tuple
  438. @merge_with_config_defaults
  439. @capture_outputs
  440. @auto_docstring
  441. def forward(
  442. self,
  443. input_ids: torch.LongTensor | None = None,
  444. attention_mask: torch.Tensor | None = None,
  445. position_ids: torch.LongTensor | None = None,
  446. past_key_values: Cache | None = None,
  447. inputs_embeds: torch.FloatTensor | None = None,
  448. use_cache: bool | None = None,
  449. **kwargs: Unpack[TransformersKwargs],
  450. ) -> tuple | BaseModelOutputWithPast:
  451. if (input_ids is None) ^ (inputs_embeds is not None):
  452. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  453. if inputs_embeds is None:
  454. inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
  455. if use_cache and past_key_values is None:
  456. past_key_values = DynamicCache(config=self.config)
  457. if position_ids is None:
  458. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  459. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  460. position_ids = position_ids.unsqueeze(0)
  461. # It may already have been prepared by e.g. `generate`
  462. if not isinstance(causal_mask_mapping := attention_mask, dict):
  463. # Prepare mask arguments
  464. mask_kwargs = {
  465. "config": self.config,
  466. "inputs_embeds": inputs_embeds,
  467. "attention_mask": attention_mask,
  468. "past_key_values": past_key_values,
  469. "position_ids": position_ids,
  470. }
  471. # Create the masks
  472. causal_mask_mapping = {
  473. "full_attention": create_causal_mask(**mask_kwargs),
  474. "chunked_attention": create_chunked_causal_mask(**mask_kwargs),
  475. }
  476. hidden_states = inputs_embeds
  477. # create position embeddings to be shared across the decoder layers
  478. freq_cis = self.rotary_emb(hidden_states, position_ids)
  479. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  480. hidden_states = decoder_layer(
  481. hidden_states,
  482. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  483. position_ids=position_ids,
  484. past_key_values=past_key_values,
  485. use_cache=use_cache,
  486. position_embeddings=freq_cis,
  487. **kwargs,
  488. )
  489. hidden_states = self.norm(hidden_states)
  490. return BaseModelOutputWithPast(
  491. last_hidden_state=hidden_states,
  492. past_key_values=past_key_values if use_cache else None,
  493. )
  494. class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
  495. _no_split_modules = ["Llama4TextDecoderLayer"]
  496. base_model_prefix = "language_model"
  497. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  498. _tp_plan = {"lm_head": "colwise_gather_output"}
  499. config: Llama4TextConfig
  500. def __init__(self, config: Llama4TextConfig):
  501. super().__init__(config)
  502. self.model = Llama4TextModel(config)
  503. self.vocab_size = config.vocab_size
  504. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  505. # Initialize weights and apply final processing
  506. self.post_init()
  507. @can_return_tuple
  508. @auto_docstring
  509. def forward(
  510. self,
  511. input_ids: torch.LongTensor | None = None,
  512. attention_mask: torch.Tensor | None = None,
  513. position_ids: torch.LongTensor | None = None,
  514. past_key_values: Cache | None = None,
  515. inputs_embeds: torch.FloatTensor | None = None,
  516. labels: torch.LongTensor | None = None,
  517. use_cache: bool | None = None,
  518. logits_to_keep: int | torch.Tensor = 0,
  519. **kwargs: Unpack[TransformersKwargs],
  520. ) -> tuple | CausalLMOutputWithPast:
  521. r"""
  522. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  523. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  524. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  525. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  526. Example:
  527. ```python
  528. >>> from transformers import AutoTokenizer, Llama4ForCausalLM
  529. >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
  530. >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
  531. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  532. >>> inputs = tokenizer(prompt, return_tensors="pt")
  533. >>> # Generate
  534. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  535. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  536. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  537. ```"""
  538. outputs = self.model(
  539. input_ids=input_ids,
  540. attention_mask=attention_mask,
  541. position_ids=position_ids,
  542. past_key_values=past_key_values,
  543. inputs_embeds=inputs_embeds,
  544. use_cache=use_cache,
  545. **kwargs,
  546. )
  547. hidden_states = outputs[0]
  548. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  549. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  550. logits = self.lm_head(hidden_states[:, slice_indices, :])
  551. loss = None
  552. if labels is not None:
  553. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  554. return CausalLMOutputWithPast(
  555. loss=loss,
  556. logits=logits,
  557. past_key_values=outputs.past_key_values,
  558. hidden_states=outputs.hidden_states,
  559. attentions=outputs.attentions,
  560. )
  561. @dataclass
  562. @auto_docstring(
  563. custom_intro="""
  564. Base class for Llava causal language model (or autoregressive) outputs.
  565. """
  566. )
  567. class Llama4CausalLMOutputWithPast(ModelOutput):
  568. r"""
  569. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  570. Language modeling loss (for next-token prediction).
  571. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  572. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  573. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  574. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  575. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  576. `past_key_values` input) to speed up sequential decoding.
  577. image_hidden_states (`torch.FloatTensor`, *optional*):
  578. A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
  579. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  580. """
  581. loss: torch.FloatTensor | None = None
  582. logits: torch.FloatTensor | None = None
  583. past_key_values: Cache | None = None
  584. hidden_states: tuple[torch.FloatTensor] | None = None
  585. attentions: tuple[torch.FloatTensor] | None = None
  586. image_hidden_states: torch.FloatTensor | None = None
  587. class Llama4VisionMLP2(torch.nn.Module):
  588. def __init__(self, config):
  589. super().__init__()
  590. self.hidden_size = config.hidden_size
  591. self.intermediate_size = config.intermediate_size
  592. self.fc1 = nn.Linear(self.intermediate_size, config.projector_input_dim, bias=False)
  593. self.fc2 = nn.Linear(config.projector_output_dim, config.projector_output_dim, bias=False)
  594. self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act]
  595. self.dropout = config.projector_dropout
  596. def forward(self, hidden_states):
  597. hidden_states = self.fc1(hidden_states)
  598. hidden_states = self.activation_fn(hidden_states)
  599. hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
  600. return self.activation_fn(self.fc2(hidden_states))
  601. class Llama4MultiModalProjector(nn.Module):
  602. def __init__(self, config):
  603. super().__init__()
  604. self.linear_1 = nn.Linear(
  605. config.vision_config.vision_output_dim,
  606. config.text_config.hidden_size,
  607. bias=False,
  608. )
  609. def forward(self, image_features):
  610. hidden_states = self.linear_1(image_features)
  611. return hidden_states
  612. def pixel_shuffle(input_tensor, shuffle_ratio):
  613. # input_tensor: [batch_size, num_patches, channels]
  614. batch_size, num_patches, channels = input_tensor.shape
  615. patch_size = int(math.sqrt(num_patches))
  616. input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
  617. batch_size, height, width, channels = input_tensor.size()
  618. reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio))
  619. reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
  620. reshaped_tensor = reshaped_tensor.view(
  621. batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2))
  622. )
  623. reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
  624. output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
  625. return output_tensor
  626. class Llama4VisionPixelShuffleMLP(nn.Module):
  627. def __init__(self, config):
  628. super().__init__()
  629. self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
  630. self.inner_dim = int(config.projector_input_dim // (self.pixel_shuffle_ratio**2))
  631. self.output_dim = config.projector_output_dim
  632. self.mlp = Llama4VisionMLP2(config)
  633. def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
  634. encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
  635. return self.mlp(encoded_patches)
  636. # TODO there is a different RoPE for vision encoder, defined as below
  637. def reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor):
  638. ndim = query.ndim
  639. shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)]
  640. return freqs_ci.view(*shape)
  641. def vision_apply_rotary_emb(
  642. query: torch.Tensor,
  643. key: torch.Tensor,
  644. freqs_ci: torch.Tensor,
  645. ) -> tuple[torch.Tensor, torch.Tensor]:
  646. query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
  647. key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
  648. freqs_ci = reshape_for_broadcast(freqs_ci=freqs_ci, query=query_) # freqs_ci[:,:,None,:]
  649. freqs_ci = freqs_ci.to(query_.device)
  650. query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
  651. key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
  652. return query_out.type_as(query), key_out.type_as(key) # but this drops to 8e-3
  653. class Llama4VisionAttention(nn.Module):
  654. def __init__(self, config: Llama4VisionConfig):
  655. super().__init__()
  656. self.config = config
  657. self.embed_dim = config.hidden_size
  658. self.num_heads = config.num_attention_heads
  659. self.head_dim = config.hidden_size // config.num_attention_heads
  660. self.num_key_value_groups = 1
  661. self.attention_dropout = config.attention_dropout
  662. self.scaling = self.head_dim**-0.5
  663. self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True)
  664. self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True)
  665. self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True)
  666. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=True)
  667. def forward(
  668. self,
  669. hidden_states: torch.Tensor,
  670. freqs_ci: torch.Tensor,
  671. attention_mask: torch.Tensor | None = None,
  672. past_key_values: Cache | None = None,
  673. **kwargs: Unpack[FlashAttentionKwargs],
  674. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  675. input_shape = hidden_states.shape[:-1]
  676. hidden_shape = (*input_shape, -1, self.head_dim)
  677. query_states = self.q_proj(hidden_states).view(hidden_shape)
  678. key_states = self.k_proj(hidden_states).view(hidden_shape)
  679. value_states = self.v_proj(hidden_states).view(hidden_shape)
  680. query_states, key_states = vision_apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci)
  681. query_states = query_states.transpose(1, 2)
  682. key_states = key_states.transpose(1, 2)
  683. value_states = value_states.transpose(1, 2)
  684. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  685. self.config._attn_implementation, vision_eager_attention_forward
  686. )
  687. attn_output, attn_weights = attention_interface(
  688. self,
  689. query_states,
  690. key_states,
  691. value_states,
  692. None,
  693. dropout=0.0 if not self.training else self.attention_dropout,
  694. scaling=None, # TODO Might be enforced here for TP compatibility as scaling is not just sqrt(head_dim)
  695. is_causal=False, # HAS TO BE ENFORCED
  696. **kwargs,
  697. )
  698. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  699. attn_output = self.o_proj(attn_output)
  700. return attn_output, attn_weights
  701. class Llama4VisionMLP(nn.Module):
  702. def __init__(self, config):
  703. super().__init__()
  704. self.config = config
  705. self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act]
  706. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
  707. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
  708. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  709. hidden_states = self.fc1(hidden_states)
  710. hidden_states = self.activation_fn(hidden_states)
  711. hidden_states = self.fc2(hidden_states)
  712. return hidden_states
  713. class Llama4VisionEncoderLayer(GradientCheckpointingLayer):
  714. def __init__(self, config: Llama4VisionConfig):
  715. super().__init__()
  716. self.hidden_size = config.hidden_size
  717. self.self_attn = Llama4VisionAttention(config)
  718. self.mlp = Llama4VisionMLP(config)
  719. self.input_layernorm = nn.LayerNorm(config.hidden_size)
  720. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
  721. def forward(
  722. self,
  723. hidden_state: torch.Tensor,
  724. freqs_ci: torch.Tensor,
  725. attention_mask: torch.Tensor | None = None,
  726. output_attentions: bool | None = None,
  727. ):
  728. # Self Attention
  729. residual = hidden_state
  730. hidden_state = self.input_layernorm(hidden_state)
  731. hidden_state, attn_weights = self.self_attn(
  732. hidden_state,
  733. freqs_ci=freqs_ci,
  734. attention_mask=attention_mask,
  735. )
  736. hidden_state = residual + hidden_state
  737. # Feed forward
  738. residual = hidden_state
  739. hidden_state = self.post_attention_layernorm(hidden_state)
  740. hidden_state = self.mlp(hidden_state)
  741. hidden_state = residual + hidden_state
  742. outputs = (hidden_state,)
  743. if output_attentions:
  744. outputs += (attn_weights,)
  745. return outputs
  746. class Llama4VisionEncoder(nn.Module):
  747. """
  748. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  749. [`Llama4VisionEncoderLayer`].
  750. Args:
  751. config: Llama4VisionConfig
  752. """
  753. def __init__(self, config: Llama4VisionConfig):
  754. super().__init__()
  755. self.config = config
  756. self.layers = nn.ModuleList([Llama4VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  757. self.gradient_checkpointing = False
  758. self.config = config
  759. def forward(
  760. self,
  761. hidden_states: torch.Tensor,
  762. freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around
  763. attention_mask: torch.Tensor | None = None,
  764. output_attentions: bool | None = None,
  765. output_hidden_states: bool | None = None,
  766. return_dict: bool | None = None,
  767. ) -> tuple | BaseModelOutput:
  768. r"""
  769. Args:
  770. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  771. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  772. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  773. than the model's internal embedding lookup matrix.
  774. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  775. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  776. - 1 for tokens that are **not masked**,
  777. - 0 for tokens that are **masked**.
  778. [What are attention masks?](../glossary#attention-mask)
  779. output_attentions (`bool`, *optional*):
  780. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  781. returned tensors for more detail.
  782. output_hidden_states (`bool`, *optional*):
  783. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  784. for more detail.
  785. return_dict (`bool`, *optional*):
  786. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  787. """
  788. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  789. output_hidden_states = (
  790. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  791. )
  792. return_dict = return_dict if return_dict is not None else self.config.return_dict
  793. encoder_states = () if output_hidden_states else None
  794. all_attentions = () if output_attentions else None
  795. for encoder_layer in self.layers:
  796. if output_hidden_states:
  797. encoder_states = encoder_states + (hidden_states,)
  798. layer_outputs = encoder_layer(
  799. hidden_state=hidden_states,
  800. attention_mask=attention_mask,
  801. output_attentions=output_attentions,
  802. freqs_ci=freqs_ci,
  803. )
  804. if output_attentions:
  805. all_attentions = all_attentions + (layer_outputs[1],)
  806. hidden_states = layer_outputs[0]
  807. if output_hidden_states:
  808. encoder_states = encoder_states + (hidden_states,)
  809. if not return_dict:
  810. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  811. return BaseModelOutput(
  812. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  813. )
  814. class Llama4UnfoldConvolution(nn.Module):
  815. def __init__(self, config):
  816. super().__init__()
  817. kernel_size = config.patch_size
  818. if isinstance(kernel_size, int):
  819. kernel_size = (kernel_size, kernel_size)
  820. self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
  821. self.linear = nn.Linear(
  822. config.num_channels * kernel_size[0] * kernel_size[1],
  823. config.hidden_size,
  824. bias=False,
  825. )
  826. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  827. hidden_states = self.unfold(hidden_states)
  828. hidden_states = hidden_states.permute(0, 2, 1)
  829. hidden_states = self.linear(hidden_states)
  830. return hidden_states
  831. class Llama4VisionRotaryEmbedding(nn.Module):
  832. def __init__(self, config: Llama4VisionConfig):
  833. super().__init__()
  834. self.config = config
  835. self.register_buffer("freqs_ci", self._compute_freqs_ci(config), persistent=False)
  836. @staticmethod
  837. def _compute_freqs_ci(config):
  838. idx = config.image_size // config.patch_size
  839. img_idx = torch.arange(idx**2, dtype=torch.int32).reshape(idx**2, 1)
  840. img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
  841. img_idx[-1, -1] = -2 # ID_CLS_TOKEN
  842. frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x
  843. frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y
  844. freq_dim = config.hidden_size // config.num_attention_heads // 2
  845. rope_freq = 1.0 / (
  846. config.rope_parameters["rope_theta"]
  847. ** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim)
  848. )
  849. freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1)
  850. freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1)
  851. freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
  852. freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
  853. freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
  854. return freq_cis # idx**2, idx**2, idx * 2
  855. def forward(self, hidden_states):
  856. return self.freqs_ci.to(hidden_states.device)
  857. class Llama4VisionModel(Llama4PreTrainedModel):
  858. base_model_prefix = "vision_model"
  859. input_modalities = ("image",)
  860. _no_split_modules = ["Llama4VisionEncoderLayer"]
  861. config: Llama4VisionConfig
  862. def __init__(self, config: Llama4VisionConfig):
  863. super().__init__(config)
  864. self.image_size = config.image_size
  865. self.patch_size = config.patch_size
  866. self.hidden_size = config.hidden_size
  867. self.num_channels = config.num_channels
  868. self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
  869. self.scale = config.hidden_size**-0.5
  870. self.patch_embedding = Llama4UnfoldConvolution(config)
  871. self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
  872. self.positional_embedding_vlm = nn.Parameter(self.scale * torch.randn(self.num_patches, self.hidden_size))
  873. self.rotary_embedding = Llama4VisionRotaryEmbedding(config)
  874. # layer norms
  875. self.layernorm_pre = nn.LayerNorm(self.hidden_size)
  876. self.layernorm_post = nn.LayerNorm(self.hidden_size)
  877. # encoders
  878. self.model = Llama4VisionEncoder(config)
  879. self.vision_adapter = Llama4VisionPixelShuffleMLP(config)
  880. self.post_init()
  881. def get_input_embeddings(self):
  882. """
  883. This function is used to fetch the first embedding layer to activate grads on inputs.
  884. """
  885. return self.patch_embedding
  886. def forward(
  887. self,
  888. pixel_values: torch.Tensor,
  889. attention_mask: torch.Tensor | None = None,
  890. output_attentions: bool | None = None,
  891. output_hidden_states: bool | None = None,
  892. return_dict: bool | None = None,
  893. **kwargs,
  894. ) -> BaseModelOutputWithPooling | tuple[torch.Tensor, ...]:
  895. r"""
  896. Example:
  897. ```python
  898. >>> from PIL import Image
  899. >>> import httpx
  900. >>> from io import BytesIO
  901. >>> from transformers import AutoProcessor, MllamaVisionModel
  902. >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
  903. >>> model = MllamaVisionModel.from_pretrained(checkpoint)
  904. >>> processor = AutoProcessor.from_pretrained(checkpoint)
  905. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  906. >>> with httpx.stream("GET", url) as response:
  907. ... image = Image.open(BytesIO(response.read()))
  908. >>> inputs = processor(images=image, return_tensors="pt")
  909. >>> output = model(**inputs)
  910. >>> print(output.last_hidden_state.shape)
  911. torch.Size([1, 1, 4, 1025, 7680])
  912. ```
  913. """
  914. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  915. output_hidden_states = (
  916. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  917. )
  918. return_dict = return_dict if return_dict is not None else self.config.return_dict
  919. # num_concurrent_media and num_chunks are both currently 1
  920. batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape
  921. num_concurrent_media = 1
  922. num_chunks = 1
  923. hidden_state = self.patch_embedding(pixel_values)
  924. _, num_patches, hidden_dim = hidden_state.shape
  925. # Add cls token
  926. hidden_state = hidden_state.reshape(
  927. batch_size_times_num_tiles * num_concurrent_media * num_chunks, num_patches, hidden_dim
  928. )
  929. class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, hidden_state.shape[-1])
  930. hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
  931. num_patches += 1
  932. # Position embeddings
  933. hidden_state = hidden_state.reshape(
  934. batch_size_times_num_tiles * num_concurrent_media, num_chunks, num_patches, hidden_dim
  935. )
  936. positional_embedding = self.positional_embedding_vlm.to(dtype=hidden_state.dtype, device=hidden_state.device)
  937. hidden_state = hidden_state + positional_embedding
  938. hidden_state = self.layernorm_pre(hidden_state)
  939. hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim)
  940. freqs_ci = self.rotary_embedding(pixel_values)
  941. output = self.model(
  942. hidden_state,
  943. attention_mask=None,
  944. output_hidden_states=output_hidden_states,
  945. output_attentions=output_attentions,
  946. freqs_ci=freqs_ci,
  947. )
  948. hidden_state = output.last_hidden_state
  949. hidden_state = self.layernorm_post(hidden_state)
  950. hidden_state = hidden_state[:, :-1, :]
  951. # now, we use Llama4VisionPixelShuffle + mlp to project embeddings
  952. hidden_state = self.vision_adapter(hidden_state)
  953. hidden_states = output.hidden_states if output_hidden_states else None
  954. if output_attentions:
  955. attentions = output[2]
  956. else:
  957. attentions = None
  958. if not return_dict:
  959. return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None)
  960. return BaseModelOutputWithPooling(
  961. last_hidden_state=hidden_state,
  962. hidden_states=hidden_states,
  963. attentions=attentions,
  964. )
  965. class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
  966. _no_split_modules = ["Llama4TextDecoderLayer", "Llama4VisionEncoderLayer"]
  967. _tp_plan = {}
  968. base_model_prefix = "model"
  969. config: Llama4Config
  970. def __init__(self, config: Llama4Config):
  971. super().__init__(config)
  972. self.vision_model = Llama4VisionModel(config.vision_config)
  973. self.multi_modal_projector = Llama4MultiModalProjector(config)
  974. self.language_model = Llama4ForCausalLM(config.text_config)
  975. self.vocab_size = config.text_config.vocab_size
  976. if hasattr(self.config, "pad_token_id"):
  977. self.pad_token_id = self.config.pad_token_id
  978. else:
  979. self.pad_token_id = self.config.text_config.pad_token_id or -1
  980. self.post_init()
  981. def get_input_embeddings(self):
  982. return self.language_model.get_input_embeddings()
  983. def set_input_embeddings(self, value):
  984. self.language_model.set_input_embeddings(value)
  985. def get_output_embeddings(self):
  986. return self.language_model.get_output_embeddings()
  987. def set_output_embeddings(self, new_embeddings):
  988. self.language_model.set_output_embeddings(new_embeddings)
  989. def set_decoder(self, decoder):
  990. self.language_model.set_decoder(decoder)
  991. def get_decoder(self):
  992. return self.language_model.get_decoder()
  993. @merge_with_config_defaults
  994. @capture_outputs(tie_last_hidden_states=False)
  995. @auto_docstring(custom_intro="Obtains image last hidden states from the vision tower and apply al projection.")
  996. def get_image_features(
  997. self,
  998. pixel_values: torch.FloatTensor,
  999. vision_feature_select_strategy: str,
  1000. **kwargs: Unpack[TransformersKwargs],
  1001. ) -> tuple | BaseModelOutputWithPooling:
  1002. r"""
  1003. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
  1004. The tensors corresponding to the input images.
  1005. vision_feature_select_strategy (`str`):
  1006. The feature selection strategy used to select the vision feature from the vision backbone.
  1007. Can be one of `"default"` or `"full"`
  1008. """
  1009. kwargs = {k: v for k, v in kwargs.items() if v is not None}
  1010. return self.vision_model(pixel_values, **kwargs)
  1011. def get_placeholder_mask(
  1012. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  1013. ):
  1014. """
  1015. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  1016. equal to the length of multimodal features. If the lengths are different, an error is raised.
  1017. """
  1018. if input_ids is None:
  1019. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  1020. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  1021. )
  1022. special_image_mask = special_image_mask.all(-1)
  1023. else:
  1024. special_image_mask = input_ids == self.config.image_token_id
  1025. n_image_tokens = special_image_mask.sum()
  1026. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1027. torch_compilable_check(
  1028. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  1029. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}",
  1030. )
  1031. return special_image_mask
  1032. @merge_with_config_defaults
  1033. @capture_outputs(tie_last_hidden_states=False)
  1034. @auto_docstring
  1035. def forward(
  1036. self,
  1037. input_ids: torch.LongTensor | None = None,
  1038. pixel_values: torch.FloatTensor | None = None,
  1039. attention_mask: torch.Tensor | None = None,
  1040. position_ids: torch.LongTensor | None = None,
  1041. past_key_values: Cache | None = None,
  1042. inputs_embeds: torch.FloatTensor | None = None,
  1043. vision_feature_select_strategy: str | None = None,
  1044. labels: torch.LongTensor | None = None,
  1045. use_cache: bool | None = None,
  1046. output_attentions: bool | None = None,
  1047. output_hidden_states: bool | None = None,
  1048. return_dict: bool | None = None,
  1049. logits_to_keep: int | torch.Tensor = 0,
  1050. **kwargs: Unpack[TransformersKwargs],
  1051. ) -> tuple | Llama4CausalLMOutputWithPast:
  1052. r"""
  1053. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1054. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1055. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1056. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1057. Example:
  1058. ```python
  1059. >>> from PIL import Image
  1060. >>> import httpx
  1061. >>> from io import BytesIO
  1062. >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
  1063. >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
  1064. >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
  1065. >>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
  1066. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  1067. >>> with httpx.stream("GET", url) as response:
  1068. ... image = Image.open(BytesIO(response.read()))
  1069. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  1070. >>> # Generate
  1071. >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
  1072. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1073. "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
  1074. ```"""
  1075. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1076. output_hidden_states = (
  1077. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1078. )
  1079. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1080. if (input_ids is None) ^ (inputs_embeds is not None):
  1081. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1082. if pixel_values is not None and inputs_embeds is not None:
  1083. raise ValueError(
  1084. "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
  1085. )
  1086. if inputs_embeds is None:
  1087. inputs_embeds = self.get_input_embeddings()(input_ids)
  1088. if pixel_values is not None:
  1089. image_features = self.get_image_features(
  1090. pixel_values=pixel_values,
  1091. vision_feature_select_strategy=vision_feature_select_strategy,
  1092. return_dict=True,
  1093. ).last_hidden_state
  1094. vision_flat = image_features.view(-1, image_features.size(-1))
  1095. projected_vision_flat = self.multi_modal_projector(vision_flat).to(
  1096. inputs_embeds.device, inputs_embeds.dtype
  1097. )
  1098. special_image_mask = self.get_placeholder_mask(
  1099. input_ids, inputs_embeds=inputs_embeds, image_features=projected_vision_flat
  1100. )
  1101. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, projected_vision_flat)
  1102. outputs = self.language_model(
  1103. attention_mask=attention_mask,
  1104. position_ids=position_ids,
  1105. past_key_values=past_key_values,
  1106. inputs_embeds=inputs_embeds,
  1107. use_cache=use_cache,
  1108. output_attentions=output_attentions,
  1109. output_hidden_states=output_hidden_states,
  1110. return_dict=return_dict,
  1111. logits_to_keep=logits_to_keep,
  1112. **kwargs,
  1113. )
  1114. logits = outputs[0]
  1115. loss = None
  1116. if labels is not None:
  1117. # Shift so that tokens < n predict n
  1118. if attention_mask is not None:
  1119. # we use the input attention mask to shift the logits and labels, because it is 2D.
  1120. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
  1121. shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
  1122. shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
  1123. shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
  1124. else:
  1125. shift_logits = logits[..., :-1, :].contiguous()
  1126. shift_labels = labels[..., 1:].contiguous()
  1127. # Flatten the tokens
  1128. loss_fct = nn.CrossEntropyLoss()
  1129. loss = loss_fct(
  1130. shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
  1131. )
  1132. if not return_dict:
  1133. output = (logits,) + outputs[1:]
  1134. return (loss,) + output if loss is not None else output
  1135. return Llama4CausalLMOutputWithPast(
  1136. loss=loss,
  1137. logits=logits,
  1138. past_key_values=outputs.past_key_values,
  1139. hidden_states=outputs.hidden_states,
  1140. attentions=outputs.attentions,
  1141. image_hidden_states=image_features if pixel_values is not None else None,
  1142. )
  1143. def prepare_inputs_for_generation(
  1144. self,
  1145. input_ids,
  1146. past_key_values=None,
  1147. inputs_embeds=None,
  1148. pixel_values=None,
  1149. attention_mask=None,
  1150. logits_to_keep=None,
  1151. is_first_iteration=False,
  1152. **kwargs,
  1153. ):
  1154. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  1155. model_inputs = self.language_model.prepare_inputs_for_generation(
  1156. input_ids,
  1157. past_key_values=past_key_values,
  1158. inputs_embeds=inputs_embeds,
  1159. attention_mask=attention_mask,
  1160. logits_to_keep=logits_to_keep,
  1161. is_first_iteration=is_first_iteration,
  1162. **kwargs,
  1163. )
  1164. if is_first_iteration or not kwargs.get("use_cache", True):
  1165. # Pixel values are used only in the first iteration if available
  1166. # In subsequent iterations, they are already merged with text and cached
  1167. # NOTE: first iteration doesn't have to be prefill, it can be the first
  1168. # iteration with a question and cached system prompt (continue generate from cache)
  1169. model_inputs["pixel_values"] = pixel_values
  1170. return model_inputs
  1171. __all__ = [
  1172. "Llama4PreTrainedModel",
  1173. "Llama4TextModel",
  1174. "Llama4VisionModel",
  1175. "Llama4ForCausalLM",
  1176. "Llama4ForConditionalGeneration",
  1177. ]