modeling_aria.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/aria/modular_aria.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_aria.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from dataclasses import dataclass
  22. from typing import Optional
  23. import torch
  24. from torch import nn
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  30. from ...masking_utils import create_causal_mask
  31. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  32. from ...modeling_layers import GradientCheckpointingLayer
  33. from ...modeling_outputs import (
  34. BaseModelOutputWithPast,
  35. BaseModelOutputWithPooling,
  36. CausalLMOutputWithPast,
  37. ModelOutput,
  38. )
  39. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  40. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  41. from ...processing_utils import Unpack
  42. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
  43. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  44. from ...utils.output_capturing import capture_outputs
  45. from ..auto import AutoModel
  46. from .configuration_aria import AriaConfig, AriaTextConfig
  47. @use_kernel_forward_from_hub("RMSNorm")
  48. class AriaTextRMSNorm(nn.Module):
  49. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  50. """
  51. AriaTextRMSNorm is equivalent to T5LayerNorm
  52. """
  53. super().__init__()
  54. self.weight = nn.Parameter(torch.ones(hidden_size))
  55. self.variance_epsilon = eps
  56. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  57. input_dtype = hidden_states.dtype
  58. hidden_states = hidden_states.to(torch.float32)
  59. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  60. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  61. return self.weight * hidden_states.to(input_dtype)
  62. def extra_repr(self):
  63. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  64. class AriaProjectorMLP(nn.Module):
  65. """
  66. Feed-Forward Network module for the Aria Projector.
  67. Args:
  68. in_features (`int`):
  69. Input embedding dimension.
  70. hidden_features (`int`):
  71. Hidden dimension of the feed-forward network.
  72. output_dim (`int`):
  73. Output dimension.
  74. """
  75. def __init__(self, in_features, hidden_features, output_dim):
  76. super().__init__()
  77. self.linear_in = nn.Linear(in_features, hidden_features, bias=False)
  78. self.linear_out = nn.Linear(hidden_features, output_dim, bias=False)
  79. self.act = ACT2FN["gelu_new"]
  80. def forward(self, hidden_states):
  81. hidden_states = self.act(self.linear_in(hidden_states))
  82. hidden_states = self.linear_out(hidden_states)
  83. return hidden_states
  84. class AriaCrossAttention(nn.Module):
  85. """
  86. Aria Cross-Attention module.
  87. Args:
  88. config (`AriaConfig`):
  89. The configuration to use.
  90. """
  91. def __init__(self, config: AriaConfig, dropout_rate: float = 0):
  92. super().__init__()
  93. hidden_size = config.vision_config.hidden_size
  94. num_heads = config.vision_config.num_attention_heads
  95. self.num_heads = num_heads
  96. self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
  97. self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
  98. self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
  99. # Original code here: https://github.com/rhymes-ai/Aria/blob/719ff4e52b727443cba3793b0e27fe64e0244fe1/aria/model/projector.py#L48
  100. self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
  101. self.linear = nn.Linear(hidden_size, hidden_size)
  102. self.dropout = nn.Dropout(dropout_rate)
  103. self.layer_norm = nn.LayerNorm(hidden_size)
  104. self.layer_norm_kv = nn.LayerNorm(hidden_size)
  105. def forward(self, key_value_states, hidden_states, attn_mask=None):
  106. """
  107. Forward pass of the AriaCrossAttention module.
  108. Args:
  109. key_value_states (`torch.Tensor`):
  110. Input tensor for key and value.
  111. hidden_states (`torch.Tensor`):
  112. Input tensor for query.
  113. attn_mask (`torch.Tensor`, *optional*, defaults to None):
  114. Attention mask.
  115. Returns:
  116. torch.Tensor:
  117. Output tensor after cross-attention.
  118. """
  119. query = self.q_proj(self.layer_norm(hidden_states))
  120. key_value_states = self.layer_norm_kv(key_value_states)
  121. key = self.k_proj(key_value_states)
  122. value = self.v_proj(key_value_states)
  123. attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask)
  124. attn_output = self.dropout(self.linear(attn_output))
  125. return attn_output
  126. class AriaProjector(nn.Module):
  127. """
  128. Aria Projector module.
  129. This module projects vision features into the language model's embedding space, enabling interaction between vision and language components.
  130. Args:
  131. config (`AriaConfig`):
  132. Configuration object for the model.
  133. """
  134. def __init__(
  135. self,
  136. config: AriaConfig,
  137. ):
  138. super().__init__()
  139. self.patch_to_query_dict = config.projector_patch_to_query_dict
  140. self.in_features = config.vision_config.hidden_size
  141. self.num_heads = config.vision_config.num_attention_heads
  142. self.kv_dim = config.vision_config.hidden_size
  143. self.hidden_features = config.text_config.hidden_size
  144. self.output_dim = config.text_config.hidden_size
  145. self.query = nn.Parameter(torch.zeros(config.max_value_projector_patch_to_query_dict, self.in_features))
  146. self.cross_attn = AriaCrossAttention(config)
  147. self.layer_norm = nn.LayerNorm(self.in_features)
  148. self.feed_forward = AriaProjectorMLP(self.in_features, self.hidden_features, self.output_dim)
  149. def forward(self, key_value_states: torch.Tensor, attn_mask: torch.Tensor | None = None):
  150. """
  151. Forward pass of the Projector module.
  152. Args:
  153. key_value_states (`torch.Tensor`):
  154. Input tensor of shape (batch_size, num_patches, kv_dim).
  155. attn_mask (`torch.Tensor`, *optional*, default is None):
  156. Attention mask.
  157. Returns:
  158. `torch.Tensor`: Output tensor of shape (batch_size, query_number, output_dim).
  159. """
  160. batch_size, num_patches = key_value_states.shape[0], key_value_states.shape[1]
  161. if num_patches not in self.patch_to_query_dict:
  162. raise KeyError(
  163. f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}."
  164. )
  165. query_num = self.patch_to_query_dict[num_patches]
  166. queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
  167. if attn_mask is not None:
  168. attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
  169. attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)
  170. attention_out = self.cross_attn(key_value_states, queries, attn_mask=attn_mask)
  171. out = self.feed_forward(self.layer_norm(attention_out))
  172. return out
  173. class AriaSharedExpertsMLP(nn.Module):
  174. """
  175. Shared Expert MLP for shared experts.
  176. Unlike routed experts, shared experts process all tokens without routing.
  177. This class reconfigures the intermediate size in comparison to the LlamaMLP.
  178. Args:
  179. config (`AriaTextConfig`): Configuration object for the Aria language model.
  180. """
  181. def __init__(self, config: AriaTextConfig):
  182. super().__init__()
  183. self.config = config
  184. self.hidden_size = config.hidden_size
  185. self.intermediate_size = config.intermediate_size * config.moe_num_shared_experts
  186. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  187. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  188. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  189. self.act_fn = ACT2FN[config.hidden_act]
  190. def forward(self, x):
  191. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  192. return down_proj
  193. def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
  194. """
  195. Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
  196. Args:
  197. token_states (torch.Tensor): Input tensor of shape (num_tokens, in_features).
  198. expert_weights (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features).
  199. tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
  200. Returns:
  201. torch.Tensor: Output tensor of shape (num_tokens, out_features).
  202. """
  203. num_tokens = token_states.shape[0]
  204. out_features = expert_weights.shape[-1]
  205. output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device)
  206. cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
  207. # Insert zero at the beginning for offset index's convenience
  208. zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
  209. cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
  210. for expert_num in range(expert_weights.shape[0]):
  211. start = cumsum_num_tokens[expert_num]
  212. end = cumsum_num_tokens[expert_num + 1]
  213. tokens = token_states[start:end]
  214. out = torch.matmul(tokens, expert_weights[expert_num])
  215. output[start:end] = out
  216. return output
  217. class AriaGroupedExpertsGemm(nn.Module):
  218. """
  219. Grouped GEMM (General Matrix Multiplication) module for efficient expert computation.
  220. This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm)
  221. for optimized performance. If the grouped_gemm library is not installed, it gracefully
  222. falls back to a sequential GEMM implementation, which may be slower but ensures
  223. functionality.
  224. Args:
  225. in_features (`int`):
  226. Number of input features.
  227. out_features (`int`):
  228. Number of output features.
  229. groups (`int`):
  230. Number of expert groups.
  231. """
  232. def __init__(self, in_features, out_features, groups):
  233. super().__init__()
  234. self.in_features = in_features
  235. self.out_features = out_features
  236. self.groups = groups
  237. self.weight = nn.Parameter(torch.empty(groups, in_features, out_features))
  238. def forward(self, input, tokens_per_expert):
  239. """
  240. Perform grouped matrix multiplication.
  241. Args:
  242. input (`torch.Tensor`):
  243. Input tensor of shape (num_tokens, in_features).
  244. tokens_per_expert (`torch.Tensor`):
  245. Number of tokens assigned to each expert.
  246. Returns:
  247. torch.Tensor: Output tensor of shape (num_tokens, out_features).
  248. """
  249. return sequential_experts_gemm(
  250. input,
  251. self.weight,
  252. tokens_per_expert.cpu(),
  253. )
  254. class AriaExperts(nn.Module):
  255. def __init__(self, config: AriaTextConfig) -> None:
  256. super().__init__()
  257. self.config = config
  258. self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts)
  259. self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts)
  260. def route_tokens_to_experts(self, router_logits):
  261. top_logits, top_indices = torch.topk(router_logits, k=self.config.moe_topk, dim=1)
  262. scores = nn.functional.softmax(top_logits, dim=-1)
  263. return top_indices, scores
  264. def forward(self, hidden_states, router_logits) -> torch.Tensor:
  265. top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
  266. original_dtype = top_k_index.dtype
  267. tokens_per_expert = torch.histc(
  268. top_k_index.flatten().to(torch.float32),
  269. bins=self.config.moe_num_experts,
  270. min=0,
  271. max=self.config.moe_num_experts - 1,
  272. ).to(original_dtype)
  273. indices = top_k_index
  274. flatten_indices = indices.view(-1)
  275. sorted_indices = torch.argsort(flatten_indices)
  276. permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk)
  277. fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
  278. projection, gate = torch.chunk(fc1_output, 2, dim=-1)
  279. fc1_output = nn.functional.silu(projection) * gate
  280. expert_output = self.fc2(fc1_output, tokens_per_expert)
  281. unpermuted_tokens = torch.zeros(
  282. (top_k_weights.shape[0] * self.config.moe_topk, expert_output.size(1)),
  283. dtype=expert_output.dtype,
  284. device=expert_output.device,
  285. )
  286. unpermuted_tokens.index_copy_(0, sorted_indices, expert_output)
  287. unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1))
  288. output = (unpermuted_tokens * top_k_weights.unsqueeze(-1)).sum(dim=1)
  289. return output
  290. class AriaTextMoELayer(nn.Module):
  291. def __init__(self, config: AriaTextConfig):
  292. super().__init__()
  293. self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
  294. self.experts = AriaExperts(config)
  295. self.shared_experts = AriaSharedExpertsMLP(config)
  296. self.config = config
  297. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  298. original_shape = hidden_states.shape
  299. hidden_states = hidden_states.view(-1, hidden_states.size(-1))
  300. router_logits = self.router(hidden_states)
  301. expert_output = self.experts(hidden_states, router_logits).view(original_shape)
  302. shared_expert_output = self.shared_experts(hidden_states.view(original_shape))
  303. return expert_output + shared_expert_output
  304. def rotate_half(x):
  305. """Rotates half the hidden dims of the input."""
  306. x1 = x[..., : x.shape[-1] // 2]
  307. x2 = x[..., x.shape[-1] // 2 :]
  308. return torch.cat((-x2, x1), dim=-1)
  309. @use_kernel_func_from_hub("rotary_pos_emb")
  310. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  311. """Applies Rotary Position Embedding to the query and key tensors.
  312. Args:
  313. q (`torch.Tensor`): The query tensor.
  314. k (`torch.Tensor`): The key tensor.
  315. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  316. sin (`torch.Tensor`): The sine part of the rotary embedding.
  317. unsqueeze_dim (`int`, *optional*, defaults to 1):
  318. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  319. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  320. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  321. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  322. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  323. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  324. Returns:
  325. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  326. """
  327. cos = cos.unsqueeze(unsqueeze_dim)
  328. sin = sin.unsqueeze(unsqueeze_dim)
  329. q_embed = (q * cos) + (rotate_half(q) * sin)
  330. k_embed = (k * cos) + (rotate_half(k) * sin)
  331. return q_embed, k_embed
  332. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  333. """
  334. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  335. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  336. """
  337. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  338. if n_rep == 1:
  339. return hidden_states
  340. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  341. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  342. def eager_attention_forward(
  343. module: nn.Module,
  344. query: torch.Tensor,
  345. key: torch.Tensor,
  346. value: torch.Tensor,
  347. attention_mask: torch.Tensor | None,
  348. scaling: float,
  349. dropout: float = 0.0,
  350. **kwargs: Unpack[TransformersKwargs],
  351. ):
  352. key_states = repeat_kv(key, module.num_key_value_groups)
  353. value_states = repeat_kv(value, module.num_key_value_groups)
  354. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  355. if attention_mask is not None:
  356. attn_weights = attn_weights + attention_mask
  357. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  358. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  359. attn_output = torch.matmul(attn_weights, value_states)
  360. attn_output = attn_output.transpose(1, 2).contiguous()
  361. return attn_output, attn_weights
  362. @use_kernelized_func(apply_rotary_pos_emb)
  363. class AriaTextAttention(nn.Module):
  364. """Multi-headed attention from 'Attention Is All You Need' paper"""
  365. def __init__(self, config: AriaTextConfig, layer_idx: int):
  366. super().__init__()
  367. self.config = config
  368. self.layer_idx = layer_idx
  369. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  370. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  371. self.scaling = self.head_dim**-0.5
  372. self.attention_dropout = config.attention_dropout
  373. self.is_causal = True
  374. self.q_proj = nn.Linear(
  375. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  376. )
  377. self.k_proj = nn.Linear(
  378. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  379. )
  380. self.v_proj = nn.Linear(
  381. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  382. )
  383. self.o_proj = nn.Linear(
  384. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  385. )
  386. def forward(
  387. self,
  388. hidden_states: torch.Tensor,
  389. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  390. attention_mask: torch.Tensor | None = None,
  391. past_key_values: Cache | None = None,
  392. **kwargs: Unpack[TransformersKwargs],
  393. ) -> tuple[torch.Tensor, torch.Tensor]:
  394. input_shape = hidden_states.shape[:-1]
  395. hidden_shape = (*input_shape, -1, self.head_dim)
  396. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  397. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  398. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  399. cos, sin = position_embeddings
  400. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  401. if past_key_values is not None:
  402. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  403. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  404. self.config._attn_implementation, eager_attention_forward
  405. )
  406. attn_output, attn_weights = attention_interface(
  407. self,
  408. query_states,
  409. key_states,
  410. value_states,
  411. attention_mask,
  412. dropout=0.0 if not self.training else self.attention_dropout,
  413. scaling=self.scaling,
  414. **kwargs,
  415. )
  416. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  417. attn_output = self.o_proj(attn_output)
  418. return attn_output, attn_weights
  419. class AriaTextDecoderLayer(GradientCheckpointingLayer):
  420. """
  421. Aria Text Decoder Layer.
  422. This class defines a single decoder layer in the language model, incorporating self-attention and Mixture of Experts (MoE) feed-forward network.
  423. Args:
  424. config (`AriaTextConfig`):
  425. Configuration object for the text component of the model.
  426. layer_idx (`int`):
  427. Index of the layer.
  428. """
  429. def __init__(self, config: AriaTextConfig, layer_idx: int):
  430. super().__init__()
  431. self.hidden_size = config.hidden_size
  432. self.self_attn = AriaTextAttention(config=config, layer_idx=layer_idx)
  433. self.mlp = AriaTextMoELayer(config)
  434. self.input_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  435. self.post_attention_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  436. def forward(
  437. self,
  438. hidden_states: torch.Tensor,
  439. attention_mask: torch.Tensor | None = None,
  440. position_ids: torch.LongTensor | None = None,
  441. past_key_values: Cache | None = None,
  442. use_cache: bool | None = False,
  443. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  444. **kwargs: Unpack[TransformersKwargs],
  445. ) -> torch.Tensor:
  446. residual = hidden_states
  447. hidden_states = self.input_layernorm(hidden_states)
  448. # Self Attention
  449. hidden_states, _ = self.self_attn(
  450. hidden_states=hidden_states,
  451. attention_mask=attention_mask,
  452. position_ids=position_ids,
  453. past_key_values=past_key_values,
  454. use_cache=use_cache,
  455. position_embeddings=position_embeddings,
  456. **kwargs,
  457. )
  458. hidden_states = residual + hidden_states
  459. # Fully Connected
  460. residual = hidden_states
  461. hidden_states = self.post_attention_layernorm(hidden_states)
  462. hidden_states = self.mlp(hidden_states)
  463. hidden_states = residual + hidden_states
  464. return hidden_states
  465. @auto_docstring
  466. class AriaTextPreTrainedModel(PreTrainedModel):
  467. config: AriaTextConfig
  468. base_model_prefix = "model"
  469. input_modalities = ("image", "text")
  470. _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
  471. supports_gradient_checkpointing = True
  472. _skip_keys_device_placement = "past_key_values"
  473. _supports_flash_attn = True
  474. _supports_sdpa = True
  475. _supports_attention_backend = True
  476. _can_record_outputs = {
  477. "hidden_states": AriaTextDecoderLayer,
  478. "attentions": AriaTextAttention,
  479. }
  480. @torch.no_grad()
  481. def _init_weights(self, module):
  482. super()._init_weights(module)
  483. if isinstance(module, AriaGroupedExpertsGemm):
  484. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  485. @auto_docstring
  486. class AriaPreTrainedModel(PreTrainedModel):
  487. config: AriaConfig
  488. base_model_prefix = "model"
  489. supports_gradient_checkpointing = True
  490. _no_split_modules = ["AriaDecoderLayer"]
  491. _skip_keys_device_placement = ["past_key_values"]
  492. _supports_flash_attn = True
  493. _supports_sdpa = True
  494. _supports_flex_attn = True
  495. _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
  496. _supports_attention_backend = True
  497. _can_record_outputs = {
  498. "hidden_states": AriaTextDecoderLayer,
  499. "attentions": AriaTextAttention,
  500. }
  501. @torch.no_grad()
  502. def _init_weights(self, module):
  503. super()._init_weights(module)
  504. if isinstance(module, AriaProjector):
  505. init.trunc_normal_(module.query, std=self.config.initializer_range)
  506. class AriaTextRotaryEmbedding(nn.Module):
  507. inv_freq: torch.Tensor # fix linting for `register_buffer`
  508. def __init__(self, config: AriaTextConfig, device=None):
  509. super().__init__()
  510. self.max_seq_len_cached = config.max_position_embeddings
  511. self.original_max_seq_len = config.max_position_embeddings
  512. self.config = config
  513. self.rope_type = self.config.rope_parameters["rope_type"]
  514. rope_init_fn: Callable = self.compute_default_rope_parameters
  515. if self.rope_type != "default":
  516. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  517. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  518. self.register_buffer("inv_freq", inv_freq, persistent=False)
  519. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  520. @staticmethod
  521. def compute_default_rope_parameters(
  522. config: AriaTextConfig | None = None,
  523. device: Optional["torch.device"] = None,
  524. seq_len: int | None = None,
  525. ) -> tuple["torch.Tensor", float]:
  526. """
  527. Computes the inverse frequencies according to the original RoPE implementation
  528. Args:
  529. config ([`~transformers.PreTrainedConfig`]):
  530. The model configuration.
  531. device (`torch.device`):
  532. The device to use for initialization of the inverse frequencies.
  533. seq_len (`int`, *optional*):
  534. The current sequence length. Unused for this type of RoPE.
  535. Returns:
  536. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  537. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  538. """
  539. base = config.rope_parameters["rope_theta"]
  540. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  541. attention_factor = 1.0 # Unused in this type of RoPE
  542. # Compute the inverse frequencies
  543. inv_freq = 1.0 / (
  544. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  545. )
  546. return inv_freq, attention_factor
  547. @torch.no_grad()
  548. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  549. def forward(self, x, position_ids):
  550. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  551. position_ids_expanded = position_ids[:, None, :].float()
  552. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  553. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  554. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  555. emb = torch.cat((freqs, freqs), dim=-1)
  556. cos = emb.cos() * self.attention_scaling
  557. sin = emb.sin() * self.attention_scaling
  558. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  559. @auto_docstring
  560. class AriaTextModel(AriaTextPreTrainedModel):
  561. def __init__(self, config: AriaTextConfig):
  562. super().__init__(config)
  563. self.padding_idx = config.pad_token_id
  564. self.vocab_size = config.vocab_size
  565. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  566. self.layers = nn.ModuleList(
  567. [AriaTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  568. )
  569. self.norm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  570. self.rotary_emb = AriaTextRotaryEmbedding(config=config)
  571. self.gradient_checkpointing = False
  572. # Initialize weights and apply final processing
  573. self.post_init()
  574. @merge_with_config_defaults
  575. @capture_outputs
  576. @auto_docstring
  577. def forward(
  578. self,
  579. input_ids: torch.LongTensor | None = None,
  580. attention_mask: torch.Tensor | None = None,
  581. position_ids: torch.LongTensor | None = None,
  582. past_key_values: Cache | None = None,
  583. inputs_embeds: torch.FloatTensor | None = None,
  584. use_cache: bool | None = None,
  585. **kwargs: Unpack[TransformersKwargs],
  586. ) -> BaseModelOutputWithPast:
  587. if (input_ids is None) ^ (inputs_embeds is not None):
  588. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  589. if inputs_embeds is None:
  590. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  591. if use_cache and past_key_values is None:
  592. past_key_values = DynamicCache(config=self.config)
  593. if position_ids is None:
  594. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  595. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  596. position_ids = position_ids.unsqueeze(0)
  597. causal_mask = create_causal_mask(
  598. config=self.config,
  599. inputs_embeds=inputs_embeds,
  600. attention_mask=attention_mask,
  601. past_key_values=past_key_values,
  602. position_ids=position_ids,
  603. )
  604. hidden_states = inputs_embeds
  605. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  606. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  607. hidden_states = decoder_layer(
  608. hidden_states,
  609. attention_mask=causal_mask,
  610. position_embeddings=position_embeddings,
  611. position_ids=position_ids,
  612. past_key_values=past_key_values,
  613. use_cache=use_cache,
  614. **kwargs,
  615. )
  616. hidden_states = self.norm(hidden_states)
  617. return BaseModelOutputWithPast(
  618. last_hidden_state=hidden_states,
  619. past_key_values=past_key_values,
  620. )
  621. @auto_docstring
  622. class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
  623. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  624. _tp_plan = {"lm_head": "colwise_gather_output"}
  625. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  626. def __init__(self, config: AriaTextConfig):
  627. super().__init__(config)
  628. self.model = AriaTextModel(config)
  629. self.vocab_size = config.vocab_size
  630. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  631. # Initialize weights and apply final processing
  632. self.post_init()
  633. @auto_docstring
  634. def forward(
  635. self,
  636. input_ids: torch.LongTensor | None = None,
  637. attention_mask: torch.Tensor | None = None,
  638. position_ids: torch.LongTensor | None = None,
  639. past_key_values: Cache | None = None,
  640. inputs_embeds: torch.FloatTensor | None = None,
  641. labels: torch.LongTensor | None = None,
  642. use_cache: bool | None = None,
  643. logits_to_keep: int | torch.Tensor = 0,
  644. **kwargs: Unpack[TransformersKwargs],
  645. ) -> CausalLMOutputWithPast:
  646. r"""
  647. Example:
  648. ```python
  649. >>> from transformers import AutoTokenizer, AriaTextForCausalLM
  650. >>> model = AriaTextForCausalLM.from_pretrained("meta-aria_text/AriaText-2-7b-hf")
  651. >>> tokenizer = AutoTokenizer.from_pretrained("meta-aria_text/AriaText-2-7b-hf")
  652. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  653. >>> inputs = tokenizer(prompt, return_tensors="pt")
  654. >>> # Generate
  655. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  656. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  657. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  658. ```"""
  659. outputs: BaseModelOutputWithPast = self.model(
  660. input_ids=input_ids,
  661. attention_mask=attention_mask,
  662. position_ids=position_ids,
  663. past_key_values=past_key_values,
  664. inputs_embeds=inputs_embeds,
  665. use_cache=use_cache,
  666. **kwargs,
  667. )
  668. hidden_states = outputs.last_hidden_state
  669. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  670. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  671. logits = self.lm_head(hidden_states[:, slice_indices, :])
  672. loss = None
  673. if labels is not None:
  674. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  675. return CausalLMOutputWithPast(
  676. loss=loss,
  677. logits=logits,
  678. past_key_values=outputs.past_key_values,
  679. hidden_states=outputs.hidden_states,
  680. attentions=outputs.attentions,
  681. )
  682. @dataclass
  683. @auto_docstring(
  684. custom_intro="""
  685. Base class for Aria causal language model (or autoregressive) outputs.
  686. """
  687. )
  688. class AriaCausalLMOutputWithPast(ModelOutput):
  689. r"""
  690. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  691. Language modeling loss (for next-token prediction).
  692. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  693. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  694. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  695. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  696. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  697. `past_key_values` input) to speed up sequential decoding.
  698. image_hidden_states (`torch.FloatTensor`, *optional*):
  699. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  700. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  701. """
  702. loss: torch.FloatTensor | None = None
  703. logits: torch.FloatTensor | None = None
  704. past_key_values: Cache | None = None
  705. hidden_states: tuple[torch.FloatTensor] | None = None
  706. attentions: tuple[torch.FloatTensor] | None = None
  707. image_hidden_states: torch.FloatTensor | None = None
  708. @dataclass
  709. @auto_docstring(
  710. custom_intro="""
  711. Base class for Aria outputs, with hidden states and attentions.
  712. """
  713. )
  714. class AriaModelOutputWithPast(BaseModelOutputWithPast):
  715. r"""
  716. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  717. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  718. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  719. `past_key_values` input) to speed up sequential decoding.
  720. image_hidden_states (`torch.FloatTensor`, *optional*):
  721. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  722. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  723. """
  724. image_hidden_states: torch.FloatTensor | None = None
  725. @auto_docstring(
  726. custom_intro="""
  727. The Aria model which consists of a vision backbone and a language model, without a language modeling head.
  728. """
  729. )
  730. class AriaModel(AriaPreTrainedModel):
  731. def __init__(self, config: AriaConfig):
  732. super().__init__(config)
  733. self.vision_tower = AutoModel.from_config(config.vision_config)
  734. self.multi_modal_projector = AriaProjector(config)
  735. self.language_model = AutoModel.from_config(config.text_config)
  736. self.post_init()
  737. def get_input_embeddings(self):
  738. return self.language_model.get_input_embeddings()
  739. def set_input_embeddings(self, value):
  740. self.language_model.set_input_embeddings(value)
  741. @merge_with_config_defaults
  742. @can_return_tuple
  743. @auto_docstring(
  744. custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
  745. )
  746. def get_image_features(
  747. self,
  748. pixel_values: torch.FloatTensor,
  749. pixel_mask: torch.FloatTensor | None = None,
  750. vision_feature_layer: int | list[int] = -1,
  751. output_hidden_states: bool | None = None,
  752. **kwargs: Unpack[TransformersKwargs],
  753. ) -> tuple | BaseModelOutputWithPooling:
  754. patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
  755. image_outputs = self.vision_tower(
  756. pixel_values,
  757. patch_attention_mask=patch_attention_mask,
  758. output_hidden_states=True, # Ignore arg on purpose
  759. return_dict=True,
  760. **kwargs,
  761. )
  762. image_attn_mask = None
  763. if patch_attention_mask is not None:
  764. flattened_mask = patch_attention_mask.flatten(1)
  765. image_attn_mask = torch.logical_not(flattened_mask)
  766. selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
  767. image_outputs.pooler_output = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
  768. return image_outputs
  769. def get_placeholder_mask(
  770. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  771. ):
  772. """
  773. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  774. equal to the length of multimodal features. If the lengths are different, an error is raised.
  775. """
  776. if input_ids is None:
  777. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  778. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  779. )
  780. special_image_mask = special_image_mask.all(-1)
  781. else:
  782. special_image_mask = input_ids == self.config.image_token_id
  783. n_image_tokens = special_image_mask.sum()
  784. n_image_features = image_features.shape[0] * image_features.shape[1]
  785. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  786. torch_compilable_check(
  787. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  788. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
  789. )
  790. return special_image_mask
  791. @can_return_tuple
  792. @auto_docstring
  793. def forward(
  794. self,
  795. input_ids: torch.LongTensor | None = None,
  796. pixel_values: torch.FloatTensor | None = None,
  797. pixel_mask: torch.LongTensor | None = None,
  798. attention_mask: torch.Tensor | None = None,
  799. position_ids: torch.LongTensor | None = None,
  800. past_key_values: Cache | None = None,
  801. inputs_embeds: torch.FloatTensor | None = None,
  802. use_cache: bool | None = None,
  803. **kwargs: Unpack[FlashAttentionKwargs],
  804. ) -> tuple | AriaModelOutputWithPast:
  805. if inputs_embeds is None:
  806. inputs_embeds = self.get_input_embeddings()(input_ids)
  807. # 2. Merge text and images
  808. if pixel_values is not None and inputs_embeds.shape[1] != 1:
  809. image_features = self.get_image_features(
  810. pixel_values=pixel_values,
  811. pixel_mask=pixel_mask,
  812. vision_feature_layer=self.config.vision_feature_layer,
  813. return_dict=True,
  814. ).pooler_output
  815. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  816. special_image_mask = self.get_placeholder_mask(
  817. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  818. )
  819. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  820. outputs = self.language_model(
  821. attention_mask=attention_mask,
  822. position_ids=position_ids,
  823. past_key_values=past_key_values,
  824. inputs_embeds=inputs_embeds,
  825. use_cache=use_cache,
  826. **kwargs,
  827. )
  828. return AriaModelOutputWithPast(
  829. last_hidden_state=outputs.last_hidden_state,
  830. past_key_values=outputs.past_key_values if use_cache else None,
  831. hidden_states=outputs.hidden_states,
  832. attentions=outputs.attentions,
  833. image_hidden_states=image_features if pixel_values is not None else None,
  834. )
  835. def _create_patch_attention_mask(self, pixel_mask):
  836. if pixel_mask is None:
  837. return None
  838. patches_subgrid = pixel_mask.unfold(
  839. dimension=1,
  840. size=self.vision_tower.config.patch_size,
  841. step=self.vision_tower.config.patch_size,
  842. )
  843. patches_subgrid = patches_subgrid.unfold(
  844. dimension=2,
  845. size=self.vision_tower.config.patch_size,
  846. step=self.vision_tower.config.patch_size,
  847. )
  848. return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
  849. @auto_docstring(
  850. custom_intro="""
  851. Aria model for conditional generation tasks.
  852. This model combines a vision tower, a multi-modal projector, and a language model
  853. to perform tasks that involve both image and text inputs.
  854. """
  855. )
  856. class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
  857. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  858. def __init__(self, config: AriaConfig):
  859. super().__init__(config)
  860. self.model = AriaModel(config)
  861. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  862. self.post_init()
  863. def get_input_embeddings(self):
  864. return self.model.get_input_embeddings()
  865. def set_input_embeddings(self, value):
  866. self.model.set_input_embeddings(value)
  867. def get_output_embeddings(self) -> nn.Module:
  868. return self.lm_head
  869. @auto_docstring
  870. def get_image_features(
  871. self,
  872. pixel_values: torch.FloatTensor,
  873. pixel_mask: torch.FloatTensor | None = None,
  874. vision_feature_layer: int | list[int] = -1,
  875. **kwargs: Unpack[TransformersKwargs],
  876. ) -> tuple | BaseModelOutputWithPooling:
  877. return self.model.get_image_features(
  878. pixel_values=pixel_values,
  879. pixel_mask=pixel_mask,
  880. vision_feature_layer=vision_feature_layer,
  881. **kwargs,
  882. )
  883. @can_return_tuple
  884. @auto_docstring
  885. def forward(
  886. self,
  887. input_ids: torch.LongTensor | None = None,
  888. pixel_values: torch.FloatTensor | None = None,
  889. pixel_mask: torch.LongTensor | None = None,
  890. attention_mask: torch.Tensor | None = None,
  891. position_ids: torch.LongTensor | None = None,
  892. past_key_values: Cache | None = None,
  893. inputs_embeds: torch.FloatTensor | None = None,
  894. labels: torch.LongTensor | None = None,
  895. use_cache: bool | None = None,
  896. logits_to_keep: int | torch.Tensor = 0,
  897. **kwargs: Unpack[TransformersKwargs],
  898. ) -> tuple | AriaCausalLMOutputWithPast:
  899. r"""
  900. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  901. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  902. config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `AriaForConditionalGeneration`).
  903. Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
  904. computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  905. Example:
  906. ```python
  907. >>> import httpx
  908. >>> from io import BytesIO
  909. >>> import torch
  910. >>> from PIL import Image
  911. >>> from io import BytesIO
  912. >>> from transformers import AutoProcessor, AutoModel
  913. >>> from transformers.image_utils import load_image
  914. >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
  915. >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
  916. >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
  917. >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
  918. >>> processor = AutoProcessor.from_pretrained("Rhymes-AI/Aria")
  919. >>> model = AutoModel.from_pretrained("Rhymes-AI/Aria", dtype=torch.bfloat16, device_map="auto")
  920. >>> # Create inputs
  921. >>> messages = [
  922. ... {
  923. ... "role": "user",
  924. ... "content": [
  925. ... {"type": "image"},
  926. ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
  927. ... {"type": "image"},
  928. ... {"type": "text", "text": "What can we see in this image?"},
  929. ... ]
  930. ... },
  931. ... {
  932. ... "role": "user",
  933. ... "content": [
  934. ... {"type": "image"},
  935. ... {"type": "text", "text": "In which city is that bridge located?"},
  936. ... ]
  937. ... }
  938. ... ]
  939. >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
  940. >>> images = [[image1, image2], [image3]]
  941. >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
  942. >>> # Generate
  943. >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
  944. >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
  945. >>> print(generated_texts[0])
  946. Assistant: There are buildings, trees, lights, and water visible in this image.
  947. >>> print(generated_texts[1])
  948. Assistant: The bridge is in San Francisco.
  949. ```"""
  950. outputs = self.model(
  951. input_ids=input_ids,
  952. pixel_values=pixel_values,
  953. pixel_mask=pixel_mask,
  954. attention_mask=attention_mask,
  955. position_ids=position_ids,
  956. past_key_values=past_key_values,
  957. inputs_embeds=inputs_embeds,
  958. use_cache=use_cache,
  959. **kwargs,
  960. )
  961. hidden_states = outputs[0]
  962. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  963. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  964. logits = self.lm_head(hidden_states[:, slice_indices, :])
  965. loss = None
  966. if labels is not None:
  967. loss = self.loss_function(
  968. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  969. )
  970. return AriaCausalLMOutputWithPast(
  971. loss=loss,
  972. logits=logits,
  973. past_key_values=outputs.past_key_values,
  974. hidden_states=outputs.hidden_states,
  975. attentions=outputs.attentions,
  976. )
  977. def prepare_inputs_for_generation(
  978. self,
  979. input_ids,
  980. past_key_values=None,
  981. inputs_embeds=None,
  982. pixel_values=None,
  983. pixel_mask=None,
  984. attention_mask=None,
  985. logits_to_keep=None,
  986. is_first_iteration=False,
  987. **kwargs,
  988. ):
  989. model_inputs = super().prepare_inputs_for_generation(
  990. input_ids,
  991. past_key_values=past_key_values,
  992. inputs_embeds=inputs_embeds,
  993. attention_mask=attention_mask,
  994. logits_to_keep=logits_to_keep,
  995. is_first_iteration=is_first_iteration,
  996. **kwargs,
  997. )
  998. if is_first_iteration or not kwargs.get("use_cache", True):
  999. # Pixel values are used only in the first iteration if available
  1000. # In subsequent iterations, they are already merged with text and cached
  1001. # NOTE: first iteration doesn't have to be prefill, it can be the first
  1002. # iteration with a question and cached system prompt (continue generate from cache)
  1003. model_inputs["pixel_values"] = pixel_values
  1004. model_inputs["pixel_mask"] = pixel_mask
  1005. return model_inputs
  1006. __all__ = [
  1007. "AriaForConditionalGeneration",
  1008. "AriaPreTrainedModel",
  1009. "AriaTextPreTrainedModel",
  1010. "AriaTextModel",
  1011. "AriaModel",
  1012. "AriaTextForCausalLM",
  1013. ]