modular_mixtral.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  4. # and OPT implementations in this library. It has been modified from its
  5. # original forms to accommodate minor architectural differences compared
  6. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """PyTorch Mixtral model."""
  20. import torch
  21. import torch.nn.functional as F
  22. from torch import nn
  23. from ... import initialization as init
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...integrations import use_experts_implementation
  27. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  28. from ...modeling_layers import GradientCheckpointingLayer
  29. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  30. from ...modeling_utils import PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...utils import TransformersKwargs, logging
  33. from ...utils.output_capturing import OutputRecorder
  34. from ..mistral.modeling_mistral import (
  35. MistralAttention,
  36. MistralForCausalLM,
  37. MistralForQuestionAnswering,
  38. MistralForSequenceClassification,
  39. MistralForTokenClassification,
  40. MistralModel,
  41. MistralPreTrainedModel,
  42. MistralRMSNorm,
  43. MistralRotaryEmbedding,
  44. )
  45. from .configuration_mixtral import MixtralConfig
  46. logger = logging.get_logger(__name__)
  47. def load_balancing_loss_func(
  48. gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
  49. num_experts: int | None = None,
  50. top_k=2,
  51. attention_mask: torch.Tensor | None = None,
  52. ) -> torch.Tensor | int:
  53. r"""
  54. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  55. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  56. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  57. experts is too unbalanced.
  58. Args:
  59. gate_logits:
  60. Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
  61. shape [batch_size X sequence_length, num_experts].
  62. num_experts:
  63. Number of experts
  64. top_k:
  65. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  66. parameter.
  67. attention_mask (`torch.Tensor`, *optional*):
  68. The attention_mask used in forward function
  69. shape [batch_size X sequence_length] if not None.
  70. Returns:
  71. The auxiliary loss.
  72. """
  73. if gate_logits is None or not isinstance(gate_logits, tuple):
  74. return 0
  75. if isinstance(gate_logits, tuple):
  76. compute_device = gate_logits[0].device
  77. concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
  78. routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
  79. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  80. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  81. if attention_mask is None:
  82. # Compute the percentage of tokens routed to each experts
  83. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  84. # Compute the average probability of routing to these experts
  85. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  86. else:
  87. batch_size, sequence_length = attention_mask.shape
  88. num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
  89. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  90. expert_attention_mask = (
  91. attention_mask[None, :, :, None, None]
  92. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  93. .reshape(-1, top_k, num_experts)
  94. .to(compute_device)
  95. )
  96. # Compute the percentage of tokens routed to each experts
  97. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  98. expert_attention_mask, dim=0
  99. )
  100. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  101. router_per_expert_attention_mask = (
  102. attention_mask[None, :, :, None]
  103. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  104. .reshape(-1, num_experts)
  105. .to(compute_device)
  106. )
  107. # Compute the average probability of routing to these experts
  108. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  109. router_per_expert_attention_mask, dim=0
  110. )
  111. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
  112. return overall_loss * num_experts
  113. @use_experts_implementation
  114. class MixtralExperts(nn.Module):
  115. """Collection of expert weights stored as 3D tensors."""
  116. def __init__(self, config: MixtralConfig):
  117. super().__init__()
  118. self.num_experts = config.num_local_experts
  119. self.hidden_dim = config.hidden_size
  120. self.intermediate_dim = config.intermediate_size
  121. self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
  122. self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
  123. self.act_fn = ACT2FN[config.hidden_act]
  124. def forward(
  125. self,
  126. hidden_states: torch.Tensor,
  127. top_k_index: torch.Tensor,
  128. top_k_weights: torch.Tensor,
  129. ) -> torch.Tensor:
  130. final_hidden_states = torch.zeros_like(hidden_states)
  131. with torch.no_grad():
  132. expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
  133. expert_mask = expert_mask.permute(2, 1, 0)
  134. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  135. for expert_idx in expert_hit:
  136. expert_idx = expert_idx[0]
  137. if expert_idx == self.num_experts:
  138. continue
  139. top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
  140. current_state = hidden_states[token_idx]
  141. gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
  142. current_hidden_states = self.act_fn(gate) * up
  143. current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
  144. current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
  145. final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
  146. return final_hidden_states
  147. class MixtralTopKRouter(nn.Module):
  148. def __init__(self, config):
  149. super().__init__()
  150. self.top_k = config.num_experts_per_tok
  151. self.num_experts = config.num_local_experts
  152. self.hidden_dim = config.hidden_size
  153. self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
  154. def forward(self, hidden_states):
  155. hidden_states = hidden_states.reshape(-1, self.hidden_dim)
  156. router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
  157. router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1)
  158. router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
  159. router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
  160. router_scores = router_top_value
  161. return router_logits, router_scores, router_indices
  162. class MixtralSparseMoeBlock(nn.Module):
  163. def __init__(self, config):
  164. super().__init__()
  165. self.top_k = config.num_experts_per_tok
  166. self.jitter_noise = config.router_jitter_noise
  167. self.gate = MixtralTopKRouter(config)
  168. self.experts = MixtralExperts(config)
  169. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  170. batch_size, sequence_length, hidden_dim = hidden_states.shape
  171. if self.training and self.jitter_noise > 0:
  172. hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
  173. hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
  174. _, top_k_weights, top_k_index = self.gate(hidden_states)
  175. hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)
  176. hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
  177. return hidden_states
  178. class MixtralRMSNorm(MistralRMSNorm):
  179. pass
  180. class MixtralRotaryEmbedding(MistralRotaryEmbedding):
  181. pass
  182. class MixtralAttention(MistralAttention):
  183. pass
  184. class MixtralDecoderLayer(GradientCheckpointingLayer):
  185. def __init__(self, config: MixtralConfig, layer_idx: int):
  186. super().__init__()
  187. self.hidden_size = config.hidden_size
  188. self.self_attn = MixtralAttention(config, layer_idx)
  189. self.mlp = MixtralSparseMoeBlock(config)
  190. self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  191. self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  192. def forward(
  193. self,
  194. hidden_states: torch.Tensor,
  195. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  196. attention_mask: torch.Tensor | None = None,
  197. position_ids: torch.LongTensor | None = None,
  198. past_key_values: Cache | None = None,
  199. **kwargs: Unpack[TransformersKwargs],
  200. ) -> torch.Tensor:
  201. residual = hidden_states
  202. hidden_states = self.input_layernorm(hidden_states)
  203. hidden_states, _ = self.self_attn(
  204. hidden_states=hidden_states,
  205. position_embeddings=position_embeddings,
  206. attention_mask=attention_mask,
  207. position_ids=position_ids,
  208. past_key_values=past_key_values,
  209. **kwargs,
  210. )
  211. hidden_states = residual + hidden_states
  212. residual = hidden_states
  213. hidden_states = self.post_attention_layernorm(hidden_states)
  214. hidden_states = self.mlp(hidden_states)
  215. hidden_states = residual + hidden_states
  216. return hidden_states
  217. class MixtralPreTrainedModel(MistralPreTrainedModel):
  218. _can_record_outputs = {
  219. "router_logits": OutputRecorder(MixtralTopKRouter, index=0),
  220. "hidden_states": MixtralDecoderLayer,
  221. "attentions": MixtralAttention,
  222. }
  223. @torch.no_grad()
  224. def _init_weights(self, module):
  225. PreTrainedModel._init_weights(self, module)
  226. std = self.config.initializer_range
  227. if isinstance(module, MixtralExperts):
  228. init.normal_(module.gate_up_proj, mean=0.0, std=std)
  229. init.normal_(module.down_proj, mean=0.0, std=std)
  230. elif isinstance(module, MixtralTopKRouter):
  231. init.normal_(module.weight, mean=0.0, std=std)
  232. class MixtralModel(MistralModel):
  233. def forward(
  234. self,
  235. input_ids: torch.LongTensor | None = None,
  236. attention_mask: torch.Tensor | None = None,
  237. position_ids: torch.LongTensor | None = None,
  238. past_key_values: Cache | None = None,
  239. inputs_embeds: torch.FloatTensor | None = None,
  240. use_cache: bool | None = None,
  241. **kwargs: Unpack[TransformersKwargs],
  242. ) -> MoeModelOutputWithPast:
  243. if (input_ids is None) ^ (inputs_embeds is not None):
  244. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  245. if use_cache and past_key_values is None:
  246. past_key_values = DynamicCache(config=self.config)
  247. if inputs_embeds is None:
  248. inputs_embeds = self.embed_tokens(input_ids)
  249. if position_ids is None:
  250. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  251. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  252. position_ids = position_ids.unsqueeze(0)
  253. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  254. causal_mask = mask_function(
  255. config=self.config,
  256. inputs_embeds=inputs_embeds,
  257. attention_mask=attention_mask,
  258. past_key_values=past_key_values,
  259. position_ids=position_ids,
  260. )
  261. hidden_states = inputs_embeds
  262. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  263. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  264. hidden_states = decoder_layer(
  265. hidden_states,
  266. attention_mask=causal_mask,
  267. position_ids=position_ids,
  268. past_key_values=past_key_values,
  269. use_cache=use_cache,
  270. position_embeddings=position_embeddings,
  271. **kwargs,
  272. )
  273. hidden_states = self.norm(hidden_states)
  274. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  275. last_hidden_state=hidden_states,
  276. past_key_values=past_key_values,
  277. )
  278. class MixtralForCausalLM(MistralForCausalLM):
  279. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  280. def __init__(self, config):
  281. super().__init__(config)
  282. self.model = MixtralModel(config)
  283. self.router_aux_loss_coef = config.router_aux_loss_coef
  284. self.num_experts = config.num_local_experts
  285. self.num_experts_per_tok = config.num_experts_per_tok
  286. def forward(
  287. self,
  288. input_ids: torch.LongTensor | None = None,
  289. attention_mask: torch.Tensor | None = None,
  290. position_ids: torch.LongTensor | None = None,
  291. past_key_values: Cache | None = None,
  292. inputs_embeds: torch.FloatTensor | None = None,
  293. labels: torch.LongTensor | None = None,
  294. use_cache: bool | None = None,
  295. output_router_logits: bool | None = None,
  296. logits_to_keep: int | torch.Tensor = 0,
  297. **kwargs: Unpack[TransformersKwargs],
  298. ) -> MoeCausalLMOutputWithPast:
  299. r"""
  300. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  301. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  302. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  303. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  304. Example:
  305. ```python
  306. >>> from transformers import AutoTokenizer, MixtralForCausalLM
  307. >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
  308. >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
  309. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  310. >>> inputs = tokenizer(prompt, return_tensors="pt")
  311. >>> # Generate
  312. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  313. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  314. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  315. ```"""
  316. output_router_logits = (
  317. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  318. )
  319. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  320. outputs: MoeModelOutputWithPast = self.model(
  321. input_ids=input_ids,
  322. attention_mask=attention_mask,
  323. position_ids=position_ids,
  324. past_key_values=past_key_values,
  325. inputs_embeds=inputs_embeds,
  326. use_cache=use_cache,
  327. output_router_logits=output_router_logits,
  328. **kwargs,
  329. )
  330. hidden_states = outputs.last_hidden_state
  331. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  332. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  333. logits = self.lm_head(hidden_states[:, slice_indices, :])
  334. loss = None
  335. if labels is not None:
  336. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  337. aux_loss = None
  338. if output_router_logits:
  339. aux_loss = load_balancing_loss_func(
  340. outputs.router_logits,
  341. self.num_experts,
  342. self.num_experts_per_tok,
  343. attention_mask,
  344. )
  345. if labels is not None:
  346. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  347. return MoeCausalLMOutputWithPast(
  348. loss=loss,
  349. aux_loss=aux_loss,
  350. logits=logits,
  351. past_key_values=outputs.past_key_values,
  352. hidden_states=outputs.hidden_states,
  353. attentions=outputs.attentions,
  354. router_logits=outputs.router_logits,
  355. )
  356. class MixtralForSequenceClassification(MistralForSequenceClassification):
  357. pass
  358. class MixtralForTokenClassification(MistralForTokenClassification):
  359. pass
  360. class MixtralForQuestionAnswering(MistralForQuestionAnswering):
  361. pass
  362. __all__ = [
  363. "MixtralForCausalLM",
  364. "MixtralForQuestionAnswering",
  365. "MixtralModel",
  366. "MixtralPreTrainedModel",
  367. "MixtralForSequenceClassification",
  368. "MixtralForTokenClassification",
  369. ]