modeling_phi.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/phi/modular_phi.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_phi.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. from collections.abc import Callable
  8. from typing import Optional
  9. import torch
  10. import torch.nn as nn
  11. from ...activations import ACT2FN
  12. from ...cache_utils import Cache, DynamicCache
  13. from ...generation import GenerationMixin
  14. from ...integrations import use_kernel_func_from_hub, use_kernelized_func
  15. from ...masking_utils import create_causal_mask
  16. from ...modeling_layers import (
  17. GenericForSequenceClassification,
  18. GenericForTokenClassification,
  19. GradientCheckpointingLayer,
  20. )
  21. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  22. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  23. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  24. from ...processing_utils import Unpack
  25. from ...utils import TransformersKwargs, auto_docstring
  26. from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults
  27. from ...utils.output_capturing import capture_outputs
  28. from .configuration_phi import PhiConfig
  29. class PhiRotaryEmbedding(nn.Module):
  30. inv_freq: torch.Tensor # fix linting for `register_buffer`
  31. def __init__(self, config: PhiConfig, device=None):
  32. super().__init__()
  33. self.max_seq_len_cached = config.max_position_embeddings
  34. self.original_max_seq_len = config.max_position_embeddings
  35. self.config = config
  36. self.rope_type = self.config.rope_parameters["rope_type"]
  37. rope_init_fn: Callable = self.compute_default_rope_parameters
  38. if self.rope_type != "default":
  39. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  40. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  41. self.register_buffer("inv_freq", inv_freq, persistent=False)
  42. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  43. @staticmethod
  44. def compute_default_rope_parameters(
  45. config: PhiConfig | None = None,
  46. device: Optional["torch.device"] = None,
  47. seq_len: int | None = None,
  48. ) -> tuple["torch.Tensor", float]:
  49. """
  50. Computes the inverse frequencies according to the original RoPE implementation
  51. Args:
  52. config ([`~transformers.PreTrainedConfig`]):
  53. The model configuration.
  54. device (`torch.device`):
  55. The device to use for initialization of the inverse frequencies.
  56. seq_len (`int`, *optional*):
  57. The current sequence length. Unused for this type of RoPE.
  58. Returns:
  59. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  60. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  61. """
  62. base = config.rope_parameters["rope_theta"]
  63. partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
  64. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  65. dim = int(head_dim * partial_rotary_factor)
  66. attention_factor = 1.0 # Unused in this type of RoPE
  67. # Compute the inverse frequencies
  68. inv_freq = 1.0 / (
  69. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  70. )
  71. return inv_freq, attention_factor
  72. @torch.no_grad()
  73. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  74. def forward(self, x, position_ids):
  75. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  76. position_ids_expanded = position_ids[:, None, :].float()
  77. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  78. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  79. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  80. emb = torch.cat((freqs, freqs), dim=-1)
  81. cos = emb.cos() * self.attention_scaling
  82. sin = emb.sin() * self.attention_scaling
  83. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  84. def rotate_half(x):
  85. """Rotates half the hidden dims of the input."""
  86. x1 = x[..., : x.shape[-1] // 2]
  87. x2 = x[..., x.shape[-1] // 2 :]
  88. return torch.cat((-x2, x1), dim=-1)
  89. @use_kernel_func_from_hub("rotary_pos_emb")
  90. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  91. """Applies Rotary Position Embedding to the query and key tensors.
  92. Args:
  93. q (`torch.Tensor`): The query tensor.
  94. k (`torch.Tensor`): The key tensor.
  95. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  96. sin (`torch.Tensor`): The sine part of the rotary embedding.
  97. unsqueeze_dim (`int`, *optional*, defaults to 1):
  98. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  99. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  100. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  101. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  102. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  103. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  104. Returns:
  105. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  106. """
  107. cos = cos.unsqueeze(unsqueeze_dim)
  108. sin = sin.unsqueeze(unsqueeze_dim)
  109. q_embed = (q * cos) + (rotate_half(q) * sin)
  110. k_embed = (k * cos) + (rotate_half(k) * sin)
  111. return q_embed, k_embed
  112. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  113. """
  114. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  115. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  116. """
  117. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  118. if n_rep == 1:
  119. return hidden_states
  120. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  121. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  122. def eager_attention_forward(
  123. module: nn.Module,
  124. query: torch.Tensor,
  125. key: torch.Tensor,
  126. value: torch.Tensor,
  127. attention_mask: torch.Tensor | None,
  128. scaling: float,
  129. dropout: float = 0.0,
  130. **kwargs: Unpack[TransformersKwargs],
  131. ):
  132. key_states = repeat_kv(key, module.num_key_value_groups)
  133. value_states = repeat_kv(value, module.num_key_value_groups)
  134. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  135. if attention_mask is not None:
  136. attn_weights = attn_weights + attention_mask
  137. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  138. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  139. attn_output = torch.matmul(attn_weights, value_states)
  140. attn_output = attn_output.transpose(1, 2).contiguous()
  141. return attn_output, attn_weights
  142. @use_kernelized_func(apply_rotary_pos_emb)
  143. class PhiAttention(nn.Module):
  144. """Multi-headed attention from 'Attention Is All You Need' paper"""
  145. def __init__(self, config: PhiConfig, layer_idx: int):
  146. super().__init__()
  147. self.config = config
  148. self.layer_idx = layer_idx
  149. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  150. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  151. self.scaling = self.head_dim**-0.5
  152. self.attention_dropout = config.attention_dropout
  153. self.is_causal = True
  154. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
  155. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  156. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  157. self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True)
  158. self.rotary_ndims = int(self.head_dim * config.rope_parameters["partial_rotary_factor"])
  159. self.qk_layernorm = config.qk_layernorm
  160. if self.qk_layernorm:
  161. self.q_layernorm = nn.LayerNorm(
  162. config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True
  163. )
  164. self.k_layernorm = nn.LayerNorm(
  165. config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True
  166. )
  167. def forward(
  168. self,
  169. hidden_states: torch.Tensor,
  170. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  171. attention_mask: torch.Tensor | None,
  172. past_key_values: Cache | None = None,
  173. **kwargs,
  174. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  175. input_shape = hidden_states.shape[:-1]
  176. hidden_shape = (*input_shape, -1, self.head_dim)
  177. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  178. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  179. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  180. if self.qk_layernorm:
  181. query_states = self.q_layernorm(query_states)
  182. key_states = self.k_layernorm(key_states)
  183. cos, sin = position_embeddings
  184. # Partial rotary embedding
  185. query_rot, query_pass = (
  186. query_states[..., : self.rotary_ndims],
  187. query_states[..., self.rotary_ndims :],
  188. )
  189. key_rot, key_pass = (
  190. key_states[..., : self.rotary_ndims],
  191. key_states[..., self.rotary_ndims :],
  192. )
  193. # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
  194. query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
  195. # [batch_size, seq_length, num_heads, head_dim]
  196. query_states = torch.cat((query_rot, query_pass), dim=-1)
  197. key_states = torch.cat((key_rot, key_pass), dim=-1)
  198. if past_key_values is not None:
  199. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  200. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  201. self.config._attn_implementation, eager_attention_forward
  202. )
  203. attn_output, attn_weights = attention_interface(
  204. self,
  205. query_states,
  206. key_states,
  207. value_states,
  208. attention_mask,
  209. dropout=0.0 if not self.training else self.attention_dropout,
  210. scaling=self.scaling,
  211. **kwargs,
  212. )
  213. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  214. attn_output = self.dense(attn_output)
  215. return attn_output, attn_weights
  216. class PhiMLP(nn.Module):
  217. def __init__(self, config):
  218. super().__init__()
  219. self.config = config
  220. self.activation_fn = ACT2FN[config.hidden_act]
  221. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  222. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  223. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  224. hidden_states = self.fc1(hidden_states)
  225. hidden_states = self.activation_fn(hidden_states)
  226. hidden_states = self.fc2(hidden_states)
  227. return hidden_states
  228. class PhiDecoderLayer(GradientCheckpointingLayer):
  229. def __init__(self, config: PhiConfig, layer_idx: int):
  230. super().__init__()
  231. self.self_attn = PhiAttention(config, layer_idx=layer_idx)
  232. self.mlp = PhiMLP(config)
  233. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  234. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  235. def forward(
  236. self,
  237. hidden_states: torch.Tensor,
  238. attention_mask: torch.Tensor | None = None,
  239. position_ids: torch.LongTensor | None = None,
  240. past_key_values: Cache | None = None,
  241. use_cache: bool | None = False,
  242. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  243. **kwargs: Unpack[TransformersKwargs],
  244. ) -> torch.Tensor:
  245. residual = hidden_states
  246. hidden_states = self.input_layernorm(hidden_states)
  247. attn_outputs, _ = self.self_attn(
  248. hidden_states=hidden_states,
  249. attention_mask=attention_mask,
  250. position_ids=position_ids,
  251. past_key_values=past_key_values,
  252. use_cache=use_cache,
  253. position_embeddings=position_embeddings,
  254. **kwargs,
  255. )
  256. attn_outputs = self.resid_dropout(attn_outputs)
  257. feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
  258. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  259. return hidden_states
  260. @auto_docstring
  261. class PhiPreTrainedModel(PreTrainedModel):
  262. config: PhiConfig
  263. base_model_prefix = "model"
  264. supports_gradient_checkpointing = True
  265. _no_split_modules = ["PhiDecoderLayer"]
  266. _skip_keys_device_placement = ["past_key_values"]
  267. _supports_flash_attn = True
  268. _supports_sdpa = True
  269. _supports_flex_attn = True
  270. _can_compile_fullgraph = True
  271. _supports_attention_backend = True
  272. _can_record_outputs = {
  273. "hidden_states": PhiDecoderLayer,
  274. "attentions": PhiAttention,
  275. }
  276. @auto_docstring
  277. class PhiModel(PhiPreTrainedModel):
  278. def __init__(self, config: PhiConfig):
  279. super().__init__(config)
  280. self.padding_idx = config.pad_token_id
  281. self.vocab_size = config.vocab_size
  282. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  283. self.layers = nn.ModuleList(
  284. [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  285. )
  286. self.rotary_emb = PhiRotaryEmbedding(config=config)
  287. self.gradient_checkpointing = False
  288. self.embed_dropout = nn.Dropout(config.embd_pdrop)
  289. self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  290. # Initialize weights and apply final processing
  291. self.post_init()
  292. @merge_with_config_defaults
  293. @capture_outputs
  294. @auto_docstring
  295. def forward(
  296. self,
  297. input_ids: torch.LongTensor | None = None,
  298. attention_mask: torch.Tensor | None = None,
  299. position_ids: torch.LongTensor | None = None,
  300. past_key_values: Cache | None = None,
  301. inputs_embeds: torch.FloatTensor | None = None,
  302. use_cache: bool | None = None,
  303. **kwargs: Unpack[TransformersKwargs],
  304. ) -> BaseModelOutputWithPast:
  305. if (input_ids is None) ^ (inputs_embeds is not None):
  306. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  307. if inputs_embeds is None:
  308. inputs_embeds = self.embed_tokens(input_ids)
  309. if use_cache and past_key_values is None:
  310. past_key_values = DynamicCache(config=self.config)
  311. if position_ids is None:
  312. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  313. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  314. position_ids = position_ids.unsqueeze(0)
  315. causal_mask = create_causal_mask(
  316. config=self.config,
  317. inputs_embeds=inputs_embeds,
  318. attention_mask=attention_mask,
  319. past_key_values=past_key_values,
  320. position_ids=position_ids,
  321. )
  322. inputs_embeds = self.embed_dropout(inputs_embeds)
  323. hidden_states = inputs_embeds
  324. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  325. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  326. hidden_states = decoder_layer(
  327. hidden_states,
  328. attention_mask=causal_mask,
  329. position_ids=position_ids,
  330. past_key_values=past_key_values,
  331. use_cache=use_cache,
  332. position_embeddings=position_embeddings,
  333. **kwargs,
  334. )
  335. hidden_states = self.final_layernorm(hidden_states)
  336. return BaseModelOutputWithPast(
  337. last_hidden_state=hidden_states,
  338. past_key_values=past_key_values,
  339. )
  340. @auto_docstring
  341. class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
  342. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  343. _tp_plan = {"lm_head": "colwise_gather_output"}
  344. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  345. def __init__(self, config):
  346. super().__init__(config)
  347. self.model = PhiModel(config)
  348. self.vocab_size = config.vocab_size
  349. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
  350. # Initialize weights and apply final processing
  351. self.post_init()
  352. @can_return_tuple
  353. @auto_docstring
  354. def forward(
  355. self,
  356. input_ids: torch.LongTensor | None = None,
  357. attention_mask: torch.Tensor | None = None,
  358. position_ids: torch.LongTensor | None = None,
  359. past_key_values: Cache | None = None,
  360. inputs_embeds: torch.FloatTensor | None = None,
  361. labels: torch.LongTensor | None = None,
  362. use_cache: bool | None = None,
  363. logits_to_keep: int | torch.Tensor = 0,
  364. **kwargs: Unpack[TransformersKwargs],
  365. ) -> CausalLMOutputWithPast:
  366. r"""
  367. Example:
  368. ```python
  369. >>> from transformers import AutoTokenizer, PhiForCausalLM
  370. >>> model = PhiForCausalLM.from_pretrained("meta-phi/Phi-2-7b-hf")
  371. >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi/Phi-2-7b-hf")
  372. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  373. >>> inputs = tokenizer(prompt, return_tensors="pt")
  374. >>> # Generate
  375. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  376. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  377. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  378. ```"""
  379. outputs: BaseModelOutputWithPast = self.model(
  380. input_ids=input_ids,
  381. attention_mask=attention_mask,
  382. position_ids=position_ids,
  383. past_key_values=past_key_values,
  384. inputs_embeds=inputs_embeds,
  385. use_cache=use_cache,
  386. **kwargs,
  387. )
  388. hidden_states = outputs.last_hidden_state
  389. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  390. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  391. logits = self.lm_head(hidden_states[:, slice_indices, :])
  392. loss = None
  393. if labels is not None:
  394. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  395. return CausalLMOutputWithPast(
  396. loss=loss,
  397. logits=logits,
  398. past_key_values=outputs.past_key_values,
  399. hidden_states=outputs.hidden_states,
  400. attentions=outputs.attentions,
  401. )
  402. class PhiForSequenceClassification(GenericForSequenceClassification, PhiPreTrainedModel):
  403. pass
  404. class PhiForTokenClassification(GenericForTokenClassification, PhiPreTrainedModel):
  405. pass
  406. __all__ = [
  407. "PhiPreTrainedModel",
  408. "PhiModel",
  409. "PhiForCausalLM",
  410. "PhiForSequenceClassification",
  411. "PhiForTokenClassification",
  412. ]