modeling_afmoe.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/afmoe/modular_afmoe.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_afmoe.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. from torch import nn
  24. from ... import initialization as init
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import (
  29. use_experts_implementation,
  30. use_kernel_forward_from_hub,
  31. use_kernel_func_from_hub,
  32. use_kernelized_func,
  33. )
  34. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  35. from ...modeling_layers import GradientCheckpointingLayer
  36. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
  41. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  42. from ...utils.output_capturing import OutputRecorder, capture_outputs
  43. from .configuration_afmoe import AfmoeConfig
  44. class AfmoeRotaryEmbedding(nn.Module):
  45. inv_freq: torch.Tensor # fix linting for `register_buffer`
  46. def __init__(self, config: AfmoeConfig, device=None):
  47. super().__init__()
  48. self.max_seq_len_cached = config.max_position_embeddings
  49. self.original_max_seq_len = config.max_position_embeddings
  50. self.config = config
  51. self.rope_type = self.config.rope_parameters["rope_type"]
  52. rope_init_fn: Callable = self.compute_default_rope_parameters
  53. if self.rope_type != "default":
  54. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  55. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  56. self.register_buffer("inv_freq", inv_freq, persistent=False)
  57. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  58. @staticmethod
  59. def compute_default_rope_parameters(
  60. config: AfmoeConfig | None = None,
  61. device: Optional["torch.device"] = None,
  62. seq_len: int | None = None,
  63. ) -> tuple["torch.Tensor", float]:
  64. """
  65. Computes the inverse frequencies according to the original RoPE implementation
  66. Args:
  67. config ([`~transformers.PreTrainedConfig`]):
  68. The model configuration.
  69. device (`torch.device`):
  70. The device to use for initialization of the inverse frequencies.
  71. seq_len (`int`, *optional*):
  72. The current sequence length. Unused for this type of RoPE.
  73. Returns:
  74. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  75. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  76. """
  77. base = config.rope_parameters["rope_theta"]
  78. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  79. attention_factor = 1.0 # Unused in this type of RoPE
  80. # Compute the inverse frequencies
  81. inv_freq = 1.0 / (
  82. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  83. )
  84. return inv_freq, attention_factor
  85. @torch.no_grad()
  86. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  87. def forward(self, x, position_ids):
  88. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  89. position_ids_expanded = position_ids[:, None, :].float()
  90. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  91. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  92. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  93. emb = torch.cat((freqs, freqs), dim=-1)
  94. cos = emb.cos() * self.attention_scaling
  95. sin = emb.sin() * self.attention_scaling
  96. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  97. @use_kernel_forward_from_hub("RMSNorm")
  98. class AfmoeRMSNorm(nn.Module):
  99. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  100. """
  101. AfmoeRMSNorm is equivalent to T5LayerNorm
  102. """
  103. super().__init__()
  104. self.weight = nn.Parameter(torch.ones(hidden_size))
  105. self.variance_epsilon = eps
  106. def forward(self, hidden_states) -> torch.Tensor:
  107. input_dtype = hidden_states.dtype
  108. hidden_states = hidden_states.to(torch.float32)
  109. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  110. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  111. return (self.weight * hidden_states).to(input_dtype) # main diff with Llama
  112. def extra_repr(self):
  113. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  114. class AfmoeMLP(nn.Module):
  115. def __init__(self, config, intermediate_size=None):
  116. super().__init__()
  117. self.config = config
  118. self.hidden_size = config.hidden_size
  119. self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
  120. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  121. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  122. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  123. self.act_fn = ACT2FN[config.hidden_act]
  124. def forward(self, x):
  125. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  126. return down_proj
  127. class AfmoeTokenChoiceRouter(nn.Module):
  128. """
  129. Token-choice top-K router for MoE routing.
  130. This router assigns each token to the top-K experts based on sigmoid scores, matching the released checkpoints.
  131. """
  132. def __init__(self, config):
  133. super().__init__()
  134. self.config = config
  135. self.top_k = config.num_experts_per_tok
  136. self.num_experts = config.num_experts
  137. self.route_scale = config.route_scale
  138. self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
  139. def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor):
  140. _, _, hidden_dim = hidden_states.shape
  141. hidden_states = hidden_states.view(-1, hidden_dim)
  142. router_logits = self.gate(hidden_states).to(torch.float32)
  143. scores = torch.sigmoid(router_logits)
  144. _, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
  145. top_scores = scores.gather(dim=1, index=selected_experts)
  146. denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
  147. top_scores = top_scores / denominator
  148. top_scores = top_scores * self.route_scale
  149. return router_logits, top_scores, selected_experts
  150. @use_experts_implementation
  151. class AfmoeExperts(nn.Module):
  152. """Collection of expert weights stored as 3D tensors."""
  153. def __init__(self, config):
  154. super().__init__()
  155. self.num_experts = config.num_experts
  156. self.hidden_dim = config.hidden_size
  157. self.intermediate_dim = config.moe_intermediate_size
  158. self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
  159. self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
  160. self.act_fn = ACT2FN[config.hidden_act]
  161. def forward(
  162. self,
  163. hidden_states: torch.Tensor,
  164. top_k_index: torch.Tensor,
  165. top_k_weights: torch.Tensor,
  166. ) -> torch.Tensor:
  167. final_hidden_states = torch.zeros_like(hidden_states)
  168. with torch.no_grad():
  169. expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
  170. expert_mask = expert_mask.permute(2, 1, 0)
  171. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  172. for expert_idx in expert_hit:
  173. expert_idx = expert_idx[0]
  174. if expert_idx == self.num_experts:
  175. continue
  176. top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
  177. current_state = hidden_states[token_idx]
  178. gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
  179. current_hidden_states = self.act_fn(gate) * up
  180. current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
  181. current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
  182. final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
  183. return final_hidden_states
  184. class AfmoeSparseMoeBlock(nn.Module):
  185. """
  186. Mixture of Experts (MoE) module for AFMoE.
  187. This module implements a sparse MoE layer with both shared experts (always active) and
  188. routed experts (activated based on token-choice routing).
  189. """
  190. def __init__(self, config):
  191. super().__init__()
  192. self.config = config
  193. self.router = AfmoeTokenChoiceRouter(config)
  194. self.shared_experts = AfmoeMLP(config, config.moe_intermediate_size * config.num_shared_experts)
  195. self.experts = AfmoeExperts(config)
  196. self.expert_bias = nn.Parameter(torch.zeros(config.num_experts), requires_grad=False)
  197. def forward(self, hidden_states):
  198. batch_size, seq_len, hidden_dim = hidden_states.shape
  199. hidden_states_flat = hidden_states.view(-1, hidden_dim)
  200. # Get routing decisions (returns flattened top-k)
  201. router_logits, top_scores, selected_experts = self.router(hidden_states, self.expert_bias)
  202. # Process through shared experts
  203. shared_output = self.shared_experts(hidden_states_flat).view(batch_size, seq_len, hidden_dim)
  204. routed_output = self.experts(hidden_states_flat, selected_experts, top_scores).view(
  205. batch_size, seq_len, hidden_dim
  206. )
  207. return shared_output + routed_output
  208. def rotate_half(x):
  209. """Rotates half the hidden dims of the input."""
  210. x1 = x[..., : x.shape[-1] // 2]
  211. x2 = x[..., x.shape[-1] // 2 :]
  212. return torch.cat((-x2, x1), dim=-1)
  213. @use_kernel_func_from_hub("rotary_pos_emb")
  214. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  215. """Applies Rotary Position Embedding to the query and key tensors.
  216. Args:
  217. q (`torch.Tensor`): The query tensor.
  218. k (`torch.Tensor`): The key tensor.
  219. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  220. sin (`torch.Tensor`): The sine part of the rotary embedding.
  221. unsqueeze_dim (`int`, *optional*, defaults to 1):
  222. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  223. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  224. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  225. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  226. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  227. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  228. Returns:
  229. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  230. """
  231. cos = cos.unsqueeze(unsqueeze_dim)
  232. sin = sin.unsqueeze(unsqueeze_dim)
  233. q_embed = (q * cos) + (rotate_half(q) * sin)
  234. k_embed = (k * cos) + (rotate_half(k) * sin)
  235. return q_embed, k_embed
  236. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  237. """
  238. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  239. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  240. """
  241. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  242. if n_rep == 1:
  243. return hidden_states
  244. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  245. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  246. def eager_attention_forward(
  247. module: nn.Module,
  248. query: torch.Tensor,
  249. key: torch.Tensor,
  250. value: torch.Tensor,
  251. attention_mask: torch.Tensor | None,
  252. scaling: float,
  253. dropout: float = 0.0,
  254. **kwargs: Unpack[TransformersKwargs],
  255. ):
  256. key_states = repeat_kv(key, module.num_key_value_groups)
  257. value_states = repeat_kv(value, module.num_key_value_groups)
  258. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  259. if attention_mask is not None:
  260. attn_weights = attn_weights + attention_mask
  261. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  262. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  263. attn_output = torch.matmul(attn_weights, value_states)
  264. attn_output = attn_output.transpose(1, 2).contiguous()
  265. return attn_output, attn_weights
  266. @use_kernelized_func(apply_rotary_pos_emb)
  267. class AfmoeAttention(nn.Module):
  268. """
  269. Multi-headed attention module with optional sliding window and gating.
  270. This attention mechanism supports both full attention and sliding window attention,
  271. and includes Q/K normalization and gating of the output. It inherits from [`LlamaAttention`] to minimize the amount
  272. of custom logic we need to maintain.
  273. """
  274. def __init__(self, config: AfmoeConfig, layer_idx: int):
  275. super().__init__()
  276. self.config = config
  277. self.layer_idx = layer_idx
  278. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  279. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  280. self.scaling = self.head_dim**-0.5
  281. self.attention_dropout = config.attention_dropout
  282. self.is_causal = True
  283. self.q_proj = nn.Linear(
  284. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  285. )
  286. self.k_proj = nn.Linear(
  287. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  288. )
  289. self.v_proj = nn.Linear(
  290. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  291. )
  292. self.o_proj = nn.Linear(
  293. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  294. )
  295. # Parent LlamaAttention already sets: layer_idx, num_heads, num_key_value_heads, num_key_value_groups, head_dim
  296. # We only add AFMoE-specific attributes
  297. self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
  298. self.sliding_window = config.sliding_window if self.is_local_attention else None
  299. self.q_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
  300. self.k_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
  301. self.gate_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  302. def forward(
  303. self,
  304. hidden_states: torch.Tensor,
  305. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  306. attention_mask: torch.Tensor | None,
  307. past_key_value: Cache | None = None,
  308. **kwargs: Unpack[TransformersKwargs],
  309. ) -> tuple[torch.Tensor, torch.Tensor]:
  310. input_shape = hidden_states.shape[:-1]
  311. hidden_shape = (*input_shape, -1, self.head_dim)
  312. query_states = self.q_proj(hidden_states).view(hidden_shape)
  313. key_states = self.k_proj(hidden_states).view(hidden_shape)
  314. value_states = self.v_proj(hidden_states).view(hidden_shape)
  315. gate_states = self.gate_proj(hidden_states)
  316. query_states = self.q_norm(query_states).transpose(1, 2)
  317. key_states = self.k_norm(key_states).transpose(1, 2)
  318. value_states = value_states.transpose(1, 2)
  319. if self.is_local_attention:
  320. cos, sin = position_embeddings
  321. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  322. if past_key_value is not None:
  323. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
  324. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  325. self.config._attn_implementation, eager_attention_forward
  326. )
  327. output, attn_weights = attention_interface(
  328. self,
  329. query_states,
  330. key_states,
  331. value_states,
  332. attention_mask=attention_mask,
  333. dropout=0.0 if not self.training else self.attention_dropout,
  334. scaling=self.scaling,
  335. sliding_window=self.sliding_window,
  336. **kwargs,
  337. )
  338. output = output.view(*input_shape, -1).contiguous()
  339. output = output * torch.sigmoid(gate_states)
  340. attn_output = self.o_proj(output)
  341. return attn_output, attn_weights
  342. class AfmoeDecoderLayer(GradientCheckpointingLayer):
  343. """
  344. AFMoE decoder layer with dual normalization.
  345. This layer applies self-attention followed by either a dense MLP or MoE block,
  346. with dual normalization (pre and post) around each component.
  347. """
  348. def __init__(self, config: AfmoeConfig, layer_idx: int):
  349. super().__init__()
  350. self.hidden_size = config.hidden_size
  351. self.layer_idx = layer_idx
  352. self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx)
  353. # Dual normalization for attention
  354. self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  355. self.post_attention_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  356. # Dual normalization for FFN
  357. self.pre_mlp_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  358. self.post_mlp_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  359. # MoE or dense FFN
  360. self.moe_enabled = layer_idx >= config.num_dense_layers
  361. if self.moe_enabled:
  362. self.mlp = AfmoeSparseMoeBlock(config)
  363. else:
  364. self.mlp = AfmoeMLP(config)
  365. def forward(
  366. self,
  367. hidden_states: torch.Tensor,
  368. attention_mask: torch.Tensor | None = None,
  369. position_ids: torch.LongTensor | None = None,
  370. past_key_value: Cache | None = None,
  371. use_cache: bool | None = None,
  372. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  373. **kwargs: Unpack[TransformersKwargs],
  374. ) -> torch.FloatTensor:
  375. residual = hidden_states
  376. # Self Attention with dual normalization
  377. hidden_states = self.input_layernorm(hidden_states)
  378. hidden_states, _ = self.self_attn(
  379. hidden_states=hidden_states,
  380. attention_mask=attention_mask,
  381. position_ids=position_ids,
  382. past_key_value=past_key_value,
  383. use_cache=use_cache,
  384. position_embeddings=position_embeddings,
  385. **kwargs,
  386. )
  387. hidden_states = self.post_attention_layernorm(hidden_states)
  388. hidden_states = residual + hidden_states
  389. # FFN with dual normalization
  390. residual = hidden_states
  391. hidden_states = self.pre_mlp_layernorm(hidden_states)
  392. hidden_states = self.mlp(hidden_states)
  393. hidden_states = self.post_mlp_layernorm(hidden_states)
  394. hidden_states = residual + hidden_states
  395. return hidden_states
  396. class AfmoePreTrainedModel(PreTrainedModel):
  397. """
  398. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  399. models.
  400. """
  401. config: AfmoeConfig
  402. base_model_prefix = "model"
  403. _no_split_modules = ["AfmoeDecoderLayer"]
  404. _skip_keys_device_placement = ["past_key_values"]
  405. _can_record_outputs = {
  406. "router_logits": OutputRecorder(AfmoeTokenChoiceRouter, index=0),
  407. "hidden_states": AfmoeDecoderLayer,
  408. "attentions": AfmoeAttention,
  409. }
  410. _keep_in_fp32_modules = [
  411. "input_layernorm",
  412. "post_attention_layernorm",
  413. "pre_mlp_layernorm",
  414. "post_mlp_layernorm",
  415. "q_norm",
  416. "k_norm",
  417. "norm",
  418. "expert_bias",
  419. ]
  420. _supports_sdpa = True
  421. _supports_flash_attn = True
  422. _supports_flex_attn = True
  423. _can_compile_fullgraph = (
  424. is_grouped_mm_available()
  425. ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
  426. _supports_attention_backend = True
  427. supports_gradient_checkpointing = True
  428. def _init_weights(self, module):
  429. """Initialize the weights"""
  430. super()._init_weights(module)
  431. std = self.config.initializer_range
  432. if isinstance(module, AfmoeExperts):
  433. init.normal_(module.gate_up_proj, mean=0.0, std=std)
  434. init.normal_(module.down_proj, mean=0.0, std=std)
  435. elif isinstance(module, AfmoeTokenChoiceRouter):
  436. init.zeros_(module.gate.weight)
  437. elif isinstance(module, AfmoeSparseMoeBlock):
  438. init.zeros_(module.expert_bias)
  439. @auto_docstring
  440. class AfmoeModel(AfmoePreTrainedModel):
  441. """
  442. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AfmoeDecoderLayer`]
  443. Args:
  444. config: AfmoeConfig
  445. """
  446. def __init__(self, config: AfmoeConfig):
  447. super().__init__(config)
  448. self.padding_idx = config.pad_token_id
  449. self.vocab_size = config.vocab_size
  450. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  451. self.layers = nn.ModuleList(
  452. [AfmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  453. )
  454. self.norm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  455. self.rotary_emb = AfmoeRotaryEmbedding(config=config)
  456. self.gradient_checkpointing = False
  457. self.post_init()
  458. @auto_docstring
  459. @merge_with_config_defaults
  460. @capture_outputs
  461. def forward(
  462. self,
  463. input_ids: torch.LongTensor | None = None,
  464. attention_mask: torch.Tensor | None = None,
  465. inputs_embeds: torch.FloatTensor | None = None,
  466. position_ids: torch.LongTensor | None = None,
  467. past_key_values: Cache | None = None,
  468. use_cache: bool | None = None,
  469. **kwargs: Unpack[TransformersKwargs],
  470. ) -> tuple | MoeModelOutputWithPast:
  471. if (input_ids is None) ^ (inputs_embeds is not None):
  472. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  473. if use_cache and past_key_values is None:
  474. past_key_values = DynamicCache(config=self.config)
  475. if inputs_embeds is None:
  476. inputs_embeds = self.embed_tokens(input_ids)
  477. if position_ids is None:
  478. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  479. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  480. position_ids = position_ids.unsqueeze(0)
  481. # It may already have been prepared by e.g. `generate`
  482. if not isinstance(causal_mask_mapping := attention_mask, dict):
  483. mask_kwargs = {
  484. "config": self.config,
  485. "inputs_embeds": inputs_embeds,
  486. "attention_mask": attention_mask,
  487. "past_key_values": past_key_values,
  488. }
  489. causal_mask_mapping = {
  490. "full_attention": create_causal_mask(**mask_kwargs),
  491. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  492. }
  493. hidden_states = inputs_embeds
  494. # Apply muP input scaling if enabled
  495. if self.config.mup_enabled:
  496. hidden_states = hidden_states * (self.config.hidden_size**0.5)
  497. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  498. for i, decoder_layer in enumerate(self.layers):
  499. hidden_states = decoder_layer(
  500. hidden_states,
  501. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  502. position_ids=position_ids,
  503. past_key_value=past_key_values,
  504. use_cache=use_cache,
  505. position_embeddings=position_embeddings,
  506. **kwargs,
  507. )
  508. hidden_states = self.norm(hidden_states)
  509. return MoeModelOutputWithPast(
  510. last_hidden_state=hidden_states,
  511. past_key_values=past_key_values if use_cache else None,
  512. )
  513. @auto_docstring
  514. class AfmoeForCausalLM(AfmoePreTrainedModel, GenerationMixin):
  515. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  516. _tp_plan = {"lm_head": "colwise_gather_output"}
  517. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  518. def __init__(self, config):
  519. super().__init__(config)
  520. self.model = AfmoeModel(config)
  521. self.vocab_size = config.vocab_size
  522. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  523. self.post_init()
  524. @can_return_tuple
  525. @auto_docstring
  526. def forward(
  527. self,
  528. input_ids: torch.LongTensor | None = None,
  529. attention_mask: torch.Tensor | None = None,
  530. position_ids: torch.LongTensor | None = None,
  531. past_key_values: Cache | None = None,
  532. inputs_embeds: torch.FloatTensor | None = None,
  533. labels: torch.LongTensor | None = None,
  534. use_cache: bool | None = None,
  535. output_router_logits: bool | None = None,
  536. logits_to_keep: int | torch.Tensor = 0,
  537. **kwargs: Unpack[TransformersKwargs],
  538. ) -> MoeCausalLMOutputWithPast:
  539. r"""
  540. Example:
  541. ```python
  542. >>> from transformers import AutoTokenizer, AfmoeForCausalLM
  543. >>> model = AfmoeForCausalLM.from_pretrained("meta-afmoe/Afmoe-2-7b-hf")
  544. >>> tokenizer = AutoTokenizer.from_pretrained("meta-afmoe/Afmoe-2-7b-hf")
  545. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  546. >>> inputs = tokenizer(prompt, return_tensors="pt")
  547. >>> # Generate
  548. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  549. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  550. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  551. ```"""
  552. output_router_logits = (
  553. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  554. )
  555. outputs: MoeModelOutputWithPast = self.model(
  556. input_ids=input_ids,
  557. attention_mask=attention_mask,
  558. position_ids=position_ids,
  559. past_key_values=past_key_values,
  560. inputs_embeds=inputs_embeds,
  561. use_cache=use_cache,
  562. output_router_logits=output_router_logits,
  563. **kwargs,
  564. )
  565. hidden_states = outputs.last_hidden_state
  566. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  567. logits = self.lm_head(hidden_states[:, slice_indices, :])
  568. loss = None
  569. if labels is not None:
  570. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  571. return MoeCausalLMOutputWithPast(
  572. loss=loss,
  573. logits=logits,
  574. past_key_values=outputs.past_key_values,
  575. hidden_states=outputs.hidden_states,
  576. attentions=outputs.attentions,
  577. router_logits=outputs.router_logits,
  578. )
  579. __all__ = ["AfmoeForCausalLM", "AfmoeModel", "AfmoePreTrainedModel"]