modular_phi.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. from collections.abc import Callable
  2. from typing import Optional
  3. import torch
  4. import torch.nn as nn
  5. from ...cache_utils import Cache, DynamicCache
  6. from ...masking_utils import create_causal_mask
  7. from ...modeling_layers import GradientCheckpointingLayer
  8. from ...modeling_outputs import (
  9. BaseModelOutputWithPast,
  10. )
  11. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  12. from ...processing_utils import Unpack
  13. from ...utils import TransformersKwargs, auto_docstring, logging
  14. from ...utils.generic import merge_with_config_defaults
  15. from ...utils.output_capturing import capture_outputs
  16. from ..clip.modeling_clip import CLIPMLP
  17. from ..llama.modeling_llama import (
  18. LlamaAttention,
  19. LlamaForCausalLM,
  20. LlamaForSequenceClassification,
  21. LlamaForTokenClassification,
  22. LlamaModel,
  23. LlamaPreTrainedModel,
  24. LlamaRotaryEmbedding,
  25. apply_rotary_pos_emb,
  26. eager_attention_forward,
  27. )
  28. from .configuration_phi import PhiConfig
  29. logger = logging.get_logger(__name__)
  30. _CHECKPOINT_FOR_DOC = "microsoft/phi-1"
  31. _CONFIG_FOR_DOC = "PhiConfig"
  32. class PhiRotaryEmbedding(LlamaRotaryEmbedding):
  33. @staticmethod
  34. def compute_default_rope_parameters(
  35. config: PhiConfig | None = None,
  36. device: Optional["torch.device"] = None,
  37. seq_len: int | None = None,
  38. ) -> tuple["torch.Tensor", float]:
  39. """
  40. Computes the inverse frequencies according to the original RoPE implementation
  41. Args:
  42. config ([`~transformers.PreTrainedConfig`]):
  43. The model configuration.
  44. device (`torch.device`):
  45. The device to use for initialization of the inverse frequencies.
  46. seq_len (`int`, *optional*):
  47. The current sequence length. Unused for this type of RoPE.
  48. Returns:
  49. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  50. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  51. """
  52. base = config.rope_parameters["rope_theta"]
  53. partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
  54. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  55. dim = int(head_dim * partial_rotary_factor)
  56. attention_factor = 1.0 # Unused in this type of RoPE
  57. # Compute the inverse frequencies
  58. inv_freq = 1.0 / (
  59. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  60. )
  61. return inv_freq, attention_factor
  62. class PhiAttention(LlamaAttention):
  63. def __init__(self, config: PhiConfig, layer_idx: int):
  64. super().__init__(config, layer_idx)
  65. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
  66. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  67. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  68. self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True)
  69. del self.o_proj
  70. self.rotary_ndims = int(self.head_dim * config.rope_parameters["partial_rotary_factor"])
  71. self.qk_layernorm = config.qk_layernorm
  72. if self.qk_layernorm:
  73. self.q_layernorm = nn.LayerNorm(
  74. config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True
  75. )
  76. self.k_layernorm = nn.LayerNorm(
  77. config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True
  78. )
  79. def forward(
  80. self,
  81. hidden_states: torch.Tensor,
  82. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  83. attention_mask: torch.Tensor | None,
  84. past_key_values: Cache | None = None,
  85. **kwargs,
  86. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  87. input_shape = hidden_states.shape[:-1]
  88. hidden_shape = (*input_shape, -1, self.head_dim)
  89. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  90. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  91. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  92. if self.qk_layernorm:
  93. query_states = self.q_layernorm(query_states)
  94. key_states = self.k_layernorm(key_states)
  95. cos, sin = position_embeddings
  96. # Partial rotary embedding
  97. query_rot, query_pass = (
  98. query_states[..., : self.rotary_ndims],
  99. query_states[..., self.rotary_ndims :],
  100. )
  101. key_rot, key_pass = (
  102. key_states[..., : self.rotary_ndims],
  103. key_states[..., self.rotary_ndims :],
  104. )
  105. # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
  106. query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
  107. # [batch_size, seq_length, num_heads, head_dim]
  108. query_states = torch.cat((query_rot, query_pass), dim=-1)
  109. key_states = torch.cat((key_rot, key_pass), dim=-1)
  110. if past_key_values is not None:
  111. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  112. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  113. self.config._attn_implementation, eager_attention_forward
  114. )
  115. attn_output, attn_weights = attention_interface(
  116. self,
  117. query_states,
  118. key_states,
  119. value_states,
  120. attention_mask,
  121. dropout=0.0 if not self.training else self.attention_dropout,
  122. scaling=self.scaling,
  123. **kwargs,
  124. )
  125. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  126. attn_output = self.dense(attn_output)
  127. return attn_output, attn_weights
  128. class PhiMLP(CLIPMLP):
  129. pass
  130. class PhiDecoderLayer(GradientCheckpointingLayer):
  131. def __init__(self, config: PhiConfig, layer_idx: int):
  132. super().__init__()
  133. self.self_attn = PhiAttention(config, layer_idx=layer_idx)
  134. self.mlp = PhiMLP(config)
  135. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  136. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  137. def forward(
  138. self,
  139. hidden_states: torch.Tensor,
  140. attention_mask: torch.Tensor | None = None,
  141. position_ids: torch.LongTensor | None = None,
  142. past_key_values: Cache | None = None,
  143. use_cache: bool | None = False,
  144. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  145. **kwargs: Unpack[TransformersKwargs],
  146. ) -> torch.Tensor:
  147. residual = hidden_states
  148. hidden_states = self.input_layernorm(hidden_states)
  149. attn_outputs, _ = self.self_attn(
  150. hidden_states=hidden_states,
  151. attention_mask=attention_mask,
  152. position_ids=position_ids,
  153. past_key_values=past_key_values,
  154. use_cache=use_cache,
  155. position_embeddings=position_embeddings,
  156. **kwargs,
  157. )
  158. attn_outputs = self.resid_dropout(attn_outputs)
  159. feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
  160. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  161. return hidden_states
  162. class PhiPreTrainedModel(LlamaPreTrainedModel):
  163. _can_record_outputs = {
  164. "hidden_states": PhiDecoderLayer,
  165. "attentions": PhiAttention,
  166. }
  167. class PhiModel(LlamaModel):
  168. def __init__(self, config: PhiConfig):
  169. super().__init__(config)
  170. self.layers = nn.ModuleList(
  171. [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  172. )
  173. self.embed_dropout = nn.Dropout(config.embd_pdrop)
  174. self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  175. del self.norm
  176. @merge_with_config_defaults
  177. @capture_outputs
  178. @auto_docstring
  179. def forward(
  180. self,
  181. input_ids: torch.LongTensor | None = None,
  182. attention_mask: torch.Tensor | None = None,
  183. position_ids: torch.LongTensor | None = None,
  184. past_key_values: Cache | None = None,
  185. inputs_embeds: torch.FloatTensor | None = None,
  186. use_cache: bool | None = None,
  187. **kwargs: Unpack[TransformersKwargs],
  188. ) -> BaseModelOutputWithPast:
  189. if (input_ids is None) ^ (inputs_embeds is not None):
  190. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  191. if inputs_embeds is None:
  192. inputs_embeds = self.embed_tokens(input_ids)
  193. if use_cache and past_key_values is None:
  194. past_key_values = DynamicCache(config=self.config)
  195. if position_ids is None:
  196. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  197. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  198. position_ids = position_ids.unsqueeze(0)
  199. causal_mask = create_causal_mask(
  200. config=self.config,
  201. inputs_embeds=inputs_embeds,
  202. attention_mask=attention_mask,
  203. past_key_values=past_key_values,
  204. position_ids=position_ids,
  205. )
  206. inputs_embeds = self.embed_dropout(inputs_embeds)
  207. hidden_states = inputs_embeds
  208. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  209. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  210. hidden_states = decoder_layer(
  211. hidden_states,
  212. attention_mask=causal_mask,
  213. position_ids=position_ids,
  214. past_key_values=past_key_values,
  215. use_cache=use_cache,
  216. position_embeddings=position_embeddings,
  217. **kwargs,
  218. )
  219. hidden_states = self.final_layernorm(hidden_states)
  220. return BaseModelOutputWithPast(
  221. last_hidden_state=hidden_states,
  222. past_key_values=past_key_values,
  223. )
  224. class PhiForCausalLM(LlamaForCausalLM):
  225. def __init__(self, config):
  226. super().__init__(config)
  227. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
  228. class PhiForSequenceClassification(LlamaForSequenceClassification):
  229. pass
  230. class PhiForTokenClassification(LlamaForTokenClassification):
  231. pass
  232. __all__ = [
  233. "PhiPreTrainedModel",
  234. "PhiModel",
  235. "PhiForCausalLM",
  236. "PhiForSequenceClassification",
  237. "PhiForTokenClassification",
  238. ]