| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448 |
- # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
- #
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
- # and OPT implementations in this library. It has been modified from its
- # original forms to accommodate minor architectural differences compared
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch Mixtral model."""
- import torch
- import torch.nn.functional as F
- from torch import nn
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache
- from ...integrations import use_experts_implementation
- from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, logging
- from ...utils.output_capturing import OutputRecorder
- from ..mistral.modeling_mistral import (
- MistralAttention,
- MistralForCausalLM,
- MistralForQuestionAnswering,
- MistralForSequenceClassification,
- MistralForTokenClassification,
- MistralModel,
- MistralPreTrainedModel,
- MistralRMSNorm,
- MistralRotaryEmbedding,
- )
- from .configuration_mixtral import MixtralConfig
- logger = logging.get_logger(__name__)
- def load_balancing_loss_func(
- gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
- num_experts: int | None = None,
- top_k=2,
- attention_mask: torch.Tensor | None = None,
- ) -> torch.Tensor | int:
- r"""
- Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
- See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
- function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
- experts is too unbalanced.
- Args:
- gate_logits:
- Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
- shape [batch_size X sequence_length, num_experts].
- num_experts:
- Number of experts
- top_k:
- The number of experts to route per-token, can be also interpreted as the `top-k` routing
- parameter.
- attention_mask (`torch.Tensor`, *optional*):
- The attention_mask used in forward function
- shape [batch_size X sequence_length] if not None.
- Returns:
- The auxiliary loss.
- """
- if gate_logits is None or not isinstance(gate_logits, tuple):
- return 0
- if isinstance(gate_logits, tuple):
- compute_device = gate_logits[0].device
- concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
- routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
- _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
- if attention_mask is None:
- # Compute the percentage of tokens routed to each experts
- tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
- # Compute the average probability of routing to these experts
- router_prob_per_expert = torch.mean(routing_weights, dim=0)
- else:
- batch_size, sequence_length = attention_mask.shape
- num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
- expert_attention_mask = (
- attention_mask[None, :, :, None, None]
- .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
- .reshape(-1, top_k, num_experts)
- .to(compute_device)
- )
- # Compute the percentage of tokens routed to each experts
- tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
- expert_attention_mask, dim=0
- )
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
- router_per_expert_attention_mask = (
- attention_mask[None, :, :, None]
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
- .reshape(-1, num_experts)
- .to(compute_device)
- )
- # Compute the average probability of routing to these experts
- router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
- router_per_expert_attention_mask, dim=0
- )
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
- return overall_loss * num_experts
- @use_experts_implementation
- class MixtralExperts(nn.Module):
- """Collection of expert weights stored as 3D tensors."""
- def __init__(self, config: MixtralConfig):
- super().__init__()
- self.num_experts = config.num_local_experts
- self.hidden_dim = config.hidden_size
- self.intermediate_dim = config.intermediate_size
- self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
- self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
- self.act_fn = ACT2FN[config.hidden_act]
- def forward(
- self,
- hidden_states: torch.Tensor,
- top_k_index: torch.Tensor,
- top_k_weights: torch.Tensor,
- ) -> torch.Tensor:
- final_hidden_states = torch.zeros_like(hidden_states)
- with torch.no_grad():
- expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
- expert_mask = expert_mask.permute(2, 1, 0)
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
- for expert_idx in expert_hit:
- expert_idx = expert_idx[0]
- if expert_idx == self.num_experts:
- continue
- top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
- current_state = hidden_states[token_idx]
- gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
- current_hidden_states = self.act_fn(gate) * up
- current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
- current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
- final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
- return final_hidden_states
- class MixtralTopKRouter(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.top_k = config.num_experts_per_tok
- self.num_experts = config.num_local_experts
- self.hidden_dim = config.hidden_size
- self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
- def forward(self, hidden_states):
- hidden_states = hidden_states.reshape(-1, self.hidden_dim)
- router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
- router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1)
- router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
- router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
- router_scores = router_top_value
- return router_logits, router_scores, router_indices
- class MixtralSparseMoeBlock(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.top_k = config.num_experts_per_tok
- self.jitter_noise = config.router_jitter_noise
- self.gate = MixtralTopKRouter(config)
- self.experts = MixtralExperts(config)
- def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- if self.training and self.jitter_noise > 0:
- hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
- _, top_k_weights, top_k_index = self.gate(hidden_states)
- hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)
- hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
- return hidden_states
- class MixtralRMSNorm(MistralRMSNorm):
- pass
- class MixtralRotaryEmbedding(MistralRotaryEmbedding):
- pass
- class MixtralAttention(MistralAttention):
- pass
- class MixtralDecoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: MixtralConfig, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.self_attn = MixtralAttention(config, layer_idx)
- self.mlp = MixtralSparseMoeBlock(config)
- self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- **kwargs,
- )
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- class MixtralPreTrainedModel(MistralPreTrainedModel):
- _can_record_outputs = {
- "router_logits": OutputRecorder(MixtralTopKRouter, index=0),
- "hidden_states": MixtralDecoderLayer,
- "attentions": MixtralAttention,
- }
- @torch.no_grad()
- def _init_weights(self, module):
- PreTrainedModel._init_weights(self, module)
- std = self.config.initializer_range
- if isinstance(module, MixtralExperts):
- init.normal_(module.gate_up_proj, mean=0.0, std=std)
- init.normal_(module.down_proj, mean=0.0, std=std)
- elif isinstance(module, MixtralTopKRouter):
- init.normal_(module.weight, mean=0.0, std=std)
- class MixtralModel(MistralModel):
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> MoeModelOutputWithPast:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- if position_ids is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
- position_ids = position_ids.unsqueeze(0)
- mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
- causal_mask = mask_function(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- position_ids=position_ids,
- )
- hidden_states = inputs_embeds
- position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
- hidden_states = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = self.norm(hidden_states)
- return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- )
- class MixtralForCausalLM(MistralForCausalLM):
- _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
- def __init__(self, config):
- super().__init__(config)
- self.model = MixtralModel(config)
- self.router_aux_loss_coef = config.router_aux_loss_coef
- self.num_experts = config.num_local_experts
- self.num_experts_per_tok = config.num_experts_per_tok
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- output_router_logits: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> MoeCausalLMOutputWithPast:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, MixtralForCausalLM
- >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
- >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
- output_router_logits = (
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
- )
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs: MoeModelOutputWithPast = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_router_logits=output_router_logits,
- **kwargs,
- )
- hidden_states = outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
- aux_loss = None
- if output_router_logits:
- aux_loss = load_balancing_loss_func(
- outputs.router_logits,
- self.num_experts,
- self.num_experts_per_tok,
- attention_mask,
- )
- if labels is not None:
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
- return MoeCausalLMOutputWithPast(
- loss=loss,
- aux_loss=aux_loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- router_logits=outputs.router_logits,
- )
- class MixtralForSequenceClassification(MistralForSequenceClassification):
- pass
- class MixtralForTokenClassification(MistralForTokenClassification):
- pass
- class MixtralForQuestionAnswering(MistralForQuestionAnswering):
- pass
- __all__ = [
- "MixtralForCausalLM",
- "MixtralForQuestionAnswering",
- "MixtralModel",
- "MixtralPreTrainedModel",
- "MixtralForSequenceClassification",
- "MixtralForTokenClassification",
- ]
|