modeling_jetmoe.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/jetmoe/modular_jetmoe.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_jetmoe.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 JetMoe AI and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. from torch import nn
  24. from torch.nn import functional as F
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
  30. from ...masking_utils import create_causal_mask
  31. from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
  32. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  33. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  34. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  35. from ...processing_utils import Unpack
  36. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  37. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  38. from ...utils.output_capturing import OutputRecorder, capture_outputs
  39. from .configuration_jetmoe import JetMoeConfig
  40. logger = logging.get_logger(__name__)
  41. @use_kernel_forward_from_hub("RMSNorm")
  42. class JetMoeRMSNorm(nn.Module):
  43. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  44. """
  45. JetMoeRMSNorm is equivalent to T5LayerNorm
  46. """
  47. super().__init__()
  48. self.weight = nn.Parameter(torch.ones(hidden_size))
  49. self.variance_epsilon = eps
  50. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  51. input_dtype = hidden_states.dtype
  52. hidden_states = hidden_states.to(torch.float32)
  53. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  54. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  55. return self.weight * hidden_states.to(input_dtype)
  56. def extra_repr(self):
  57. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  58. class JetMoeRotaryEmbedding(nn.Module):
  59. inv_freq: torch.Tensor # fix linting for `register_buffer`
  60. def __init__(self, config: JetMoeConfig, device=None):
  61. super().__init__()
  62. self.max_seq_len_cached = config.max_position_embeddings
  63. self.original_max_seq_len = config.max_position_embeddings
  64. self.config = config
  65. self.rope_type = self.config.rope_parameters["rope_type"]
  66. rope_init_fn: Callable = self.compute_default_rope_parameters
  67. if self.rope_type != "default":
  68. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  69. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  70. self.register_buffer("inv_freq", inv_freq, persistent=False)
  71. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  72. @staticmethod
  73. def compute_default_rope_parameters(
  74. config: JetMoeConfig | None = None,
  75. device: Optional["torch.device"] = None,
  76. seq_len: int | None = None,
  77. ) -> tuple["torch.Tensor", float]:
  78. """
  79. Computes the inverse frequencies according to the original RoPE implementation
  80. Args:
  81. config ([`~transformers.PreTrainedConfig`]):
  82. The model configuration.
  83. device (`torch.device`):
  84. The device to use for initialization of the inverse frequencies.
  85. seq_len (`int`, *optional*):
  86. The current sequence length. Unused for this type of RoPE.
  87. Returns:
  88. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  89. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  90. """
  91. base = config.rope_parameters["rope_theta"]
  92. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  93. attention_factor = 1.0 # Unused in this type of RoPE
  94. # Compute the inverse frequencies
  95. inv_freq = 1.0 / (
  96. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  97. )
  98. return inv_freq, attention_factor
  99. @torch.no_grad()
  100. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  101. def forward(self, x, position_ids):
  102. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  103. position_ids_expanded = position_ids[:, None, :].float()
  104. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  105. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  106. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  107. emb = torch.cat((freqs, freqs), dim=-1)
  108. cos = emb.cos() * self.attention_scaling
  109. sin = emb.sin() * self.attention_scaling
  110. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  111. class JetMoeParallelExperts(nn.Module):
  112. def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
  113. """
  114. Initialize the JetMoeParallelExperts module.
  115. The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with
  116. many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and
  117. [ScatterMoE](https://github.com/shawntan/scattermoe), as well as the
  118. [MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py)
  119. used in vllm.
  120. Args:
  121. num_experts (int):
  122. Number of experts.
  123. input_size (int):
  124. Size of the input.
  125. output_size (int):
  126. Size of the output.
  127. """
  128. super().__init__()
  129. self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
  130. self.num_experts = num_experts
  131. self.input_size = input_size
  132. self.output_size = output_size
  133. def forward(self, inputs, expert_size):
  134. """
  135. Forward pass of the JetMoeParallelExperts module.
  136. Args:
  137. inputs (Tensor):
  138. Input tensor.
  139. expert_size:
  140. Expert size information.
  141. Returns:
  142. Tensor: Output tensor.
  143. """
  144. input_list = inputs.split(expert_size, dim=0)
  145. output_list = []
  146. for i in range(self.num_experts):
  147. output_list.append(F.linear(input_list[i], self.weight[i]))
  148. results = torch.cat(output_list, dim=0)
  149. return results
  150. class JetMoeTopKGating(nn.Module):
  151. def __init__(self, input_size: int, num_experts: int, top_k: int):
  152. """
  153. Initialize the top-k gating mechanism.
  154. Args:
  155. input_size (`int`):
  156. Size of the input.
  157. num_experts (`int`):
  158. Number of experts.
  159. top_k (`int`):
  160. Number of top experts to select.
  161. """
  162. super().__init__()
  163. self.num_experts = num_experts
  164. self.input_size = input_size
  165. self.top_k = top_k
  166. self.layer = nn.Linear(input_size, num_experts, bias=False)
  167. def forward(self, hidden_states):
  168. # compute the top_k routing decision
  169. logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
  170. top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
  171. top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
  172. # compute number of input given to each expert
  173. zeros = torch.zeros(
  174. [top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
  175. ) # [num_tokens, num_experts]
  176. gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
  177. expert_size = gates.long().sum(0) # [num_experts,]
  178. # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
  179. # (and `DataDependentOutputException`)
  180. expert_size = expert_size.tolist()
  181. # sort and group input tokens according to expert assignment
  182. top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
  183. _, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
  184. batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]
  185. # gather the gate values for grouped input tokens
  186. top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
  187. batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]
  188. return index_sorted_experts, batch_index, batch_gates, expert_size, logits
  189. class JetMoeMoE(nn.Module):
  190. """
  191. A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
  192. Args:
  193. config:
  194. Configuration object with model hyperparameters.
  195. """
  196. def __init__(self, config: JetMoeConfig):
  197. super().__init__()
  198. self.input_size = config.hidden_size
  199. self.hidden_size = config.intermediate_size
  200. self.activation = ACT2FN[config.activation_function]
  201. self.bias = torch.nn.Parameter(torch.empty(self.input_size))
  202. self.input_linear = JetMoeParallelExperts(config.num_local_experts, self.input_size, self.hidden_size * 2)
  203. self.output_linear = JetMoeParallelExperts(config.num_local_experts, self.hidden_size, self.input_size)
  204. self.router = JetMoeTopKGating(
  205. input_size=self.input_size,
  206. num_experts=config.num_local_experts,
  207. top_k=config.num_experts_per_tok,
  208. )
  209. def forward(self, layer_input):
  210. """
  211. Forward pass of the mixture of experts layer.
  212. Args:
  213. layer_input (Tensor):
  214. Input tensor.
  215. Returns:
  216. Tensor:
  217. Output tensor.
  218. Tensor:
  219. Router logits.
  220. """
  221. bsz, length, emb_size = layer_input.size()
  222. layer_input = layer_input.reshape(-1, emb_size)
  223. _, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)
  224. expert_inputs = layer_input[batch_index]
  225. hidden_states = self.input_linear(expert_inputs, expert_size)
  226. chunked_hidden_states = hidden_states.chunk(2, dim=-1)
  227. hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
  228. expert_outputs = self.output_linear(hidden_states, expert_size)
  229. expert_outputs = expert_outputs * batch_gates[:, None]
  230. zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
  231. layer_output = zeros.index_add(0, batch_index, expert_outputs)
  232. layer_output = layer_output.view(bsz, length, self.input_size)
  233. layer_output = layer_output + self.bias
  234. return layer_output
  235. class JetMoeMoA(nn.Module):
  236. """
  237. A Sparsely gated mixture of attention layer with pairs of query- and output-projections as experts.
  238. Args:
  239. config:
  240. Configuration object with model hyperparameters.
  241. """
  242. def __init__(self, config: JetMoeConfig):
  243. super().__init__()
  244. self.num_experts = config.num_local_experts
  245. self.input_size = config.hidden_size
  246. self.hidden_size = config.kv_channels * config.num_key_value_heads
  247. self.top_k = config.num_experts_per_tok
  248. self.bias = torch.nn.Parameter(torch.empty(self.input_size))
  249. self.input_linear = JetMoeParallelExperts(self.num_experts, self.input_size, self.hidden_size)
  250. self.output_linear = JetMoeParallelExperts(self.num_experts, self.hidden_size, self.input_size)
  251. self.router = JetMoeTopKGating(
  252. input_size=self.input_size,
  253. num_experts=self.num_experts,
  254. top_k=self.top_k,
  255. )
  256. def map(self, layer_input):
  257. """
  258. Map inputs to attention experts according to routing decision and compute query projection inside each experts.
  259. """
  260. # Compute gating topology
  261. bsz, length, emb_size = layer_input.size()
  262. layer_input = layer_input.reshape(-1, emb_size) # [bsz * length, emb_size]
  263. index_sorted_experts, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)
  264. topo_info = (index_sorted_experts, batch_index, batch_gates, expert_size)
  265. # Group inputs according to topology and compute query projection
  266. expert_inputs = layer_input[batch_index] # [bsz * length * top_k, emb_size]
  267. expert_outputs = self.input_linear(expert_inputs, expert_size) # [bsz * length * top_k, hidden_size]
  268. # Ungroup queries back to original order
  269. zeros = torch.zeros(
  270. (bsz * length * self.top_k, self.hidden_size), dtype=expert_outputs.dtype, device=expert_outputs.device
  271. )
  272. layer_output = zeros.index_add(0, index_sorted_experts, expert_outputs)
  273. layer_output = layer_output.view(bsz, length, self.top_k, -1) # [bsz, length, top_k, hidden_size]
  274. return layer_output, router_logits, topo_info
  275. def reduce(self, layer_input, topo_info):
  276. """
  277. Compute output projection inside each attention experts and merge the outputs of different experts.
  278. """
  279. bsz, length, k, hidden_size = layer_input.size()
  280. layer_input = layer_input.reshape(-1, hidden_size) # [bsz * length * k, hidden_size]
  281. index_sorted_experts, batch_index, batch_gates, expert_size = topo_info
  282. # Group inputs according to topology and compute output projection
  283. expert_inputs = layer_input[index_sorted_experts] # [bsz * length * top_k, hidden_size]
  284. expert_outputs = self.output_linear(expert_inputs, expert_size) # [bsz * length * top_k, emb_size]
  285. # Apply gates to attention expert outputs
  286. expert_outputs = expert_outputs * batch_gates[:, None]
  287. # Ungroup and merge outputs to original order
  288. zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
  289. layer_output = zeros.index_add(0, batch_index, expert_outputs)
  290. layer_output = layer_output.view(bsz, length, self.input_size)
  291. layer_output = layer_output + self.bias
  292. return layer_output
  293. def forward(self, layer_input):
  294. raise NotImplementedError("This module doesn't support call and forward.")
  295. def rotate_half(x):
  296. """Rotates half the hidden dims of the input."""
  297. x1 = x[..., : x.shape[-1] // 2]
  298. x2 = x[..., x.shape[-1] // 2 :]
  299. return torch.cat((-x2, x1), dim=-1)
  300. @use_kernel_func_from_hub("rotary_pos_emb")
  301. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  302. """Applies Rotary Position Embedding to the query and key tensors.
  303. Args:
  304. q (`torch.Tensor`): The query tensor.
  305. k (`torch.Tensor`): The key tensor.
  306. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  307. sin (`torch.Tensor`): The sine part of the rotary embedding.
  308. unsqueeze_dim (`int`, *optional*, defaults to 1):
  309. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  310. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  311. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  312. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  313. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  314. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  315. Returns:
  316. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  317. """
  318. cos = cos.unsqueeze(unsqueeze_dim)
  319. sin = sin.unsqueeze(unsqueeze_dim)
  320. q_embed = (q * cos) + (rotate_half(q) * sin)
  321. k_embed = (k * cos) + (rotate_half(k) * sin)
  322. return q_embed, k_embed
  323. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  324. """
  325. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  326. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  327. """
  328. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  329. if n_rep == 1:
  330. return hidden_states
  331. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  332. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  333. def eager_attention_forward(
  334. module: nn.Module,
  335. query: torch.Tensor,
  336. key: torch.Tensor,
  337. value: torch.Tensor,
  338. attention_mask: torch.Tensor | None,
  339. scaling: float,
  340. dropout: float = 0.0,
  341. **kwargs: Unpack[TransformersKwargs],
  342. ):
  343. key_states = repeat_kv(key, module.num_key_value_groups)
  344. value_states = repeat_kv(value, module.num_key_value_groups)
  345. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  346. if attention_mask is not None:
  347. attn_weights = attn_weights + attention_mask
  348. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  349. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  350. attn_output = torch.matmul(attn_weights, value_states)
  351. attn_output = attn_output.transpose(1, 2).contiguous()
  352. return attn_output, attn_weights
  353. class JetMoeAttention(nn.Module):
  354. """
  355. Multi-headed attention from 'Attention Is All You Need' paper.
  356. """
  357. def __init__(self, config: JetMoeConfig, layer_idx: int | None = None):
  358. """
  359. Initialize the JetMoeAttention module.
  360. Args:
  361. config:
  362. Configuration object with model hyperparameters.
  363. layer_idx:
  364. Index of the layer in the model.
  365. """
  366. super().__init__()
  367. self.config = config
  368. self.layer_idx = layer_idx
  369. self.is_causal = True
  370. if layer_idx is None:
  371. logger.warning_once(
  372. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  373. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  374. "when creating this class."
  375. )
  376. self.num_key_value_groups = 1 # We ignore this by setting it to 1 as we have different repeat patterns
  377. self.top_k = config.num_experts_per_tok
  378. self.attention_dropout = config.attention_dropout
  379. self.kv_projection_size = config.kv_channels * config.num_key_value_heads
  380. self.num_key_value_heads = config.num_key_value_heads
  381. self.num_heads = config.num_attention_heads
  382. self.head_dim = config.kv_channels
  383. self.scaling = self.head_dim**-0.5
  384. self.experts = JetMoeMoA(config)
  385. self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False)
  386. def forward(
  387. self,
  388. hidden_states: torch.Tensor,
  389. attention_mask: torch.Tensor | None = None,
  390. position_embeddings: torch.LongTensor | None = None,
  391. past_key_values: Cache | None = None,
  392. **kwargs,
  393. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  394. input_shape = hidden_states.shape[:-1]
  395. hidden_shape = (*input_shape, -1, self.head_dim)
  396. query_states, router_logits, topo_info = self.experts.map(hidden_states)
  397. key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1)
  398. query_states = query_states.view(hidden_shape).transpose(1, 2)
  399. key_states = key_states.view(hidden_shape).transpose(1, 2)
  400. value_states = value_states.view(hidden_shape).transpose(1, 2)
  401. cos, sin = position_embeddings
  402. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  403. if past_key_values is not None:
  404. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  405. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  406. self.config._attn_implementation, eager_attention_forward
  407. )
  408. # This is different from other models where we repeat k/v heads
  409. # instead of repeat interleaving them
  410. key_states = key_states.repeat(1, self.top_k, 1, 1)
  411. value_states = value_states.repeat(1, self.top_k, 1, 1)
  412. attn_output, attn_weights = attention_interface(
  413. self,
  414. query_states,
  415. key_states,
  416. value_states,
  417. attention_mask,
  418. dropout=0.0 if not self.training else self.attention_dropout,
  419. scaling=self.scaling,
  420. **kwargs,
  421. )
  422. attn_output = attn_output.view(*input_shape, self.top_k, -1)
  423. attn_output = self.experts.reduce(attn_output, topo_info)
  424. attn_output = attn_output.view(*input_shape, -1)
  425. return attn_output, attn_weights, router_logits
  426. class JetMoeDecoderLayer(GradientCheckpointingLayer):
  427. def __init__(self, config: JetMoeConfig, layer_idx: int | None = None):
  428. super().__init__()
  429. self.hidden_size = config.hidden_size
  430. self.mlp = JetMoeMoE(config)
  431. self.input_layernorm = JetMoeRMSNorm(config.hidden_size)
  432. self.post_attention_layernorm = JetMoeRMSNorm(config.hidden_size)
  433. self.self_attention = JetMoeAttention(config, layer_idx)
  434. def forward(
  435. self,
  436. hidden_states: torch.Tensor,
  437. attention_mask: torch.Tensor | None = None,
  438. position_ids: torch.LongTensor | None = None,
  439. past_key_values: Cache | None = None,
  440. use_cache: bool | None = False,
  441. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  442. **kwargs: Unpack[TransformersKwargs],
  443. ) -> torch.Tensor:
  444. residual = hidden_states
  445. hidden_states = self.input_layernorm(hidden_states)
  446. # Self Attention
  447. hidden_states, _, _ = self.self_attention(
  448. hidden_states=hidden_states,
  449. attention_mask=attention_mask,
  450. position_ids=position_ids,
  451. past_key_values=past_key_values,
  452. use_cache=use_cache,
  453. position_embeddings=position_embeddings,
  454. **kwargs,
  455. )
  456. hidden_states = residual + hidden_states
  457. # Fully Connected
  458. residual = hidden_states
  459. hidden_states = self.post_attention_layernorm(hidden_states)
  460. hidden_states = self.mlp(hidden_states)
  461. hidden_states = residual + hidden_states
  462. return hidden_states
  463. @auto_docstring
  464. class JetMoePreTrainedModel(PreTrainedModel):
  465. config: JetMoeConfig
  466. base_model_prefix = "model"
  467. supports_gradient_checkpointing = False
  468. _no_split_modules = ["JetMoeDecoderLayer"]
  469. _skip_keys_device_placement = ["past_key_values"]
  470. _supports_flash_attn = True
  471. _supports_sdpa = True
  472. _supports_flex_attn = True
  473. _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
  474. _supports_attention_backend = True
  475. _can_record_outputs = {
  476. "router_logits": [OutputRecorder(JetMoeAttention, index=2), OutputRecorder(JetMoeTopKGating, index=4)],
  477. "hidden_states": JetMoeDecoderLayer,
  478. "attentions": OutputRecorder(JetMoeAttention, index=1),
  479. }
  480. @torch.no_grad()
  481. def _init_weights(self, module):
  482. """Initialize the weights."""
  483. super()._init_weights(module)
  484. if isinstance(module, JetMoeParallelExperts):
  485. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  486. elif isinstance(module, JetMoeMoA | JetMoeMoE):
  487. init.zeros_(module.bias)
  488. @auto_docstring
  489. class JetMoeModel(JetMoePreTrainedModel):
  490. def __init__(self, config: JetMoeConfig):
  491. super().__init__(config)
  492. self.padding_idx = config.pad_token_id
  493. self.vocab_size = config.vocab_size
  494. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  495. self.layers = nn.ModuleList(
  496. [JetMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  497. )
  498. self.norm = JetMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  499. self.rotary_emb = JetMoeRotaryEmbedding(config=config)
  500. self.gradient_checkpointing = False
  501. self._attn_implementation = config._attn_implementation
  502. # Initialize weights and apply final processing
  503. self.post_init()
  504. @merge_with_config_defaults
  505. @capture_outputs
  506. @auto_docstring
  507. def forward(
  508. self,
  509. input_ids: torch.LongTensor | None = None,
  510. attention_mask: torch.Tensor | None = None,
  511. position_ids: torch.LongTensor | None = None,
  512. past_key_values: Cache | None = None,
  513. inputs_embeds: torch.FloatTensor | None = None,
  514. use_cache: bool | None = None,
  515. **kwargs: Unpack[TransformersKwargs],
  516. ) -> MoeModelOutputWithPast:
  517. if (input_ids is None) ^ (inputs_embeds is not None):
  518. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  519. if use_cache and past_key_values is None:
  520. past_key_values = DynamicCache(config=self.config)
  521. if inputs_embeds is None:
  522. inputs_embeds = self.embed_tokens(input_ids)
  523. if position_ids is None:
  524. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  525. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  526. position_ids = position_ids.unsqueeze(0)
  527. causal_mask = create_causal_mask(
  528. config=self.config,
  529. inputs_embeds=inputs_embeds,
  530. attention_mask=attention_mask,
  531. past_key_values=past_key_values,
  532. position_ids=position_ids,
  533. )
  534. hidden_states = inputs_embeds
  535. # create position embeddings to be shared across the decoder layers
  536. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  537. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  538. hidden_states = decoder_layer(
  539. hidden_states,
  540. position_embeddings=position_embeddings,
  541. attention_mask=causal_mask,
  542. past_key_values=past_key_values,
  543. use_cache=use_cache,
  544. position_ids=position_ids,
  545. **kwargs,
  546. )
  547. hidden_states = self.norm(hidden_states)
  548. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  549. last_hidden_state=hidden_states,
  550. past_key_values=past_key_values,
  551. )
  552. def load_balancing_loss_func(
  553. gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
  554. num_experts: int | None = None,
  555. top_k=2,
  556. attention_mask: torch.Tensor | None = None,
  557. ) -> torch.Tensor | int:
  558. r"""
  559. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  560. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  561. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  562. experts is too unbalanced.
  563. Args:
  564. gate_logits:
  565. Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
  566. shape [batch_size X sequence_length, num_experts].
  567. num_experts:
  568. Number of experts
  569. top_k:
  570. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  571. parameter.
  572. attention_mask (`torch.Tensor`, *optional*):
  573. The attention_mask used in forward function
  574. shape [batch_size X sequence_length] if not None.
  575. Returns:
  576. The auxiliary loss.
  577. """
  578. if gate_logits is None or not isinstance(gate_logits, tuple):
  579. return 0
  580. if isinstance(gate_logits, tuple):
  581. compute_device = gate_logits[0].device
  582. concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
  583. routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
  584. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  585. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  586. if attention_mask is None:
  587. # Compute the percentage of tokens routed to each experts
  588. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  589. # Compute the average probability of routing to these experts
  590. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  591. else:
  592. batch_size, sequence_length = attention_mask.shape
  593. num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
  594. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  595. expert_attention_mask = (
  596. attention_mask[None, :, :, None, None]
  597. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  598. .reshape(-1, top_k, num_experts)
  599. .to(compute_device)
  600. )
  601. # Compute the percentage of tokens routed to each experts
  602. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  603. expert_attention_mask, dim=0
  604. )
  605. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  606. router_per_expert_attention_mask = (
  607. attention_mask[None, :, :, None]
  608. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  609. .reshape(-1, num_experts)
  610. .to(compute_device)
  611. )
  612. # Compute the average probability of routing to these experts
  613. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  614. router_per_expert_attention_mask, dim=0
  615. )
  616. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
  617. return overall_loss * num_experts
  618. class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin):
  619. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  620. def __init__(self, config):
  621. super().__init__(config)
  622. self.model = JetMoeModel(config)
  623. self.vocab_size = config.vocab_size
  624. self.aux_loss_coef = config.aux_loss_coef
  625. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  626. self.tie_word_embeddings = config.tie_word_embeddings
  627. self.num_experts = config.num_local_experts
  628. self.num_experts_per_tok = config.num_experts_per_tok
  629. # Initialize weights and apply final processing
  630. self.post_init()
  631. @can_return_tuple
  632. @auto_docstring
  633. def forward(
  634. self,
  635. input_ids: torch.LongTensor | None = None,
  636. attention_mask: torch.Tensor | None = None,
  637. position_ids: torch.LongTensor | None = None,
  638. past_key_values: Cache | None = None,
  639. inputs_embeds: torch.FloatTensor | None = None,
  640. labels: torch.LongTensor | None = None,
  641. use_cache: bool | None = None,
  642. logits_to_keep: int | torch.Tensor = 0,
  643. output_router_logits: bool | None = False,
  644. **kwargs,
  645. ) -> MoeCausalLMOutputWithPast:
  646. outputs: MoeModelOutputWithPast = self.model(
  647. input_ids=input_ids,
  648. attention_mask=attention_mask,
  649. position_ids=position_ids,
  650. past_key_values=past_key_values,
  651. inputs_embeds=inputs_embeds,
  652. use_cache=use_cache,
  653. output_router_logits=output_router_logits,
  654. **kwargs,
  655. )
  656. hidden_states = outputs.last_hidden_state
  657. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  658. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  659. logits = self.lm_head(hidden_states[:, slice_indices, :])
  660. loss = None
  661. if labels is not None:
  662. loss = self.loss_function(
  663. logits,
  664. labels,
  665. vocab_size=self.config.vocab_size,
  666. **kwargs,
  667. )
  668. aux_loss = None
  669. if output_router_logits:
  670. aux_loss = load_balancing_loss_func(
  671. outputs.router_logits,
  672. self.num_experts,
  673. self.num_experts_per_tok,
  674. attention_mask,
  675. )
  676. if labels is not None:
  677. loss += self.aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  678. return MoeCausalLMOutputWithPast(
  679. loss=loss,
  680. aux_loss=aux_loss,
  681. logits=logits,
  682. past_key_values=outputs.past_key_values,
  683. hidden_states=outputs.hidden_states,
  684. attentions=outputs.attentions,
  685. router_logits=outputs.router_logits,
  686. )
  687. class JetMoeForSequenceClassification(GenericForSequenceClassification, JetMoePreTrainedModel): ...
  688. __all__ = ["JetMoeForCausalLM", "JetMoeModel", "JetMoePreTrainedModel", "JetMoeForSequenceClassification"]