modular_phimoe.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Phimoe model."""
  15. from collections.abc import Callable
  16. import torch
  17. from torch import nn
  18. from ...modeling_layers import (
  19. GenericForSequenceClassification,
  20. )
  21. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
  22. from ...utils.generic import maybe_autocast
  23. from ...utils.output_capturing import OutputRecorder
  24. from ..llama.modeling_llama import LlamaAttention
  25. from ..mixtral.modeling_mixtral import (
  26. MixtralDecoderLayer,
  27. MixtralExperts,
  28. MixtralForCausalLM,
  29. MixtralModel,
  30. MixtralPreTrainedModel,
  31. MixtralRotaryEmbedding,
  32. )
  33. from .configuration_phimoe import PhimoeConfig
  34. class PhimoeRotaryEmbedding(MixtralRotaryEmbedding):
  35. def __init__(self, config: PhimoeConfig, device=None):
  36. nn.Module.__init__()
  37. self.max_seq_len_cached = config.max_position_embeddings
  38. self.original_max_seq_len = config.max_position_embeddings
  39. self.config = config
  40. self.rope_type = self.config.rope_parameters["rope_type"]
  41. self.rope_init_fn: Callable = self.compute_default_rope_parameters
  42. if self.rope_type != "default":
  43. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  44. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  45. self.register_buffer("inv_freq", inv_freq, persistent=False)
  46. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  47. def forward(self, x, position_ids=None, layer_type=None):
  48. if layer_type is not None:
  49. raise ValueError(
  50. f"{self.__class__.__name__} does not support layer types, but got `layer_type={layer_type}`"
  51. )
  52. mscale = None
  53. seq_len = torch.max(position_ids) + 1
  54. if self.config.rope_parameters["rope_type"] != "default" and seq_len:
  55. mscale = (
  56. self.config.rope_parameters["long_mscale"]
  57. if seq_len > self.config.rope_parameters["original_max_position_embeddings"]
  58. else self.config.rope_parameters["short_mscale"]
  59. )
  60. inv_freq, attention_scaling = self.rope_init_fn(self.config, x.device, seq_len)
  61. mscale = attention_scaling if mscale is None else mscale
  62. inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  63. position_ids_expanded = position_ids[:, None, :].float()
  64. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  65. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  66. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  67. emb = torch.cat((freqs, freqs), dim=-1)
  68. cos = emb.cos() * mscale
  69. sin = emb.sin() * mscale
  70. return cos.to(x.dtype), sin.to(x.dtype)
  71. class PhimoeAttention(LlamaAttention):
  72. pass
  73. class PhimoeMultiplier(torch.autograd.Function):
  74. @staticmethod
  75. def forward(
  76. ctx,
  77. scores: torch.Tensor,
  78. multiplier: torch.Tensor,
  79. selected_experts: torch.Tensor,
  80. masked_gates: torch.Tensor,
  81. mask_for_one: torch.Tensor,
  82. ):
  83. """
  84. Forward pass for the custom autograd function.
  85. Args:
  86. ctx: Context object to save information for backward computation.
  87. scores (torch.Tensor): Input scores tensor.
  88. multiplier (torch.Tensor): Multiplier tensor.
  89. selected_experts (torch.Tensor): Tensor of selected experts.
  90. masked_gates (torch.Tensor): Masked gates tensor.
  91. mask_for_one (torch.Tensor): Mask for one tensor.
  92. Returns:
  93. torch.Tensor: Result of the forward pass.
  94. """
  95. ctx.save_for_backward(multiplier, selected_experts, masked_gates)
  96. return multiplier * mask_for_one
  97. @staticmethod
  98. def backward(
  99. ctx,
  100. grad_at_output: torch.Tensor,
  101. ):
  102. """
  103. Backward pass for the custom autograd function.
  104. Args:
  105. ctx: Context object with saved tensors from the forward pass.
  106. grad_at_output (torch.Tensor): Gradient at the output.
  107. Returns:
  108. tuple[torch.Tensor, None, None, None, None]: Gradients for the inputs.
  109. """
  110. multiplier, selected_experts, masked_gates = ctx.saved_tensors
  111. grad_at_output = grad_at_output * multiplier
  112. grad_at_scores_expanded = masked_gates * grad_at_output.mul(-1)
  113. grad_at_scores_expanded.scatter_add_(
  114. dim=-1,
  115. index=selected_experts,
  116. src=grad_at_output,
  117. )
  118. return (
  119. grad_at_scores_expanded,
  120. None,
  121. None,
  122. None,
  123. None,
  124. )
  125. def sparsemixer(scores, jitter_eps, training, top_k=2):
  126. """
  127. Sparse mixer function to select top-k experts and compute multipliers.
  128. Based on the paper: https://huggingface.co/papers/2409.12136
  129. We first replace the TopK(·) function as random sampling of discrete variables
  130. in model training. Then, following Liu et al. (2023a) and Liu et al. (2023b), we apply Heun's
  131. third order method to approximate the expert routing gradient and construct a modified
  132. back-propagation to give a mathematically sound gradient estimation for expert routing.
  133. Args:
  134. scores (torch.Tensor): Input scores tensor.
  135. jitter_eps (float): Jitter epsilon for numerical stability.
  136. training (bool): Flag indicating if the model is in training mode.
  137. top_k (int): Number of top experts to select.
  138. Returns:
  139. tuple[torch.Tensor, torch.Tensor]: Multiplier and selected experts tensors.
  140. """
  141. with torch.no_grad():
  142. # Compute mask for sparsity
  143. mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
  144. factor = scores.abs().clamp(min=mask_logits_threshold)
  145. mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
  146. # Apply mask
  147. masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf"))
  148. if training:
  149. selected_experts = (
  150. (
  151. masked_gates
  152. - torch.empty_like(masked_gates, memory_format=torch.legacy_contiguous_format).exponential_().log()
  153. )
  154. .max(dim=-1)[1]
  155. .unsqueeze(-1)
  156. ) # Gumbel sampling, more robust than the multinomial method
  157. else:
  158. selected_experts = max_ind
  159. # Compute scores for gradients
  160. masked_gates = torch.softmax(masked_gates, dim=-1)
  161. multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
  162. if training:
  163. # Compute midpoint mask
  164. max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True)
  165. mask_for_one = torch.logical_or(
  166. selected_experts == max_ind,
  167. torch.rand_like(max_scores) > 0.75, # Heun's third-order method
  168. )
  169. # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
  170. mask_for_one = torch.add(0.3333, mask_for_one, alpha=0.6667).type_as(masked_gates)
  171. multiplier = PhimoeMultiplier.apply(
  172. scores,
  173. multiplier_o,
  174. selected_experts,
  175. masked_gates,
  176. mask_for_one,
  177. )
  178. else:
  179. multiplier = multiplier_o
  180. # Masked out first expert
  181. masked_scores = torch.scatter(
  182. scores,
  183. -1,
  184. selected_experts,
  185. float("-inf"),
  186. )
  187. with torch.no_grad():
  188. # Compute mask for sparsity
  189. mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
  190. factor = scores.abs().clamp(min=mask_logits_threshold)
  191. mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
  192. # Apply mask
  193. masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float("-inf"))
  194. if training:
  195. selected_experts_top2 = (
  196. (
  197. masked_gates_top2
  198. - torch.empty_like(masked_gates_top2, memory_format=torch.legacy_contiguous_format)
  199. .exponential_()
  200. .log()
  201. )
  202. .max(dim=-1)[1]
  203. .unsqueeze(-1)
  204. ) # Gumbel sampling, more robust than the multinomial method
  205. else:
  206. selected_experts_top2 = max_ind
  207. # Compute scores for gradients
  208. masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
  209. multiplier_top2_o = masked_gates_top2.gather(dim=-1, index=selected_experts_top2)
  210. if training:
  211. # Compute midpoint mask
  212. max_scores, max_ind = masked_gates_top2.max(dim=-1, keepdim=True)
  213. mask_for_one_top2 = torch.logical_or(
  214. selected_experts_top2 == max_ind,
  215. torch.rand_like(max_scores).uniform_() > 0.75, # Heun's third-order method
  216. )
  217. # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
  218. mask_for_one_top2 = torch.add(0.3333, mask_for_one_top2, alpha=0.6667).type_as(masked_gates_top2)
  219. multiplier_top2 = PhimoeMultiplier.apply(
  220. scores,
  221. multiplier_top2_o,
  222. selected_experts_top2,
  223. masked_gates_top2,
  224. mask_for_one_top2,
  225. )
  226. else:
  227. multiplier_top2 = multiplier_top2_o
  228. multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
  229. selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)
  230. return (
  231. multiplier,
  232. selected_experts,
  233. )
  234. class PhimoeExperts(MixtralExperts):
  235. pass
  236. class PhimoeTopKRouter(nn.Linear):
  237. def __init__(self, config: PhimoeConfig):
  238. super().__init__(config.hidden_size, config.num_local_experts, bias=False)
  239. self.router_jitter_noise = config.router_jitter_noise
  240. self.input_jitter_noise = config.input_jitter_noise
  241. self.top_k = config.num_experts_per_tok
  242. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  243. if self.training and self.input_jitter_noise > 0:
  244. hidden_states *= torch.empty_like(hidden_states).uniform_(
  245. 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise
  246. )
  247. router_logits = super().forward(hidden_states)
  248. routing_weights, selected_experts = sparsemixer(
  249. router_logits, jitter_eps=self.router_jitter_noise, training=self.training, top_k=self.top_k
  250. )
  251. return router_logits, routing_weights, selected_experts
  252. class PhimoeSparseMoeBlock(nn.Module):
  253. """
  254. This implementation is
  255. strictly equivalent to standard MoE with full capacity (no
  256. dropped tokens). It's faster since it formulates MoE operations
  257. in terms of block-sparse operations to accommodate imbalanced
  258. assignments of tokens to experts, whereas standard MoE either
  259. (1) drop tokens at the cost of reduced performance or (2) set
  260. capacity factor to number of experts and thus waste computation
  261. and memory on padding.
  262. """
  263. def __init__(self, config):
  264. super().__init__()
  265. self.hidden_dim = config.hidden_size
  266. self.ffn_dim = config.intermediate_size
  267. self.num_experts = config.num_local_experts
  268. self.top_k = config.num_experts_per_tok
  269. self.router = PhimoeTopKRouter(config)
  270. self.experts = PhimoeExperts(config)
  271. self.input_jitter_noise = config.input_jitter_noise
  272. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  273. batch_size, sequence_length, hidden_dim = hidden_states.shape
  274. if self.training and self.input_jitter_noise > 0:
  275. hidden_states *= torch.empty_like(hidden_states).uniform_(
  276. 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise
  277. )
  278. batch_size, sequence_length, hidden_dim = hidden_states.shape
  279. hidden_states = hidden_states.reshape(-1, hidden_dim)
  280. _, routing_weights, selected_experts = self.router(hidden_states)
  281. final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
  282. return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
  283. class PhimoeDecoderLayer(MixtralDecoderLayer):
  284. def __init__(self, config: PhimoeConfig, layer_idx: int):
  285. super().__init__(config, layer_idx)
  286. # Phimoe uses nn.LayerNorm
  287. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True)
  288. self.post_attention_layernorm = nn.LayerNorm(
  289. config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True
  290. )
  291. class PhimoePreTrainedModel(MixtralPreTrainedModel):
  292. _can_record_outputs = {
  293. "router_logits": OutputRecorder(PhimoeTopKRouter, index=0),
  294. "hidden_states": PhimoeDecoderLayer,
  295. "attentions": PhimoeAttention,
  296. }
  297. class PhimoeModel(MixtralModel):
  298. def __init__(self, config: PhimoeConfig):
  299. super().__init__(config)
  300. self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True)
  301. class PhimoeForCausalLM(MixtralForCausalLM):
  302. def __init__(self, config):
  303. super().__init__(config)
  304. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias)
  305. # Copied from transformers.models.phi3.modeling_phi3.Phi3ForCausalLM.prepare_inputs_for_generation
  306. def prepare_inputs_for_generation(
  307. self,
  308. input_ids,
  309. past_key_values=None,
  310. attention_mask=None,
  311. inputs_embeds=None,
  312. position_ids=None,
  313. use_cache=True,
  314. logits_to_keep=None,
  315. **kwargs,
  316. ):
  317. # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
  318. # process
  319. # When the first time input length reached long and short factor switching point, enforce re-compute cache
  320. # It will cause downside of slower at this single token position, however, better than current failure.
  321. if (
  322. past_key_values
  323. and hasattr(self.config, "original_max_position_embeddings")
  324. and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
  325. ):
  326. past_length = past_key_values.get_seq_length()
  327. if past_length <= self.config.original_max_position_embeddings:
  328. past_key_values = None
  329. model_inputs = super().prepare_inputs_for_generation(
  330. input_ids=input_ids,
  331. past_key_values=past_key_values,
  332. attention_mask=attention_mask,
  333. inputs_embeds=inputs_embeds,
  334. position_ids=position_ids,
  335. use_cache=use_cache,
  336. logits_to_keep=logits_to_keep,
  337. **kwargs,
  338. )
  339. return model_inputs
  340. class PhimoeForSequenceClassification(GenericForSequenceClassification, PhimoePreTrainedModel): ...
  341. __all__ = [
  342. "PhimoePreTrainedModel",
  343. "PhimoeModel",
  344. "PhimoeForCausalLM",
  345. "PhimoeForSequenceClassification",
  346. ]