| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230 |
- # Copyright 2024 HuggingFace Inc. team. All rights reserved.
- #
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
- # and OPT implementations in this library. It has been modified from its
- # original forms to accommodate minor architectural differences compared
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Callable
- import torch
- import torch.nn as nn
- from huggingface_hub.dataclasses import strict
- from transformers.utils.generic import TransformersKwargs
- from ...cache_utils import Cache
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
- from ...processing_utils import Unpack
- from ...utils import auto_docstring, logging
- from ..llama.modeling_llama import LlamaPreTrainedModel, LlamaRMSNorm, eager_attention_forward
- from ..olmo.configuration_olmo import OlmoConfig
- from ..olmo.modeling_olmo import (
- OlmoAttention,
- OlmoDecoderLayer,
- OlmoForCausalLM,
- OlmoModel,
- OlmoRotaryEmbedding,
- apply_rotary_pos_emb,
- )
- logger = logging.get_logger(__name__)
- @auto_docstring(checkpoint="allenai/Olmo2-7B-1124-hf")
- @strict
- class Olmo2Config(OlmoConfig):
- r"""
- Example:
- ```python
- >>> from transformers import Olmo2Model, Olmo2Config
- >>> # Initializing a Olmo2 7B style configuration
- >>> configuration = Olmo2Config()
- >>> # Initializing a model from the Olmo2 7B style configuration
- >>> model = Olmo2Model(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```
- """
- model_type = "olmo2"
- base_model_tp_plan = {
- "layers.*.self_attn.q_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k
- "layers.*.self_attn.k_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k
- "layers.*.self_attn.v_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k
- "layers.*.self_attn.o_proj": "rowwise_split_input", # input is replicated due to the added norm on q and k
- "layers.*.mlp.gate_proj": "colwise",
- "layers.*.mlp.up_proj": "colwise",
- "layers.*.mlp.down_proj": "rowwise",
- }
- base_model_pp_plan = {
- "embed_tokens": (["input_ids"], ["inputs_embeds"]),
- "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
- "norm": (["hidden_states"], ["hidden_states"]),
- }
- rms_norm_eps: float = 1e-5
- clip_qkv = AttributeError()
- # OLMo2 RMS norm is identical to Llama RMS norm except:
- # - Weight and hidden states are multiplied before converting back to the input dtype, rather than after.
- class Olmo2RMSNorm(LlamaRMSNorm):
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return (self.weight * hidden_states).to(input_dtype)
- class Olmo2RotaryEmbedding(OlmoRotaryEmbedding):
- pass
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- # Olmo2 attention is identical to OLMo attention except:
- # - Norm is applied to attention queries and keys.
- # - No qkv clipping.
- class Olmo2Attention(OlmoAttention):
- def __init__(self, config: Olmo2Config, layer_idx: int | None = None):
- super().__init__(config, layer_idx=layer_idx)
- self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
- self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: torch.Tensor | None,
- past_key_values: Cache | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_states = self.q_norm(self.q_proj(hidden_states))
- key_states = self.k_norm(self.k_proj(hidden_states))
- value_states = self.v_proj(hidden_states)
- query_states = query_states.view(hidden_shape).transpose(1, 2)
- key_states = key_states.view(hidden_shape).transpose(1, 2)
- value_states = value_states.view(hidden_shape).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- # The OLMo2 layers are identical to those of the OLMo model except:
- # - RMSNorm is used instead of standard layer norm.
- # - Norm is applied after attention/feedforward rather than before.
- class Olmo2DecoderLayer(OlmoDecoderLayer):
- def __init__(self, config: Olmo2Config, layer_idx: int):
- super().__init__(config, layer_idx=layer_idx)
- self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx)
- del self.input_layernorm
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool | None = False,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = residual + hidden_states
- # Fully Connected
- residual = hidden_states
- hidden_states = self.mlp(hidden_states)
- hidden_states = self.post_feedforward_layernorm(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- class Olmo2PreTrainedModel(LlamaPreTrainedModel):
- pass
- # The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of
- # standard layer norm for the output norm.
- class Olmo2Model(OlmoModel):
- def __init__(self, config: Olmo2Config):
- super().__init__(config)
- self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.layers = nn.ModuleList(
- [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- # The heads now only need to redefine the model inside to the correct `RobertaModel`
- class Olmo2ForCausalLM(OlmoForCausalLM):
- pass
- __all__ = [
- "Olmo2Config",
- "Olmo2ForCausalLM",
- "Olmo2Model",
- "Olmo2PreTrainedModel",
- ]
|