modular_olmoe.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # Licensed under the Apache License, Version 2.0 (the "License");
  2. # you may not use this file except in compliance with the License.
  3. # You may obtain a copy of the License at
  4. #
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. #
  7. # Unless required by applicable law or agreed to in writing, software
  8. # distributed under the License is distributed on an "AS IS" BASIS,
  9. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. # See the License for the specific language governing permissions and
  11. # limitations under the License.
  12. """PyTorch OLMoE model."""
  13. from collections.abc import Callable
  14. import torch
  15. from torch import nn
  16. from ... import initialization as init
  17. from ...cache_utils import Cache, DynamicCache
  18. from ...generation import GenerationMixin
  19. from ...masking_utils import create_causal_mask
  20. from ...modeling_outputs import MoeModelOutputWithPast
  21. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  22. from ...processing_utils import Unpack
  23. from ...utils import TransformersKwargs, auto_docstring, logging
  24. from ...utils.output_capturing import OutputRecorder
  25. from ..gemma.modeling_gemma import GemmaMLP
  26. from ..llama.modeling_llama import (
  27. LlamaAttention,
  28. LlamaDecoderLayer,
  29. LlamaRMSNorm,
  30. LlamaRotaryEmbedding,
  31. apply_rotary_pos_emb,
  32. eager_attention_forward,
  33. )
  34. from ..mixtral.modeling_mixtral import MixtralExperts, MixtralForCausalLM, MixtralModel
  35. from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeTopKRouter
  36. from .configuration_olmoe import OlmoeConfig
  37. logger = logging.get_logger(__name__)
  38. class OlmoeRMSNorm(LlamaRMSNorm):
  39. def __init__(self, hidden_size, eps=1e-5):
  40. super().__init__(hidden_size, eps)
  41. class OlmoeRotaryEmbedding(LlamaRotaryEmbedding):
  42. pass
  43. class OlmoeMLP(GemmaMLP):
  44. pass
  45. class OlmoeAttention(LlamaAttention):
  46. def __init__(self, config: OlmoeConfig, layer_idx: int | None = None):
  47. super().__init__(config, layer_idx)
  48. self.q_norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  49. self.k_norm = OlmoeRMSNorm(
  50. (config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps
  51. )
  52. def forward(
  53. self,
  54. hidden_states: torch.Tensor,
  55. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  56. attention_mask: torch.Tensor | None,
  57. past_key_values: Cache | None = None,
  58. **kwargs: Unpack[TransformersKwargs],
  59. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  60. input_shape = hidden_states.shape[:-1]
  61. hidden_shape = (*input_shape, -1, self.head_dim)
  62. query_states = self.q_norm(self.q_proj(hidden_states))
  63. key_states = self.k_norm(self.k_proj(hidden_states))
  64. value_states = self.v_proj(hidden_states)
  65. if self.config.clip_qkv is not None: # Diff with llama
  66. query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  67. key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  68. value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  69. query_states = query_states.view(*hidden_shape).transpose(1, 2)
  70. key_states = key_states.view(*hidden_shape).transpose(1, 2)
  71. value_states = value_states.view(*hidden_shape).transpose(1, 2)
  72. cos, sin = position_embeddings
  73. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  74. if past_key_values is not None:
  75. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  76. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  77. self.config._attn_implementation, eager_attention_forward
  78. )
  79. attn_output, attn_weights = attention_interface(
  80. self,
  81. query_states,
  82. key_states,
  83. value_states,
  84. attention_mask,
  85. dropout=0.0 if not self.training else self.attention_dropout,
  86. scaling=self.scaling,
  87. sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
  88. **kwargs,
  89. )
  90. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  91. attn_output = self.o_proj(attn_output)
  92. return attn_output, attn_weights
  93. class OlmoeExperts(MixtralExperts):
  94. pass
  95. class OlmoeTopKRouter(Qwen2MoeTopKRouter):
  96. pass
  97. class OlmoeSparseMoeBlock(nn.Module):
  98. def __init__(self, config):
  99. super().__init__()
  100. self.gate = OlmoeTopKRouter(config)
  101. self.experts = OlmoeExperts(config)
  102. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  103. batch_size, sequence_length, hidden_dim = hidden_states.shape
  104. hidden_states = hidden_states.view(-1, hidden_dim)
  105. _, top_k_weights, top_k_index = self.gate(hidden_states)
  106. final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape(
  107. batch_size, sequence_length, hidden_dim
  108. )
  109. return final_hidden_states
  110. class OlmoeDecoderLayer(LlamaDecoderLayer):
  111. def __init__(self, config: OlmoeConfig, layer_idx: int):
  112. super().__init__(config, layer_idx)
  113. self.hidden_size = config.hidden_size
  114. self.self_attn = OlmoeAttention(config=config, layer_idx=layer_idx)
  115. self.mlp = OlmoeSparseMoeBlock(config)
  116. self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  117. self.post_attention_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  118. @auto_docstring
  119. class OlmoePreTrainedModel(PreTrainedModel):
  120. config: OlmoeConfig
  121. base_model_prefix = "model"
  122. supports_gradient_checkpointing = True
  123. _no_split_modules = ["OlmoeDecoderLayer"]
  124. _skip_keys_device_placement = ["past_key_values"]
  125. _supports_flash_attn = True
  126. _supports_sdpa = True
  127. _can_record_outputs = {
  128. "router_logits": OutputRecorder(OlmoeTopKRouter, index=0),
  129. "hidden_states": OlmoeDecoderLayer,
  130. "attentions": OlmoeAttention,
  131. }
  132. _supports_attention_backend = True
  133. @torch.no_grad()
  134. def _init_weights(self, module):
  135. PreTrainedModel._init_weights(self, module)
  136. if isinstance(module, OlmoeExperts):
  137. init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
  138. init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
  139. elif isinstance(module, OlmoeTopKRouter):
  140. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  141. @auto_docstring
  142. class OlmoeModel(MixtralModel):
  143. def __init__(self, config: OlmoeConfig):
  144. super().__init__(config)
  145. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  146. self.layers = nn.ModuleList(
  147. [OlmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  148. )
  149. self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  150. self.rotary_emb = OlmoeRotaryEmbedding(config=config)
  151. def forward(
  152. self,
  153. input_ids: torch.LongTensor | None = None,
  154. attention_mask: torch.Tensor | None = None,
  155. position_ids: torch.LongTensor | None = None,
  156. past_key_values: Cache | None = None,
  157. inputs_embeds: torch.FloatTensor | None = None,
  158. use_cache: bool | None = None,
  159. **kwargs: Unpack[TransformersKwargs],
  160. ) -> MoeModelOutputWithPast:
  161. if (input_ids is None) ^ (inputs_embeds is not None):
  162. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  163. if use_cache and past_key_values is None:
  164. past_key_values = DynamicCache(config=self.config)
  165. if inputs_embeds is None:
  166. inputs_embeds = self.embed_tokens(input_ids)
  167. if position_ids is None:
  168. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  169. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  170. position_ids = position_ids.unsqueeze(0)
  171. causal_mask = create_causal_mask( # diff with mixtral: no sliding
  172. config=self.config,
  173. inputs_embeds=inputs_embeds,
  174. attention_mask=attention_mask,
  175. past_key_values=past_key_values,
  176. position_ids=position_ids,
  177. )
  178. hidden_states = inputs_embeds
  179. # create position embeddings to be shared across the decoder layers
  180. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  181. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  182. hidden_states = decoder_layer(
  183. hidden_states,
  184. position_embeddings=position_embeddings,
  185. attention_mask=causal_mask,
  186. position_ids=position_ids,
  187. past_key_values=past_key_values,
  188. use_cache=use_cache,
  189. **kwargs,
  190. )
  191. hidden_states = self.norm(hidden_states)
  192. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  193. last_hidden_state=hidden_states,
  194. past_key_values=past_key_values,
  195. )
  196. class OlmoeForCausalLM(MixtralForCausalLM, GenerationMixin):
  197. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  198. def __init__(self, config):
  199. super().__init__(config)
  200. self.model = OlmoeModel(config)
  201. self.num_experts = config.num_experts
  202. def forward(self, **super_kwargs):
  203. r"""
  204. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  205. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  206. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  207. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  208. Example:
  209. ```python
  210. >>> from transformers import AutoTokenizer, OlmoeForCausalLM
  211. >>> model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
  212. >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
  213. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  214. >>> inputs = tokenizer(prompt, return_tensors="pt")
  215. >>> # Generate
  216. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  217. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  218. 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
  219. ```
  220. """
  221. return super().forward(**super_kwargs)
  222. __all__ = ["OlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]