modeling_phi3.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/phi3/modular_phi3.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_phi3.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. from torch import nn
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...generation import GenerationMixin
  27. from ...integrations import use_kernel_forward_from_hub
  28. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  29. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  30. from ...modeling_layers import (
  31. GenericForSequenceClassification,
  32. GenericForTokenClassification,
  33. GradientCheckpointingLayer,
  34. )
  35. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  36. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  37. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  38. from ...processing_utils import Unpack
  39. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  40. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  41. from ...utils.output_capturing import capture_outputs
  42. from .configuration_phi3 import Phi3Config
  43. class Phi3MLP(nn.Module):
  44. def __init__(self, config):
  45. super().__init__()
  46. self.config = config
  47. self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
  48. self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  49. self.activation_fn = ACT2FN[config.hidden_act]
  50. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  51. up_states = self.gate_up_proj(hidden_states)
  52. gate, up_states = up_states.chunk(2, dim=-1)
  53. up_states = up_states * self.activation_fn(gate)
  54. return self.down_proj(up_states)
  55. class Phi3RotaryEmbedding(nn.Module):
  56. inv_freq: torch.Tensor # fix linting for `register_buffer`
  57. def __init__(self, config: Phi3Config, device=None):
  58. super().__init__()
  59. self.max_seq_len_cached = config.max_position_embeddings
  60. self.original_max_seq_len = config.max_position_embeddings
  61. self.config = config
  62. self.rope_type = self.config.rope_parameters["rope_type"]
  63. rope_init_fn: Callable = self.compute_default_rope_parameters
  64. if self.rope_type != "default":
  65. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  66. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  67. self.register_buffer("inv_freq", inv_freq, persistent=False)
  68. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  69. @staticmethod
  70. def compute_default_rope_parameters(
  71. config: Phi3Config | None = None,
  72. device: Optional["torch.device"] = None,
  73. seq_len: int | None = None,
  74. ) -> tuple["torch.Tensor", float]:
  75. """
  76. Computes the inverse frequencies according to the original RoPE implementation
  77. Args:
  78. config ([`~transformers.PreTrainedConfig`]):
  79. The model configuration.
  80. device (`torch.device`):
  81. The device to use for initialization of the inverse frequencies.
  82. seq_len (`int`, *optional*):
  83. The current sequence length. Unused for this type of RoPE.
  84. Returns:
  85. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  86. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  87. """
  88. base = config.rope_parameters["rope_theta"]
  89. partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
  90. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  91. dim = int(head_dim * partial_rotary_factor)
  92. attention_factor = 1.0 # Unused in this type of RoPE
  93. # Compute the inverse frequencies
  94. inv_freq = 1.0 / (
  95. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  96. )
  97. return inv_freq, attention_factor
  98. @torch.no_grad()
  99. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  100. def forward(self, x, position_ids):
  101. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  102. position_ids_expanded = position_ids[:, None, :].float()
  103. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  104. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  105. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  106. emb = torch.cat((freqs, freqs), dim=-1)
  107. cos = emb.cos() * self.attention_scaling
  108. sin = emb.sin() * self.attention_scaling
  109. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  110. def rotate_half(x):
  111. """Rotates half the hidden dims of the input."""
  112. x1 = x[..., : x.shape[-1] // 2]
  113. x2 = x[..., x.shape[-1] // 2 :]
  114. return torch.cat((-x2, x1), dim=-1)
  115. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  116. """
  117. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  118. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  119. """
  120. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  121. if n_rep == 1:
  122. return hidden_states
  123. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  124. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  125. def eager_attention_forward(
  126. module: nn.Module,
  127. query: torch.Tensor,
  128. key: torch.Tensor,
  129. value: torch.Tensor,
  130. attention_mask: torch.Tensor | None,
  131. scaling: float,
  132. dropout: float = 0.0,
  133. **kwargs: Unpack[TransformersKwargs],
  134. ):
  135. key_states = repeat_kv(key, module.num_key_value_groups)
  136. value_states = repeat_kv(value, module.num_key_value_groups)
  137. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  138. if attention_mask is not None:
  139. attn_weights = attn_weights + attention_mask
  140. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  141. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  142. attn_output = torch.matmul(attn_weights, value_states)
  143. attn_output = attn_output.transpose(1, 2).contiguous()
  144. return attn_output, attn_weights
  145. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  146. """Applies Rotary Position Embedding to the query and key tensors.
  147. Args:
  148. q (`torch.Tensor`): The query tensor.
  149. k (`torch.Tensor`): The key tensor.
  150. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  151. sin (`torch.Tensor`): The sine part of the rotary embedding.
  152. unsqueeze_dim (`int`, *optional*, defaults to 1):
  153. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  154. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  155. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  156. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  157. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  158. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  159. Returns:
  160. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  161. """
  162. cos = cos.unsqueeze(unsqueeze_dim)
  163. sin = sin.unsqueeze(unsqueeze_dim)
  164. rotary_dim = cos.shape[-1]
  165. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  166. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  167. q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
  168. k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
  169. return q_embed, k_embed
  170. class Phi3Attention(nn.Module):
  171. """Multi-headed attention from 'Attention Is All You Need' paper"""
  172. def __init__(self, config: Phi3Config, layer_idx: int | None = None):
  173. super().__init__()
  174. self.config = config
  175. self.layer_idx = layer_idx
  176. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  177. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  178. self.num_key_value_heads = config.num_key_value_heads
  179. self.scaling = self.head_dim**-0.5
  180. self.attention_dropout = config.attention_dropout
  181. self.is_causal = True
  182. op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim)
  183. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  184. self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False)
  185. def forward(
  186. self,
  187. hidden_states: torch.Tensor,
  188. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  189. attention_mask: torch.Tensor | None,
  190. past_key_values: Cache | None = None,
  191. **kwargs: Unpack[FlashAttentionKwargs],
  192. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  193. input_shape = hidden_states.shape[:-1]
  194. hidden_shape = (*input_shape, -1, self.head_dim)
  195. qkv = self.qkv_proj(hidden_states)
  196. query_pos = self.config.num_attention_heads * self.head_dim
  197. query_states = qkv[..., :query_pos]
  198. key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
  199. value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
  200. query_states = query_states.view(hidden_shape).transpose(1, 2)
  201. key_states = key_states.view(hidden_shape).transpose(1, 2)
  202. value_states = value_states.view(hidden_shape).transpose(1, 2)
  203. cos, sin = position_embeddings
  204. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  205. if past_key_values is not None:
  206. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  207. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  208. self.config._attn_implementation, eager_attention_forward
  209. )
  210. attn_output, attn_weights = attention_interface(
  211. self,
  212. query_states,
  213. key_states,
  214. value_states,
  215. attention_mask,
  216. dropout=0.0 if not self.training else self.attention_dropout,
  217. scaling=self.scaling,
  218. sliding_window=getattr(self.config, "sliding_window", None),
  219. **kwargs,
  220. )
  221. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  222. attn_output = self.o_proj(attn_output)
  223. return attn_output, attn_weights
  224. @use_kernel_forward_from_hub("RMSNorm")
  225. class Phi3RMSNorm(nn.Module):
  226. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  227. """
  228. Phi3RMSNorm is equivalent to T5LayerNorm
  229. """
  230. super().__init__()
  231. self.weight = nn.Parameter(torch.ones(hidden_size))
  232. self.variance_epsilon = eps
  233. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  234. input_dtype = hidden_states.dtype
  235. hidden_states = hidden_states.to(torch.float32)
  236. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  237. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  238. return self.weight * hidden_states.to(input_dtype)
  239. def extra_repr(self):
  240. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  241. class Phi3DecoderLayer(GradientCheckpointingLayer):
  242. def __init__(self, config: Phi3Config, layer_idx: int):
  243. super().__init__()
  244. self.hidden_size = config.hidden_size
  245. self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx)
  246. self.mlp = Phi3MLP(config)
  247. self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  248. self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  249. self.config = config
  250. self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
  251. self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
  252. def forward(
  253. self,
  254. hidden_states: torch.Tensor,
  255. attention_mask: torch.Tensor | None = None,
  256. position_ids: torch.LongTensor | None = None,
  257. past_key_values: Cache | None = None,
  258. use_cache: bool | None = False,
  259. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  260. **kwargs: Unpack[FlashAttentionKwargs],
  261. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  262. residual = hidden_states
  263. hidden_states = self.input_layernorm(hidden_states)
  264. hidden_states, self_attn_weights = self.self_attn(
  265. hidden_states=hidden_states,
  266. attention_mask=attention_mask,
  267. position_ids=position_ids,
  268. past_key_values=past_key_values,
  269. use_cache=use_cache,
  270. position_embeddings=position_embeddings,
  271. **kwargs,
  272. )
  273. hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama
  274. residual = hidden_states
  275. hidden_states = self.post_attention_layernorm(hidden_states)
  276. hidden_states = self.mlp(hidden_states)
  277. hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama
  278. return hidden_states
  279. @auto_docstring
  280. class Phi3PreTrainedModel(PreTrainedModel):
  281. config: Phi3Config
  282. base_model_prefix = "model"
  283. supports_gradient_checkpointing = True
  284. _no_split_modules = ["Phi3DecoderLayer"]
  285. _skip_keys_device_placement = ["past_key_values"]
  286. _supports_flash_attn = True
  287. _supports_sdpa = True
  288. _supports_flex_attn = True
  289. _can_compile_fullgraph = True
  290. _supports_attention_backend = True
  291. _can_record_outputs = {
  292. "hidden_states": Phi3DecoderLayer,
  293. "attentions": Phi3Attention,
  294. }
  295. _version = "0.0.5"
  296. @auto_docstring
  297. class Phi3Model(Phi3PreTrainedModel):
  298. def __init__(self, config: Phi3Config):
  299. super().__init__(config)
  300. self.padding_idx = config.pad_token_id
  301. self.vocab_size = config.vocab_size
  302. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  303. self.layers = nn.ModuleList(
  304. [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  305. )
  306. self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  307. self.rotary_emb = Phi3RotaryEmbedding(config=config)
  308. self.gradient_checkpointing = False
  309. # Initialize weights and apply final processing
  310. self.post_init()
  311. @merge_with_config_defaults
  312. @capture_outputs
  313. @auto_docstring
  314. def forward(
  315. self,
  316. input_ids: torch.LongTensor | None = None,
  317. attention_mask: torch.Tensor | None = None,
  318. position_ids: torch.LongTensor | None = None,
  319. past_key_values: Cache | None = None,
  320. inputs_embeds: torch.FloatTensor | None = None,
  321. use_cache: bool | None = None,
  322. **kwargs: Unpack[TransformersKwargs],
  323. ) -> BaseModelOutputWithPast:
  324. if (input_ids is None) ^ (inputs_embeds is not None):
  325. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  326. if inputs_embeds is None:
  327. inputs_embeds = self.embed_tokens(input_ids)
  328. if use_cache and past_key_values is None:
  329. past_key_values = DynamicCache(config=self.config)
  330. if position_ids is None:
  331. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  332. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  333. position_ids = position_ids.unsqueeze(0)
  334. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  335. causal_mask = mask_function(
  336. config=self.config,
  337. inputs_embeds=inputs_embeds,
  338. attention_mask=attention_mask,
  339. past_key_values=past_key_values,
  340. position_ids=position_ids,
  341. )
  342. hidden_states = inputs_embeds
  343. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  344. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  345. hidden_states = decoder_layer(
  346. hidden_states,
  347. attention_mask=causal_mask,
  348. position_ids=position_ids,
  349. past_key_values=past_key_values,
  350. use_cache=use_cache,
  351. position_embeddings=position_embeddings,
  352. **kwargs,
  353. )
  354. hidden_states = self.norm(hidden_states)
  355. return BaseModelOutputWithPast(
  356. last_hidden_state=hidden_states,
  357. past_key_values=past_key_values if use_cache else None,
  358. )
  359. @auto_docstring
  360. class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
  361. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  362. _tp_plan = {"lm_head": "colwise_gather_output"}
  363. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  364. def __init__(self, config):
  365. super().__init__(config)
  366. self.model = Phi3Model(config)
  367. self.vocab_size = config.vocab_size
  368. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  369. # Initialize weights and apply final processing
  370. self.post_init()
  371. @can_return_tuple
  372. @auto_docstring
  373. def forward(
  374. self,
  375. input_ids: torch.LongTensor | None = None,
  376. attention_mask: torch.Tensor | None = None,
  377. position_ids: torch.LongTensor | None = None,
  378. past_key_values: Cache | None = None,
  379. inputs_embeds: torch.FloatTensor | None = None,
  380. labels: torch.LongTensor | None = None,
  381. use_cache: bool | None = None,
  382. logits_to_keep: int | torch.Tensor = 0,
  383. **kwargs: Unpack[TransformersKwargs],
  384. ) -> CausalLMOutputWithPast:
  385. r"""
  386. Example:
  387. ```python
  388. >>> from transformers import AutoTokenizer, Phi3ForCausalLM
  389. >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf")
  390. >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf")
  391. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  392. >>> inputs = tokenizer(prompt, return_tensors="pt")
  393. >>> # Generate
  394. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  395. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  396. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  397. ```"""
  398. outputs: BaseModelOutputWithPast = self.model(
  399. input_ids=input_ids,
  400. attention_mask=attention_mask,
  401. position_ids=position_ids,
  402. past_key_values=past_key_values,
  403. inputs_embeds=inputs_embeds,
  404. use_cache=use_cache,
  405. **kwargs,
  406. )
  407. hidden_states = outputs.last_hidden_state
  408. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  409. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  410. logits = self.lm_head(hidden_states[:, slice_indices, :])
  411. loss = None
  412. if labels is not None:
  413. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  414. return CausalLMOutputWithPast(
  415. loss=loss,
  416. logits=logits,
  417. past_key_values=outputs.past_key_values,
  418. hidden_states=outputs.hidden_states,
  419. attentions=outputs.attentions,
  420. )
  421. def prepare_inputs_for_generation(
  422. self,
  423. input_ids,
  424. past_key_values=None,
  425. attention_mask=None,
  426. inputs_embeds=None,
  427. position_ids=None,
  428. use_cache=True,
  429. logits_to_keep=None,
  430. **kwargs,
  431. ):
  432. # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
  433. # process
  434. # When the first time input length reached long and short factor switching point, enforce re-compute cache
  435. # It will cause downside of slower at this single token position, however, better than current failure.
  436. if (
  437. past_key_values
  438. and hasattr(self.config, "original_max_position_embeddings")
  439. and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
  440. ):
  441. past_length = past_key_values.get_seq_length()
  442. if past_length <= self.config.original_max_position_embeddings:
  443. past_key_values = None
  444. model_inputs = super().prepare_inputs_for_generation(
  445. input_ids=input_ids,
  446. past_key_values=past_key_values,
  447. attention_mask=attention_mask,
  448. inputs_embeds=inputs_embeds,
  449. position_ids=position_ids,
  450. use_cache=use_cache,
  451. logits_to_keep=logits_to_keep,
  452. **kwargs,
  453. )
  454. return model_inputs
  455. class Phi3ForSequenceClassification(GenericForSequenceClassification, Phi3PreTrainedModel):
  456. pass
  457. class Phi3ForTokenClassification(GenericForTokenClassification, Phi3PreTrainedModel):
  458. pass
  459. __all__ = [
  460. "Phi3PreTrainedModel",
  461. "Phi3Model",
  462. "Phi3ForCausalLM",
  463. "Phi3ForSequenceClassification",
  464. "Phi3ForTokenClassification",
  465. ]