modular_afmoe.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. # Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch AFMoE model."""
  15. from collections.abc import Callable
  16. import torch
  17. from torch import nn
  18. from ... import initialization as init
  19. from ...cache_utils import Cache, DynamicCache
  20. from ...generation import GenerationMixin
  21. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available, logging
  27. from ...utils.generic import merge_with_config_defaults
  28. from ...utils.output_capturing import OutputRecorder, capture_outputs
  29. from ..gpt_oss.modeling_gpt_oss import GptOssRMSNorm
  30. from ..llama.modeling_llama import (
  31. LlamaAttention,
  32. LlamaForCausalLM,
  33. LlamaRotaryEmbedding,
  34. apply_rotary_pos_emb,
  35. eager_attention_forward,
  36. )
  37. from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeExperts, Qwen2MoeMLP
  38. from .configuration_afmoe import AfmoeConfig
  39. logger = logging.get_logger(__name__)
  40. class AfmoeRotaryEmbedding(LlamaRotaryEmbedding):
  41. pass
  42. class AfmoeRMSNorm(GptOssRMSNorm):
  43. pass
  44. class AfmoeMLP(Qwen2MoeMLP):
  45. pass
  46. class AfmoeTokenChoiceRouter(nn.Module):
  47. """
  48. Token-choice top-K router for MoE routing.
  49. This router assigns each token to the top-K experts based on sigmoid scores, matching the released checkpoints.
  50. """
  51. def __init__(self, config):
  52. super().__init__()
  53. self.config = config
  54. self.top_k = config.num_experts_per_tok
  55. self.num_experts = config.num_experts
  56. self.route_scale = config.route_scale
  57. self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
  58. def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor):
  59. _, _, hidden_dim = hidden_states.shape
  60. hidden_states = hidden_states.view(-1, hidden_dim)
  61. router_logits = self.gate(hidden_states).to(torch.float32)
  62. scores = torch.sigmoid(router_logits)
  63. _, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
  64. top_scores = scores.gather(dim=1, index=selected_experts)
  65. denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
  66. top_scores = top_scores / denominator
  67. top_scores = top_scores * self.route_scale
  68. return router_logits, top_scores, selected_experts
  69. class AfmoeExperts(Qwen2MoeExperts):
  70. pass
  71. class AfmoeSparseMoeBlock(nn.Module):
  72. """
  73. Mixture of Experts (MoE) module for AFMoE.
  74. This module implements a sparse MoE layer with both shared experts (always active) and
  75. routed experts (activated based on token-choice routing).
  76. """
  77. def __init__(self, config):
  78. super().__init__()
  79. self.config = config
  80. self.router = AfmoeTokenChoiceRouter(config)
  81. self.shared_experts = AfmoeMLP(config, config.moe_intermediate_size * config.num_shared_experts)
  82. self.experts = AfmoeExperts(config)
  83. self.expert_bias = nn.Parameter(torch.zeros(config.num_experts), requires_grad=False)
  84. def forward(self, hidden_states):
  85. batch_size, seq_len, hidden_dim = hidden_states.shape
  86. hidden_states_flat = hidden_states.view(-1, hidden_dim)
  87. # Get routing decisions (returns flattened top-k)
  88. router_logits, top_scores, selected_experts = self.router(hidden_states, self.expert_bias)
  89. # Process through shared experts
  90. shared_output = self.shared_experts(hidden_states_flat).view(batch_size, seq_len, hidden_dim)
  91. routed_output = self.experts(hidden_states_flat, selected_experts, top_scores).view(
  92. batch_size, seq_len, hidden_dim
  93. )
  94. return shared_output + routed_output
  95. class AfmoeAttention(LlamaAttention):
  96. """
  97. Multi-headed attention module with optional sliding window and gating.
  98. This attention mechanism supports both full attention and sliding window attention,
  99. and includes Q/K normalization and gating of the output. It inherits from [`LlamaAttention`] to minimize the amount
  100. of custom logic we need to maintain.
  101. """
  102. def __init__(self, config: AfmoeConfig, layer_idx: int):
  103. super().__init__(config, layer_idx)
  104. # Parent LlamaAttention already sets: layer_idx, num_heads, num_key_value_heads, num_key_value_groups, head_dim
  105. # We only add AFMoE-specific attributes
  106. self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
  107. self.sliding_window = config.sliding_window if self.is_local_attention else None
  108. self.q_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
  109. self.k_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
  110. self.gate_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  111. def forward(
  112. self,
  113. hidden_states: torch.Tensor,
  114. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  115. attention_mask: torch.Tensor | None,
  116. past_key_value: Cache | None = None,
  117. **kwargs: Unpack[TransformersKwargs],
  118. ) -> tuple[torch.Tensor, torch.Tensor]:
  119. input_shape = hidden_states.shape[:-1]
  120. hidden_shape = (*input_shape, -1, self.head_dim)
  121. query_states = self.q_proj(hidden_states).view(hidden_shape)
  122. key_states = self.k_proj(hidden_states).view(hidden_shape)
  123. value_states = self.v_proj(hidden_states).view(hidden_shape)
  124. gate_states = self.gate_proj(hidden_states)
  125. query_states = self.q_norm(query_states).transpose(1, 2)
  126. key_states = self.k_norm(key_states).transpose(1, 2)
  127. value_states = value_states.transpose(1, 2)
  128. if self.is_local_attention:
  129. cos, sin = position_embeddings
  130. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  131. if past_key_value is not None:
  132. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
  133. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  134. self.config._attn_implementation, eager_attention_forward
  135. )
  136. output, attn_weights = attention_interface(
  137. self,
  138. query_states,
  139. key_states,
  140. value_states,
  141. attention_mask=attention_mask,
  142. dropout=0.0 if not self.training else self.attention_dropout,
  143. scaling=self.scaling,
  144. sliding_window=self.sliding_window,
  145. **kwargs,
  146. )
  147. output = output.view(*input_shape, -1).contiguous()
  148. output = output * torch.sigmoid(gate_states)
  149. attn_output = self.o_proj(output)
  150. return attn_output, attn_weights
  151. class AfmoeDecoderLayer(GradientCheckpointingLayer):
  152. """
  153. AFMoE decoder layer with dual normalization.
  154. This layer applies self-attention followed by either a dense MLP or MoE block,
  155. with dual normalization (pre and post) around each component.
  156. """
  157. def __init__(self, config: AfmoeConfig, layer_idx: int):
  158. super().__init__()
  159. self.hidden_size = config.hidden_size
  160. self.layer_idx = layer_idx
  161. self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx)
  162. # Dual normalization for attention
  163. self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  164. self.post_attention_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  165. # Dual normalization for FFN
  166. self.pre_mlp_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  167. self.post_mlp_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  168. # MoE or dense FFN
  169. self.moe_enabled = layer_idx >= config.num_dense_layers
  170. if self.moe_enabled:
  171. self.mlp = AfmoeSparseMoeBlock(config)
  172. else:
  173. self.mlp = AfmoeMLP(config)
  174. def forward(
  175. self,
  176. hidden_states: torch.Tensor,
  177. attention_mask: torch.Tensor | None = None,
  178. position_ids: torch.LongTensor | None = None,
  179. past_key_value: Cache | None = None,
  180. use_cache: bool | None = None,
  181. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  182. **kwargs: Unpack[TransformersKwargs],
  183. ) -> torch.FloatTensor:
  184. residual = hidden_states
  185. # Self Attention with dual normalization
  186. hidden_states = self.input_layernorm(hidden_states)
  187. hidden_states, _ = self.self_attn(
  188. hidden_states=hidden_states,
  189. attention_mask=attention_mask,
  190. position_ids=position_ids,
  191. past_key_value=past_key_value,
  192. use_cache=use_cache,
  193. position_embeddings=position_embeddings,
  194. **kwargs,
  195. )
  196. hidden_states = self.post_attention_layernorm(hidden_states)
  197. hidden_states = residual + hidden_states
  198. # FFN with dual normalization
  199. residual = hidden_states
  200. hidden_states = self.pre_mlp_layernorm(hidden_states)
  201. hidden_states = self.mlp(hidden_states)
  202. hidden_states = self.post_mlp_layernorm(hidden_states)
  203. hidden_states = residual + hidden_states
  204. return hidden_states
  205. class AfmoePreTrainedModel(PreTrainedModel):
  206. """
  207. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  208. models.
  209. """
  210. config: AfmoeConfig
  211. base_model_prefix = "model"
  212. _no_split_modules = ["AfmoeDecoderLayer"]
  213. _skip_keys_device_placement = ["past_key_values"]
  214. _can_record_outputs = {
  215. "router_logits": OutputRecorder(AfmoeTokenChoiceRouter, index=0),
  216. "hidden_states": AfmoeDecoderLayer,
  217. "attentions": AfmoeAttention,
  218. }
  219. _keep_in_fp32_modules = [
  220. "input_layernorm",
  221. "post_attention_layernorm",
  222. "pre_mlp_layernorm",
  223. "post_mlp_layernorm",
  224. "q_norm",
  225. "k_norm",
  226. "norm",
  227. "expert_bias",
  228. ]
  229. _supports_sdpa = True
  230. _supports_flash_attn = True
  231. _supports_flex_attn = True
  232. _can_compile_fullgraph = (
  233. is_grouped_mm_available()
  234. ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
  235. _supports_attention_backend = True
  236. supports_gradient_checkpointing = True
  237. def _init_weights(self, module):
  238. """Initialize the weights"""
  239. super()._init_weights(module)
  240. std = self.config.initializer_range
  241. if isinstance(module, AfmoeExperts):
  242. init.normal_(module.gate_up_proj, mean=0.0, std=std)
  243. init.normal_(module.down_proj, mean=0.0, std=std)
  244. elif isinstance(module, AfmoeTokenChoiceRouter):
  245. init.zeros_(module.gate.weight)
  246. elif isinstance(module, AfmoeSparseMoeBlock):
  247. init.zeros_(module.expert_bias)
  248. @auto_docstring
  249. class AfmoeModel(AfmoePreTrainedModel):
  250. """
  251. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AfmoeDecoderLayer`]
  252. Args:
  253. config: AfmoeConfig
  254. """
  255. def __init__(self, config: AfmoeConfig):
  256. super().__init__(config)
  257. self.padding_idx = config.pad_token_id
  258. self.vocab_size = config.vocab_size
  259. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  260. self.layers = nn.ModuleList(
  261. [AfmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  262. )
  263. self.norm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  264. self.rotary_emb = AfmoeRotaryEmbedding(config=config)
  265. self.gradient_checkpointing = False
  266. self.post_init()
  267. @auto_docstring
  268. @merge_with_config_defaults
  269. @capture_outputs
  270. def forward(
  271. self,
  272. input_ids: torch.LongTensor | None = None,
  273. attention_mask: torch.Tensor | None = None,
  274. inputs_embeds: torch.FloatTensor | None = None,
  275. position_ids: torch.LongTensor | None = None,
  276. past_key_values: Cache | None = None,
  277. use_cache: bool | None = None,
  278. **kwargs: Unpack[TransformersKwargs],
  279. ) -> tuple | MoeModelOutputWithPast:
  280. if (input_ids is None) ^ (inputs_embeds is not None):
  281. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  282. if use_cache and past_key_values is None:
  283. past_key_values = DynamicCache(config=self.config)
  284. if inputs_embeds is None:
  285. inputs_embeds = self.embed_tokens(input_ids)
  286. if position_ids is None:
  287. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  288. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  289. position_ids = position_ids.unsqueeze(0)
  290. # It may already have been prepared by e.g. `generate`
  291. if not isinstance(causal_mask_mapping := attention_mask, dict):
  292. mask_kwargs = {
  293. "config": self.config,
  294. "inputs_embeds": inputs_embeds,
  295. "attention_mask": attention_mask,
  296. "past_key_values": past_key_values,
  297. }
  298. causal_mask_mapping = {
  299. "full_attention": create_causal_mask(**mask_kwargs),
  300. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  301. }
  302. hidden_states = inputs_embeds
  303. # Apply muP input scaling if enabled
  304. if self.config.mup_enabled:
  305. hidden_states = hidden_states * (self.config.hidden_size**0.5)
  306. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  307. for i, decoder_layer in enumerate(self.layers):
  308. hidden_states = decoder_layer(
  309. hidden_states,
  310. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  311. position_ids=position_ids,
  312. past_key_value=past_key_values,
  313. use_cache=use_cache,
  314. position_embeddings=position_embeddings,
  315. **kwargs,
  316. )
  317. hidden_states = self.norm(hidden_states)
  318. return MoeModelOutputWithPast(
  319. last_hidden_state=hidden_states,
  320. past_key_values=past_key_values if use_cache else None,
  321. )
  322. class AfmoeForCausalLM(LlamaForCausalLM, AfmoePreTrainedModel, GenerationMixin):
  323. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  324. _tp_plan = {"lm_head": "colwise_gather_output"}
  325. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  326. def __init__(self, config):
  327. AfmoePreTrainedModel.__init__(self, config)
  328. self.model = AfmoeModel(config)
  329. self.vocab_size = config.vocab_size
  330. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  331. self.post_init()
  332. @can_return_tuple
  333. @auto_docstring
  334. def forward(
  335. self,
  336. input_ids: torch.LongTensor | None = None,
  337. attention_mask: torch.Tensor | None = None,
  338. position_ids: torch.LongTensor | None = None,
  339. past_key_values: Cache | None = None,
  340. inputs_embeds: torch.FloatTensor | None = None,
  341. labels: torch.LongTensor | None = None,
  342. use_cache: bool | None = None,
  343. output_router_logits: bool | None = None,
  344. logits_to_keep: int | torch.Tensor = 0,
  345. **kwargs: Unpack[TransformersKwargs],
  346. ) -> MoeCausalLMOutputWithPast:
  347. output_router_logits = (
  348. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  349. )
  350. outputs: MoeModelOutputWithPast = self.model(
  351. input_ids=input_ids,
  352. attention_mask=attention_mask,
  353. position_ids=position_ids,
  354. past_key_values=past_key_values,
  355. inputs_embeds=inputs_embeds,
  356. use_cache=use_cache,
  357. output_router_logits=output_router_logits,
  358. **kwargs,
  359. )
  360. hidden_states = outputs.last_hidden_state
  361. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  362. logits = self.lm_head(hidden_states[:, slice_indices, :])
  363. loss = None
  364. if labels is not None:
  365. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  366. return MoeCausalLMOutputWithPast(
  367. loss=loss,
  368. logits=logits,
  369. past_key_values=outputs.past_key_values,
  370. hidden_states=outputs.hidden_states,
  371. attentions=outputs.attentions,
  372. router_logits=outputs.router_logits,
  373. )
  374. __all__ = [
  375. "AfmoeForCausalLM",
  376. "AfmoeModel",
  377. "AfmoePreTrainedModel",
  378. ]