modular_olmo.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  2. #
  3. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  4. # and OPT implementations in this library. It has been modified from its
  5. # original forms to accommodate minor architectural differences compared
  6. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. from collections.abc import Callable
  20. import torch
  21. import torch.nn as nn
  22. import torch.nn.functional as F
  23. from ...cache_utils import Cache
  24. from ...modeling_rope_utils import dynamic_rope_update
  25. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  26. from ...utils import logging
  27. from ...utils.generic import maybe_autocast
  28. from ..llama.modeling_llama import (
  29. LlamaAttention,
  30. LlamaDecoderLayer,
  31. LlamaForCausalLM,
  32. LlamaMLP,
  33. LlamaModel,
  34. LlamaRotaryEmbedding,
  35. eager_attention_forward,
  36. rotate_half,
  37. )
  38. from .configuration_olmo import OlmoConfig
  39. logger = logging.get_logger(__name__)
  40. class OlmoLayerNorm(nn.Module):
  41. """LayerNorm but with no learnable weight or bias."""
  42. def __init__(self, hidden_size: int) -> None:
  43. super().__init__()
  44. self.normalized_shape = (hidden_size,)
  45. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  46. orig_dtype = hidden_states.dtype
  47. return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
  48. orig_dtype
  49. )
  50. class OlmoMLP(LlamaMLP):
  51. def __init__(self, config):
  52. super().__init__(config)
  53. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  54. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  55. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  56. # This is identical to LlamaRotaryEmbedding except the output cos and sin are returned
  57. # as float32 rather than the input type.
  58. class OlmoRotaryEmbedding(LlamaRotaryEmbedding):
  59. @torch.no_grad()
  60. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  61. def forward(self, x, position_ids):
  62. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  63. position_ids_expanded = position_ids[:, None, :].float()
  64. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  65. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  66. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  67. emb = torch.cat((freqs, freqs), dim=-1)
  68. cos = emb.cos() * self.attention_scaling
  69. sin = emb.sin() * self.attention_scaling
  70. return cos, sin
  71. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  72. """Applies Rotary Position Embedding to the query and key tensors.
  73. Args:
  74. q (`torch.Tensor`): The query tensor.
  75. k (`torch.Tensor`): The key tensor.
  76. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  77. sin (`torch.Tensor`): The sine part of the rotary embedding.
  78. unsqueeze_dim (`int`, *optional*, defaults to 1):
  79. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  80. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  81. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  82. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  83. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  84. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  85. Returns:
  86. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  87. """
  88. q_type, k_type = q.dtype, k.dtype
  89. cos = cos.unsqueeze(unsqueeze_dim)
  90. sin = sin.unsqueeze(unsqueeze_dim)
  91. q_embed = (q * cos) + (rotate_half(q) * sin)
  92. k_embed = (k * cos) + (rotate_half(k) * sin)
  93. return q_embed.to(q_type), k_embed.to(k_type)
  94. class OlmoAttention(LlamaAttention):
  95. def forward(
  96. self,
  97. hidden_states: torch.Tensor,
  98. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  99. attention_mask: torch.Tensor | None,
  100. past_key_values: Cache | None = None,
  101. **kwargs,
  102. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  103. input_shape = hidden_states.shape[:-1]
  104. hidden_shape = (*input_shape, -1, self.head_dim)
  105. query_states = self.q_proj(hidden_states)
  106. key_states = self.k_proj(hidden_states)
  107. value_states = self.v_proj(hidden_states)
  108. if self.config.clip_qkv is not None:
  109. query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  110. key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  111. value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  112. query_states = query_states.view(hidden_shape).transpose(1, 2)
  113. key_states = key_states.view(hidden_shape).transpose(1, 2)
  114. value_states = value_states.view(hidden_shape).transpose(1, 2)
  115. cos, sin = position_embeddings
  116. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  117. if past_key_values is not None:
  118. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  119. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  120. self.config._attn_implementation, eager_attention_forward
  121. )
  122. attn_output, attn_weights = attention_interface(
  123. self,
  124. query_states,
  125. key_states,
  126. value_states,
  127. attention_mask,
  128. dropout=0.0 if not self.training else self.attention_dropout,
  129. scaling=self.scaling,
  130. **kwargs,
  131. )
  132. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  133. attn_output = self.o_proj(attn_output)
  134. return attn_output, attn_weights
  135. class OlmoDecoderLayer(LlamaDecoderLayer):
  136. def __init__(self, config: OlmoConfig, layer_idx: int):
  137. super().__init__(config, layer_idx)
  138. self.input_layernorm = OlmoLayerNorm(config.hidden_size)
  139. self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
  140. self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
  141. class OlmoModel(LlamaModel):
  142. def __init__(self, config: OlmoConfig):
  143. super().__init__(config)
  144. self.layers = nn.ModuleList(
  145. [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  146. )
  147. self.norm = OlmoLayerNorm(config.hidden_size)
  148. class OlmoForCausalLM(LlamaForCausalLM):
  149. pass
  150. __all__ = [
  151. "OlmoForCausalLM",
  152. "OlmoModel",
  153. "OlmoPreTrainedModel", # noqa: F822
  154. ]