modeling_granitemoe.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/granitemoe/modular_granitemoe.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_granitemoe.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. from typing import Optional
  23. import torch
  24. from torch import nn
  25. from torch.nn import functional as F
  26. from ... import initialization as init
  27. from ...activations import ACT2FN
  28. from ...cache_utils import Cache, DynamicCache
  29. from ...generation import GenerationMixin
  30. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  31. from ...masking_utils import create_causal_mask
  32. from ...modeling_layers import GradientCheckpointingLayer
  33. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  34. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  35. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  36. from ...processing_utils import Unpack
  37. from ...utils import TransformersKwargs, auto_docstring
  38. from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults
  39. from ...utils.output_capturing import capture_outputs
  40. from .configuration_granitemoe import GraniteMoeConfig
  41. @use_kernel_forward_from_hub("RMSNorm")
  42. class GraniteMoeRMSNorm(nn.Module):
  43. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  44. """
  45. GraniteMoeRMSNorm is equivalent to T5LayerNorm
  46. """
  47. super().__init__()
  48. self.weight = nn.Parameter(torch.ones(hidden_size))
  49. self.variance_epsilon = eps
  50. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  51. input_dtype = hidden_states.dtype
  52. hidden_states = hidden_states.to(torch.float32)
  53. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  54. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  55. return self.weight * hidden_states.to(input_dtype)
  56. def extra_repr(self):
  57. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  58. class GraniteMoeRotaryEmbedding(nn.Module):
  59. inv_freq: torch.Tensor # fix linting for `register_buffer`
  60. def __init__(self, config: GraniteMoeConfig, device=None):
  61. super().__init__()
  62. self.max_seq_len_cached = config.max_position_embeddings
  63. self.original_max_seq_len = config.max_position_embeddings
  64. self.config = config
  65. self.rope_type = self.config.rope_parameters["rope_type"]
  66. rope_init_fn: Callable = self.compute_default_rope_parameters
  67. if self.rope_type != "default":
  68. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  69. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  70. self.register_buffer("inv_freq", inv_freq, persistent=False)
  71. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  72. @staticmethod
  73. def compute_default_rope_parameters(
  74. config: GraniteMoeConfig | None = None,
  75. device: Optional["torch.device"] = None,
  76. seq_len: int | None = None,
  77. ) -> tuple["torch.Tensor", float]:
  78. """
  79. Computes the inverse frequencies according to the original RoPE implementation
  80. Args:
  81. config ([`~transformers.PreTrainedConfig`]):
  82. The model configuration.
  83. device (`torch.device`):
  84. The device to use for initialization of the inverse frequencies.
  85. seq_len (`int`, *optional*):
  86. The current sequence length. Unused for this type of RoPE.
  87. Returns:
  88. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  89. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  90. """
  91. base = config.rope_parameters["rope_theta"]
  92. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  93. attention_factor = 1.0 # Unused in this type of RoPE
  94. # Compute the inverse frequencies
  95. inv_freq = 1.0 / (
  96. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  97. )
  98. return inv_freq, attention_factor
  99. @torch.no_grad()
  100. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  101. def forward(self, x, position_ids):
  102. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  103. position_ids_expanded = position_ids[:, None, :].float()
  104. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  105. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  106. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  107. emb = torch.cat((freqs, freqs), dim=-1)
  108. cos = emb.cos() * self.attention_scaling
  109. sin = emb.sin() * self.attention_scaling
  110. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  111. class GraniteMoeParallelExperts(nn.Module):
  112. def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
  113. """
  114. Initialize the GraniteMoeParallelExperts module.
  115. The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with
  116. many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and
  117. [ScatterMoE](https://github.com/shawntan/scattermoe), as well as the
  118. [MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py)
  119. used in vllm.
  120. Args:
  121. num_experts (int):
  122. Number of experts.
  123. input_size (int):
  124. Size of the input.
  125. output_size (int):
  126. Size of the output.
  127. """
  128. super().__init__()
  129. self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
  130. self.num_experts = num_experts
  131. self.input_size = input_size
  132. self.output_size = output_size
  133. def forward(self, inputs, expert_size):
  134. """
  135. Forward pass of the GraniteMoeParallelExperts module.
  136. Args:
  137. inputs (Tensor):
  138. Input tensor.
  139. expert_size:
  140. Expert size information.
  141. Returns:
  142. Tensor: Output tensor.
  143. """
  144. input_list = inputs.split(expert_size, dim=0)
  145. output_list = []
  146. for i in range(self.num_experts):
  147. output_list.append(F.linear(input_list[i], self.weight[i]))
  148. results = torch.cat(output_list, dim=0)
  149. return results
  150. class GraniteMoeTopKGating(nn.Module):
  151. def __init__(self, input_size: int, num_experts: int, top_k: int):
  152. """
  153. Initialize the top-k gating mechanism.
  154. Args:
  155. input_size (`int`):
  156. Size of the input.
  157. num_experts (`int`):
  158. Number of experts.
  159. top_k (`int`):
  160. Number of top experts to select.
  161. """
  162. super().__init__()
  163. self.num_experts = num_experts
  164. self.input_size = input_size
  165. self.top_k = top_k
  166. self.layer = nn.Linear(input_size, num_experts, bias=False)
  167. def forward(self, hidden_states):
  168. # compute the top_k routing decision
  169. logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
  170. top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
  171. top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
  172. # compute number of input given to each expert
  173. zeros = torch.zeros(
  174. [top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
  175. ) # [num_tokens, num_experts]
  176. gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
  177. expert_size = gates.long().sum(0) # [num_experts,]
  178. # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
  179. # (and `DataDependentOutputException`)
  180. expert_size = expert_size.tolist()
  181. # sort and group input tokens according to expert assignment
  182. top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
  183. _, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
  184. batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]
  185. # gather the gate values for grouped input tokens
  186. top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
  187. batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]
  188. return index_sorted_experts, batch_index, batch_gates, expert_size, logits
  189. class GraniteMoeMoE(nn.Module):
  190. """
  191. A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
  192. Args:
  193. config:
  194. Configuration object with model hyperparameters.
  195. """
  196. def __init__(self, config: GraniteMoeConfig):
  197. super().__init__()
  198. self.input_size = config.hidden_size
  199. self.hidden_size = config.intermediate_size
  200. self.activation = ACT2FN[config.hidden_act]
  201. self.input_linear = GraniteMoeParallelExperts(config.num_local_experts, self.input_size, self.hidden_size * 2)
  202. self.output_linear = GraniteMoeParallelExperts(config.num_local_experts, self.hidden_size, self.input_size)
  203. self.router = GraniteMoeTopKGating(
  204. input_size=self.input_size,
  205. num_experts=config.num_local_experts,
  206. top_k=config.num_experts_per_tok,
  207. )
  208. def forward(self, layer_input):
  209. bsz, length, emb_size = layer_input.size()
  210. layer_input = layer_input.reshape(-1, emb_size)
  211. _, batch_index, batch_gates, expert_size, _ = self.router(layer_input)
  212. expert_inputs = layer_input[batch_index]
  213. hidden_states = self.input_linear(expert_inputs, expert_size)
  214. chunked_hidden_states = hidden_states.chunk(2, dim=-1)
  215. hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
  216. expert_outputs = self.output_linear(hidden_states, expert_size)
  217. expert_outputs = expert_outputs * batch_gates[:, None]
  218. zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
  219. layer_output = zeros.index_add(0, batch_index, expert_outputs)
  220. layer_output = layer_output.view(bsz, length, self.input_size)
  221. return layer_output
  222. def rotate_half(x):
  223. """Rotates half the hidden dims of the input."""
  224. x1 = x[..., : x.shape[-1] // 2]
  225. x2 = x[..., x.shape[-1] // 2 :]
  226. return torch.cat((-x2, x1), dim=-1)
  227. @use_kernel_func_from_hub("rotary_pos_emb")
  228. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  229. """Applies Rotary Position Embedding to the query and key tensors.
  230. Args:
  231. q (`torch.Tensor`): The query tensor.
  232. k (`torch.Tensor`): The key tensor.
  233. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  234. sin (`torch.Tensor`): The sine part of the rotary embedding.
  235. unsqueeze_dim (`int`, *optional*, defaults to 1):
  236. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  237. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  238. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  239. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  240. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  241. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  242. Returns:
  243. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  244. """
  245. cos = cos.unsqueeze(unsqueeze_dim)
  246. sin = sin.unsqueeze(unsqueeze_dim)
  247. q_embed = (q * cos) + (rotate_half(q) * sin)
  248. k_embed = (k * cos) + (rotate_half(k) * sin)
  249. return q_embed, k_embed
  250. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  251. """
  252. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  253. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  254. """
  255. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  256. if n_rep == 1:
  257. return hidden_states
  258. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  259. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  260. def eager_attention_forward(
  261. module: nn.Module,
  262. query: torch.Tensor,
  263. key: torch.Tensor,
  264. value: torch.Tensor,
  265. attention_mask: torch.Tensor | None,
  266. scaling: float,
  267. dropout: float = 0.0,
  268. **kwargs: Unpack[TransformersKwargs],
  269. ):
  270. key_states = repeat_kv(key, module.num_key_value_groups)
  271. value_states = repeat_kv(value, module.num_key_value_groups)
  272. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  273. if attention_mask is not None:
  274. attn_weights = attn_weights + attention_mask
  275. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  276. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  277. attn_output = torch.matmul(attn_weights, value_states)
  278. attn_output = attn_output.transpose(1, 2).contiguous()
  279. return attn_output, attn_weights
  280. @use_kernelized_func(apply_rotary_pos_emb)
  281. class GraniteMoeAttention(nn.Module):
  282. """Multi-headed attention from 'Attention Is All You Need' paper"""
  283. def __init__(self, config: GraniteMoeConfig, layer_idx: int):
  284. super().__init__()
  285. self.config = config
  286. self.layer_idx = layer_idx
  287. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  288. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  289. self.scaling = config.attention_multiplier # Only diff with llama
  290. self.attention_dropout = config.attention_dropout
  291. self.is_causal = True
  292. self.q_proj = nn.Linear(
  293. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  294. )
  295. self.k_proj = nn.Linear(
  296. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  297. )
  298. self.v_proj = nn.Linear(
  299. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  300. )
  301. self.o_proj = nn.Linear(
  302. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  303. )
  304. def forward(
  305. self,
  306. hidden_states: torch.Tensor,
  307. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  308. attention_mask: torch.Tensor | None = None,
  309. past_key_values: Cache | None = None,
  310. **kwargs: Unpack[TransformersKwargs],
  311. ) -> tuple[torch.Tensor, torch.Tensor]:
  312. input_shape = hidden_states.shape[:-1]
  313. hidden_shape = (*input_shape, -1, self.head_dim)
  314. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  315. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  316. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  317. cos, sin = position_embeddings
  318. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  319. if past_key_values is not None:
  320. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  321. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  322. self.config._attn_implementation, eager_attention_forward
  323. )
  324. attn_output, attn_weights = attention_interface(
  325. self,
  326. query_states,
  327. key_states,
  328. value_states,
  329. attention_mask,
  330. dropout=0.0 if not self.training else self.attention_dropout,
  331. scaling=self.scaling,
  332. **kwargs,
  333. )
  334. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  335. attn_output = self.o_proj(attn_output)
  336. return attn_output, attn_weights
  337. class GraniteMoeDecoderLayer(GradientCheckpointingLayer):
  338. def __init__(self, config: GraniteMoeConfig, layer_idx: int):
  339. super().__init__()
  340. self.hidden_size = config.hidden_size
  341. self.self_attn = GraniteMoeAttention(config=config, layer_idx=layer_idx)
  342. self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  343. self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  344. self.block_sparse_moe = GraniteMoeMoE(config)
  345. self.residual_multiplier = config.residual_multiplier # Only diff with mixtral!
  346. def forward(
  347. self,
  348. hidden_states: torch.Tensor,
  349. attention_mask: torch.Tensor | None = None,
  350. past_key_values: Cache | None = None,
  351. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  352. **kwargs,
  353. ) -> torch.Tensor:
  354. residual = hidden_states
  355. hidden_states = self.input_layernorm(hidden_states)
  356. hidden_states, _ = self.self_attn(
  357. hidden_states=hidden_states,
  358. attention_mask=attention_mask,
  359. past_key_values=past_key_values,
  360. position_embeddings=position_embeddings,
  361. **kwargs,
  362. )
  363. hidden_states = residual + hidden_states * self.residual_multiplier # diff
  364. residual = hidden_states
  365. hidden_states = self.post_attention_layernorm(hidden_states)
  366. hidden_states = self.block_sparse_moe(hidden_states)
  367. hidden_states = residual + hidden_states * self.residual_multiplier # diff
  368. return hidden_states
  369. @auto_docstring
  370. class GraniteMoePreTrainedModel(PreTrainedModel):
  371. config: GraniteMoeConfig
  372. base_model_prefix = "model"
  373. supports_gradient_checkpointing = True
  374. _no_split_modules = ["GraniteMoeDecoderLayer"]
  375. _skip_keys_device_placement = ["past_key_values"]
  376. _supports_flash_attn = True
  377. _supports_sdpa = True
  378. _supports_flex_attn = True
  379. _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
  380. _supports_attention_backend = True
  381. _can_record_outputs = {
  382. "hidden_states": GraniteMoeDecoderLayer,
  383. "attentions": GraniteMoeAttention,
  384. }
  385. @torch.no_grad()
  386. def _init_weights(self, module):
  387. super()._init_weights(module)
  388. if isinstance(module, GraniteMoeParallelExperts):
  389. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  390. @auto_docstring
  391. class GraniteMoeModel(GraniteMoePreTrainedModel):
  392. def __init__(self, config: GraniteMoeConfig):
  393. super().__init__(config)
  394. self.padding_idx = config.pad_token_id
  395. self.vocab_size = config.vocab_size
  396. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  397. self.layers = nn.ModuleList(
  398. [GraniteMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  399. )
  400. self.norm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  401. self.rotary_emb = GraniteMoeRotaryEmbedding(config=config)
  402. self.gradient_checkpointing = False
  403. self.embedding_multiplier = config.embedding_multiplier
  404. # Initialize weights and apply final processing
  405. self.post_init()
  406. @merge_with_config_defaults
  407. @capture_outputs
  408. @auto_docstring
  409. def forward(
  410. self,
  411. input_ids: torch.LongTensor | None = None,
  412. attention_mask: torch.Tensor | None = None,
  413. position_ids: torch.LongTensor | None = None,
  414. past_key_values: Cache | None = None,
  415. inputs_embeds: torch.FloatTensor | None = None,
  416. use_cache: bool | None = None,
  417. **kwargs: Unpack[TransformersKwargs],
  418. ) -> MoeModelOutputWithPast:
  419. if (input_ids is None) ^ (inputs_embeds is not None):
  420. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  421. if use_cache and past_key_values is None:
  422. past_key_values = DynamicCache(config=self.config)
  423. if inputs_embeds is None:
  424. inputs_embeds = self.embed_tokens(input_ids)
  425. if position_ids is None:
  426. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  427. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  428. position_ids = position_ids.unsqueeze(0)
  429. causal_mask = create_causal_mask( # ONLY DIFF WITH MIXTRAL: NO SLIDING
  430. config=self.config,
  431. inputs_embeds=inputs_embeds,
  432. attention_mask=attention_mask,
  433. past_key_values=past_key_values,
  434. position_ids=position_ids,
  435. )
  436. inputs_embeds = inputs_embeds * self.embedding_multiplier
  437. hidden_states = inputs_embeds
  438. # create position embeddings to be shared across the decoder layers
  439. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  440. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  441. hidden_states = decoder_layer(
  442. hidden_states,
  443. position_embeddings=position_embeddings,
  444. attention_mask=causal_mask,
  445. position_ids=position_ids,
  446. past_key_values=past_key_values,
  447. use_cache=use_cache,
  448. **kwargs,
  449. )
  450. hidden_states = self.norm(hidden_states)
  451. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  452. last_hidden_state=hidden_states,
  453. past_key_values=past_key_values,
  454. )
  455. def load_balancing_loss_func(
  456. gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
  457. num_experts: int | None = None,
  458. top_k=2,
  459. attention_mask: torch.Tensor | None = None,
  460. ) -> torch.Tensor | int:
  461. r"""
  462. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  463. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  464. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  465. experts is too unbalanced.
  466. Args:
  467. gate_logits:
  468. Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
  469. shape [batch_size X sequence_length, num_experts].
  470. num_experts:
  471. Number of experts
  472. top_k:
  473. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  474. parameter.
  475. attention_mask (`torch.Tensor`, *optional*):
  476. The attention_mask used in forward function
  477. shape [batch_size X sequence_length] if not None.
  478. Returns:
  479. The auxiliary loss.
  480. """
  481. if gate_logits is None or not isinstance(gate_logits, tuple):
  482. return 0
  483. if isinstance(gate_logits, tuple):
  484. compute_device = gate_logits[0].device
  485. concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
  486. routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
  487. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  488. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  489. if attention_mask is None:
  490. # Compute the percentage of tokens routed to each experts
  491. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  492. # Compute the average probability of routing to these experts
  493. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  494. else:
  495. batch_size, sequence_length = attention_mask.shape
  496. num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
  497. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  498. expert_attention_mask = (
  499. attention_mask[None, :, :, None, None]
  500. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  501. .reshape(-1, top_k, num_experts)
  502. .to(compute_device)
  503. )
  504. # Compute the percentage of tokens routed to each experts
  505. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  506. expert_attention_mask, dim=0
  507. )
  508. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  509. router_per_expert_attention_mask = (
  510. attention_mask[None, :, :, None]
  511. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  512. .reshape(-1, num_experts)
  513. .to(compute_device)
  514. )
  515. # Compute the average probability of routing to these experts
  516. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  517. router_per_expert_attention_mask, dim=0
  518. )
  519. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
  520. return overall_loss * num_experts
  521. @auto_docstring
  522. class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
  523. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  524. _tp_plan = {"lm_head": "colwise_gather_output"}
  525. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  526. def __init__(self, config: GraniteMoeConfig):
  527. super().__init__(config)
  528. self.model = GraniteMoeModel(config)
  529. self.vocab_size = config.vocab_size
  530. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  531. self.router_aux_loss_coef = config.router_aux_loss_coef
  532. self.num_experts = config.num_local_experts
  533. self.num_experts_per_tok = config.num_experts_per_tok
  534. self.logits_scaling = config.logits_scaling
  535. # Initialize weights and apply final processing
  536. self.post_init()
  537. @auto_docstring
  538. @can_return_tuple
  539. def forward(
  540. self,
  541. input_ids: torch.LongTensor | None = None,
  542. attention_mask: torch.Tensor | None = None,
  543. position_ids: torch.LongTensor | None = None,
  544. past_key_values: Cache | None = None,
  545. inputs_embeds: torch.FloatTensor | None = None,
  546. labels: torch.LongTensor | None = None,
  547. output_router_logits: bool | None = None,
  548. logits_to_keep: int | torch.Tensor = 0,
  549. **kwargs,
  550. ) -> tuple | MoeCausalLMOutputWithPast:
  551. r"""
  552. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  553. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  554. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  555. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  556. Example:
  557. ```python
  558. >>> from transformers import AutoTokenizer, GraniteMoeForCausalLM
  559. >>> model = GraniteMoeForCausalLM.from_pretrained("ibm/PowerMoE-3b")
  560. >>> tokenizer = AutoTokenizer.from_pretrained("ibm/PowerMoE-3b")
  561. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  562. >>> inputs = tokenizer(prompt, return_tensors="pt")
  563. >>> # Generate
  564. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  565. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  566. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  567. ```"""
  568. output_router_logits = (
  569. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  570. )
  571. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  572. outputs = self.model(
  573. input_ids=input_ids,
  574. attention_mask=attention_mask,
  575. position_ids=position_ids,
  576. past_key_values=past_key_values,
  577. inputs_embeds=inputs_embeds,
  578. **kwargs,
  579. )
  580. # Only compute necessary logits
  581. hidden_states = outputs.last_hidden_state
  582. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  583. logits = self.lm_head(hidden_states[:, slice_indices, :])
  584. logits = logits / self.config.logits_scaling
  585. loss = None
  586. if labels is not None:
  587. # Flatten the tokens
  588. loss = self.loss_function(
  589. logits,
  590. labels,
  591. vocab_size=self.config.vocab_size,
  592. **kwargs,
  593. )
  594. aux_loss = None
  595. if output_router_logits:
  596. aux_loss = load_balancing_loss_func(
  597. outputs.router_logits,
  598. self.num_experts,
  599. self.num_experts_per_tok,
  600. attention_mask,
  601. )
  602. if labels is not None:
  603. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  604. return MoeCausalLMOutputWithPast(
  605. loss=loss,
  606. aux_loss=aux_loss,
  607. logits=logits,
  608. past_key_values=outputs.past_key_values,
  609. hidden_states=outputs.hidden_states,
  610. attentions=outputs.attentions,
  611. router_logits=outputs.router_logits,
  612. )
  613. __all__ = ["GraniteMoeForCausalLM", "GraniteMoeModel", "GraniteMoePreTrainedModel"]