modeling_olmo.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/olmo/modular_olmo.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_olmo.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. from collections.abc import Callable
  26. from typing import Optional
  27. import torch
  28. import torch.nn as nn
  29. import torch.nn.functional as F
  30. from ...activations import ACT2FN
  31. from ...cache_utils import Cache, DynamicCache
  32. from ...generation import GenerationMixin
  33. from ...integrations import use_kernelized_func
  34. from ...masking_utils import create_causal_mask
  35. from ...modeling_layers import GradientCheckpointingLayer
  36. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  41. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  42. from ...utils.output_capturing import capture_outputs
  43. from .configuration_olmo import OlmoConfig
  44. class OlmoLayerNorm(nn.Module):
  45. """LayerNorm but with no learnable weight or bias."""
  46. def __init__(self, hidden_size: int) -> None:
  47. super().__init__()
  48. self.normalized_shape = (hidden_size,)
  49. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  50. orig_dtype = hidden_states.dtype
  51. return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
  52. orig_dtype
  53. )
  54. class OlmoMLP(nn.Module):
  55. def __init__(self, config):
  56. super().__init__()
  57. self.config = config
  58. self.hidden_size = config.hidden_size
  59. self.intermediate_size = config.intermediate_size
  60. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  61. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  62. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  63. self.act_fn = ACT2FN[config.hidden_act]
  64. def forward(self, x):
  65. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  66. return down_proj
  67. class OlmoRotaryEmbedding(nn.Module):
  68. inv_freq: torch.Tensor # fix linting for `register_buffer`
  69. def __init__(self, config: OlmoConfig, device=None):
  70. super().__init__()
  71. self.max_seq_len_cached = config.max_position_embeddings
  72. self.original_max_seq_len = config.max_position_embeddings
  73. self.config = config
  74. self.rope_type = self.config.rope_parameters["rope_type"]
  75. rope_init_fn: Callable = self.compute_default_rope_parameters
  76. if self.rope_type != "default":
  77. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  78. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  79. self.register_buffer("inv_freq", inv_freq, persistent=False)
  80. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  81. @staticmethod
  82. def compute_default_rope_parameters(
  83. config: OlmoConfig | None = None,
  84. device: Optional["torch.device"] = None,
  85. seq_len: int | None = None,
  86. ) -> tuple["torch.Tensor", float]:
  87. """
  88. Computes the inverse frequencies according to the original RoPE implementation
  89. Args:
  90. config ([`~transformers.PreTrainedConfig`]):
  91. The model configuration.
  92. device (`torch.device`):
  93. The device to use for initialization of the inverse frequencies.
  94. seq_len (`int`, *optional*):
  95. The current sequence length. Unused for this type of RoPE.
  96. Returns:
  97. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  98. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  99. """
  100. base = config.rope_parameters["rope_theta"]
  101. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  102. attention_factor = 1.0 # Unused in this type of RoPE
  103. # Compute the inverse frequencies
  104. inv_freq = 1.0 / (
  105. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  106. )
  107. return inv_freq, attention_factor
  108. @torch.no_grad()
  109. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  110. def forward(self, x, position_ids):
  111. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  112. position_ids_expanded = position_ids[:, None, :].float()
  113. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  114. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  115. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  116. emb = torch.cat((freqs, freqs), dim=-1)
  117. cos = emb.cos() * self.attention_scaling
  118. sin = emb.sin() * self.attention_scaling
  119. return cos, sin
  120. def rotate_half(x):
  121. """Rotates half the hidden dims of the input."""
  122. x1 = x[..., : x.shape[-1] // 2]
  123. x2 = x[..., x.shape[-1] // 2 :]
  124. return torch.cat((-x2, x1), dim=-1)
  125. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  126. """
  127. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  128. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  129. """
  130. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  131. if n_rep == 1:
  132. return hidden_states
  133. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  134. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  135. def eager_attention_forward(
  136. module: nn.Module,
  137. query: torch.Tensor,
  138. key: torch.Tensor,
  139. value: torch.Tensor,
  140. attention_mask: torch.Tensor | None,
  141. scaling: float,
  142. dropout: float = 0.0,
  143. **kwargs: Unpack[TransformersKwargs],
  144. ):
  145. key_states = repeat_kv(key, module.num_key_value_groups)
  146. value_states = repeat_kv(value, module.num_key_value_groups)
  147. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  148. if attention_mask is not None:
  149. attn_weights = attn_weights + attention_mask
  150. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  151. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  152. attn_output = torch.matmul(attn_weights, value_states)
  153. attn_output = attn_output.transpose(1, 2).contiguous()
  154. return attn_output, attn_weights
  155. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  156. """Applies Rotary Position Embedding to the query and key tensors.
  157. Args:
  158. q (`torch.Tensor`): The query tensor.
  159. k (`torch.Tensor`): The key tensor.
  160. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  161. sin (`torch.Tensor`): The sine part of the rotary embedding.
  162. unsqueeze_dim (`int`, *optional*, defaults to 1):
  163. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  164. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  165. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  166. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  167. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  168. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  169. Returns:
  170. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  171. """
  172. q_type, k_type = q.dtype, k.dtype
  173. cos = cos.unsqueeze(unsqueeze_dim)
  174. sin = sin.unsqueeze(unsqueeze_dim)
  175. q_embed = (q * cos) + (rotate_half(q) * sin)
  176. k_embed = (k * cos) + (rotate_half(k) * sin)
  177. return q_embed.to(q_type), k_embed.to(k_type)
  178. @use_kernelized_func(apply_rotary_pos_emb)
  179. class OlmoAttention(nn.Module):
  180. """Multi-headed attention from 'Attention Is All You Need' paper"""
  181. def __init__(self, config: OlmoConfig, layer_idx: int):
  182. super().__init__()
  183. self.config = config
  184. self.layer_idx = layer_idx
  185. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  186. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  187. self.scaling = self.head_dim**-0.5
  188. self.attention_dropout = config.attention_dropout
  189. self.is_causal = True
  190. self.q_proj = nn.Linear(
  191. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  192. )
  193. self.k_proj = nn.Linear(
  194. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  195. )
  196. self.v_proj = nn.Linear(
  197. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  198. )
  199. self.o_proj = nn.Linear(
  200. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  201. )
  202. def forward(
  203. self,
  204. hidden_states: torch.Tensor,
  205. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  206. attention_mask: torch.Tensor | None,
  207. past_key_values: Cache | None = None,
  208. **kwargs,
  209. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  210. input_shape = hidden_states.shape[:-1]
  211. hidden_shape = (*input_shape, -1, self.head_dim)
  212. query_states = self.q_proj(hidden_states)
  213. key_states = self.k_proj(hidden_states)
  214. value_states = self.v_proj(hidden_states)
  215. if self.config.clip_qkv is not None:
  216. query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  217. key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  218. value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  219. query_states = query_states.view(hidden_shape).transpose(1, 2)
  220. key_states = key_states.view(hidden_shape).transpose(1, 2)
  221. value_states = value_states.view(hidden_shape).transpose(1, 2)
  222. cos, sin = position_embeddings
  223. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  224. if past_key_values is not None:
  225. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  226. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  227. self.config._attn_implementation, eager_attention_forward
  228. )
  229. attn_output, attn_weights = attention_interface(
  230. self,
  231. query_states,
  232. key_states,
  233. value_states,
  234. attention_mask,
  235. dropout=0.0 if not self.training else self.attention_dropout,
  236. scaling=self.scaling,
  237. **kwargs,
  238. )
  239. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  240. attn_output = self.o_proj(attn_output)
  241. return attn_output, attn_weights
  242. class OlmoDecoderLayer(GradientCheckpointingLayer):
  243. def __init__(self, config: OlmoConfig, layer_idx: int):
  244. super().__init__()
  245. self.hidden_size = config.hidden_size
  246. self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
  247. self.mlp = OlmoMLP(config)
  248. self.input_layernorm = OlmoLayerNorm(config.hidden_size)
  249. self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
  250. def forward(
  251. self,
  252. hidden_states: torch.Tensor,
  253. attention_mask: torch.Tensor | None = None,
  254. position_ids: torch.LongTensor | None = None,
  255. past_key_values: Cache | None = None,
  256. use_cache: bool | None = False,
  257. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  258. **kwargs: Unpack[TransformersKwargs],
  259. ) -> torch.Tensor:
  260. residual = hidden_states
  261. hidden_states = self.input_layernorm(hidden_states)
  262. # Self Attention
  263. hidden_states, _ = self.self_attn(
  264. hidden_states=hidden_states,
  265. attention_mask=attention_mask,
  266. position_ids=position_ids,
  267. past_key_values=past_key_values,
  268. use_cache=use_cache,
  269. position_embeddings=position_embeddings,
  270. **kwargs,
  271. )
  272. hidden_states = residual + hidden_states
  273. # Fully Connected
  274. residual = hidden_states
  275. hidden_states = self.post_attention_layernorm(hidden_states)
  276. hidden_states = self.mlp(hidden_states)
  277. hidden_states = residual + hidden_states
  278. return hidden_states
  279. @auto_docstring
  280. class OlmoPreTrainedModel(PreTrainedModel):
  281. config: OlmoConfig
  282. base_model_prefix = "model"
  283. supports_gradient_checkpointing = True
  284. _no_split_modules = ["OlmoDecoderLayer"]
  285. _skip_keys_device_placement = ["past_key_values"]
  286. _supports_flash_attn = True
  287. _supports_sdpa = True
  288. _supports_flex_attn = True
  289. _can_compile_fullgraph = True
  290. _supports_attention_backend = True
  291. _can_record_outputs = {
  292. "hidden_states": OlmoDecoderLayer,
  293. "attentions": OlmoAttention,
  294. }
  295. @auto_docstring
  296. class OlmoModel(OlmoPreTrainedModel):
  297. def __init__(self, config: OlmoConfig):
  298. super().__init__(config)
  299. self.padding_idx = config.pad_token_id
  300. self.vocab_size = config.vocab_size
  301. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  302. self.layers = nn.ModuleList(
  303. [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  304. )
  305. self.norm = OlmoLayerNorm(config.hidden_size)
  306. self.rotary_emb = OlmoRotaryEmbedding(config=config)
  307. self.gradient_checkpointing = False
  308. # Initialize weights and apply final processing
  309. self.post_init()
  310. @merge_with_config_defaults
  311. @capture_outputs
  312. @auto_docstring
  313. def forward(
  314. self,
  315. input_ids: torch.LongTensor | None = None,
  316. attention_mask: torch.Tensor | None = None,
  317. position_ids: torch.LongTensor | None = None,
  318. past_key_values: Cache | None = None,
  319. inputs_embeds: torch.FloatTensor | None = None,
  320. use_cache: bool | None = None,
  321. **kwargs: Unpack[TransformersKwargs],
  322. ) -> BaseModelOutputWithPast:
  323. if (input_ids is None) ^ (inputs_embeds is not None):
  324. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  325. if inputs_embeds is None:
  326. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  327. if use_cache and past_key_values is None:
  328. past_key_values = DynamicCache(config=self.config)
  329. if position_ids is None:
  330. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  331. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  332. position_ids = position_ids.unsqueeze(0)
  333. causal_mask = create_causal_mask(
  334. config=self.config,
  335. inputs_embeds=inputs_embeds,
  336. attention_mask=attention_mask,
  337. past_key_values=past_key_values,
  338. position_ids=position_ids,
  339. )
  340. hidden_states = inputs_embeds
  341. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  342. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  343. hidden_states = decoder_layer(
  344. hidden_states,
  345. attention_mask=causal_mask,
  346. position_embeddings=position_embeddings,
  347. position_ids=position_ids,
  348. past_key_values=past_key_values,
  349. use_cache=use_cache,
  350. **kwargs,
  351. )
  352. hidden_states = self.norm(hidden_states)
  353. return BaseModelOutputWithPast(
  354. last_hidden_state=hidden_states,
  355. past_key_values=past_key_values,
  356. )
  357. @auto_docstring
  358. class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
  359. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  360. _tp_plan = {"lm_head": "colwise_gather_output"}
  361. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  362. def __init__(self, config):
  363. super().__init__(config)
  364. self.model = OlmoModel(config)
  365. self.vocab_size = config.vocab_size
  366. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  367. # Initialize weights and apply final processing
  368. self.post_init()
  369. @can_return_tuple
  370. @auto_docstring
  371. def forward(
  372. self,
  373. input_ids: torch.LongTensor | None = None,
  374. attention_mask: torch.Tensor | None = None,
  375. position_ids: torch.LongTensor | None = None,
  376. past_key_values: Cache | None = None,
  377. inputs_embeds: torch.FloatTensor | None = None,
  378. labels: torch.LongTensor | None = None,
  379. use_cache: bool | None = None,
  380. logits_to_keep: int | torch.Tensor = 0,
  381. **kwargs: Unpack[TransformersKwargs],
  382. ) -> CausalLMOutputWithPast:
  383. r"""
  384. Example:
  385. ```python
  386. >>> from transformers import AutoTokenizer, OlmoForCausalLM
  387. >>> model = OlmoForCausalLM.from_pretrained("meta-olmo/Olmo-2-7b-hf")
  388. >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo/Olmo-2-7b-hf")
  389. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  390. >>> inputs = tokenizer(prompt, return_tensors="pt")
  391. >>> # Generate
  392. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  393. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  394. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  395. ```"""
  396. outputs: BaseModelOutputWithPast = self.model(
  397. input_ids=input_ids,
  398. attention_mask=attention_mask,
  399. position_ids=position_ids,
  400. past_key_values=past_key_values,
  401. inputs_embeds=inputs_embeds,
  402. use_cache=use_cache,
  403. **kwargs,
  404. )
  405. hidden_states = outputs.last_hidden_state
  406. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  407. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  408. logits = self.lm_head(hidden_states[:, slice_indices, :])
  409. loss = None
  410. if labels is not None:
  411. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  412. return CausalLMOutputWithPast(
  413. loss=loss,
  414. logits=logits,
  415. past_key_values=outputs.past_key_values,
  416. hidden_states=outputs.hidden_states,
  417. attentions=outputs.attentions,
  418. )
  419. __all__ = ["OlmoForCausalLM", "OlmoModel", "OlmoPreTrainedModel"]