modeling_olmo2.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/olmo2/modular_olmo2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_olmo2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. from collections.abc import Callable
  26. from typing import Optional
  27. import torch
  28. import torch.nn as nn
  29. from transformers.utils.generic import TransformersKwargs
  30. from ...activations import ACT2FN
  31. from ...cache_utils import Cache, DynamicCache
  32. from ...generation import GenerationMixin
  33. from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
  34. from ...masking_utils import create_causal_mask
  35. from ...modeling_layers import GradientCheckpointingLayer
  36. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import auto_docstring, can_return_tuple
  41. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  42. from ...utils.output_capturing import capture_outputs
  43. from .configuration_olmo2 import Olmo2Config
  44. @use_kernel_forward_from_hub("RMSNorm")
  45. class Olmo2RMSNorm(nn.Module):
  46. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  47. """
  48. Olmo2RMSNorm is equivalent to T5LayerNorm
  49. """
  50. super().__init__()
  51. self.weight = nn.Parameter(torch.ones(hidden_size))
  52. self.variance_epsilon = eps
  53. def forward(self, hidden_states) -> torch.Tensor:
  54. input_dtype = hidden_states.dtype
  55. hidden_states = hidden_states.to(torch.float32)
  56. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  57. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  58. return (self.weight * hidden_states).to(input_dtype)
  59. def extra_repr(self):
  60. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  61. class Olmo2RotaryEmbedding(nn.Module):
  62. inv_freq: torch.Tensor # fix linting for `register_buffer`
  63. def __init__(self, config: Olmo2Config, device=None):
  64. super().__init__()
  65. self.max_seq_len_cached = config.max_position_embeddings
  66. self.original_max_seq_len = config.max_position_embeddings
  67. self.config = config
  68. self.rope_type = self.config.rope_parameters["rope_type"]
  69. rope_init_fn: Callable = self.compute_default_rope_parameters
  70. if self.rope_type != "default":
  71. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  72. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  73. self.register_buffer("inv_freq", inv_freq, persistent=False)
  74. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  75. @staticmethod
  76. def compute_default_rope_parameters(
  77. config: Olmo2Config | None = None,
  78. device: Optional["torch.device"] = None,
  79. seq_len: int | None = None,
  80. ) -> tuple["torch.Tensor", float]:
  81. """
  82. Computes the inverse frequencies according to the original RoPE implementation
  83. Args:
  84. config ([`~transformers.PreTrainedConfig`]):
  85. The model configuration.
  86. device (`torch.device`):
  87. The device to use for initialization of the inverse frequencies.
  88. seq_len (`int`, *optional*):
  89. The current sequence length. Unused for this type of RoPE.
  90. Returns:
  91. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  92. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  93. """
  94. base = config.rope_parameters["rope_theta"]
  95. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  96. attention_factor = 1.0 # Unused in this type of RoPE
  97. # Compute the inverse frequencies
  98. inv_freq = 1.0 / (
  99. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  100. )
  101. return inv_freq, attention_factor
  102. @torch.no_grad()
  103. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  104. def forward(self, x, position_ids):
  105. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  106. position_ids_expanded = position_ids[:, None, :].float()
  107. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  108. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  109. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  110. emb = torch.cat((freqs, freqs), dim=-1)
  111. cos = emb.cos() * self.attention_scaling
  112. sin = emb.sin() * self.attention_scaling
  113. return cos, sin
  114. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  115. """
  116. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  117. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  118. """
  119. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  120. if n_rep == 1:
  121. return hidden_states
  122. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  123. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  124. def eager_attention_forward(
  125. module: nn.Module,
  126. query: torch.Tensor,
  127. key: torch.Tensor,
  128. value: torch.Tensor,
  129. attention_mask: torch.Tensor | None,
  130. scaling: float,
  131. dropout: float = 0.0,
  132. **kwargs: Unpack[TransformersKwargs],
  133. ):
  134. key_states = repeat_kv(key, module.num_key_value_groups)
  135. value_states = repeat_kv(value, module.num_key_value_groups)
  136. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  137. if attention_mask is not None:
  138. attn_weights = attn_weights + attention_mask
  139. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  140. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  141. attn_output = torch.matmul(attn_weights, value_states)
  142. attn_output = attn_output.transpose(1, 2).contiguous()
  143. return attn_output, attn_weights
  144. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  145. """Applies Rotary Position Embedding to the query and key tensors.
  146. Args:
  147. q (`torch.Tensor`): The query tensor.
  148. k (`torch.Tensor`): The key tensor.
  149. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  150. sin (`torch.Tensor`): The sine part of the rotary embedding.
  151. unsqueeze_dim (`int`, *optional*, defaults to 1):
  152. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  153. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  154. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  155. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  156. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  157. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  158. Returns:
  159. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  160. """
  161. q_type, k_type = q.dtype, k.dtype
  162. cos = cos.unsqueeze(unsqueeze_dim)
  163. sin = sin.unsqueeze(unsqueeze_dim)
  164. q_embed = (q * cos) + (rotate_half(q) * sin)
  165. k_embed = (k * cos) + (rotate_half(k) * sin)
  166. return q_embed.to(q_type), k_embed.to(k_type)
  167. def rotate_half(x):
  168. """Rotates half the hidden dims of the input."""
  169. x1 = x[..., : x.shape[-1] // 2]
  170. x2 = x[..., x.shape[-1] // 2 :]
  171. return torch.cat((-x2, x1), dim=-1)
  172. @use_kernelized_func(apply_rotary_pos_emb)
  173. class Olmo2Attention(nn.Module):
  174. """Multi-headed attention from 'Attention Is All You Need' paper"""
  175. def __init__(self, config: Olmo2Config, layer_idx: int | None = None):
  176. super().__init__()
  177. self.config = config
  178. self.layer_idx = layer_idx
  179. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  180. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  181. self.scaling = self.head_dim**-0.5
  182. self.attention_dropout = config.attention_dropout
  183. self.is_causal = True
  184. self.q_proj = nn.Linear(
  185. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  186. )
  187. self.k_proj = nn.Linear(
  188. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  189. )
  190. self.v_proj = nn.Linear(
  191. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  192. )
  193. self.o_proj = nn.Linear(
  194. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  195. )
  196. self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
  197. self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
  198. def forward(
  199. self,
  200. hidden_states: torch.Tensor,
  201. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  202. attention_mask: torch.Tensor | None,
  203. past_key_values: Cache | None = None,
  204. **kwargs: Unpack[TransformersKwargs],
  205. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  206. input_shape = hidden_states.shape[:-1]
  207. hidden_shape = (*input_shape, -1, self.head_dim)
  208. query_states = self.q_norm(self.q_proj(hidden_states))
  209. key_states = self.k_norm(self.k_proj(hidden_states))
  210. value_states = self.v_proj(hidden_states)
  211. query_states = query_states.view(hidden_shape).transpose(1, 2)
  212. key_states = key_states.view(hidden_shape).transpose(1, 2)
  213. value_states = value_states.view(hidden_shape).transpose(1, 2)
  214. cos, sin = position_embeddings
  215. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  216. if past_key_values is not None:
  217. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  218. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  219. self.config._attn_implementation, eager_attention_forward
  220. )
  221. attn_output, attn_weights = attention_interface(
  222. self,
  223. query_states,
  224. key_states,
  225. value_states,
  226. attention_mask,
  227. dropout=0.0 if not self.training else self.attention_dropout,
  228. scaling=self.scaling,
  229. **kwargs,
  230. )
  231. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  232. attn_output = self.o_proj(attn_output)
  233. return attn_output, attn_weights
  234. class Olmo2MLP(nn.Module):
  235. def __init__(self, config):
  236. super().__init__()
  237. self.config = config
  238. self.hidden_size = config.hidden_size
  239. self.intermediate_size = config.intermediate_size
  240. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  241. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  242. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  243. self.act_fn = ACT2FN[config.hidden_act]
  244. def forward(self, x):
  245. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  246. return down_proj
  247. class Olmo2DecoderLayer(GradientCheckpointingLayer):
  248. def __init__(self, config: Olmo2Config, layer_idx: int):
  249. super().__init__()
  250. self.hidden_size = config.hidden_size
  251. self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx)
  252. self.mlp = Olmo2MLP(config)
  253. self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  254. self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  255. def forward(
  256. self,
  257. hidden_states: torch.Tensor,
  258. attention_mask: torch.Tensor | None = None,
  259. position_ids: torch.LongTensor | None = None,
  260. past_key_values: Cache | None = None,
  261. use_cache: bool | None = False,
  262. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  263. **kwargs: Unpack[TransformersKwargs],
  264. ) -> torch.Tensor:
  265. residual = hidden_states
  266. hidden_states, _ = self.self_attn(
  267. hidden_states=hidden_states,
  268. attention_mask=attention_mask,
  269. position_ids=position_ids,
  270. past_key_values=past_key_values,
  271. use_cache=use_cache,
  272. position_embeddings=position_embeddings,
  273. **kwargs,
  274. )
  275. hidden_states = self.post_attention_layernorm(hidden_states)
  276. hidden_states = residual + hidden_states
  277. # Fully Connected
  278. residual = hidden_states
  279. hidden_states = self.mlp(hidden_states)
  280. hidden_states = self.post_feedforward_layernorm(hidden_states)
  281. hidden_states = residual + hidden_states
  282. return hidden_states
  283. @auto_docstring
  284. class Olmo2PreTrainedModel(PreTrainedModel):
  285. config: Olmo2Config
  286. base_model_prefix = "model"
  287. supports_gradient_checkpointing = True
  288. _no_split_modules = ["Olmo2DecoderLayer"]
  289. _skip_keys_device_placement = ["past_key_values"]
  290. _supports_flash_attn = True
  291. _supports_sdpa = True
  292. _supports_flex_attn = True
  293. _can_compile_fullgraph = True
  294. _supports_attention_backend = True
  295. _can_record_outputs = {
  296. "hidden_states": Olmo2DecoderLayer,
  297. "attentions": Olmo2Attention,
  298. }
  299. @auto_docstring
  300. class Olmo2Model(Olmo2PreTrainedModel):
  301. def __init__(self, config: Olmo2Config):
  302. super().__init__(config)
  303. self.padding_idx = config.pad_token_id
  304. self.vocab_size = config.vocab_size
  305. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  306. self.layers = nn.ModuleList(
  307. [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  308. )
  309. self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  310. self.rotary_emb = Olmo2RotaryEmbedding(config=config)
  311. self.gradient_checkpointing = False
  312. # Initialize weights and apply final processing
  313. self.post_init()
  314. @merge_with_config_defaults
  315. @capture_outputs
  316. @auto_docstring
  317. def forward(
  318. self,
  319. input_ids: torch.LongTensor | None = None,
  320. attention_mask: torch.Tensor | None = None,
  321. position_ids: torch.LongTensor | None = None,
  322. past_key_values: Cache | None = None,
  323. inputs_embeds: torch.FloatTensor | None = None,
  324. use_cache: bool | None = None,
  325. **kwargs: Unpack[TransformersKwargs],
  326. ) -> BaseModelOutputWithPast:
  327. if (input_ids is None) ^ (inputs_embeds is not None):
  328. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  329. if inputs_embeds is None:
  330. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  331. if use_cache and past_key_values is None:
  332. past_key_values = DynamicCache(config=self.config)
  333. if position_ids is None:
  334. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  335. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  336. position_ids = position_ids.unsqueeze(0)
  337. causal_mask = create_causal_mask(
  338. config=self.config,
  339. inputs_embeds=inputs_embeds,
  340. attention_mask=attention_mask,
  341. past_key_values=past_key_values,
  342. position_ids=position_ids,
  343. )
  344. hidden_states = inputs_embeds
  345. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  346. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  347. hidden_states = decoder_layer(
  348. hidden_states,
  349. attention_mask=causal_mask,
  350. position_embeddings=position_embeddings,
  351. position_ids=position_ids,
  352. past_key_values=past_key_values,
  353. use_cache=use_cache,
  354. **kwargs,
  355. )
  356. hidden_states = self.norm(hidden_states)
  357. return BaseModelOutputWithPast(
  358. last_hidden_state=hidden_states,
  359. past_key_values=past_key_values,
  360. )
  361. @auto_docstring
  362. class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
  363. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  364. _tp_plan = {"lm_head": "colwise_gather_output"}
  365. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  366. def __init__(self, config):
  367. super().__init__(config)
  368. self.model = Olmo2Model(config)
  369. self.vocab_size = config.vocab_size
  370. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  371. # Initialize weights and apply final processing
  372. self.post_init()
  373. @can_return_tuple
  374. @auto_docstring
  375. def forward(
  376. self,
  377. input_ids: torch.LongTensor | None = None,
  378. attention_mask: torch.Tensor | None = None,
  379. position_ids: torch.LongTensor | None = None,
  380. past_key_values: Cache | None = None,
  381. inputs_embeds: torch.FloatTensor | None = None,
  382. labels: torch.LongTensor | None = None,
  383. use_cache: bool | None = None,
  384. logits_to_keep: int | torch.Tensor = 0,
  385. **kwargs: Unpack[TransformersKwargs],
  386. ) -> CausalLMOutputWithPast:
  387. r"""
  388. Example:
  389. ```python
  390. >>> from transformers import AutoTokenizer, Olmo2ForCausalLM
  391. >>> model = Olmo2ForCausalLM.from_pretrained("meta-olmo2/Olmo2-2-7b-hf")
  392. >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo2/Olmo2-2-7b-hf")
  393. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  394. >>> inputs = tokenizer(prompt, return_tensors="pt")
  395. >>> # Generate
  396. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  397. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  398. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  399. ```"""
  400. outputs: BaseModelOutputWithPast = self.model(
  401. input_ids=input_ids,
  402. attention_mask=attention_mask,
  403. position_ids=position_ids,
  404. past_key_values=past_key_values,
  405. inputs_embeds=inputs_embeds,
  406. use_cache=use_cache,
  407. **kwargs,
  408. )
  409. hidden_states = outputs.last_hidden_state
  410. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  411. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  412. logits = self.lm_head(hidden_states[:, slice_indices, :])
  413. loss = None
  414. if labels is not None:
  415. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  416. return CausalLMOutputWithPast(
  417. loss=loss,
  418. logits=logits,
  419. past_key_values=outputs.past_key_values,
  420. hidden_states=outputs.hidden_states,
  421. attentions=outputs.attentions,
  422. )
  423. __all__ = ["Olmo2ForCausalLM", "Olmo2Model", "Olmo2PreTrainedModel"]