| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410 |
- # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch Phimoe model."""
- from collections.abc import Callable
- import torch
- from torch import nn
- from ...modeling_layers import (
- GenericForSequenceClassification,
- )
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
- from ...utils.generic import maybe_autocast
- from ...utils.output_capturing import OutputRecorder
- from ..llama.modeling_llama import LlamaAttention
- from ..mixtral.modeling_mixtral import (
- MixtralDecoderLayer,
- MixtralExperts,
- MixtralForCausalLM,
- MixtralModel,
- MixtralPreTrainedModel,
- MixtralRotaryEmbedding,
- )
- from .configuration_phimoe import PhimoeConfig
- class PhimoeRotaryEmbedding(MixtralRotaryEmbedding):
- def __init__(self, config: PhimoeConfig, device=None):
- nn.Module.__init__()
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.rope_type = self.config.rope_parameters["rope_type"]
- self.rope_init_fn: Callable = self.compute_default_rope_parameters
- if self.rope_type != "default":
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
- def forward(self, x, position_ids=None, layer_type=None):
- if layer_type is not None:
- raise ValueError(
- f"{self.__class__.__name__} does not support layer types, but got `layer_type={layer_type}`"
- )
- mscale = None
- seq_len = torch.max(position_ids) + 1
- if self.config.rope_parameters["rope_type"] != "default" and seq_len:
- mscale = (
- self.config.rope_parameters["long_mscale"]
- if seq_len > self.config.rope_parameters["original_max_position_embeddings"]
- else self.config.rope_parameters["short_mscale"]
- )
- inv_freq, attention_scaling = self.rope_init_fn(self.config, x.device, seq_len)
- mscale = attention_scaling if mscale is None else mscale
- inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
- position_ids_expanded = position_ids[:, None, :].float()
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * mscale
- sin = emb.sin() * mscale
- return cos.to(x.dtype), sin.to(x.dtype)
- class PhimoeAttention(LlamaAttention):
- pass
- class PhimoeMultiplier(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- scores: torch.Tensor,
- multiplier: torch.Tensor,
- selected_experts: torch.Tensor,
- masked_gates: torch.Tensor,
- mask_for_one: torch.Tensor,
- ):
- """
- Forward pass for the custom autograd function.
- Args:
- ctx: Context object to save information for backward computation.
- scores (torch.Tensor): Input scores tensor.
- multiplier (torch.Tensor): Multiplier tensor.
- selected_experts (torch.Tensor): Tensor of selected experts.
- masked_gates (torch.Tensor): Masked gates tensor.
- mask_for_one (torch.Tensor): Mask for one tensor.
- Returns:
- torch.Tensor: Result of the forward pass.
- """
- ctx.save_for_backward(multiplier, selected_experts, masked_gates)
- return multiplier * mask_for_one
- @staticmethod
- def backward(
- ctx,
- grad_at_output: torch.Tensor,
- ):
- """
- Backward pass for the custom autograd function.
- Args:
- ctx: Context object with saved tensors from the forward pass.
- grad_at_output (torch.Tensor): Gradient at the output.
- Returns:
- tuple[torch.Tensor, None, None, None, None]: Gradients for the inputs.
- """
- multiplier, selected_experts, masked_gates = ctx.saved_tensors
- grad_at_output = grad_at_output * multiplier
- grad_at_scores_expanded = masked_gates * grad_at_output.mul(-1)
- grad_at_scores_expanded.scatter_add_(
- dim=-1,
- index=selected_experts,
- src=grad_at_output,
- )
- return (
- grad_at_scores_expanded,
- None,
- None,
- None,
- None,
- )
- def sparsemixer(scores, jitter_eps, training, top_k=2):
- """
- Sparse mixer function to select top-k experts and compute multipliers.
- Based on the paper: https://huggingface.co/papers/2409.12136
- We first replace the TopK(·) function as random sampling of discrete variables
- in model training. Then, following Liu et al. (2023a) and Liu et al. (2023b), we apply Heun's
- third order method to approximate the expert routing gradient and construct a modified
- back-propagation to give a mathematically sound gradient estimation for expert routing.
- Args:
- scores (torch.Tensor): Input scores tensor.
- jitter_eps (float): Jitter epsilon for numerical stability.
- training (bool): Flag indicating if the model is in training mode.
- top_k (int): Number of top experts to select.
- Returns:
- tuple[torch.Tensor, torch.Tensor]: Multiplier and selected experts tensors.
- """
- with torch.no_grad():
- # Compute mask for sparsity
- mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
- factor = scores.abs().clamp(min=mask_logits_threshold)
- mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
- # Apply mask
- masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf"))
- if training:
- selected_experts = (
- (
- masked_gates
- - torch.empty_like(masked_gates, memory_format=torch.legacy_contiguous_format).exponential_().log()
- )
- .max(dim=-1)[1]
- .unsqueeze(-1)
- ) # Gumbel sampling, more robust than the multinomial method
- else:
- selected_experts = max_ind
- # Compute scores for gradients
- masked_gates = torch.softmax(masked_gates, dim=-1)
- multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
- if training:
- # Compute midpoint mask
- max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True)
- mask_for_one = torch.logical_or(
- selected_experts == max_ind,
- torch.rand_like(max_scores) > 0.75, # Heun's third-order method
- )
- # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
- mask_for_one = torch.add(0.3333, mask_for_one, alpha=0.6667).type_as(masked_gates)
- multiplier = PhimoeMultiplier.apply(
- scores,
- multiplier_o,
- selected_experts,
- masked_gates,
- mask_for_one,
- )
- else:
- multiplier = multiplier_o
- # Masked out first expert
- masked_scores = torch.scatter(
- scores,
- -1,
- selected_experts,
- float("-inf"),
- )
- with torch.no_grad():
- # Compute mask for sparsity
- mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
- factor = scores.abs().clamp(min=mask_logits_threshold)
- mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
- # Apply mask
- masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float("-inf"))
- if training:
- selected_experts_top2 = (
- (
- masked_gates_top2
- - torch.empty_like(masked_gates_top2, memory_format=torch.legacy_contiguous_format)
- .exponential_()
- .log()
- )
- .max(dim=-1)[1]
- .unsqueeze(-1)
- ) # Gumbel sampling, more robust than the multinomial method
- else:
- selected_experts_top2 = max_ind
- # Compute scores for gradients
- masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
- multiplier_top2_o = masked_gates_top2.gather(dim=-1, index=selected_experts_top2)
- if training:
- # Compute midpoint mask
- max_scores, max_ind = masked_gates_top2.max(dim=-1, keepdim=True)
- mask_for_one_top2 = torch.logical_or(
- selected_experts_top2 == max_ind,
- torch.rand_like(max_scores).uniform_() > 0.75, # Heun's third-order method
- )
- # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
- mask_for_one_top2 = torch.add(0.3333, mask_for_one_top2, alpha=0.6667).type_as(masked_gates_top2)
- multiplier_top2 = PhimoeMultiplier.apply(
- scores,
- multiplier_top2_o,
- selected_experts_top2,
- masked_gates_top2,
- mask_for_one_top2,
- )
- else:
- multiplier_top2 = multiplier_top2_o
- multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
- selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)
- return (
- multiplier,
- selected_experts,
- )
- class PhimoeExperts(MixtralExperts):
- pass
- class PhimoeTopKRouter(nn.Linear):
- def __init__(self, config: PhimoeConfig):
- super().__init__(config.hidden_size, config.num_local_experts, bias=False)
- self.router_jitter_noise = config.router_jitter_noise
- self.input_jitter_noise = config.input_jitter_noise
- self.top_k = config.num_experts_per_tok
- def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
- if self.training and self.input_jitter_noise > 0:
- hidden_states *= torch.empty_like(hidden_states).uniform_(
- 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise
- )
- router_logits = super().forward(hidden_states)
- routing_weights, selected_experts = sparsemixer(
- router_logits, jitter_eps=self.router_jitter_noise, training=self.training, top_k=self.top_k
- )
- return router_logits, routing_weights, selected_experts
- class PhimoeSparseMoeBlock(nn.Module):
- """
- This implementation is
- strictly equivalent to standard MoE with full capacity (no
- dropped tokens). It's faster since it formulates MoE operations
- in terms of block-sparse operations to accommodate imbalanced
- assignments of tokens to experts, whereas standard MoE either
- (1) drop tokens at the cost of reduced performance or (2) set
- capacity factor to number of experts and thus waste computation
- and memory on padding.
- """
- def __init__(self, config):
- super().__init__()
- self.hidden_dim = config.hidden_size
- self.ffn_dim = config.intermediate_size
- self.num_experts = config.num_local_experts
- self.top_k = config.num_experts_per_tok
- self.router = PhimoeTopKRouter(config)
- self.experts = PhimoeExperts(config)
- self.input_jitter_noise = config.input_jitter_noise
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- if self.training and self.input_jitter_noise > 0:
- hidden_states *= torch.empty_like(hidden_states).uniform_(
- 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise
- )
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- hidden_states = hidden_states.reshape(-1, hidden_dim)
- _, routing_weights, selected_experts = self.router(hidden_states)
- final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
- return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
- class PhimoeDecoderLayer(MixtralDecoderLayer):
- def __init__(self, config: PhimoeConfig, layer_idx: int):
- super().__init__(config, layer_idx)
- # Phimoe uses nn.LayerNorm
- self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True)
- self.post_attention_layernorm = nn.LayerNorm(
- config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True
- )
- class PhimoePreTrainedModel(MixtralPreTrainedModel):
- _can_record_outputs = {
- "router_logits": OutputRecorder(PhimoeTopKRouter, index=0),
- "hidden_states": PhimoeDecoderLayer,
- "attentions": PhimoeAttention,
- }
- class PhimoeModel(MixtralModel):
- def __init__(self, config: PhimoeConfig):
- super().__init__(config)
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True)
- class PhimoeForCausalLM(MixtralForCausalLM):
- def __init__(self, config):
- super().__init__(config)
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias)
- # Copied from transformers.models.phi3.modeling_phi3.Phi3ForCausalLM.prepare_inputs_for_generation
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- position_ids=None,
- use_cache=True,
- logits_to_keep=None,
- **kwargs,
- ):
- # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
- # process
- # When the first time input length reached long and short factor switching point, enforce re-compute cache
- # It will cause downside of slower at this single token position, however, better than current failure.
- if (
- past_key_values
- and hasattr(self.config, "original_max_position_embeddings")
- and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
- ):
- past_length = past_key_values.get_seq_length()
- if past_length <= self.config.original_max_position_embeddings:
- past_key_values = None
- model_inputs = super().prepare_inputs_for_generation(
- input_ids=input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- position_ids=position_ids,
- use_cache=use_cache,
- logits_to_keep=logits_to_keep,
- **kwargs,
- )
- return model_inputs
- class PhimoeForSequenceClassification(GenericForSequenceClassification, PhimoePreTrainedModel): ...
- __all__ = [
- "PhimoePreTrainedModel",
- "PhimoeModel",
- "PhimoeForCausalLM",
- "PhimoeForSequenceClassification",
- ]
|