modeling_apertus.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/apertus/modular_apertus.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_apertus.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 the HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. from typing import Optional
  23. import torch
  24. from torch import nn
  25. from ...activations import ACT2CLS, ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  29. from ...masking_utils import create_causal_mask
  30. from ...modeling_layers import GenericForTokenClassification, GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  32. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  36. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  37. from ...utils.output_capturing import capture_outputs
  38. from .configuration_apertus import ApertusConfig
  39. class ApertusMLP(nn.Module):
  40. def __init__(self, config):
  41. super().__init__()
  42. self.config = config
  43. self.hidden_size = config.hidden_size
  44. self.intermediate_size = config.intermediate_size
  45. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  46. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  47. self.act_fn = ACT2FN[config.hidden_act]
  48. if config.hidden_act == "xielu":
  49. self.act_fn = ACT2CLS["xielu"](dtype=config.dtype)
  50. def forward(self, x):
  51. return self.down_proj(self.act_fn(self.up_proj(x)))
  52. @use_kernel_forward_from_hub("RMSNorm")
  53. class ApertusRMSNorm(nn.Module):
  54. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  55. """
  56. ApertusRMSNorm is equivalent to T5LayerNorm
  57. """
  58. super().__init__()
  59. self.weight = nn.Parameter(torch.ones(hidden_size))
  60. self.variance_epsilon = eps
  61. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  62. input_dtype = hidden_states.dtype
  63. hidden_states = hidden_states.to(torch.float32)
  64. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  65. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  66. return self.weight * hidden_states.to(input_dtype)
  67. def extra_repr(self):
  68. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  69. class ApertusRotaryEmbedding(nn.Module):
  70. inv_freq: torch.Tensor # fix linting for `register_buffer`
  71. def __init__(self, config: ApertusConfig, device=None):
  72. super().__init__()
  73. self.max_seq_len_cached = config.max_position_embeddings
  74. self.original_max_seq_len = config.max_position_embeddings
  75. self.config = config
  76. self.rope_type = self.config.rope_parameters["rope_type"]
  77. rope_init_fn: Callable = self.compute_default_rope_parameters
  78. if self.rope_type != "default":
  79. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  80. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  81. self.register_buffer("inv_freq", inv_freq, persistent=False)
  82. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  83. @staticmethod
  84. def compute_default_rope_parameters(
  85. config: ApertusConfig | None = None,
  86. device: Optional["torch.device"] = None,
  87. seq_len: int | None = None,
  88. ) -> tuple["torch.Tensor", float]:
  89. """
  90. Computes the inverse frequencies according to the original RoPE implementation
  91. Args:
  92. config ([`~transformers.PreTrainedConfig`]):
  93. The model configuration.
  94. device (`torch.device`):
  95. The device to use for initialization of the inverse frequencies.
  96. seq_len (`int`, *optional*):
  97. The current sequence length. Unused for this type of RoPE.
  98. Returns:
  99. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  100. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  101. """
  102. base = config.rope_parameters["rope_theta"]
  103. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  104. attention_factor = 1.0 # Unused in this type of RoPE
  105. # Compute the inverse frequencies
  106. inv_freq = 1.0 / (
  107. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  108. )
  109. return inv_freq, attention_factor
  110. @torch.no_grad()
  111. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  112. def forward(self, x, position_ids):
  113. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  114. position_ids_expanded = position_ids[:, None, :].float()
  115. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  116. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  117. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  118. emb = torch.cat((freqs, freqs), dim=-1)
  119. cos = emb.cos() * self.attention_scaling
  120. sin = emb.sin() * self.attention_scaling
  121. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  122. def rotate_half(x):
  123. """Rotates half the hidden dims of the input."""
  124. x1 = x[..., : x.shape[-1] // 2]
  125. x2 = x[..., x.shape[-1] // 2 :]
  126. return torch.cat((-x2, x1), dim=-1)
  127. @use_kernel_func_from_hub("rotary_pos_emb")
  128. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  129. """Applies Rotary Position Embedding to the query and key tensors.
  130. Args:
  131. q (`torch.Tensor`): The query tensor.
  132. k (`torch.Tensor`): The key tensor.
  133. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  134. sin (`torch.Tensor`): The sine part of the rotary embedding.
  135. unsqueeze_dim (`int`, *optional*, defaults to 1):
  136. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  137. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  138. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  139. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  140. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  141. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  142. Returns:
  143. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  144. """
  145. cos = cos.unsqueeze(unsqueeze_dim)
  146. sin = sin.unsqueeze(unsqueeze_dim)
  147. q_embed = (q * cos) + (rotate_half(q) * sin)
  148. k_embed = (k * cos) + (rotate_half(k) * sin)
  149. return q_embed, k_embed
  150. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  151. """
  152. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  153. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  154. """
  155. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  156. if n_rep == 1:
  157. return hidden_states
  158. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  159. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  160. def eager_attention_forward(
  161. module: nn.Module,
  162. query: torch.Tensor,
  163. key: torch.Tensor,
  164. value: torch.Tensor,
  165. attention_mask: torch.Tensor | None,
  166. scaling: float,
  167. dropout: float = 0.0,
  168. **kwargs: Unpack[TransformersKwargs],
  169. ):
  170. key_states = repeat_kv(key, module.num_key_value_groups)
  171. value_states = repeat_kv(value, module.num_key_value_groups)
  172. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  173. if attention_mask is not None:
  174. attn_weights = attn_weights + attention_mask
  175. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  176. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  177. attn_output = torch.matmul(attn_weights, value_states)
  178. attn_output = attn_output.transpose(1, 2).contiguous()
  179. return attn_output, attn_weights
  180. @use_kernelized_func(apply_rotary_pos_emb)
  181. class ApertusAttention(nn.Module):
  182. """Multi-headed attention from 'Attention Is All You Need' paper"""
  183. def __init__(self, config: ApertusConfig, layer_idx: int | None = None):
  184. super().__init__()
  185. self.config = config
  186. self.layer_idx = layer_idx
  187. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  188. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  189. self.scaling = self.head_dim**-0.5
  190. self.attention_dropout = config.attention_dropout
  191. self.is_causal = True
  192. self.q_proj = nn.Linear(
  193. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  194. )
  195. self.k_proj = nn.Linear(
  196. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  197. )
  198. self.v_proj = nn.Linear(
  199. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  200. )
  201. self.o_proj = nn.Linear(
  202. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  203. )
  204. self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
  205. self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
  206. def forward(
  207. self,
  208. hidden_states: torch.Tensor,
  209. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  210. attention_mask: torch.Tensor | None,
  211. past_key_values: Cache | None = None,
  212. **kwargs: Unpack[TransformersKwargs],
  213. ) -> tuple[torch.Tensor, torch.Tensor]:
  214. input_shape = hidden_states.shape[:-1]
  215. hidden_shape = (*input_shape, -1, self.head_dim)
  216. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  217. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  218. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  219. query_states = self.q_norm(query_states)
  220. key_states = self.k_norm(key_states)
  221. cos, sin = position_embeddings
  222. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  223. if past_key_values is not None:
  224. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  225. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  226. self.config._attn_implementation, eager_attention_forward
  227. )
  228. attn_output, attn_weights = attention_interface(
  229. self,
  230. query_states,
  231. key_states,
  232. value_states,
  233. attention_mask,
  234. dropout=0.0 if not self.training else self.attention_dropout,
  235. scaling=self.scaling,
  236. **kwargs,
  237. )
  238. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  239. attn_output = self.o_proj(attn_output)
  240. return attn_output, attn_weights
  241. class ApertusDecoderLayer(GradientCheckpointingLayer):
  242. def __init__(self, config: ApertusConfig, layer_idx: int):
  243. super().__init__()
  244. self.hidden_size = config.hidden_size
  245. self.self_attn = ApertusAttention(config=config, layer_idx=layer_idx)
  246. self.mlp = ApertusMLP(config)
  247. self.attention_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  248. self.feedforward_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  249. def forward(
  250. self,
  251. hidden_states: torch.Tensor,
  252. attention_mask: torch.Tensor | None = None,
  253. position_ids: torch.LongTensor | None = None,
  254. past_key_values: Cache | None = None,
  255. use_cache: bool | None = False,
  256. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  257. **kwargs: Unpack[TransformersKwargs],
  258. ) -> torch.Tensor:
  259. residual = hidden_states
  260. hidden_states = self.attention_layernorm(hidden_states)
  261. hidden_states, _ = self.self_attn(
  262. hidden_states=hidden_states,
  263. attention_mask=attention_mask,
  264. position_ids=position_ids,
  265. past_key_values=past_key_values,
  266. use_cache=use_cache,
  267. position_embeddings=position_embeddings,
  268. **kwargs,
  269. )
  270. hidden_states = residual + hidden_states
  271. # Fully Connected
  272. residual = hidden_states
  273. hidden_states = self.feedforward_layernorm(hidden_states)
  274. hidden_states = self.mlp(hidden_states)
  275. hidden_states = residual + hidden_states
  276. return hidden_states
  277. @auto_docstring
  278. class ApertusPreTrainedModel(PreTrainedModel):
  279. config: ApertusConfig
  280. base_model_prefix = "model"
  281. supports_gradient_checkpointing = True
  282. _no_split_modules = ["ApertusDecoderLayer"]
  283. _skip_keys_device_placement = ["past_key_values"]
  284. _supports_flash_attn = True
  285. _supports_sdpa = True
  286. _supports_flex_attn = True
  287. _can_compile_fullgraph = True
  288. _supports_attention_backend = True
  289. _can_record_outputs = {
  290. "hidden_states": ApertusDecoderLayer,
  291. "attentions": ApertusAttention,
  292. }
  293. @auto_docstring
  294. class ApertusModel(ApertusPreTrainedModel):
  295. def __init__(self, config: ApertusConfig):
  296. super().__init__(config)
  297. self.padding_idx = config.pad_token_id
  298. self.vocab_size = config.vocab_size
  299. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  300. self.layers = nn.ModuleList(
  301. [ApertusDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  302. )
  303. self.norm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  304. self.rotary_emb = ApertusRotaryEmbedding(config=config)
  305. self.gradient_checkpointing = False
  306. # Initialize weights and apply final processing
  307. self.post_init()
  308. @merge_with_config_defaults
  309. @capture_outputs
  310. @auto_docstring
  311. def forward(
  312. self,
  313. input_ids: torch.LongTensor | None = None,
  314. attention_mask: torch.Tensor | None = None,
  315. position_ids: torch.LongTensor | None = None,
  316. past_key_values: Cache | None = None,
  317. inputs_embeds: torch.FloatTensor | None = None,
  318. use_cache: bool | None = None,
  319. **kwargs: Unpack[TransformersKwargs],
  320. ) -> BaseModelOutputWithPast:
  321. if (input_ids is None) ^ (inputs_embeds is not None):
  322. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  323. if inputs_embeds is None:
  324. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  325. if use_cache and past_key_values is None:
  326. past_key_values = DynamicCache(config=self.config)
  327. if position_ids is None:
  328. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  329. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  330. position_ids = position_ids.unsqueeze(0)
  331. causal_mask = create_causal_mask(
  332. config=self.config,
  333. inputs_embeds=inputs_embeds,
  334. attention_mask=attention_mask,
  335. past_key_values=past_key_values,
  336. position_ids=position_ids,
  337. )
  338. hidden_states = inputs_embeds
  339. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  340. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  341. hidden_states = decoder_layer(
  342. hidden_states,
  343. attention_mask=causal_mask,
  344. position_embeddings=position_embeddings,
  345. position_ids=position_ids,
  346. past_key_values=past_key_values,
  347. use_cache=use_cache,
  348. **kwargs,
  349. )
  350. hidden_states = self.norm(hidden_states)
  351. return BaseModelOutputWithPast(
  352. last_hidden_state=hidden_states,
  353. past_key_values=past_key_values,
  354. )
  355. @auto_docstring
  356. class ApertusForCausalLM(ApertusPreTrainedModel, GenerationMixin):
  357. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  358. _tp_plan = {"lm_head": "colwise_gather_output"}
  359. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  360. def __init__(self, config):
  361. super().__init__(config)
  362. self.model = ApertusModel(config)
  363. self.vocab_size = config.vocab_size
  364. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  365. # Initialize weights and apply final processing
  366. self.post_init()
  367. @can_return_tuple
  368. @auto_docstring
  369. def forward(
  370. self,
  371. input_ids: torch.LongTensor | None = None,
  372. attention_mask: torch.Tensor | None = None,
  373. position_ids: torch.LongTensor | None = None,
  374. past_key_values: Cache | None = None,
  375. inputs_embeds: torch.FloatTensor | None = None,
  376. labels: torch.LongTensor | None = None,
  377. use_cache: bool | None = None,
  378. logits_to_keep: int | torch.Tensor = 0,
  379. **kwargs: Unpack[TransformersKwargs],
  380. ) -> CausalLMOutputWithPast:
  381. r"""
  382. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  383. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  384. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  385. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  386. Example:
  387. ```python
  388. >>> from transformers import AutoTokenizer, ApertusForCausalLM
  389. >>> model = ApertusForCausalLM.from_pretrained("swiss-ai/Apertus-8B-Instruct-2509")
  390. >>> tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B-Instruct-2509")
  391. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  392. >>> inputs = tokenizer(prompt, return_tensors="pt")
  393. >>> # Generate
  394. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  395. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  396. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  397. ```"""
  398. outputs: BaseModelOutputWithPast = self.model(
  399. input_ids=input_ids,
  400. attention_mask=attention_mask,
  401. position_ids=position_ids,
  402. past_key_values=past_key_values,
  403. inputs_embeds=inputs_embeds,
  404. use_cache=use_cache,
  405. **kwargs,
  406. )
  407. hidden_states = outputs.last_hidden_state
  408. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  409. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  410. logits = self.lm_head(hidden_states[:, slice_indices, :])
  411. loss = None
  412. if labels is not None:
  413. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  414. return CausalLMOutputWithPast(
  415. loss=loss,
  416. logits=logits,
  417. past_key_values=outputs.past_key_values,
  418. hidden_states=outputs.hidden_states,
  419. attentions=outputs.attentions,
  420. )
  421. class ApertusForTokenClassification(GenericForTokenClassification, ApertusPreTrainedModel):
  422. pass
  423. __all__ = ["ApertusModel", "ApertusForCausalLM", "ApertusForTokenClassification", "ApertusPreTrainedModel"]