modeling_olmo3.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/olmo3/modular_olmo3.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_olmo3.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 the HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. import torch.nn as nn
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...generation import GenerationMixin
  27. from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
  28. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  29. from ...modeling_layers import GradientCheckpointingLayer
  30. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  31. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import auto_docstring, can_return_tuple
  35. from ...utils.generic import TransformersKwargs, maybe_autocast, merge_with_config_defaults
  36. from ...utils.output_capturing import capture_outputs
  37. from .configuration_olmo3 import Olmo3Config
  38. @use_kernel_forward_from_hub("RMSNorm")
  39. class Olmo3RMSNorm(nn.Module):
  40. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  41. """
  42. Olmo3RMSNorm is equivalent to T5LayerNorm
  43. """
  44. super().__init__()
  45. self.weight = nn.Parameter(torch.ones(hidden_size))
  46. self.variance_epsilon = eps
  47. def forward(self, hidden_states) -> torch.Tensor:
  48. input_dtype = hidden_states.dtype
  49. hidden_states = hidden_states.to(torch.float32)
  50. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  51. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  52. return (self.weight * hidden_states).to(input_dtype)
  53. def extra_repr(self):
  54. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  55. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  56. """
  57. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  58. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  59. """
  60. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  61. if n_rep == 1:
  62. return hidden_states
  63. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  64. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  65. def eager_attention_forward(
  66. module: nn.Module,
  67. query: torch.Tensor,
  68. key: torch.Tensor,
  69. value: torch.Tensor,
  70. attention_mask: torch.Tensor | None,
  71. scaling: float,
  72. dropout: float = 0.0,
  73. **kwargs: Unpack[TransformersKwargs],
  74. ):
  75. key_states = repeat_kv(key, module.num_key_value_groups)
  76. value_states = repeat_kv(value, module.num_key_value_groups)
  77. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  78. if attention_mask is not None:
  79. attn_weights = attn_weights + attention_mask
  80. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  81. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  82. attn_output = torch.matmul(attn_weights, value_states)
  83. attn_output = attn_output.transpose(1, 2).contiguous()
  84. return attn_output, attn_weights
  85. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  86. """Applies Rotary Position Embedding to the query and key tensors.
  87. Args:
  88. q (`torch.Tensor`): The query tensor.
  89. k (`torch.Tensor`): The key tensor.
  90. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  91. sin (`torch.Tensor`): The sine part of the rotary embedding.
  92. unsqueeze_dim (`int`, *optional*, defaults to 1):
  93. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  94. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  95. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  96. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  97. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  98. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  99. Returns:
  100. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  101. """
  102. q_type, k_type = q.dtype, k.dtype
  103. cos = cos.unsqueeze(unsqueeze_dim)
  104. sin = sin.unsqueeze(unsqueeze_dim)
  105. q_embed = (q * cos) + (rotate_half(q) * sin)
  106. k_embed = (k * cos) + (rotate_half(k) * sin)
  107. return q_embed.to(q_type), k_embed.to(k_type)
  108. def rotate_half(x):
  109. """Rotates half the hidden dims of the input."""
  110. x1 = x[..., : x.shape[-1] // 2]
  111. x2 = x[..., x.shape[-1] // 2 :]
  112. return torch.cat((-x2, x1), dim=-1)
  113. @use_kernelized_func(apply_rotary_pos_emb)
  114. class Olmo3Attention(nn.Module):
  115. """Multi-headed attention from 'Attention Is All You Need' paper"""
  116. def __init__(self, config: Olmo3Config, layer_idx: int):
  117. super().__init__()
  118. self.config = config
  119. self.layer_idx = layer_idx
  120. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  121. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  122. self.scaling = self.head_dim**-0.5
  123. self.attention_dropout = config.attention_dropout
  124. self.is_causal = True
  125. self.q_proj = nn.Linear(
  126. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  127. )
  128. self.k_proj = nn.Linear(
  129. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  130. )
  131. self.v_proj = nn.Linear(
  132. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  133. )
  134. self.o_proj = nn.Linear(
  135. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  136. )
  137. self.q_norm = Olmo3RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
  138. self.k_norm = Olmo3RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
  139. self.attention_type = config.layer_types[layer_idx]
  140. self.sliding_window = config.sliding_window if self.attention_type == "sliding_attention" else None
  141. def forward(
  142. self,
  143. hidden_states: torch.Tensor,
  144. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  145. attention_mask: torch.Tensor | None,
  146. past_key_values: Cache | None = None,
  147. **kwargs: Unpack[TransformersKwargs],
  148. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  149. input_shape = hidden_states.shape[:-1]
  150. hidden_shape = (*input_shape, -1, self.head_dim)
  151. query_states = self.q_norm(self.q_proj(hidden_states))
  152. key_states = self.k_norm(self.k_proj(hidden_states))
  153. value_states = self.v_proj(hidden_states)
  154. query_states = query_states.view(hidden_shape).transpose(1, 2)
  155. key_states = key_states.view(hidden_shape).transpose(1, 2)
  156. value_states = value_states.view(hidden_shape).transpose(1, 2)
  157. cos, sin = position_embeddings
  158. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  159. if past_key_values is not None:
  160. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  161. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  162. self.config._attn_implementation, eager_attention_forward
  163. )
  164. attn_output, attn_weights = attention_interface(
  165. self,
  166. query_states,
  167. key_states,
  168. value_states,
  169. attention_mask,
  170. dropout=0.0 if not self.training else self.attention_dropout,
  171. scaling=self.scaling,
  172. sliding_window=self.sliding_window,
  173. **kwargs,
  174. )
  175. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  176. attn_output = self.o_proj(attn_output)
  177. return attn_output, attn_weights
  178. class Olmo3MLP(nn.Module):
  179. def __init__(self, config):
  180. super().__init__()
  181. self.config = config
  182. self.hidden_size = config.hidden_size
  183. self.intermediate_size = config.intermediate_size
  184. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  185. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  186. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  187. self.act_fn = ACT2FN[config.hidden_act]
  188. def forward(self, x):
  189. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  190. return down_proj
  191. class Olmo3DecoderLayer(GradientCheckpointingLayer):
  192. def __init__(self, config: Olmo3Config, layer_idx: int):
  193. super().__init__()
  194. self.hidden_size = config.hidden_size
  195. self.self_attn = Olmo3Attention(config=config, layer_idx=layer_idx)
  196. self.mlp = Olmo3MLP(config)
  197. self.post_attention_layernorm = Olmo3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  198. self.post_feedforward_layernorm = Olmo3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  199. def forward(
  200. self,
  201. hidden_states: torch.Tensor,
  202. attention_mask: torch.Tensor | None = None,
  203. position_ids: torch.LongTensor | None = None,
  204. past_key_values: Cache | None = None,
  205. use_cache: bool | None = False,
  206. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  207. **kwargs: Unpack[TransformersKwargs],
  208. ) -> torch.Tensor:
  209. residual = hidden_states
  210. hidden_states, _ = self.self_attn(
  211. hidden_states=hidden_states,
  212. attention_mask=attention_mask,
  213. position_ids=position_ids,
  214. past_key_values=past_key_values,
  215. use_cache=use_cache,
  216. position_embeddings=position_embeddings,
  217. **kwargs,
  218. )
  219. hidden_states = self.post_attention_layernorm(hidden_states)
  220. hidden_states = residual + hidden_states
  221. # Fully Connected
  222. residual = hidden_states
  223. hidden_states = self.mlp(hidden_states)
  224. hidden_states = self.post_feedforward_layernorm(hidden_states)
  225. hidden_states = residual + hidden_states
  226. return hidden_states
  227. class Olmo3RotaryEmbedding(nn.Module):
  228. inv_freq: torch.Tensor # fix linting for `register_buffer`
  229. def __init__(self, config: Olmo3Config, device=None):
  230. super().__init__()
  231. self.max_seq_len_cached = config.max_position_embeddings
  232. self.original_max_seq_len = config.max_position_embeddings
  233. self.config = config
  234. self.rope_type = self.config.rope_parameters["rope_type"]
  235. rope_init_fn: Callable = self.compute_default_rope_parameters
  236. if self.rope_type != "default":
  237. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  238. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  239. self.register_buffer("inv_freq", inv_freq, persistent=False)
  240. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  241. @staticmethod
  242. def compute_default_rope_parameters(
  243. config: Olmo3Config | None = None,
  244. device: Optional["torch.device"] = None,
  245. seq_len: int | None = None,
  246. ) -> tuple["torch.Tensor", float]:
  247. """
  248. Computes the inverse frequencies according to the original RoPE implementation
  249. Args:
  250. config ([`~transformers.PreTrainedConfig`]):
  251. The model configuration.
  252. device (`torch.device`):
  253. The device to use for initialization of the inverse frequencies.
  254. seq_len (`int`, *optional*):
  255. The current sequence length. Unused for this type of RoPE.
  256. Returns:
  257. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  258. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  259. """
  260. base = config.rope_parameters["rope_theta"]
  261. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  262. attention_factor = 1.0 # Unused in this type of RoPE
  263. # Compute the inverse frequencies
  264. inv_freq = 1.0 / (
  265. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  266. )
  267. return inv_freq, attention_factor
  268. @torch.no_grad()
  269. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  270. def forward(self, x, position_ids):
  271. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  272. position_ids_expanded = position_ids[:, None, :].float()
  273. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  274. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  275. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  276. emb = torch.cat((freqs, freqs), dim=-1)
  277. cos = emb.cos() * self.attention_scaling
  278. sin = emb.sin() * self.attention_scaling
  279. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  280. @auto_docstring
  281. class Olmo3PreTrainedModel(PreTrainedModel):
  282. config: Olmo3Config
  283. base_model_prefix = "model"
  284. supports_gradient_checkpointing = True
  285. _no_split_modules = ["Olmo3DecoderLayer"]
  286. _skip_keys_device_placement = ["past_key_values"]
  287. _supports_flash_attn = True
  288. _supports_sdpa = True
  289. _supports_flex_attn = True
  290. _can_compile_fullgraph = True
  291. _supports_attention_backend = True
  292. _can_record_outputs = {
  293. "hidden_states": Olmo3DecoderLayer,
  294. "attentions": Olmo3Attention,
  295. }
  296. @auto_docstring
  297. class Olmo3Model(Olmo3PreTrainedModel):
  298. def __init__(self, config: Olmo3Config):
  299. super().__init__(config)
  300. self.padding_idx = config.pad_token_id
  301. self.vocab_size = config.vocab_size
  302. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  303. self.layers = nn.ModuleList(
  304. [Olmo3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  305. )
  306. self.norm = Olmo3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  307. self.rotary_emb = Olmo3RotaryEmbedding(config=config)
  308. self.gradient_checkpointing = False
  309. # Initialize weights and apply final processing
  310. self.post_init()
  311. @merge_with_config_defaults
  312. @capture_outputs
  313. @auto_docstring
  314. def forward(
  315. self,
  316. input_ids: torch.LongTensor | None = None,
  317. attention_mask: torch.Tensor | None = None,
  318. position_ids: torch.LongTensor | None = None,
  319. past_key_values: Cache | None = None,
  320. inputs_embeds: torch.FloatTensor | None = None,
  321. use_cache: bool | None = None,
  322. **kwargs: Unpack[TransformersKwargs],
  323. ) -> BaseModelOutputWithPast:
  324. if (input_ids is None) ^ (inputs_embeds is not None):
  325. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  326. if inputs_embeds is None:
  327. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  328. if use_cache and past_key_values is None:
  329. past_key_values = DynamicCache(config=self.config)
  330. if position_ids is None:
  331. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  332. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  333. position_ids = position_ids.unsqueeze(0)
  334. # It may already have been prepared by e.g. `generate`
  335. if not isinstance(causal_mask_mapping := attention_mask, dict):
  336. # Prepare mask arguments
  337. mask_kwargs = {
  338. "config": self.config,
  339. "inputs_embeds": inputs_embeds,
  340. "attention_mask": attention_mask,
  341. "past_key_values": past_key_values,
  342. "position_ids": position_ids,
  343. }
  344. # Create the masks
  345. causal_mask_mapping = {
  346. "full_attention": create_causal_mask(**mask_kwargs),
  347. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  348. }
  349. hidden_states = inputs_embeds
  350. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  351. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  352. hidden_states = decoder_layer(
  353. hidden_states,
  354. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  355. position_ids=position_ids,
  356. past_key_values=past_key_values,
  357. position_embeddings=position_embeddings,
  358. **kwargs,
  359. )
  360. hidden_states = self.norm(hidden_states)
  361. return BaseModelOutputWithPast(
  362. last_hidden_state=hidden_states,
  363. past_key_values=past_key_values,
  364. )
  365. @auto_docstring
  366. class Olmo3ForCausalLM(Olmo3PreTrainedModel, GenerationMixin):
  367. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  368. _tp_plan = {"lm_head": "colwise_gather_output"}
  369. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  370. def __init__(self, config):
  371. super().__init__(config)
  372. self.model = Olmo3Model(config)
  373. self.vocab_size = config.vocab_size
  374. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  375. # Initialize weights and apply final processing
  376. self.post_init()
  377. @can_return_tuple
  378. @auto_docstring
  379. def forward(
  380. self,
  381. input_ids: torch.LongTensor | None = None,
  382. attention_mask: torch.Tensor | None = None,
  383. position_ids: torch.LongTensor | None = None,
  384. past_key_values: Cache | None = None,
  385. inputs_embeds: torch.FloatTensor | None = None,
  386. labels: torch.LongTensor | None = None,
  387. use_cache: bool | None = None,
  388. logits_to_keep: int | torch.Tensor = 0,
  389. **kwargs: Unpack[TransformersKwargs],
  390. ) -> CausalLMOutputWithPast:
  391. r"""
  392. Example:
  393. ```python
  394. >>> from transformers import AutoTokenizer, Olmo3ForCausalLM
  395. >>> model = Olmo3ForCausalLM.from_pretrained("meta-olmo3/Olmo3-2-7b-hf")
  396. >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo3/Olmo3-2-7b-hf")
  397. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  398. >>> inputs = tokenizer(prompt, return_tensors="pt")
  399. >>> # Generate
  400. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  401. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  402. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  403. ```"""
  404. outputs: BaseModelOutputWithPast = self.model(
  405. input_ids=input_ids,
  406. attention_mask=attention_mask,
  407. position_ids=position_ids,
  408. past_key_values=past_key_values,
  409. inputs_embeds=inputs_embeds,
  410. use_cache=use_cache,
  411. **kwargs,
  412. )
  413. hidden_states = outputs.last_hidden_state
  414. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  415. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  416. logits = self.lm_head(hidden_states[:, slice_indices, :])
  417. loss = None
  418. if labels is not None:
  419. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  420. return CausalLMOutputWithPast(
  421. loss=loss,
  422. logits=logits,
  423. past_key_values=outputs.past_key_values,
  424. hidden_states=outputs.hidden_states,
  425. attentions=outputs.attentions,
  426. )
  427. __all__ = ["Olmo3ForCausalLM", "Olmo3Model", "Olmo3PreTrainedModel"]