modular_phi3.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Phi-3 model."""
  15. from collections.abc import Callable
  16. import torch
  17. from torch import nn
  18. from ...activations import ACT2FN
  19. from ...cache_utils import Cache
  20. from ...generation import GenerationMixin
  21. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  22. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  23. from ...processing_utils import Unpack
  24. from ...utils import logging
  25. from ..mistral.modeling_mistral import (
  26. MistralDecoderLayer,
  27. MistralForCausalLM,
  28. MistralForSequenceClassification,
  29. MistralForTokenClassification,
  30. MistralPreTrainedModel,
  31. eager_attention_forward,
  32. rotate_half,
  33. )
  34. from ..phi.modeling_phi import PhiRotaryEmbedding
  35. from .configuration_phi3 import Phi3Config
  36. logger = logging.get_logger(__name__)
  37. _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
  38. _CONFIG_FOR_DOC = "Phi3Config"
  39. class Phi3MLP(nn.Module):
  40. def __init__(self, config):
  41. super().__init__()
  42. self.config = config
  43. self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
  44. self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  45. self.activation_fn = ACT2FN[config.hidden_act]
  46. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  47. up_states = self.gate_up_proj(hidden_states)
  48. gate, up_states = up_states.chunk(2, dim=-1)
  49. up_states = up_states * self.activation_fn(gate)
  50. return self.down_proj(up_states)
  51. class Phi3RotaryEmbedding(PhiRotaryEmbedding):
  52. pass
  53. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  54. """Applies Rotary Position Embedding to the query and key tensors.
  55. Args:
  56. q (`torch.Tensor`): The query tensor.
  57. k (`torch.Tensor`): The key tensor.
  58. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  59. sin (`torch.Tensor`): The sine part of the rotary embedding.
  60. unsqueeze_dim (`int`, *optional*, defaults to 1):
  61. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  62. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  63. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  64. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  65. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  66. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  67. Returns:
  68. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  69. """
  70. cos = cos.unsqueeze(unsqueeze_dim)
  71. sin = sin.unsqueeze(unsqueeze_dim)
  72. rotary_dim = cos.shape[-1]
  73. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  74. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  75. q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
  76. k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
  77. return q_embed, k_embed
  78. class Phi3Attention(nn.Module):
  79. """Multi-headed attention from 'Attention Is All You Need' paper"""
  80. def __init__(self, config: Phi3Config, layer_idx: int | None = None):
  81. super().__init__()
  82. self.config = config
  83. self.layer_idx = layer_idx
  84. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  85. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  86. self.num_key_value_heads = config.num_key_value_heads
  87. self.scaling = self.head_dim**-0.5
  88. self.attention_dropout = config.attention_dropout
  89. self.is_causal = True
  90. op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim)
  91. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  92. self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False)
  93. def forward(
  94. self,
  95. hidden_states: torch.Tensor,
  96. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  97. attention_mask: torch.Tensor | None,
  98. past_key_values: Cache | None = None,
  99. **kwargs: Unpack[FlashAttentionKwargs],
  100. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  101. input_shape = hidden_states.shape[:-1]
  102. hidden_shape = (*input_shape, -1, self.head_dim)
  103. qkv = self.qkv_proj(hidden_states)
  104. query_pos = self.config.num_attention_heads * self.head_dim
  105. query_states = qkv[..., :query_pos]
  106. key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
  107. value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
  108. query_states = query_states.view(hidden_shape).transpose(1, 2)
  109. key_states = key_states.view(hidden_shape).transpose(1, 2)
  110. value_states = value_states.view(hidden_shape).transpose(1, 2)
  111. cos, sin = position_embeddings
  112. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  113. if past_key_values is not None:
  114. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  115. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  116. self.config._attn_implementation, eager_attention_forward
  117. )
  118. attn_output, attn_weights = attention_interface(
  119. self,
  120. query_states,
  121. key_states,
  122. value_states,
  123. attention_mask,
  124. dropout=0.0 if not self.training else self.attention_dropout,
  125. scaling=self.scaling,
  126. sliding_window=getattr(self.config, "sliding_window", None),
  127. **kwargs,
  128. )
  129. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  130. attn_output = self.o_proj(attn_output)
  131. return attn_output, attn_weights
  132. class Phi3DecoderLayer(MistralDecoderLayer):
  133. def __init__(self, config: Phi3Config, layer_idx: int):
  134. super().__init__(config, layer_idx)
  135. self.config = config
  136. self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx)
  137. self.mlp = Phi3MLP(config)
  138. self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
  139. self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
  140. def forward(
  141. self,
  142. hidden_states: torch.Tensor,
  143. attention_mask: torch.Tensor | None = None,
  144. position_ids: torch.LongTensor | None = None,
  145. past_key_values: Cache | None = None,
  146. use_cache: bool | None = False,
  147. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  148. **kwargs: Unpack[FlashAttentionKwargs],
  149. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  150. residual = hidden_states
  151. hidden_states = self.input_layernorm(hidden_states)
  152. hidden_states, self_attn_weights = self.self_attn(
  153. hidden_states=hidden_states,
  154. attention_mask=attention_mask,
  155. position_ids=position_ids,
  156. past_key_values=past_key_values,
  157. use_cache=use_cache,
  158. position_embeddings=position_embeddings,
  159. **kwargs,
  160. )
  161. hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama
  162. residual = hidden_states
  163. hidden_states = self.post_attention_layernorm(hidden_states)
  164. hidden_states = self.mlp(hidden_states)
  165. hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama
  166. return hidden_states
  167. class Phi3PreTrainedModel(MistralPreTrainedModel):
  168. _version = "0.0.5"
  169. class Phi3ForCausalLM(MistralForCausalLM):
  170. def prepare_inputs_for_generation(
  171. self,
  172. input_ids,
  173. past_key_values=None,
  174. attention_mask=None,
  175. inputs_embeds=None,
  176. position_ids=None,
  177. use_cache=True,
  178. logits_to_keep=None,
  179. **kwargs,
  180. ):
  181. # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
  182. # process
  183. # When the first time input length reached long and short factor switching point, enforce re-compute cache
  184. # It will cause downside of slower at this single token position, however, better than current failure.
  185. if (
  186. past_key_values
  187. and hasattr(self.config, "original_max_position_embeddings")
  188. and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
  189. ):
  190. past_length = past_key_values.get_seq_length()
  191. if past_length <= self.config.original_max_position_embeddings:
  192. past_key_values = None
  193. model_inputs = GenerationMixin.prepare_inputs_for_generation(
  194. self,
  195. input_ids=input_ids,
  196. past_key_values=past_key_values,
  197. attention_mask=attention_mask,
  198. inputs_embeds=inputs_embeds,
  199. position_ids=position_ids,
  200. use_cache=use_cache,
  201. logits_to_keep=logits_to_keep,
  202. **kwargs,
  203. )
  204. return model_inputs
  205. class Phi3ForSequenceClassification(MistralForSequenceClassification):
  206. pass
  207. class Phi3ForTokenClassification(MistralForTokenClassification):
  208. pass
  209. __all__ = [
  210. "Phi3PreTrainedModel",
  211. "Phi3Model", # noqa: F822
  212. "Phi3ForCausalLM",
  213. "Phi3ForSequenceClassification",
  214. "Phi3ForTokenClassification",
  215. ]