modeling_mixtral.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/mixtral/modular_mixtral.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_mixtral.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. from collections.abc import Callable
  26. from typing import Optional
  27. import torch
  28. import torch.nn.functional as F
  29. from torch import nn
  30. from ... import initialization as init
  31. from ...activations import ACT2FN
  32. from ...cache_utils import Cache, DynamicCache
  33. from ...generation import GenerationMixin
  34. from ...integrations import (
  35. use_experts_implementation,
  36. use_kernel_forward_from_hub,
  37. use_kernel_func_from_hub,
  38. use_kernelized_func,
  39. )
  40. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  41. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  42. from ...modeling_layers import (
  43. GenericForQuestionAnswering,
  44. GenericForSequenceClassification,
  45. GenericForTokenClassification,
  46. GradientCheckpointingLayer,
  47. )
  48. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  49. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  50. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  51. from ...processing_utils import Unpack
  52. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  53. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  54. from ...utils.output_capturing import OutputRecorder, capture_outputs
  55. from .configuration_mixtral import MixtralConfig
  56. @use_experts_implementation
  57. class MixtralExperts(nn.Module):
  58. """Collection of expert weights stored as 3D tensors."""
  59. def __init__(self, config: MixtralConfig):
  60. super().__init__()
  61. self.num_experts = config.num_local_experts
  62. self.hidden_dim = config.hidden_size
  63. self.intermediate_dim = config.intermediate_size
  64. self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
  65. self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
  66. self.act_fn = ACT2FN[config.hidden_act]
  67. def forward(
  68. self,
  69. hidden_states: torch.Tensor,
  70. top_k_index: torch.Tensor,
  71. top_k_weights: torch.Tensor,
  72. ) -> torch.Tensor:
  73. final_hidden_states = torch.zeros_like(hidden_states)
  74. with torch.no_grad():
  75. expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
  76. expert_mask = expert_mask.permute(2, 1, 0)
  77. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  78. for expert_idx in expert_hit:
  79. expert_idx = expert_idx[0]
  80. if expert_idx == self.num_experts:
  81. continue
  82. top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
  83. current_state = hidden_states[token_idx]
  84. gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
  85. current_hidden_states = self.act_fn(gate) * up
  86. current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
  87. current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
  88. final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
  89. return final_hidden_states
  90. class MixtralTopKRouter(nn.Module):
  91. def __init__(self, config):
  92. super().__init__()
  93. self.top_k = config.num_experts_per_tok
  94. self.num_experts = config.num_local_experts
  95. self.hidden_dim = config.hidden_size
  96. self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
  97. def forward(self, hidden_states):
  98. hidden_states = hidden_states.reshape(-1, self.hidden_dim)
  99. router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
  100. router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1)
  101. router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
  102. router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
  103. router_scores = router_top_value
  104. return router_logits, router_scores, router_indices
  105. class MixtralSparseMoeBlock(nn.Module):
  106. def __init__(self, config):
  107. super().__init__()
  108. self.top_k = config.num_experts_per_tok
  109. self.jitter_noise = config.router_jitter_noise
  110. self.gate = MixtralTopKRouter(config)
  111. self.experts = MixtralExperts(config)
  112. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  113. batch_size, sequence_length, hidden_dim = hidden_states.shape
  114. if self.training and self.jitter_noise > 0:
  115. hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
  116. hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
  117. _, top_k_weights, top_k_index = self.gate(hidden_states)
  118. hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)
  119. hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
  120. return hidden_states
  121. @use_kernel_forward_from_hub("RMSNorm")
  122. class MixtralRMSNorm(nn.Module):
  123. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  124. """
  125. MixtralRMSNorm is equivalent to T5LayerNorm
  126. """
  127. super().__init__()
  128. self.weight = nn.Parameter(torch.ones(hidden_size))
  129. self.variance_epsilon = eps
  130. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  131. input_dtype = hidden_states.dtype
  132. hidden_states = hidden_states.to(torch.float32)
  133. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  134. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  135. return self.weight * hidden_states.to(input_dtype)
  136. def extra_repr(self):
  137. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  138. class MixtralRotaryEmbedding(nn.Module):
  139. inv_freq: torch.Tensor # fix linting for `register_buffer`
  140. def __init__(self, config: MixtralConfig, device=None):
  141. super().__init__()
  142. self.max_seq_len_cached = config.max_position_embeddings
  143. self.original_max_seq_len = config.max_position_embeddings
  144. self.config = config
  145. self.rope_type = self.config.rope_parameters["rope_type"]
  146. rope_init_fn: Callable = self.compute_default_rope_parameters
  147. if self.rope_type != "default":
  148. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  149. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  150. self.register_buffer("inv_freq", inv_freq, persistent=False)
  151. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  152. @staticmethod
  153. def compute_default_rope_parameters(
  154. config: MixtralConfig | None = None,
  155. device: Optional["torch.device"] = None,
  156. seq_len: int | None = None,
  157. ) -> tuple["torch.Tensor", float]:
  158. """
  159. Computes the inverse frequencies according to the original RoPE implementation
  160. Args:
  161. config ([`~transformers.PreTrainedConfig`]):
  162. The model configuration.
  163. device (`torch.device`):
  164. The device to use for initialization of the inverse frequencies.
  165. seq_len (`int`, *optional*):
  166. The current sequence length. Unused for this type of RoPE.
  167. Returns:
  168. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  169. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  170. """
  171. base = config.rope_parameters["rope_theta"]
  172. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  173. attention_factor = 1.0 # Unused in this type of RoPE
  174. # Compute the inverse frequencies
  175. inv_freq = 1.0 / (
  176. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  177. )
  178. return inv_freq, attention_factor
  179. @torch.no_grad()
  180. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  181. def forward(self, x, position_ids):
  182. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  183. position_ids_expanded = position_ids[:, None, :].float()
  184. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  185. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  186. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  187. emb = torch.cat((freqs, freqs), dim=-1)
  188. cos = emb.cos() * self.attention_scaling
  189. sin = emb.sin() * self.attention_scaling
  190. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  191. def rotate_half(x):
  192. """Rotates half the hidden dims of the input."""
  193. x1 = x[..., : x.shape[-1] // 2]
  194. x2 = x[..., x.shape[-1] // 2 :]
  195. return torch.cat((-x2, x1), dim=-1)
  196. @use_kernel_func_from_hub("rotary_pos_emb")
  197. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  198. """Applies Rotary Position Embedding to the query and key tensors.
  199. Args:
  200. q (`torch.Tensor`): The query tensor.
  201. k (`torch.Tensor`): The key tensor.
  202. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  203. sin (`torch.Tensor`): The sine part of the rotary embedding.
  204. unsqueeze_dim (`int`, *optional*, defaults to 1):
  205. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  206. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  207. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  208. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  209. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  210. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  211. Returns:
  212. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  213. """
  214. cos = cos.unsqueeze(unsqueeze_dim)
  215. sin = sin.unsqueeze(unsqueeze_dim)
  216. q_embed = (q * cos) + (rotate_half(q) * sin)
  217. k_embed = (k * cos) + (rotate_half(k) * sin)
  218. return q_embed, k_embed
  219. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  220. """
  221. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  222. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  223. """
  224. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  225. if n_rep == 1:
  226. return hidden_states
  227. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  228. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  229. def eager_attention_forward(
  230. module: nn.Module,
  231. query: torch.Tensor,
  232. key: torch.Tensor,
  233. value: torch.Tensor,
  234. attention_mask: torch.Tensor | None,
  235. scaling: float,
  236. dropout: float = 0.0,
  237. **kwargs: Unpack[TransformersKwargs],
  238. ):
  239. key_states = repeat_kv(key, module.num_key_value_groups)
  240. value_states = repeat_kv(value, module.num_key_value_groups)
  241. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  242. if attention_mask is not None:
  243. attn_weights = attn_weights + attention_mask
  244. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  245. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  246. attn_output = torch.matmul(attn_weights, value_states)
  247. attn_output = attn_output.transpose(1, 2).contiguous()
  248. return attn_output, attn_weights
  249. @use_kernelized_func(apply_rotary_pos_emb)
  250. class MixtralAttention(nn.Module):
  251. """Multi-headed attention from 'Attention Is All You Need' paper"""
  252. def __init__(self, config: MixtralConfig, layer_idx: int):
  253. super().__init__()
  254. self.config = config
  255. self.layer_idx = layer_idx
  256. self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  257. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  258. self.scaling = self.head_dim**-0.5
  259. self.attention_dropout = config.attention_dropout
  260. self.is_causal = True
  261. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  262. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  263. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  264. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  265. def forward(
  266. self,
  267. hidden_states: torch.Tensor,
  268. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  269. attention_mask: torch.Tensor | None,
  270. past_key_values: Cache | None = None,
  271. **kwargs: Unpack[FlashAttentionKwargs],
  272. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  273. input_shape = hidden_states.shape[:-1]
  274. hidden_shape = (*input_shape, -1, self.head_dim)
  275. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  276. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  277. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  278. cos, sin = position_embeddings
  279. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  280. if past_key_values is not None:
  281. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  282. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  283. self.config._attn_implementation, eager_attention_forward
  284. )
  285. attn_output, attn_weights = attention_interface(
  286. self,
  287. query_states,
  288. key_states,
  289. value_states,
  290. attention_mask,
  291. dropout=0.0 if not self.training else self.attention_dropout,
  292. scaling=self.scaling,
  293. sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
  294. **kwargs,
  295. )
  296. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  297. attn_output = self.o_proj(attn_output)
  298. return attn_output, attn_weights
  299. class MixtralDecoderLayer(GradientCheckpointingLayer):
  300. def __init__(self, config: MixtralConfig, layer_idx: int):
  301. super().__init__()
  302. self.hidden_size = config.hidden_size
  303. self.self_attn = MixtralAttention(config, layer_idx)
  304. self.mlp = MixtralSparseMoeBlock(config)
  305. self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  306. self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  307. def forward(
  308. self,
  309. hidden_states: torch.Tensor,
  310. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  311. attention_mask: torch.Tensor | None = None,
  312. position_ids: torch.LongTensor | None = None,
  313. past_key_values: Cache | None = None,
  314. **kwargs: Unpack[TransformersKwargs],
  315. ) -> torch.Tensor:
  316. residual = hidden_states
  317. hidden_states = self.input_layernorm(hidden_states)
  318. hidden_states, _ = self.self_attn(
  319. hidden_states=hidden_states,
  320. position_embeddings=position_embeddings,
  321. attention_mask=attention_mask,
  322. position_ids=position_ids,
  323. past_key_values=past_key_values,
  324. **kwargs,
  325. )
  326. hidden_states = residual + hidden_states
  327. residual = hidden_states
  328. hidden_states = self.post_attention_layernorm(hidden_states)
  329. hidden_states = self.mlp(hidden_states)
  330. hidden_states = residual + hidden_states
  331. return hidden_states
  332. @auto_docstring
  333. class MixtralPreTrainedModel(PreTrainedModel):
  334. config: MixtralConfig
  335. base_model_prefix = "model"
  336. supports_gradient_checkpointing = True
  337. _no_split_modules = ["MixtralDecoderLayer"]
  338. _skip_keys_device_placement = ["past_key_values"]
  339. _supports_flash_attn = True
  340. _supports_sdpa = True
  341. _supports_flex_attn = True
  342. _can_compile_fullgraph = True
  343. _supports_attention_backend = True
  344. _can_record_outputs = {
  345. "router_logits": OutputRecorder(MixtralTopKRouter, index=0),
  346. "hidden_states": MixtralDecoderLayer,
  347. "attentions": MixtralAttention,
  348. }
  349. @torch.no_grad()
  350. def _init_weights(self, module):
  351. super()._init_weights(module)
  352. std = self.config.initializer_range
  353. if isinstance(module, MixtralExperts):
  354. init.normal_(module.gate_up_proj, mean=0.0, std=std)
  355. init.normal_(module.down_proj, mean=0.0, std=std)
  356. elif isinstance(module, MixtralTopKRouter):
  357. init.normal_(module.weight, mean=0.0, std=std)
  358. @auto_docstring
  359. class MixtralModel(MixtralPreTrainedModel):
  360. def __init__(self, config: MixtralConfig):
  361. super().__init__(config)
  362. self.padding_idx = config.pad_token_id
  363. self.vocab_size = config.vocab_size
  364. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  365. self.layers = nn.ModuleList(
  366. [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  367. )
  368. self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  369. self.rotary_emb = MixtralRotaryEmbedding(config=config)
  370. self.gradient_checkpointing = False
  371. # Initialize weights and apply final processing
  372. self.post_init()
  373. @merge_with_config_defaults
  374. @capture_outputs
  375. @auto_docstring
  376. def forward(
  377. self,
  378. input_ids: torch.LongTensor | None = None,
  379. attention_mask: torch.Tensor | None = None,
  380. position_ids: torch.LongTensor | None = None,
  381. past_key_values: Cache | None = None,
  382. inputs_embeds: torch.FloatTensor | None = None,
  383. use_cache: bool | None = None,
  384. **kwargs: Unpack[TransformersKwargs],
  385. ) -> MoeModelOutputWithPast:
  386. if (input_ids is None) ^ (inputs_embeds is not None):
  387. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  388. if use_cache and past_key_values is None:
  389. past_key_values = DynamicCache(config=self.config)
  390. if inputs_embeds is None:
  391. inputs_embeds = self.embed_tokens(input_ids)
  392. if position_ids is None:
  393. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  394. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  395. position_ids = position_ids.unsqueeze(0)
  396. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  397. causal_mask = mask_function(
  398. config=self.config,
  399. inputs_embeds=inputs_embeds,
  400. attention_mask=attention_mask,
  401. past_key_values=past_key_values,
  402. position_ids=position_ids,
  403. )
  404. hidden_states = inputs_embeds
  405. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  406. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  407. hidden_states = decoder_layer(
  408. hidden_states,
  409. attention_mask=causal_mask,
  410. position_ids=position_ids,
  411. past_key_values=past_key_values,
  412. use_cache=use_cache,
  413. position_embeddings=position_embeddings,
  414. **kwargs,
  415. )
  416. hidden_states = self.norm(hidden_states)
  417. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  418. last_hidden_state=hidden_states,
  419. past_key_values=past_key_values,
  420. )
  421. def load_balancing_loss_func(
  422. gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
  423. num_experts: int | None = None,
  424. top_k=2,
  425. attention_mask: torch.Tensor | None = None,
  426. ) -> torch.Tensor | int:
  427. r"""
  428. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  429. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  430. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  431. experts is too unbalanced.
  432. Args:
  433. gate_logits:
  434. Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
  435. shape [batch_size X sequence_length, num_experts].
  436. num_experts:
  437. Number of experts
  438. top_k:
  439. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  440. parameter.
  441. attention_mask (`torch.Tensor`, *optional*):
  442. The attention_mask used in forward function
  443. shape [batch_size X sequence_length] if not None.
  444. Returns:
  445. The auxiliary loss.
  446. """
  447. if gate_logits is None or not isinstance(gate_logits, tuple):
  448. return 0
  449. if isinstance(gate_logits, tuple):
  450. compute_device = gate_logits[0].device
  451. concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
  452. routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
  453. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  454. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  455. if attention_mask is None:
  456. # Compute the percentage of tokens routed to each experts
  457. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  458. # Compute the average probability of routing to these experts
  459. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  460. else:
  461. batch_size, sequence_length = attention_mask.shape
  462. num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
  463. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  464. expert_attention_mask = (
  465. attention_mask[None, :, :, None, None]
  466. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  467. .reshape(-1, top_k, num_experts)
  468. .to(compute_device)
  469. )
  470. # Compute the percentage of tokens routed to each experts
  471. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  472. expert_attention_mask, dim=0
  473. )
  474. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  475. router_per_expert_attention_mask = (
  476. attention_mask[None, :, :, None]
  477. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  478. .reshape(-1, num_experts)
  479. .to(compute_device)
  480. )
  481. # Compute the average probability of routing to these experts
  482. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  483. router_per_expert_attention_mask, dim=0
  484. )
  485. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
  486. return overall_loss * num_experts
  487. @auto_docstring
  488. class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
  489. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  490. _tp_plan = {"lm_head": "colwise_gather_output"}
  491. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  492. def __init__(self, config):
  493. super().__init__(config)
  494. self.model = MixtralModel(config)
  495. self.vocab_size = config.vocab_size
  496. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  497. self.router_aux_loss_coef = config.router_aux_loss_coef
  498. self.num_experts = config.num_local_experts
  499. self.num_experts_per_tok = config.num_experts_per_tok
  500. # Initialize weights and apply final processing
  501. self.post_init()
  502. @can_return_tuple
  503. @auto_docstring
  504. def forward(
  505. self,
  506. input_ids: torch.LongTensor | None = None,
  507. attention_mask: torch.Tensor | None = None,
  508. position_ids: torch.LongTensor | None = None,
  509. past_key_values: Cache | None = None,
  510. inputs_embeds: torch.FloatTensor | None = None,
  511. labels: torch.LongTensor | None = None,
  512. use_cache: bool | None = None,
  513. output_router_logits: bool | None = None,
  514. logits_to_keep: int | torch.Tensor = 0,
  515. **kwargs: Unpack[TransformersKwargs],
  516. ) -> MoeCausalLMOutputWithPast:
  517. r"""
  518. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  519. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  520. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  521. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  522. Example:
  523. ```python
  524. >>> from transformers import AutoTokenizer, MixtralForCausalLM
  525. >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
  526. >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
  527. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  528. >>> inputs = tokenizer(prompt, return_tensors="pt")
  529. >>> # Generate
  530. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  531. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  532. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  533. ```"""
  534. output_router_logits = (
  535. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  536. )
  537. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  538. outputs: MoeModelOutputWithPast = self.model(
  539. input_ids=input_ids,
  540. attention_mask=attention_mask,
  541. position_ids=position_ids,
  542. past_key_values=past_key_values,
  543. inputs_embeds=inputs_embeds,
  544. use_cache=use_cache,
  545. output_router_logits=output_router_logits,
  546. **kwargs,
  547. )
  548. hidden_states = outputs.last_hidden_state
  549. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  550. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  551. logits = self.lm_head(hidden_states[:, slice_indices, :])
  552. loss = None
  553. if labels is not None:
  554. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  555. aux_loss = None
  556. if output_router_logits:
  557. aux_loss = load_balancing_loss_func(
  558. outputs.router_logits,
  559. self.num_experts,
  560. self.num_experts_per_tok,
  561. attention_mask,
  562. )
  563. if labels is not None:
  564. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  565. return MoeCausalLMOutputWithPast(
  566. loss=loss,
  567. aux_loss=aux_loss,
  568. logits=logits,
  569. past_key_values=outputs.past_key_values,
  570. hidden_states=outputs.hidden_states,
  571. attentions=outputs.attentions,
  572. router_logits=outputs.router_logits,
  573. )
  574. class MixtralForSequenceClassification(GenericForSequenceClassification, MixtralPreTrainedModel):
  575. pass
  576. class MixtralForTokenClassification(GenericForTokenClassification, MixtralPreTrainedModel):
  577. pass
  578. class MixtralForQuestionAnswering(GenericForQuestionAnswering, MixtralPreTrainedModel):
  579. pass
  580. __all__ = [
  581. "MixtralForCausalLM",
  582. "MixtralForQuestionAnswering",
  583. "MixtralModel",
  584. "MixtralPreTrainedModel",
  585. "MixtralForSequenceClassification",
  586. "MixtralForTokenClassification",
  587. ]