| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279 |
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch OLMoE model."""
- from collections.abc import Callable
- import torch
- from torch import nn
- from ... import initialization as init
- from ...cache_utils import Cache, DynamicCache
- from ...generation import GenerationMixin
- from ...masking_utils import create_causal_mask
- from ...modeling_outputs import MoeModelOutputWithPast
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, logging
- from ...utils.output_capturing import OutputRecorder
- from ..gemma.modeling_gemma import GemmaMLP
- from ..llama.modeling_llama import (
- LlamaAttention,
- LlamaDecoderLayer,
- LlamaRMSNorm,
- LlamaRotaryEmbedding,
- apply_rotary_pos_emb,
- eager_attention_forward,
- )
- from ..mixtral.modeling_mixtral import MixtralExperts, MixtralForCausalLM, MixtralModel
- from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeTopKRouter
- from .configuration_olmoe import OlmoeConfig
- logger = logging.get_logger(__name__)
- class OlmoeRMSNorm(LlamaRMSNorm):
- def __init__(self, hidden_size, eps=1e-5):
- super().__init__(hidden_size, eps)
- class OlmoeRotaryEmbedding(LlamaRotaryEmbedding):
- pass
- class OlmoeMLP(GemmaMLP):
- pass
- class OlmoeAttention(LlamaAttention):
- def __init__(self, config: OlmoeConfig, layer_idx: int | None = None):
- super().__init__(config, layer_idx)
- self.q_norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.k_norm = OlmoeRMSNorm(
- (config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: torch.Tensor | None,
- past_key_values: Cache | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_states = self.q_norm(self.q_proj(hidden_states))
- key_states = self.k_norm(self.k_proj(hidden_states))
- value_states = self.v_proj(hidden_states)
- if self.config.clip_qkv is not None: # Diff with llama
- query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
- key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
- value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
- query_states = query_states.view(*hidden_shape).transpose(1, 2)
- key_states = key_states.view(*hidden_shape).transpose(1, 2)
- value_states = value_states.view(*hidden_shape).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class OlmoeExperts(MixtralExperts):
- pass
- class OlmoeTopKRouter(Qwen2MoeTopKRouter):
- pass
- class OlmoeSparseMoeBlock(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.gate = OlmoeTopKRouter(config)
- self.experts = OlmoeExperts(config)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- hidden_states = hidden_states.view(-1, hidden_dim)
- _, top_k_weights, top_k_index = self.gate(hidden_states)
- final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape(
- batch_size, sequence_length, hidden_dim
- )
- return final_hidden_states
- class OlmoeDecoderLayer(LlamaDecoderLayer):
- def __init__(self, config: OlmoeConfig, layer_idx: int):
- super().__init__(config, layer_idx)
- self.hidden_size = config.hidden_size
- self.self_attn = OlmoeAttention(config=config, layer_idx=layer_idx)
- self.mlp = OlmoeSparseMoeBlock(config)
- self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- @auto_docstring
- class OlmoePreTrainedModel(PreTrainedModel):
- config: OlmoeConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["OlmoeDecoderLayer"]
- _skip_keys_device_placement = ["past_key_values"]
- _supports_flash_attn = True
- _supports_sdpa = True
- _can_record_outputs = {
- "router_logits": OutputRecorder(OlmoeTopKRouter, index=0),
- "hidden_states": OlmoeDecoderLayer,
- "attentions": OlmoeAttention,
- }
- _supports_attention_backend = True
- @torch.no_grad()
- def _init_weights(self, module):
- PreTrainedModel._init_weights(self, module)
- if isinstance(module, OlmoeExperts):
- init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
- init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
- elif isinstance(module, OlmoeTopKRouter):
- init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
- @auto_docstring
- class OlmoeModel(MixtralModel):
- def __init__(self, config: OlmoeConfig):
- super().__init__(config)
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = nn.ModuleList(
- [OlmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = OlmoeRotaryEmbedding(config=config)
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> MoeModelOutputWithPast:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- if position_ids is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
- position_ids = position_ids.unsqueeze(0)
- causal_mask = create_causal_mask( # diff with mixtral: no sliding
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- position_ids=position_ids,
- )
- hidden_states = inputs_embeds
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
- hidden_states = decoder_layer(
- hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = self.norm(hidden_states)
- return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- )
- class OlmoeForCausalLM(MixtralForCausalLM, GenerationMixin):
- _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
- def __init__(self, config):
- super().__init__(config)
- self.model = OlmoeModel(config)
- self.num_experts = config.num_experts
- def forward(self, **super_kwargs):
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, OlmoeForCausalLM
- >>> model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
- >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
- ```
- """
- return super().forward(**super_kwargs)
- __all__ = ["OlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]
|