modular_olmo2.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  2. #
  3. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  4. # and OPT implementations in this library. It has been modified from its
  5. # original forms to accommodate minor architectural differences compared
  6. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. from collections.abc import Callable
  20. import torch
  21. import torch.nn as nn
  22. from huggingface_hub.dataclasses import strict
  23. from transformers.utils.generic import TransformersKwargs
  24. from ...cache_utils import Cache
  25. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  26. from ...processing_utils import Unpack
  27. from ...utils import auto_docstring, logging
  28. from ..llama.modeling_llama import LlamaPreTrainedModel, LlamaRMSNorm, eager_attention_forward
  29. from ..olmo.configuration_olmo import OlmoConfig
  30. from ..olmo.modeling_olmo import (
  31. OlmoAttention,
  32. OlmoDecoderLayer,
  33. OlmoForCausalLM,
  34. OlmoModel,
  35. OlmoRotaryEmbedding,
  36. apply_rotary_pos_emb,
  37. )
  38. logger = logging.get_logger(__name__)
  39. @auto_docstring(checkpoint="allenai/Olmo2-7B-1124-hf")
  40. @strict
  41. class Olmo2Config(OlmoConfig):
  42. r"""
  43. Example:
  44. ```python
  45. >>> from transformers import Olmo2Model, Olmo2Config
  46. >>> # Initializing a Olmo2 7B style configuration
  47. >>> configuration = Olmo2Config()
  48. >>> # Initializing a model from the Olmo2 7B style configuration
  49. >>> model = Olmo2Model(configuration)
  50. >>> # Accessing the model configuration
  51. >>> configuration = model.config
  52. ```
  53. """
  54. model_type = "olmo2"
  55. base_model_tp_plan = {
  56. "layers.*.self_attn.q_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k
  57. "layers.*.self_attn.k_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k
  58. "layers.*.self_attn.v_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k
  59. "layers.*.self_attn.o_proj": "rowwise_split_input", # input is replicated due to the added norm on q and k
  60. "layers.*.mlp.gate_proj": "colwise",
  61. "layers.*.mlp.up_proj": "colwise",
  62. "layers.*.mlp.down_proj": "rowwise",
  63. }
  64. base_model_pp_plan = {
  65. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  66. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  67. "norm": (["hidden_states"], ["hidden_states"]),
  68. }
  69. rms_norm_eps: float = 1e-5
  70. clip_qkv = AttributeError()
  71. # OLMo2 RMS norm is identical to Llama RMS norm except:
  72. # - Weight and hidden states are multiplied before converting back to the input dtype, rather than after.
  73. class Olmo2RMSNorm(LlamaRMSNorm):
  74. def forward(self, hidden_states):
  75. input_dtype = hidden_states.dtype
  76. hidden_states = hidden_states.to(torch.float32)
  77. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  78. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  79. return (self.weight * hidden_states).to(input_dtype)
  80. class Olmo2RotaryEmbedding(OlmoRotaryEmbedding):
  81. pass
  82. def rotate_half(x):
  83. """Rotates half the hidden dims of the input."""
  84. x1 = x[..., : x.shape[-1] // 2]
  85. x2 = x[..., x.shape[-1] // 2 :]
  86. return torch.cat((-x2, x1), dim=-1)
  87. # Olmo2 attention is identical to OLMo attention except:
  88. # - Norm is applied to attention queries and keys.
  89. # - No qkv clipping.
  90. class Olmo2Attention(OlmoAttention):
  91. def __init__(self, config: Olmo2Config, layer_idx: int | None = None):
  92. super().__init__(config, layer_idx=layer_idx)
  93. self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
  94. self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
  95. def forward(
  96. self,
  97. hidden_states: torch.Tensor,
  98. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  99. attention_mask: torch.Tensor | None,
  100. past_key_values: Cache | None = None,
  101. **kwargs: Unpack[TransformersKwargs],
  102. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  103. input_shape = hidden_states.shape[:-1]
  104. hidden_shape = (*input_shape, -1, self.head_dim)
  105. query_states = self.q_norm(self.q_proj(hidden_states))
  106. key_states = self.k_norm(self.k_proj(hidden_states))
  107. value_states = self.v_proj(hidden_states)
  108. query_states = query_states.view(hidden_shape).transpose(1, 2)
  109. key_states = key_states.view(hidden_shape).transpose(1, 2)
  110. value_states = value_states.view(hidden_shape).transpose(1, 2)
  111. cos, sin = position_embeddings
  112. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  113. if past_key_values is not None:
  114. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  115. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  116. self.config._attn_implementation, eager_attention_forward
  117. )
  118. attn_output, attn_weights = attention_interface(
  119. self,
  120. query_states,
  121. key_states,
  122. value_states,
  123. attention_mask,
  124. dropout=0.0 if not self.training else self.attention_dropout,
  125. scaling=self.scaling,
  126. **kwargs,
  127. )
  128. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  129. attn_output = self.o_proj(attn_output)
  130. return attn_output, attn_weights
  131. # The OLMo2 layers are identical to those of the OLMo model except:
  132. # - RMSNorm is used instead of standard layer norm.
  133. # - Norm is applied after attention/feedforward rather than before.
  134. class Olmo2DecoderLayer(OlmoDecoderLayer):
  135. def __init__(self, config: Olmo2Config, layer_idx: int):
  136. super().__init__(config, layer_idx=layer_idx)
  137. self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  138. self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  139. self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx)
  140. del self.input_layernorm
  141. def forward(
  142. self,
  143. hidden_states: torch.Tensor,
  144. attention_mask: torch.Tensor | None = None,
  145. position_ids: torch.LongTensor | None = None,
  146. past_key_values: Cache | None = None,
  147. use_cache: bool | None = False,
  148. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  149. **kwargs: Unpack[TransformersKwargs],
  150. ) -> torch.Tensor:
  151. residual = hidden_states
  152. hidden_states, _ = self.self_attn(
  153. hidden_states=hidden_states,
  154. attention_mask=attention_mask,
  155. position_ids=position_ids,
  156. past_key_values=past_key_values,
  157. use_cache=use_cache,
  158. position_embeddings=position_embeddings,
  159. **kwargs,
  160. )
  161. hidden_states = self.post_attention_layernorm(hidden_states)
  162. hidden_states = residual + hidden_states
  163. # Fully Connected
  164. residual = hidden_states
  165. hidden_states = self.mlp(hidden_states)
  166. hidden_states = self.post_feedforward_layernorm(hidden_states)
  167. hidden_states = residual + hidden_states
  168. return hidden_states
  169. class Olmo2PreTrainedModel(LlamaPreTrainedModel):
  170. pass
  171. # The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of
  172. # standard layer norm for the output norm.
  173. class Olmo2Model(OlmoModel):
  174. def __init__(self, config: Olmo2Config):
  175. super().__init__(config)
  176. self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  177. self.layers = nn.ModuleList(
  178. [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  179. )
  180. # The heads now only need to redefine the model inside to the correct `RobertaModel`
  181. class Olmo2ForCausalLM(OlmoForCausalLM):
  182. pass
  183. __all__ = [
  184. "Olmo2Config",
  185. "Olmo2ForCausalLM",
  186. "Olmo2Model",
  187. "Olmo2PreTrainedModel",
  188. ]