modular_granitemoe.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import torch
  16. from torch import nn
  17. from ... import initialization as init
  18. from ...activations import ACT2FN
  19. from ...cache_utils import Cache, DynamicCache
  20. from ...masking_utils import create_causal_mask
  21. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  22. from ...modeling_utils import PreTrainedModel
  23. from ...processing_utils import Unpack
  24. from ...utils import TransformersKwargs, auto_docstring
  25. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  26. from ...utils.output_capturing import capture_outputs
  27. from ..granite.modeling_granite import GraniteRMSNorm, GraniteRotaryEmbedding
  28. from ..jetmoe.modeling_jetmoe import JetMoeParallelExperts, JetMoeTopKGating
  29. from ..llama.modeling_llama import LlamaAttention, LlamaPreTrainedModel
  30. from ..mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel, load_balancing_loss_func
  31. from .configuration_granitemoe import GraniteMoeConfig
  32. class GraniteMoeRMSNorm(GraniteRMSNorm):
  33. pass
  34. class GraniteMoeRotaryEmbedding(GraniteRotaryEmbedding):
  35. pass
  36. class GraniteMoeParallelExperts(JetMoeParallelExperts):
  37. pass
  38. class GraniteMoeTopKGating(JetMoeTopKGating):
  39. pass
  40. class GraniteMoeMoE(nn.Module):
  41. """
  42. A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
  43. Args:
  44. config:
  45. Configuration object with model hyperparameters.
  46. """
  47. def __init__(self, config: GraniteMoeConfig):
  48. super().__init__()
  49. self.input_size = config.hidden_size
  50. self.hidden_size = config.intermediate_size
  51. self.activation = ACT2FN[config.hidden_act]
  52. self.input_linear = GraniteMoeParallelExperts(config.num_local_experts, self.input_size, self.hidden_size * 2)
  53. self.output_linear = GraniteMoeParallelExperts(config.num_local_experts, self.hidden_size, self.input_size)
  54. self.router = GraniteMoeTopKGating(
  55. input_size=self.input_size,
  56. num_experts=config.num_local_experts,
  57. top_k=config.num_experts_per_tok,
  58. )
  59. def forward(self, layer_input):
  60. bsz, length, emb_size = layer_input.size()
  61. layer_input = layer_input.reshape(-1, emb_size)
  62. _, batch_index, batch_gates, expert_size, _ = self.router(layer_input)
  63. expert_inputs = layer_input[batch_index]
  64. hidden_states = self.input_linear(expert_inputs, expert_size)
  65. chunked_hidden_states = hidden_states.chunk(2, dim=-1)
  66. hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
  67. expert_outputs = self.output_linear(hidden_states, expert_size)
  68. expert_outputs = expert_outputs * batch_gates[:, None]
  69. zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
  70. layer_output = zeros.index_add(0, batch_index, expert_outputs)
  71. layer_output = layer_output.view(bsz, length, self.input_size)
  72. return layer_output
  73. class GraniteMoeAttention(LlamaAttention):
  74. def __init__(self, config: GraniteMoeConfig, layer_idx: int):
  75. super().__init__(self, config, layer_idx)
  76. self.scaling = config.attention_multiplier # Only diff with llama
  77. class GraniteMoeDecoderLayer(MixtralDecoderLayer):
  78. def __init__(self, config: GraniteMoeConfig, layer_idx: int):
  79. super().__init__(config, layer_idx)
  80. self.self_attn = GraniteMoeAttention(config=config, layer_idx=layer_idx)
  81. self.block_sparse_moe = GraniteMoeMoE(config)
  82. self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  83. self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  84. del self.mlp
  85. self.block_sparse_moe = GraniteMoeMoE(config)
  86. self.residual_multiplier = config.residual_multiplier # Only diff with mixtral!
  87. def forward(
  88. self,
  89. hidden_states: torch.Tensor,
  90. attention_mask: torch.Tensor | None = None,
  91. past_key_values: Cache | None = None,
  92. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  93. **kwargs,
  94. ) -> torch.Tensor:
  95. residual = hidden_states
  96. hidden_states = self.input_layernorm(hidden_states)
  97. hidden_states, _ = self.self_attn(
  98. hidden_states=hidden_states,
  99. attention_mask=attention_mask,
  100. past_key_values=past_key_values,
  101. position_embeddings=position_embeddings,
  102. **kwargs,
  103. )
  104. hidden_states = residual + hidden_states * self.residual_multiplier # diff
  105. residual = hidden_states
  106. hidden_states = self.post_attention_layernorm(hidden_states)
  107. hidden_states = self.block_sparse_moe(hidden_states)
  108. hidden_states = residual + hidden_states * self.residual_multiplier # diff
  109. return hidden_states
  110. @auto_docstring
  111. class GraniteMoePreTrainedModel(LlamaPreTrainedModel, PreTrainedModel):
  112. config: GraniteMoeConfig
  113. base_model_prefix = "model"
  114. supports_gradient_checkpointing = True
  115. _no_split_modules = ["GraniteMoeDecoderLayer"]
  116. _skip_keys_device_placement = ["past_key_values"]
  117. _supports_flash_attn = True
  118. _supports_sdpa = True
  119. _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
  120. @torch.no_grad()
  121. def _init_weights(self, module):
  122. PreTrainedModel._init_weights(self, module)
  123. if isinstance(module, GraniteMoeParallelExperts):
  124. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  125. @auto_docstring
  126. class GraniteMoeModel(MixtralModel):
  127. def __init__(self, config: GraniteMoeConfig):
  128. super().__init__(config)
  129. self.layers = nn.ModuleList(
  130. [GraniteMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  131. )
  132. self.norm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  133. self.embedding_multiplier = config.embedding_multiplier
  134. @merge_with_config_defaults
  135. @capture_outputs
  136. @auto_docstring
  137. def forward(
  138. self,
  139. input_ids: torch.LongTensor | None = None,
  140. attention_mask: torch.Tensor | None = None,
  141. position_ids: torch.LongTensor | None = None,
  142. past_key_values: Cache | None = None,
  143. inputs_embeds: torch.FloatTensor | None = None,
  144. use_cache: bool | None = None,
  145. **kwargs: Unpack[TransformersKwargs],
  146. ) -> MoeModelOutputWithPast:
  147. if (input_ids is None) ^ (inputs_embeds is not None):
  148. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  149. if use_cache and past_key_values is None:
  150. past_key_values = DynamicCache(config=self.config)
  151. if inputs_embeds is None:
  152. inputs_embeds = self.embed_tokens(input_ids)
  153. if position_ids is None:
  154. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  155. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  156. position_ids = position_ids.unsqueeze(0)
  157. causal_mask = create_causal_mask( # ONLY DIFF WITH MIXTRAL: NO SLIDING
  158. config=self.config,
  159. inputs_embeds=inputs_embeds,
  160. attention_mask=attention_mask,
  161. past_key_values=past_key_values,
  162. position_ids=position_ids,
  163. )
  164. inputs_embeds = inputs_embeds * self.embedding_multiplier
  165. hidden_states = inputs_embeds
  166. # create position embeddings to be shared across the decoder layers
  167. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  168. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  169. hidden_states = decoder_layer(
  170. hidden_states,
  171. position_embeddings=position_embeddings,
  172. attention_mask=causal_mask,
  173. position_ids=position_ids,
  174. past_key_values=past_key_values,
  175. use_cache=use_cache,
  176. **kwargs,
  177. )
  178. hidden_states = self.norm(hidden_states)
  179. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  180. last_hidden_state=hidden_states,
  181. past_key_values=past_key_values,
  182. )
  183. class GraniteMoeForCausalLM(MixtralForCausalLM):
  184. def __init__(self, config: GraniteMoeConfig):
  185. super().__init__(config)
  186. self.model = GraniteMoeModel(config)
  187. self.logits_scaling = config.logits_scaling
  188. @auto_docstring
  189. @can_return_tuple
  190. def forward(
  191. self,
  192. input_ids: torch.LongTensor | None = None,
  193. attention_mask: torch.Tensor | None = None,
  194. position_ids: torch.LongTensor | None = None,
  195. past_key_values: Cache | None = None,
  196. inputs_embeds: torch.FloatTensor | None = None,
  197. labels: torch.LongTensor | None = None,
  198. output_router_logits: bool | None = None,
  199. logits_to_keep: int | torch.Tensor = 0,
  200. **kwargs,
  201. ) -> tuple | MoeCausalLMOutputWithPast:
  202. r"""
  203. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  204. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  205. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  206. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  207. Example:
  208. ```python
  209. >>> from transformers import AutoTokenizer, GraniteMoeForCausalLM
  210. >>> model = GraniteMoeForCausalLM.from_pretrained("ibm/PowerMoE-3b")
  211. >>> tokenizer = AutoTokenizer.from_pretrained("ibm/PowerMoE-3b")
  212. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  213. >>> inputs = tokenizer(prompt, return_tensors="pt")
  214. >>> # Generate
  215. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  216. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  217. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  218. ```"""
  219. output_router_logits = (
  220. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  221. )
  222. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  223. outputs = self.model(
  224. input_ids=input_ids,
  225. attention_mask=attention_mask,
  226. position_ids=position_ids,
  227. past_key_values=past_key_values,
  228. inputs_embeds=inputs_embeds,
  229. **kwargs,
  230. )
  231. # Only compute necessary logits
  232. hidden_states = outputs.last_hidden_state
  233. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  234. logits = self.lm_head(hidden_states[:, slice_indices, :])
  235. logits = logits / self.config.logits_scaling
  236. loss = None
  237. if labels is not None:
  238. # Flatten the tokens
  239. loss = self.loss_function(
  240. logits,
  241. labels,
  242. vocab_size=self.config.vocab_size,
  243. **kwargs,
  244. )
  245. aux_loss = None
  246. if output_router_logits:
  247. aux_loss = load_balancing_loss_func(
  248. outputs.router_logits,
  249. self.num_experts,
  250. self.num_experts_per_tok,
  251. attention_mask,
  252. )
  253. if labels is not None:
  254. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  255. return MoeCausalLMOutputWithPast(
  256. loss=loss,
  257. aux_loss=aux_loss,
  258. logits=logits,
  259. past_key_values=outputs.past_key_values,
  260. hidden_states=outputs.hidden_states,
  261. attentions=outputs.attentions,
  262. router_logits=outputs.router_logits,
  263. )
  264. __all__ = ["GraniteMoeForCausalLM", "GraniteMoeModel", "GraniteMoePreTrainedModel"]