modular_olmo3.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # Copyright 2025 the HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Callable
  15. import torch
  16. import torch.nn as nn
  17. from huggingface_hub.dataclasses import strict
  18. from ...cache_utils import Cache, DynamicCache
  19. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  20. from ...modeling_outputs import BaseModelOutputWithPast
  21. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  22. from ...processing_utils import Unpack
  23. from ...utils import auto_docstring
  24. from ...utils.generic import TransformersKwargs
  25. from ..gemma2.modeling_gemma2 import Gemma2RotaryEmbedding
  26. from ..olmo2.configuration_olmo2 import Olmo2Config
  27. from ..olmo2.modeling_olmo2 import (
  28. Olmo2Attention,
  29. Olmo2DecoderLayer,
  30. Olmo2ForCausalLM,
  31. Olmo2Model,
  32. Olmo2PreTrainedModel,
  33. Olmo2RMSNorm,
  34. apply_rotary_pos_emb,
  35. eager_attention_forward,
  36. )
  37. @auto_docstring(checkpoint="allenai/Olmo-3-7B-Instruct")
  38. @strict
  39. class Olmo3Config(Olmo2Config):
  40. r"""
  41. Example:
  42. ```python
  43. >>> from transformers import Olmo3Model, Olmo3Config
  44. >>> # Initializing a Olmo3 7B style configuration
  45. >>> configuration = Olmo3Config()
  46. >>> # Initializing a model from the Olmo3 7B style configuration
  47. >>> model = Olmo3Model(configuration)
  48. >>> # Accessing the model configuration
  49. >>> configuration = model.config
  50. ```
  51. """
  52. model_type = "olmo3"
  53. keys_to_ignore_at_inference = ["past_key_values"]
  54. base_model_tp_plan = {
  55. "layers.*.self_attn.q_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k
  56. "layers.*.self_attn.k_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k
  57. "layers.*.self_attn.v_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k
  58. "layers.*.self_attn.o_proj": "rowwise_split_input", # input is replicated due to the added norm on q and k
  59. "layers.*.mlp.gate_proj": "colwise",
  60. "layers.*.mlp.up_proj": "colwise",
  61. "layers.*.mlp.down_proj": "rowwise",
  62. }
  63. base_model_pp_plan = {
  64. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  65. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  66. "norm": (["hidden_states"], ["hidden_states"]),
  67. }
  68. sliding_window: int | None = 4096
  69. layer_types: list[str] | None = None
  70. def __post_init__(self, **kwargs):
  71. if self.num_key_value_heads is None:
  72. self.num_key_value_heads = self.num_attention_heads
  73. if self.layer_types is None:
  74. self.layer_types = [
  75. "sliding_attention" if (i + 1) % 4 != 0 else "full_attention" for i in range(self.num_hidden_layers)
  76. ]
  77. super().__post_init__(**kwargs)
  78. class Olmo3RMSNorm(Olmo2RMSNorm):
  79. pass
  80. # Olmo3 attention is identical to OLMo 2 attention except:
  81. # - Sliding window attention is used for 3 out of 4 layers.
  82. class Olmo3Attention(Olmo2Attention):
  83. def __init__(self, config: Olmo3Config, layer_idx: int):
  84. super().__init__(config, layer_idx=layer_idx)
  85. self.attention_type = config.layer_types[layer_idx]
  86. self.sliding_window = config.sliding_window if self.attention_type == "sliding_attention" else None
  87. def forward(
  88. self,
  89. hidden_states: torch.Tensor,
  90. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  91. attention_mask: torch.Tensor | None,
  92. past_key_values: Cache | None = None,
  93. **kwargs: Unpack[TransformersKwargs],
  94. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  95. input_shape = hidden_states.shape[:-1]
  96. hidden_shape = (*input_shape, -1, self.head_dim)
  97. query_states = self.q_norm(self.q_proj(hidden_states))
  98. key_states = self.k_norm(self.k_proj(hidden_states))
  99. value_states = self.v_proj(hidden_states)
  100. query_states = query_states.view(hidden_shape).transpose(1, 2)
  101. key_states = key_states.view(hidden_shape).transpose(1, 2)
  102. value_states = value_states.view(hidden_shape).transpose(1, 2)
  103. cos, sin = position_embeddings
  104. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  105. if past_key_values is not None:
  106. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  107. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  108. self.config._attn_implementation, eager_attention_forward
  109. )
  110. attn_output, attn_weights = attention_interface(
  111. self,
  112. query_states,
  113. key_states,
  114. value_states,
  115. attention_mask,
  116. dropout=0.0 if not self.training else self.attention_dropout,
  117. scaling=self.scaling,
  118. sliding_window=self.sliding_window,
  119. **kwargs,
  120. )
  121. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  122. attn_output = self.o_proj(attn_output)
  123. return attn_output, attn_weights
  124. class Olmo3DecoderLayer(Olmo2DecoderLayer):
  125. pass
  126. class Olmo3RotaryEmbedding(Gemma2RotaryEmbedding):
  127. pass
  128. class Olmo3PreTrainedModel(Olmo2PreTrainedModel):
  129. pass
  130. # The OLMo 3 model is identical to the OLMo 2 model, except:
  131. # - Sliding window attention is used for 3 out of 4 layers.
  132. # - RoPE scaling is not applied to sliding window attention layers.
  133. class Olmo3Model(Olmo2Model):
  134. def __init__(self, config: Olmo3Config):
  135. super().__init__(config)
  136. self.norm = Olmo3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  137. self.layers = nn.ModuleList(
  138. [Olmo3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  139. )
  140. self.rotary_emb = Olmo3RotaryEmbedding(config=config)
  141. def forward(
  142. self,
  143. input_ids: torch.LongTensor | None = None,
  144. attention_mask: torch.Tensor | None = None,
  145. position_ids: torch.LongTensor | None = None,
  146. past_key_values: Cache | None = None,
  147. inputs_embeds: torch.FloatTensor | None = None,
  148. use_cache: bool | None = None,
  149. **kwargs: Unpack[TransformersKwargs],
  150. ) -> BaseModelOutputWithPast:
  151. if (input_ids is None) ^ (inputs_embeds is not None):
  152. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  153. if inputs_embeds is None:
  154. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  155. if use_cache and past_key_values is None:
  156. past_key_values = DynamicCache(config=self.config)
  157. if position_ids is None:
  158. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  159. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  160. position_ids = position_ids.unsqueeze(0)
  161. # It may already have been prepared by e.g. `generate`
  162. if not isinstance(causal_mask_mapping := attention_mask, dict):
  163. # Prepare mask arguments
  164. mask_kwargs = {
  165. "config": self.config,
  166. "inputs_embeds": inputs_embeds,
  167. "attention_mask": attention_mask,
  168. "past_key_values": past_key_values,
  169. "position_ids": position_ids,
  170. }
  171. # Create the masks
  172. causal_mask_mapping = {
  173. "full_attention": create_causal_mask(**mask_kwargs),
  174. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  175. }
  176. hidden_states = inputs_embeds
  177. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  178. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  179. hidden_states = decoder_layer(
  180. hidden_states,
  181. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  182. position_ids=position_ids,
  183. past_key_values=past_key_values,
  184. position_embeddings=position_embeddings,
  185. **kwargs,
  186. )
  187. hidden_states = self.norm(hidden_states)
  188. return BaseModelOutputWithPast(
  189. last_hidden_state=hidden_states,
  190. past_key_values=past_key_values,
  191. )
  192. class Olmo3ForCausalLM(Olmo2ForCausalLM):
  193. pass
  194. __all__ = [
  195. "Olmo3Config",
  196. "Olmo3ForCausalLM",
  197. "Olmo3Model",
  198. "Olmo3PreTrainedModel",
  199. ]