modular_jetmoe.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. # Copyright 2024 JetMoe AI and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch JetMoe model."""
  15. from collections.abc import Callable
  16. import torch
  17. from torch import nn
  18. from torch.nn import functional as F
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache, DynamicCache
  22. from ...generation import GenerationMixin
  23. from ...masking_utils import create_causal_mask
  24. from ...modeling_layers import (
  25. GenericForSequenceClassification,
  26. )
  27. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  28. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  29. from ...processing_utils import Unpack
  30. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  31. from ...utils.generic import merge_with_config_defaults
  32. from ...utils.output_capturing import OutputRecorder, capture_outputs
  33. from ..llama.modeling_llama import LlamaDecoderLayer
  34. from ..mixtral.modeling_mixtral import (
  35. MixtralModel,
  36. MixtralPreTrainedModel,
  37. MixtralRMSNorm,
  38. MixtralRotaryEmbedding,
  39. apply_rotary_pos_emb,
  40. eager_attention_forward,
  41. load_balancing_loss_func,
  42. )
  43. from .configuration_jetmoe import JetMoeConfig
  44. logger = logging.get_logger(__name__)
  45. class JetMoeRMSNorm(MixtralRMSNorm):
  46. pass
  47. class JetMoeRotaryEmbedding(MixtralRotaryEmbedding):
  48. pass
  49. class JetMoeParallelExperts(nn.Module):
  50. def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
  51. """
  52. Initialize the JetMoeParallelExperts module.
  53. The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with
  54. many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and
  55. [ScatterMoE](https://github.com/shawntan/scattermoe), as well as the
  56. [MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py)
  57. used in vllm.
  58. Args:
  59. num_experts (int):
  60. Number of experts.
  61. input_size (int):
  62. Size of the input.
  63. output_size (int):
  64. Size of the output.
  65. """
  66. super().__init__()
  67. self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
  68. self.num_experts = num_experts
  69. self.input_size = input_size
  70. self.output_size = output_size
  71. def forward(self, inputs, expert_size):
  72. """
  73. Forward pass of the JetMoeParallelExperts module.
  74. Args:
  75. inputs (Tensor):
  76. Input tensor.
  77. expert_size:
  78. Expert size information.
  79. Returns:
  80. Tensor: Output tensor.
  81. """
  82. input_list = inputs.split(expert_size, dim=0)
  83. output_list = []
  84. for i in range(self.num_experts):
  85. output_list.append(F.linear(input_list[i], self.weight[i]))
  86. results = torch.cat(output_list, dim=0)
  87. return results
  88. class JetMoeTopKGating(nn.Module):
  89. def __init__(self, input_size: int, num_experts: int, top_k: int):
  90. """
  91. Initialize the top-k gating mechanism.
  92. Args:
  93. input_size (`int`):
  94. Size of the input.
  95. num_experts (`int`):
  96. Number of experts.
  97. top_k (`int`):
  98. Number of top experts to select.
  99. """
  100. super().__init__()
  101. self.num_experts = num_experts
  102. self.input_size = input_size
  103. self.top_k = top_k
  104. self.layer = nn.Linear(input_size, num_experts, bias=False)
  105. def forward(self, hidden_states):
  106. # compute the top_k routing decision
  107. logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
  108. top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
  109. top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
  110. # compute number of input given to each expert
  111. zeros = torch.zeros(
  112. [top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
  113. ) # [num_tokens, num_experts]
  114. gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
  115. expert_size = gates.long().sum(0) # [num_experts,]
  116. # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
  117. # (and `DataDependentOutputException`)
  118. expert_size = expert_size.tolist()
  119. # sort and group input tokens according to expert assignment
  120. top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
  121. _, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
  122. batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]
  123. # gather the gate values for grouped input tokens
  124. top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
  125. batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]
  126. return index_sorted_experts, batch_index, batch_gates, expert_size, logits
  127. class JetMoeMoE(nn.Module):
  128. """
  129. A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
  130. Args:
  131. config:
  132. Configuration object with model hyperparameters.
  133. """
  134. def __init__(self, config: JetMoeConfig):
  135. super().__init__()
  136. self.input_size = config.hidden_size
  137. self.hidden_size = config.intermediate_size
  138. self.activation = ACT2FN[config.activation_function]
  139. self.bias = torch.nn.Parameter(torch.empty(self.input_size))
  140. self.input_linear = JetMoeParallelExperts(config.num_local_experts, self.input_size, self.hidden_size * 2)
  141. self.output_linear = JetMoeParallelExperts(config.num_local_experts, self.hidden_size, self.input_size)
  142. self.router = JetMoeTopKGating(
  143. input_size=self.input_size,
  144. num_experts=config.num_local_experts,
  145. top_k=config.num_experts_per_tok,
  146. )
  147. def forward(self, layer_input):
  148. """
  149. Forward pass of the mixture of experts layer.
  150. Args:
  151. layer_input (Tensor):
  152. Input tensor.
  153. Returns:
  154. Tensor:
  155. Output tensor.
  156. Tensor:
  157. Router logits.
  158. """
  159. bsz, length, emb_size = layer_input.size()
  160. layer_input = layer_input.reshape(-1, emb_size)
  161. _, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)
  162. expert_inputs = layer_input[batch_index]
  163. hidden_states = self.input_linear(expert_inputs, expert_size)
  164. chunked_hidden_states = hidden_states.chunk(2, dim=-1)
  165. hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
  166. expert_outputs = self.output_linear(hidden_states, expert_size)
  167. expert_outputs = expert_outputs * batch_gates[:, None]
  168. zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
  169. layer_output = zeros.index_add(0, batch_index, expert_outputs)
  170. layer_output = layer_output.view(bsz, length, self.input_size)
  171. layer_output = layer_output + self.bias
  172. return layer_output
  173. class JetMoeMoA(nn.Module):
  174. """
  175. A Sparsely gated mixture of attention layer with pairs of query- and output-projections as experts.
  176. Args:
  177. config:
  178. Configuration object with model hyperparameters.
  179. """
  180. def __init__(self, config: JetMoeConfig):
  181. super().__init__()
  182. self.num_experts = config.num_local_experts
  183. self.input_size = config.hidden_size
  184. self.hidden_size = config.kv_channels * config.num_key_value_heads
  185. self.top_k = config.num_experts_per_tok
  186. self.bias = torch.nn.Parameter(torch.empty(self.input_size))
  187. self.input_linear = JetMoeParallelExperts(self.num_experts, self.input_size, self.hidden_size)
  188. self.output_linear = JetMoeParallelExperts(self.num_experts, self.hidden_size, self.input_size)
  189. self.router = JetMoeTopKGating(
  190. input_size=self.input_size,
  191. num_experts=self.num_experts,
  192. top_k=self.top_k,
  193. )
  194. def map(self, layer_input):
  195. """
  196. Map inputs to attention experts according to routing decision and compute query projection inside each experts.
  197. """
  198. # Compute gating topology
  199. bsz, length, emb_size = layer_input.size()
  200. layer_input = layer_input.reshape(-1, emb_size) # [bsz * length, emb_size]
  201. index_sorted_experts, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)
  202. topo_info = (index_sorted_experts, batch_index, batch_gates, expert_size)
  203. # Group inputs according to topology and compute query projection
  204. expert_inputs = layer_input[batch_index] # [bsz * length * top_k, emb_size]
  205. expert_outputs = self.input_linear(expert_inputs, expert_size) # [bsz * length * top_k, hidden_size]
  206. # Ungroup queries back to original order
  207. zeros = torch.zeros(
  208. (bsz * length * self.top_k, self.hidden_size), dtype=expert_outputs.dtype, device=expert_outputs.device
  209. )
  210. layer_output = zeros.index_add(0, index_sorted_experts, expert_outputs)
  211. layer_output = layer_output.view(bsz, length, self.top_k, -1) # [bsz, length, top_k, hidden_size]
  212. return layer_output, router_logits, topo_info
  213. def reduce(self, layer_input, topo_info):
  214. """
  215. Compute output projection inside each attention experts and merge the outputs of different experts.
  216. """
  217. bsz, length, k, hidden_size = layer_input.size()
  218. layer_input = layer_input.reshape(-1, hidden_size) # [bsz * length * k, hidden_size]
  219. index_sorted_experts, batch_index, batch_gates, expert_size = topo_info
  220. # Group inputs according to topology and compute output projection
  221. expert_inputs = layer_input[index_sorted_experts] # [bsz * length * top_k, hidden_size]
  222. expert_outputs = self.output_linear(expert_inputs, expert_size) # [bsz * length * top_k, emb_size]
  223. # Apply gates to attention expert outputs
  224. expert_outputs = expert_outputs * batch_gates[:, None]
  225. # Ungroup and merge outputs to original order
  226. zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
  227. layer_output = zeros.index_add(0, batch_index, expert_outputs)
  228. layer_output = layer_output.view(bsz, length, self.input_size)
  229. layer_output = layer_output + self.bias
  230. return layer_output
  231. def forward(self, layer_input):
  232. raise NotImplementedError("This module doesn't support call and forward.")
  233. class JetMoeAttention(nn.Module):
  234. """
  235. Multi-headed attention from 'Attention Is All You Need' paper.
  236. """
  237. def __init__(self, config: JetMoeConfig, layer_idx: int | None = None):
  238. """
  239. Initialize the JetMoeAttention module.
  240. Args:
  241. config:
  242. Configuration object with model hyperparameters.
  243. layer_idx:
  244. Index of the layer in the model.
  245. """
  246. super().__init__()
  247. self.config = config
  248. self.layer_idx = layer_idx
  249. self.is_causal = True
  250. if layer_idx is None:
  251. logger.warning_once(
  252. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  253. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  254. "when creating this class."
  255. )
  256. self.num_key_value_groups = 1 # We ignore this by setting it to 1 as we have different repeat patterns
  257. self.top_k = config.num_experts_per_tok
  258. self.attention_dropout = config.attention_dropout
  259. self.kv_projection_size = config.kv_channels * config.num_key_value_heads
  260. self.num_key_value_heads = config.num_key_value_heads
  261. self.num_heads = config.num_attention_heads
  262. self.head_dim = config.kv_channels
  263. self.scaling = self.head_dim**-0.5
  264. self.experts = JetMoeMoA(config)
  265. self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False)
  266. def forward(
  267. self,
  268. hidden_states: torch.Tensor,
  269. attention_mask: torch.Tensor | None = None,
  270. position_embeddings: torch.LongTensor | None = None,
  271. past_key_values: Cache | None = None,
  272. **kwargs,
  273. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  274. input_shape = hidden_states.shape[:-1]
  275. hidden_shape = (*input_shape, -1, self.head_dim)
  276. query_states, router_logits, topo_info = self.experts.map(hidden_states)
  277. key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1)
  278. query_states = query_states.view(hidden_shape).transpose(1, 2)
  279. key_states = key_states.view(hidden_shape).transpose(1, 2)
  280. value_states = value_states.view(hidden_shape).transpose(1, 2)
  281. cos, sin = position_embeddings
  282. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  283. if past_key_values is not None:
  284. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  285. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  286. self.config._attn_implementation, eager_attention_forward
  287. )
  288. # This is different from other models where we repeat k/v heads
  289. # instead of repeat interleaving them
  290. key_states = key_states.repeat(1, self.top_k, 1, 1)
  291. value_states = value_states.repeat(1, self.top_k, 1, 1)
  292. attn_output, attn_weights = attention_interface(
  293. self,
  294. query_states,
  295. key_states,
  296. value_states,
  297. attention_mask,
  298. dropout=0.0 if not self.training else self.attention_dropout,
  299. scaling=self.scaling,
  300. **kwargs,
  301. )
  302. attn_output = attn_output.view(*input_shape, self.top_k, -1)
  303. attn_output = self.experts.reduce(attn_output, topo_info)
  304. attn_output = attn_output.view(*input_shape, -1)
  305. return attn_output, attn_weights, router_logits
  306. class JetMoeDecoderLayer(LlamaDecoderLayer):
  307. def __init__(self, config: JetMoeConfig, layer_idx: int | None = None):
  308. super().__init__(config, layer_idx)
  309. self.input_layernorm = JetMoeRMSNorm(config.hidden_size)
  310. self.self_attention = JetMoeAttention(config, layer_idx)
  311. self.post_attention_layernorm = JetMoeRMSNorm(config.hidden_size)
  312. self.mlp = JetMoeMoE(config)
  313. del self.self_attn
  314. def forward(
  315. self,
  316. hidden_states: torch.Tensor,
  317. attention_mask: torch.Tensor | None = None,
  318. position_ids: torch.LongTensor | None = None,
  319. past_key_values: Cache | None = None,
  320. use_cache: bool | None = False,
  321. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  322. **kwargs: Unpack[TransformersKwargs],
  323. ) -> torch.Tensor:
  324. residual = hidden_states
  325. hidden_states = self.input_layernorm(hidden_states)
  326. # Self Attention
  327. hidden_states, _, _ = self.self_attention(
  328. hidden_states=hidden_states,
  329. attention_mask=attention_mask,
  330. position_ids=position_ids,
  331. past_key_values=past_key_values,
  332. use_cache=use_cache,
  333. position_embeddings=position_embeddings,
  334. **kwargs,
  335. )
  336. hidden_states = residual + hidden_states
  337. # Fully Connected
  338. residual = hidden_states
  339. hidden_states = self.post_attention_layernorm(hidden_states)
  340. hidden_states = self.mlp(hidden_states)
  341. hidden_states = residual + hidden_states
  342. return hidden_states
  343. @auto_docstring
  344. class JetMoePreTrainedModel(MixtralPreTrainedModel):
  345. _can_record_outputs = {
  346. "router_logits": [OutputRecorder(JetMoeAttention, index=2), OutputRecorder(JetMoeTopKGating, index=4)],
  347. "hidden_states": JetMoeDecoderLayer,
  348. "attentions": OutputRecorder(JetMoeAttention, index=1),
  349. }
  350. config: JetMoeConfig
  351. base_model_prefix = "model"
  352. supports_gradient_checkpointing = False
  353. _no_split_modules = ["JetMoeDecoderLayer"]
  354. _skip_keys_device_placement = ["past_key_values"]
  355. _supports_flash_attn = True
  356. _supports_sdpa = True
  357. _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
  358. @torch.no_grad()
  359. def _init_weights(self, module):
  360. """Initialize the weights."""
  361. PreTrainedModel._init_weights(self, module)
  362. if isinstance(module, JetMoeParallelExperts):
  363. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  364. elif isinstance(module, JetMoeMoA | JetMoeMoE):
  365. init.zeros_(module.bias)
  366. @auto_docstring
  367. class JetMoeModel(MixtralModel):
  368. def __init__(self, config: JetMoeConfig):
  369. super().__init__(config)
  370. self.padding_idx = config.pad_token_id
  371. self.vocab_size = config.vocab_size
  372. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  373. self.layers = nn.ModuleList(
  374. [JetMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  375. )
  376. self._attn_implementation = config._attn_implementation
  377. self.norm = JetMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  378. @merge_with_config_defaults
  379. @capture_outputs
  380. @auto_docstring
  381. def forward(
  382. self,
  383. input_ids: torch.LongTensor | None = None,
  384. attention_mask: torch.Tensor | None = None,
  385. position_ids: torch.LongTensor | None = None,
  386. past_key_values: Cache | None = None,
  387. inputs_embeds: torch.FloatTensor | None = None,
  388. use_cache: bool | None = None,
  389. **kwargs: Unpack[TransformersKwargs],
  390. ) -> MoeModelOutputWithPast:
  391. if (input_ids is None) ^ (inputs_embeds is not None):
  392. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  393. if use_cache and past_key_values is None:
  394. past_key_values = DynamicCache(config=self.config)
  395. if inputs_embeds is None:
  396. inputs_embeds = self.embed_tokens(input_ids)
  397. if position_ids is None:
  398. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  399. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  400. position_ids = position_ids.unsqueeze(0)
  401. causal_mask = create_causal_mask(
  402. config=self.config,
  403. inputs_embeds=inputs_embeds,
  404. attention_mask=attention_mask,
  405. past_key_values=past_key_values,
  406. position_ids=position_ids,
  407. )
  408. hidden_states = inputs_embeds
  409. # create position embeddings to be shared across the decoder layers
  410. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  411. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  412. hidden_states = decoder_layer(
  413. hidden_states,
  414. position_embeddings=position_embeddings,
  415. attention_mask=causal_mask,
  416. past_key_values=past_key_values,
  417. use_cache=use_cache,
  418. position_ids=position_ids,
  419. **kwargs,
  420. )
  421. hidden_states = self.norm(hidden_states)
  422. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  423. last_hidden_state=hidden_states,
  424. past_key_values=past_key_values,
  425. )
  426. class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin):
  427. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  428. def __init__(self, config):
  429. super().__init__(config)
  430. self.model = JetMoeModel(config)
  431. self.vocab_size = config.vocab_size
  432. self.aux_loss_coef = config.aux_loss_coef
  433. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  434. self.tie_word_embeddings = config.tie_word_embeddings
  435. self.num_experts = config.num_local_experts
  436. self.num_experts_per_tok = config.num_experts_per_tok
  437. # Initialize weights and apply final processing
  438. self.post_init()
  439. @can_return_tuple
  440. @auto_docstring
  441. def forward(
  442. self,
  443. input_ids: torch.LongTensor | None = None,
  444. attention_mask: torch.Tensor | None = None,
  445. position_ids: torch.LongTensor | None = None,
  446. past_key_values: Cache | None = None,
  447. inputs_embeds: torch.FloatTensor | None = None,
  448. labels: torch.LongTensor | None = None,
  449. use_cache: bool | None = None,
  450. logits_to_keep: int | torch.Tensor = 0,
  451. output_router_logits: bool | None = False,
  452. **kwargs,
  453. ) -> MoeCausalLMOutputWithPast:
  454. outputs: MoeModelOutputWithPast = self.model(
  455. input_ids=input_ids,
  456. attention_mask=attention_mask,
  457. position_ids=position_ids,
  458. past_key_values=past_key_values,
  459. inputs_embeds=inputs_embeds,
  460. use_cache=use_cache,
  461. output_router_logits=output_router_logits,
  462. **kwargs,
  463. )
  464. hidden_states = outputs.last_hidden_state
  465. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  466. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  467. logits = self.lm_head(hidden_states[:, slice_indices, :])
  468. loss = None
  469. if labels is not None:
  470. loss = self.loss_function(
  471. logits,
  472. labels,
  473. vocab_size=self.config.vocab_size,
  474. **kwargs,
  475. )
  476. aux_loss = None
  477. if output_router_logits:
  478. aux_loss = load_balancing_loss_func(
  479. outputs.router_logits,
  480. self.num_experts,
  481. self.num_experts_per_tok,
  482. attention_mask,
  483. )
  484. if labels is not None:
  485. loss += self.aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  486. return MoeCausalLMOutputWithPast(
  487. loss=loss,
  488. aux_loss=aux_loss,
  489. logits=logits,
  490. past_key_values=outputs.past_key_values,
  491. hidden_states=outputs.hidden_states,
  492. attentions=outputs.attentions,
  493. router_logits=outputs.router_logits,
  494. )
  495. class JetMoeForSequenceClassification(GenericForSequenceClassification, JetMoePreTrainedModel): ...
  496. __all__ = ["JetMoeForCausalLM", "JetMoeModel", "JetMoePreTrainedModel", "JetMoeForSequenceClassification"]