modeling_phimoe.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/phimoe/modular_phimoe.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_phimoe.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. from torch import nn
  24. from ... import initialization as init
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_experts_implementation, use_kernel_func_from_hub, use_kernelized_func
  29. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  30. from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
  31. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  32. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  36. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  37. from ...utils.output_capturing import OutputRecorder, capture_outputs
  38. from .configuration_phimoe import PhimoeConfig
  39. class PhimoeRotaryEmbedding(nn.Module):
  40. inv_freq: torch.Tensor # fix linting for `register_buffer`
  41. def __init__(self, config: PhimoeConfig, device=None):
  42. super().__init__()
  43. self.max_seq_len_cached = config.max_position_embeddings
  44. self.original_max_seq_len = config.max_position_embeddings
  45. self.config = config
  46. self.rope_type = self.config.rope_parameters["rope_type"]
  47. self.rope_init_fn: Callable = self.compute_default_rope_parameters
  48. if self.rope_type != "default":
  49. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  50. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  51. self.register_buffer("inv_freq", inv_freq, persistent=False)
  52. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  53. @staticmethod
  54. def compute_default_rope_parameters(
  55. config: PhimoeConfig | None = None,
  56. device: Optional["torch.device"] = None,
  57. seq_len: int | None = None,
  58. ) -> tuple["torch.Tensor", float]:
  59. """
  60. Computes the inverse frequencies according to the original RoPE implementation
  61. Args:
  62. config ([`~transformers.PreTrainedConfig`]):
  63. The model configuration.
  64. device (`torch.device`):
  65. The device to use for initialization of the inverse frequencies.
  66. seq_len (`int`, *optional*):
  67. The current sequence length. Unused for this type of RoPE.
  68. Returns:
  69. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  70. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  71. """
  72. base = config.rope_parameters["rope_theta"]
  73. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  74. attention_factor = 1.0 # Unused in this type of RoPE
  75. # Compute the inverse frequencies
  76. inv_freq = 1.0 / (
  77. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  78. )
  79. return inv_freq, attention_factor
  80. @torch.no_grad()
  81. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  82. def forward(self, x, position_ids=None, layer_type=None):
  83. if layer_type is not None:
  84. raise ValueError(
  85. f"{self.__class__.__name__} does not support layer types, but got `layer_type={layer_type}`"
  86. )
  87. mscale = None
  88. seq_len = torch.max(position_ids) + 1
  89. if self.config.rope_parameters["rope_type"] != "default" and seq_len:
  90. mscale = (
  91. self.config.rope_parameters["long_mscale"]
  92. if seq_len > self.config.rope_parameters["original_max_position_embeddings"]
  93. else self.config.rope_parameters["short_mscale"]
  94. )
  95. inv_freq, attention_scaling = self.rope_init_fn(self.config, x.device, seq_len)
  96. mscale = attention_scaling if mscale is None else mscale
  97. inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  98. position_ids_expanded = position_ids[:, None, :].float()
  99. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  100. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  101. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  102. emb = torch.cat((freqs, freqs), dim=-1)
  103. cos = emb.cos() * mscale
  104. sin = emb.sin() * mscale
  105. return cos.to(x.dtype), sin.to(x.dtype)
  106. def rotate_half(x):
  107. """Rotates half the hidden dims of the input."""
  108. x1 = x[..., : x.shape[-1] // 2]
  109. x2 = x[..., x.shape[-1] // 2 :]
  110. return torch.cat((-x2, x1), dim=-1)
  111. @use_kernel_func_from_hub("rotary_pos_emb")
  112. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  113. """Applies Rotary Position Embedding to the query and key tensors.
  114. Args:
  115. q (`torch.Tensor`): The query tensor.
  116. k (`torch.Tensor`): The key tensor.
  117. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  118. sin (`torch.Tensor`): The sine part of the rotary embedding.
  119. unsqueeze_dim (`int`, *optional*, defaults to 1):
  120. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  121. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  122. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  123. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  124. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  125. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  126. Returns:
  127. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  128. """
  129. cos = cos.unsqueeze(unsqueeze_dim)
  130. sin = sin.unsqueeze(unsqueeze_dim)
  131. q_embed = (q * cos) + (rotate_half(q) * sin)
  132. k_embed = (k * cos) + (rotate_half(k) * sin)
  133. return q_embed, k_embed
  134. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  135. """
  136. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  137. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  138. """
  139. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  140. if n_rep == 1:
  141. return hidden_states
  142. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  143. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  144. def eager_attention_forward(
  145. module: nn.Module,
  146. query: torch.Tensor,
  147. key: torch.Tensor,
  148. value: torch.Tensor,
  149. attention_mask: torch.Tensor | None,
  150. scaling: float,
  151. dropout: float = 0.0,
  152. **kwargs: Unpack[TransformersKwargs],
  153. ):
  154. key_states = repeat_kv(key, module.num_key_value_groups)
  155. value_states = repeat_kv(value, module.num_key_value_groups)
  156. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  157. if attention_mask is not None:
  158. attn_weights = attn_weights + attention_mask
  159. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  160. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  161. attn_output = torch.matmul(attn_weights, value_states)
  162. attn_output = attn_output.transpose(1, 2).contiguous()
  163. return attn_output, attn_weights
  164. @use_kernelized_func(apply_rotary_pos_emb)
  165. class PhimoeAttention(nn.Module):
  166. """Multi-headed attention from 'Attention Is All You Need' paper"""
  167. def __init__(self, config: PhimoeConfig, layer_idx: int):
  168. super().__init__()
  169. self.config = config
  170. self.layer_idx = layer_idx
  171. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  172. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  173. self.scaling = self.head_dim**-0.5
  174. self.attention_dropout = config.attention_dropout
  175. self.is_causal = True
  176. self.q_proj = nn.Linear(
  177. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  178. )
  179. self.k_proj = nn.Linear(
  180. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  181. )
  182. self.v_proj = nn.Linear(
  183. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  184. )
  185. self.o_proj = nn.Linear(
  186. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  187. )
  188. def forward(
  189. self,
  190. hidden_states: torch.Tensor,
  191. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  192. attention_mask: torch.Tensor | None = None,
  193. past_key_values: Cache | None = None,
  194. **kwargs: Unpack[TransformersKwargs],
  195. ) -> tuple[torch.Tensor, torch.Tensor]:
  196. input_shape = hidden_states.shape[:-1]
  197. hidden_shape = (*input_shape, -1, self.head_dim)
  198. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  199. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  200. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  201. cos, sin = position_embeddings
  202. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  203. if past_key_values is not None:
  204. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  205. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  206. self.config._attn_implementation, eager_attention_forward
  207. )
  208. attn_output, attn_weights = attention_interface(
  209. self,
  210. query_states,
  211. key_states,
  212. value_states,
  213. attention_mask,
  214. dropout=0.0 if not self.training else self.attention_dropout,
  215. scaling=self.scaling,
  216. **kwargs,
  217. )
  218. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  219. attn_output = self.o_proj(attn_output)
  220. return attn_output, attn_weights
  221. class PhimoeMultiplier(torch.autograd.Function):
  222. @staticmethod
  223. def forward(
  224. ctx,
  225. scores: torch.Tensor,
  226. multiplier: torch.Tensor,
  227. selected_experts: torch.Tensor,
  228. masked_gates: torch.Tensor,
  229. mask_for_one: torch.Tensor,
  230. ):
  231. """
  232. Forward pass for the custom autograd function.
  233. Args:
  234. ctx: Context object to save information for backward computation.
  235. scores (torch.Tensor): Input scores tensor.
  236. multiplier (torch.Tensor): Multiplier tensor.
  237. selected_experts (torch.Tensor): Tensor of selected experts.
  238. masked_gates (torch.Tensor): Masked gates tensor.
  239. mask_for_one (torch.Tensor): Mask for one tensor.
  240. Returns:
  241. torch.Tensor: Result of the forward pass.
  242. """
  243. ctx.save_for_backward(multiplier, selected_experts, masked_gates)
  244. return multiplier * mask_for_one
  245. @staticmethod
  246. def backward(
  247. ctx,
  248. grad_at_output: torch.Tensor,
  249. ):
  250. """
  251. Backward pass for the custom autograd function.
  252. Args:
  253. ctx: Context object with saved tensors from the forward pass.
  254. grad_at_output (torch.Tensor): Gradient at the output.
  255. Returns:
  256. tuple[torch.Tensor, None, None, None, None]: Gradients for the inputs.
  257. """
  258. multiplier, selected_experts, masked_gates = ctx.saved_tensors
  259. grad_at_output = grad_at_output * multiplier
  260. grad_at_scores_expanded = masked_gates * grad_at_output.mul(-1)
  261. grad_at_scores_expanded.scatter_add_(
  262. dim=-1,
  263. index=selected_experts,
  264. src=grad_at_output,
  265. )
  266. return (
  267. grad_at_scores_expanded,
  268. None,
  269. None,
  270. None,
  271. None,
  272. )
  273. @use_experts_implementation
  274. class PhimoeExperts(nn.Module):
  275. """Collection of expert weights stored as 3D tensors."""
  276. def __init__(self, config: PhimoeConfig):
  277. super().__init__()
  278. self.num_experts = config.num_local_experts
  279. self.hidden_dim = config.hidden_size
  280. self.intermediate_dim = config.intermediate_size
  281. self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
  282. self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
  283. self.act_fn = ACT2FN[config.hidden_act]
  284. def forward(
  285. self,
  286. hidden_states: torch.Tensor,
  287. top_k_index: torch.Tensor,
  288. top_k_weights: torch.Tensor,
  289. ) -> torch.Tensor:
  290. final_hidden_states = torch.zeros_like(hidden_states)
  291. with torch.no_grad():
  292. expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
  293. expert_mask = expert_mask.permute(2, 1, 0)
  294. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  295. for expert_idx in expert_hit:
  296. expert_idx = expert_idx[0]
  297. if expert_idx == self.num_experts:
  298. continue
  299. top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
  300. current_state = hidden_states[token_idx]
  301. gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
  302. current_hidden_states = self.act_fn(gate) * up
  303. current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
  304. current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
  305. final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
  306. return final_hidden_states
  307. def sparsemixer(scores, jitter_eps, training, top_k=2):
  308. """
  309. Sparse mixer function to select top-k experts and compute multipliers.
  310. Based on the paper: https://huggingface.co/papers/2409.12136
  311. We first replace the TopK(·) function as random sampling of discrete variables
  312. in model training. Then, following Liu et al. (2023a) and Liu et al. (2023b), we apply Heun's
  313. third order method to approximate the expert routing gradient and construct a modified
  314. back-propagation to give a mathematically sound gradient estimation for expert routing.
  315. Args:
  316. scores (torch.Tensor): Input scores tensor.
  317. jitter_eps (float): Jitter epsilon for numerical stability.
  318. training (bool): Flag indicating if the model is in training mode.
  319. top_k (int): Number of top experts to select.
  320. Returns:
  321. tuple[torch.Tensor, torch.Tensor]: Multiplier and selected experts tensors.
  322. """
  323. with torch.no_grad():
  324. # Compute mask for sparsity
  325. mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
  326. factor = scores.abs().clamp(min=mask_logits_threshold)
  327. mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
  328. # Apply mask
  329. masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf"))
  330. if training:
  331. selected_experts = (
  332. (
  333. masked_gates
  334. - torch.empty_like(masked_gates, memory_format=torch.legacy_contiguous_format).exponential_().log()
  335. )
  336. .max(dim=-1)[1]
  337. .unsqueeze(-1)
  338. ) # Gumbel sampling, more robust than the multinomial method
  339. else:
  340. selected_experts = max_ind
  341. # Compute scores for gradients
  342. masked_gates = torch.softmax(masked_gates, dim=-1)
  343. multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
  344. if training:
  345. # Compute midpoint mask
  346. max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True)
  347. mask_for_one = torch.logical_or(
  348. selected_experts == max_ind,
  349. torch.rand_like(max_scores) > 0.75, # Heun's third-order method
  350. )
  351. # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
  352. mask_for_one = torch.add(0.3333, mask_for_one, alpha=0.6667).type_as(masked_gates)
  353. multiplier = PhimoeMultiplier.apply(
  354. scores,
  355. multiplier_o,
  356. selected_experts,
  357. masked_gates,
  358. mask_for_one,
  359. )
  360. else:
  361. multiplier = multiplier_o
  362. # Masked out first expert
  363. masked_scores = torch.scatter(
  364. scores,
  365. -1,
  366. selected_experts,
  367. float("-inf"),
  368. )
  369. with torch.no_grad():
  370. # Compute mask for sparsity
  371. mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
  372. factor = scores.abs().clamp(min=mask_logits_threshold)
  373. mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
  374. # Apply mask
  375. masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float("-inf"))
  376. if training:
  377. selected_experts_top2 = (
  378. (
  379. masked_gates_top2
  380. - torch.empty_like(masked_gates_top2, memory_format=torch.legacy_contiguous_format)
  381. .exponential_()
  382. .log()
  383. )
  384. .max(dim=-1)[1]
  385. .unsqueeze(-1)
  386. ) # Gumbel sampling, more robust than the multinomial method
  387. else:
  388. selected_experts_top2 = max_ind
  389. # Compute scores for gradients
  390. masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
  391. multiplier_top2_o = masked_gates_top2.gather(dim=-1, index=selected_experts_top2)
  392. if training:
  393. # Compute midpoint mask
  394. max_scores, max_ind = masked_gates_top2.max(dim=-1, keepdim=True)
  395. mask_for_one_top2 = torch.logical_or(
  396. selected_experts_top2 == max_ind,
  397. torch.rand_like(max_scores).uniform_() > 0.75, # Heun's third-order method
  398. )
  399. # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
  400. mask_for_one_top2 = torch.add(0.3333, mask_for_one_top2, alpha=0.6667).type_as(masked_gates_top2)
  401. multiplier_top2 = PhimoeMultiplier.apply(
  402. scores,
  403. multiplier_top2_o,
  404. selected_experts_top2,
  405. masked_gates_top2,
  406. mask_for_one_top2,
  407. )
  408. else:
  409. multiplier_top2 = multiplier_top2_o
  410. multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
  411. selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)
  412. return (
  413. multiplier,
  414. selected_experts,
  415. )
  416. class PhimoeTopKRouter(nn.Linear):
  417. def __init__(self, config: PhimoeConfig):
  418. super().__init__(config.hidden_size, config.num_local_experts, bias=False)
  419. self.router_jitter_noise = config.router_jitter_noise
  420. self.input_jitter_noise = config.input_jitter_noise
  421. self.top_k = config.num_experts_per_tok
  422. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  423. if self.training and self.input_jitter_noise > 0:
  424. hidden_states *= torch.empty_like(hidden_states).uniform_(
  425. 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise
  426. )
  427. router_logits = super().forward(hidden_states)
  428. routing_weights, selected_experts = sparsemixer(
  429. router_logits, jitter_eps=self.router_jitter_noise, training=self.training, top_k=self.top_k
  430. )
  431. return router_logits, routing_weights, selected_experts
  432. class PhimoeSparseMoeBlock(nn.Module):
  433. """
  434. This implementation is
  435. strictly equivalent to standard MoE with full capacity (no
  436. dropped tokens). It's faster since it formulates MoE operations
  437. in terms of block-sparse operations to accommodate imbalanced
  438. assignments of tokens to experts, whereas standard MoE either
  439. (1) drop tokens at the cost of reduced performance or (2) set
  440. capacity factor to number of experts and thus waste computation
  441. and memory on padding.
  442. """
  443. def __init__(self, config):
  444. super().__init__()
  445. self.hidden_dim = config.hidden_size
  446. self.ffn_dim = config.intermediate_size
  447. self.num_experts = config.num_local_experts
  448. self.top_k = config.num_experts_per_tok
  449. self.router = PhimoeTopKRouter(config)
  450. self.experts = PhimoeExperts(config)
  451. self.input_jitter_noise = config.input_jitter_noise
  452. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  453. batch_size, sequence_length, hidden_dim = hidden_states.shape
  454. if self.training and self.input_jitter_noise > 0:
  455. hidden_states *= torch.empty_like(hidden_states).uniform_(
  456. 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise
  457. )
  458. batch_size, sequence_length, hidden_dim = hidden_states.shape
  459. hidden_states = hidden_states.reshape(-1, hidden_dim)
  460. _, routing_weights, selected_experts = self.router(hidden_states)
  461. final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
  462. return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
  463. class PhimoeDecoderLayer(GradientCheckpointingLayer):
  464. def __init__(self, config: PhimoeConfig, layer_idx: int):
  465. super().__init__()
  466. self.hidden_size = config.hidden_size
  467. self.self_attn = PhimoeAttention(config, layer_idx)
  468. self.mlp = PhimoeSparseMoeBlock(config)
  469. # Phimoe uses nn.LayerNorm
  470. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True)
  471. self.post_attention_layernorm = nn.LayerNorm(
  472. config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True
  473. )
  474. def forward(
  475. self,
  476. hidden_states: torch.Tensor,
  477. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  478. attention_mask: torch.Tensor | None = None,
  479. position_ids: torch.LongTensor | None = None,
  480. past_key_values: Cache | None = None,
  481. **kwargs: Unpack[TransformersKwargs],
  482. ) -> torch.Tensor:
  483. residual = hidden_states
  484. hidden_states = self.input_layernorm(hidden_states)
  485. hidden_states, _ = self.self_attn(
  486. hidden_states=hidden_states,
  487. position_embeddings=position_embeddings,
  488. attention_mask=attention_mask,
  489. position_ids=position_ids,
  490. past_key_values=past_key_values,
  491. **kwargs,
  492. )
  493. hidden_states = residual + hidden_states
  494. residual = hidden_states
  495. hidden_states = self.post_attention_layernorm(hidden_states)
  496. hidden_states = self.mlp(hidden_states)
  497. hidden_states = residual + hidden_states
  498. return hidden_states
  499. @auto_docstring
  500. class PhimoePreTrainedModel(PreTrainedModel):
  501. config: PhimoeConfig
  502. base_model_prefix = "model"
  503. supports_gradient_checkpointing = True
  504. _no_split_modules = ["PhimoeDecoderLayer"]
  505. _skip_keys_device_placement = ["past_key_values"]
  506. _supports_flash_attn = True
  507. _supports_sdpa = True
  508. _supports_flex_attn = True
  509. _can_compile_fullgraph = True
  510. _supports_attention_backend = True
  511. _can_record_outputs = {
  512. "router_logits": OutputRecorder(PhimoeTopKRouter, index=0),
  513. "hidden_states": PhimoeDecoderLayer,
  514. "attentions": PhimoeAttention,
  515. }
  516. @torch.no_grad()
  517. def _init_weights(self, module):
  518. super()._init_weights(module)
  519. std = self.config.initializer_range
  520. if isinstance(module, PhimoeExperts):
  521. init.normal_(module.gate_up_proj, mean=0.0, std=std)
  522. init.normal_(module.down_proj, mean=0.0, std=std)
  523. elif isinstance(module, PhimoeTopKRouter):
  524. init.normal_(module.weight, mean=0.0, std=std)
  525. @auto_docstring
  526. class PhimoeModel(PhimoePreTrainedModel):
  527. def __init__(self, config: PhimoeConfig):
  528. super().__init__(config)
  529. self.padding_idx = config.pad_token_id
  530. self.vocab_size = config.vocab_size
  531. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  532. self.layers = nn.ModuleList(
  533. [PhimoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  534. )
  535. self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True)
  536. self.rotary_emb = PhimoeRotaryEmbedding(config=config)
  537. self.gradient_checkpointing = False
  538. # Initialize weights and apply final processing
  539. self.post_init()
  540. @merge_with_config_defaults
  541. @capture_outputs
  542. @auto_docstring
  543. def forward(
  544. self,
  545. input_ids: torch.LongTensor | None = None,
  546. attention_mask: torch.Tensor | None = None,
  547. position_ids: torch.LongTensor | None = None,
  548. past_key_values: Cache | None = None,
  549. inputs_embeds: torch.FloatTensor | None = None,
  550. use_cache: bool | None = None,
  551. **kwargs: Unpack[TransformersKwargs],
  552. ) -> MoeModelOutputWithPast:
  553. if (input_ids is None) ^ (inputs_embeds is not None):
  554. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  555. if use_cache and past_key_values is None:
  556. past_key_values = DynamicCache(config=self.config)
  557. if inputs_embeds is None:
  558. inputs_embeds = self.embed_tokens(input_ids)
  559. if position_ids is None:
  560. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  561. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  562. position_ids = position_ids.unsqueeze(0)
  563. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  564. causal_mask = mask_function(
  565. config=self.config,
  566. inputs_embeds=inputs_embeds,
  567. attention_mask=attention_mask,
  568. past_key_values=past_key_values,
  569. position_ids=position_ids,
  570. )
  571. hidden_states = inputs_embeds
  572. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  573. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  574. hidden_states = decoder_layer(
  575. hidden_states,
  576. attention_mask=causal_mask,
  577. position_ids=position_ids,
  578. past_key_values=past_key_values,
  579. use_cache=use_cache,
  580. position_embeddings=position_embeddings,
  581. **kwargs,
  582. )
  583. hidden_states = self.norm(hidden_states)
  584. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  585. last_hidden_state=hidden_states,
  586. past_key_values=past_key_values,
  587. )
  588. def load_balancing_loss_func(
  589. gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
  590. num_experts: int | None = None,
  591. top_k=2,
  592. attention_mask: torch.Tensor | None = None,
  593. ) -> torch.Tensor | int:
  594. r"""
  595. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  596. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  597. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  598. experts is too unbalanced.
  599. Args:
  600. gate_logits:
  601. Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
  602. shape [batch_size X sequence_length, num_experts].
  603. num_experts:
  604. Number of experts
  605. top_k:
  606. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  607. parameter.
  608. attention_mask (`torch.Tensor`, *optional*):
  609. The attention_mask used in forward function
  610. shape [batch_size X sequence_length] if not None.
  611. Returns:
  612. The auxiliary loss.
  613. """
  614. if gate_logits is None or not isinstance(gate_logits, tuple):
  615. return 0
  616. if isinstance(gate_logits, tuple):
  617. compute_device = gate_logits[0].device
  618. concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
  619. routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
  620. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  621. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  622. if attention_mask is None:
  623. # Compute the percentage of tokens routed to each experts
  624. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  625. # Compute the average probability of routing to these experts
  626. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  627. else:
  628. batch_size, sequence_length = attention_mask.shape
  629. num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
  630. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  631. expert_attention_mask = (
  632. attention_mask[None, :, :, None, None]
  633. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  634. .reshape(-1, top_k, num_experts)
  635. .to(compute_device)
  636. )
  637. # Compute the percentage of tokens routed to each experts
  638. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  639. expert_attention_mask, dim=0
  640. )
  641. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  642. router_per_expert_attention_mask = (
  643. attention_mask[None, :, :, None]
  644. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  645. .reshape(-1, num_experts)
  646. .to(compute_device)
  647. )
  648. # Compute the average probability of routing to these experts
  649. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  650. router_per_expert_attention_mask, dim=0
  651. )
  652. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
  653. return overall_loss * num_experts
  654. @auto_docstring
  655. class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin):
  656. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  657. _tp_plan = {"lm_head": "colwise_gather_output"}
  658. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  659. def __init__(self, config):
  660. super().__init__(config)
  661. self.model = PhimoeModel(config)
  662. self.vocab_size = config.vocab_size
  663. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias)
  664. self.router_aux_loss_coef = config.router_aux_loss_coef
  665. self.num_experts = config.num_local_experts
  666. self.num_experts_per_tok = config.num_experts_per_tok
  667. # Initialize weights and apply final processing
  668. self.post_init()
  669. @can_return_tuple
  670. @auto_docstring
  671. def forward(
  672. self,
  673. input_ids: torch.LongTensor | None = None,
  674. attention_mask: torch.Tensor | None = None,
  675. position_ids: torch.LongTensor | None = None,
  676. past_key_values: Cache | None = None,
  677. inputs_embeds: torch.FloatTensor | None = None,
  678. labels: torch.LongTensor | None = None,
  679. use_cache: bool | None = None,
  680. output_router_logits: bool | None = None,
  681. logits_to_keep: int | torch.Tensor = 0,
  682. **kwargs: Unpack[TransformersKwargs],
  683. ) -> MoeCausalLMOutputWithPast:
  684. r"""
  685. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  686. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  687. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  688. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  689. Example:
  690. ```python
  691. >>> from transformers import AutoTokenizer, PhimoeForCausalLM
  692. >>> model = PhimoeForCausalLM.from_pretrained("mistralai/Phimoe-8x7B-v0.1")
  693. >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Phimoe-8x7B-v0.1")
  694. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  695. >>> inputs = tokenizer(prompt, return_tensors="pt")
  696. >>> # Generate
  697. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  698. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  699. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  700. ```"""
  701. output_router_logits = (
  702. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  703. )
  704. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  705. outputs: MoeModelOutputWithPast = self.model(
  706. input_ids=input_ids,
  707. attention_mask=attention_mask,
  708. position_ids=position_ids,
  709. past_key_values=past_key_values,
  710. inputs_embeds=inputs_embeds,
  711. use_cache=use_cache,
  712. output_router_logits=output_router_logits,
  713. **kwargs,
  714. )
  715. hidden_states = outputs.last_hidden_state
  716. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  717. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  718. logits = self.lm_head(hidden_states[:, slice_indices, :])
  719. loss = None
  720. if labels is not None:
  721. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  722. aux_loss = None
  723. if output_router_logits:
  724. aux_loss = load_balancing_loss_func(
  725. outputs.router_logits,
  726. self.num_experts,
  727. self.num_experts_per_tok,
  728. attention_mask,
  729. )
  730. if labels is not None:
  731. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  732. return MoeCausalLMOutputWithPast(
  733. loss=loss,
  734. aux_loss=aux_loss,
  735. logits=logits,
  736. past_key_values=outputs.past_key_values,
  737. hidden_states=outputs.hidden_states,
  738. attentions=outputs.attentions,
  739. router_logits=outputs.router_logits,
  740. )
  741. # Copied from transformers.models.phi3.modeling_phi3.Phi3ForCausalLM.prepare_inputs_for_generation
  742. def prepare_inputs_for_generation(
  743. self,
  744. input_ids,
  745. past_key_values=None,
  746. attention_mask=None,
  747. inputs_embeds=None,
  748. position_ids=None,
  749. use_cache=True,
  750. logits_to_keep=None,
  751. **kwargs,
  752. ):
  753. # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
  754. # process
  755. # When the first time input length reached long and short factor switching point, enforce re-compute cache
  756. # It will cause downside of slower at this single token position, however, better than current failure.
  757. if (
  758. past_key_values
  759. and hasattr(self.config, "original_max_position_embeddings")
  760. and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
  761. ):
  762. past_length = past_key_values.get_seq_length()
  763. if past_length <= self.config.original_max_position_embeddings:
  764. past_key_values = None
  765. model_inputs = super().prepare_inputs_for_generation(
  766. input_ids=input_ids,
  767. past_key_values=past_key_values,
  768. attention_mask=attention_mask,
  769. inputs_embeds=inputs_embeds,
  770. position_ids=position_ids,
  771. use_cache=use_cache,
  772. logits_to_keep=logits_to_keep,
  773. **kwargs,
  774. )
  775. return model_inputs
  776. class PhimoeForSequenceClassification(GenericForSequenceClassification, PhimoePreTrainedModel): ...
  777. __all__ = ["PhimoePreTrainedModel", "PhimoeModel", "PhimoeForCausalLM", "PhimoeForSequenceClassification"]