| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264 |
- # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch Phi-3 model."""
- from collections.abc import Callable
- import torch
- from torch import nn
- from ...activations import ACT2FN
- from ...cache_utils import Cache
- from ...generation import GenerationMixin
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
- from ...processing_utils import Unpack
- from ...utils import logging
- from ..mistral.modeling_mistral import (
- MistralDecoderLayer,
- MistralForCausalLM,
- MistralForSequenceClassification,
- MistralForTokenClassification,
- MistralPreTrainedModel,
- eager_attention_forward,
- rotate_half,
- )
- from ..phi.modeling_phi import PhiRotaryEmbedding
- from .configuration_phi3 import Phi3Config
- logger = logging.get_logger(__name__)
- _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
- _CONFIG_FOR_DOC = "Phi3Config"
- class Phi3MLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
- self.activation_fn = ACT2FN[config.hidden_act]
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
- up_states = self.gate_up_proj(hidden_states)
- gate, up_states = up_states.chunk(2, dim=-1)
- up_states = up_states * self.activation_fn(gate)
- return self.down_proj(up_states)
- class Phi3RotaryEmbedding(PhiRotaryEmbedding):
- pass
- def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- rotary_dim = cos.shape[-1]
- q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
- k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
- q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
- k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
- return q_embed, k_embed
- class Phi3Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: Phi3Config, layer_idx: int | None = None):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.num_key_value_heads = config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = config.attention_dropout
- self.is_causal = True
- op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim)
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
- self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: torch.Tensor | None,
- past_key_values: Cache | None = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- qkv = self.qkv_proj(hidden_states)
- query_pos = self.config.num_attention_heads * self.head_dim
- query_states = qkv[..., :query_pos]
- key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
- value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
- query_states = query_states.view(hidden_shape).transpose(1, 2)
- key_states = key_states.view(hidden_shape).transpose(1, 2)
- value_states = value_states.view(hidden_shape).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- sliding_window=getattr(self.config, "sliding_window", None),
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class Phi3DecoderLayer(MistralDecoderLayer):
- def __init__(self, config: Phi3Config, layer_idx: int):
- super().__init__(config, layer_idx)
- self.config = config
- self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx)
- self.mlp = Phi3MLP(config)
- self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
- self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool | None = False,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- hidden_states, self_attn_weights = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama
- return hidden_states
- class Phi3PreTrainedModel(MistralPreTrainedModel):
- _version = "0.0.5"
- class Phi3ForCausalLM(MistralForCausalLM):
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- position_ids=None,
- use_cache=True,
- logits_to_keep=None,
- **kwargs,
- ):
- # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
- # process
- # When the first time input length reached long and short factor switching point, enforce re-compute cache
- # It will cause downside of slower at this single token position, however, better than current failure.
- if (
- past_key_values
- and hasattr(self.config, "original_max_position_embeddings")
- and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
- ):
- past_length = past_key_values.get_seq_length()
- if past_length <= self.config.original_max_position_embeddings:
- past_key_values = None
- model_inputs = GenerationMixin.prepare_inputs_for_generation(
- self,
- input_ids=input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- position_ids=position_ids,
- use_cache=use_cache,
- logits_to_keep=logits_to_keep,
- **kwargs,
- )
- return model_inputs
- class Phi3ForSequenceClassification(MistralForSequenceClassification):
- pass
- class Phi3ForTokenClassification(MistralForTokenClassification):
- pass
- __all__ = [
- "Phi3PreTrainedModel",
- "Phi3Model", # noqa: F822
- "Phi3ForCausalLM",
- "Phi3ForSequenceClassification",
- "Phi3ForTokenClassification",
- ]
|