modeling_helium.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/helium/modular_helium.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_helium.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 The Kyutai and HuggingFace Inc. teams. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. from collections.abc import Callable
  23. from typing import Optional
  24. import torch
  25. import torch.nn as nn
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernelized_func
  30. from ...masking_utils import create_causal_mask
  31. from ...modeling_layers import (
  32. GenericForSequenceClassification,
  33. GenericForTokenClassification,
  34. GradientCheckpointingLayer,
  35. )
  36. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  41. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  42. from ...utils.output_capturing import capture_outputs
  43. from .configuration_helium import HeliumConfig
  44. class HeliumRMSNorm(nn.Module):
  45. def __init__(self, hidden_size, eps=1e-6):
  46. super().__init__()
  47. self.weight = nn.Parameter(torch.ones(hidden_size))
  48. self.variance_epsilon = eps
  49. def forward(self, hidden_states):
  50. input_dtype = hidden_states.dtype
  51. hidden_states = hidden_states.to(torch.float32)
  52. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  53. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  54. return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
  55. def extra_repr(self):
  56. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  57. class HeliumRotaryEmbedding(nn.Module):
  58. inv_freq: torch.Tensor # fix linting for `register_buffer`
  59. def __init__(self, config: HeliumConfig, device=None):
  60. super().__init__()
  61. self.max_seq_len_cached = config.max_position_embeddings
  62. self.original_max_seq_len = config.max_position_embeddings
  63. self.config = config
  64. self.rope_type = self.config.rope_parameters["rope_type"]
  65. rope_init_fn: Callable = self.compute_default_rope_parameters
  66. if self.rope_type != "default":
  67. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  68. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  69. self.register_buffer("inv_freq", inv_freq, persistent=False)
  70. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  71. @staticmethod
  72. def compute_default_rope_parameters(
  73. config: HeliumConfig | None = None,
  74. device: Optional["torch.device"] = None,
  75. seq_len: int | None = None,
  76. ) -> tuple["torch.Tensor", float]:
  77. """
  78. Computes the inverse frequencies according to the original RoPE implementation
  79. Args:
  80. config ([`~transformers.PreTrainedConfig`]):
  81. The model configuration.
  82. device (`torch.device`):
  83. The device to use for initialization of the inverse frequencies.
  84. seq_len (`int`, *optional*):
  85. The current sequence length. Unused for this type of RoPE.
  86. Returns:
  87. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  88. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  89. """
  90. base = config.rope_parameters["rope_theta"]
  91. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  92. attention_factor = 1.0 # Unused in this type of RoPE
  93. # Compute the inverse frequencies
  94. inv_freq = 1.0 / (
  95. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  96. )
  97. return inv_freq, attention_factor
  98. @torch.no_grad()
  99. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  100. def forward(self, x, position_ids):
  101. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  102. position_ids_expanded = position_ids[:, None, :].float()
  103. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  104. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  105. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  106. emb = torch.cat((freqs, freqs), dim=-1)
  107. cos = emb.cos() * self.attention_scaling
  108. sin = emb.sin() * self.attention_scaling
  109. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  110. class HeliumMLP(nn.Module):
  111. def __init__(self, config):
  112. super().__init__()
  113. self.config = config
  114. self.hidden_size = config.hidden_size
  115. self.intermediate_size = config.intermediate_size
  116. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  117. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  118. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  119. self.act_fn = ACT2FN[config.hidden_act]
  120. def forward(self, x):
  121. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  122. return down_proj
  123. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  124. """
  125. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  126. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  127. """
  128. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  129. if n_rep == 1:
  130. return hidden_states
  131. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  132. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  133. def eager_attention_forward(
  134. module: nn.Module,
  135. query: torch.Tensor,
  136. key: torch.Tensor,
  137. value: torch.Tensor,
  138. attention_mask: torch.Tensor | None,
  139. scaling: float,
  140. dropout: float = 0.0,
  141. **kwargs: Unpack[TransformersKwargs],
  142. ):
  143. key_states = repeat_kv(key, module.num_key_value_groups)
  144. value_states = repeat_kv(value, module.num_key_value_groups)
  145. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  146. if attention_mask is not None:
  147. attn_weights = attn_weights + attention_mask
  148. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  149. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  150. attn_output = torch.matmul(attn_weights, value_states)
  151. attn_output = attn_output.transpose(1, 2).contiguous()
  152. return attn_output, attn_weights
  153. def rotate_half(x):
  154. """Rotates half the hidden dims of the input."""
  155. x1 = x[..., 0::2]
  156. x2 = x[..., 1::2]
  157. return torch.stack((-x2, x1), dim=-1).flatten(-2)
  158. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  159. """Applies Rotary Position Embedding to the query and key tensors.
  160. Args:
  161. q (`torch.Tensor`): The query tensor.
  162. k (`torch.Tensor`): The key tensor.
  163. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  164. sin (`torch.Tensor`): The sine part of the rotary embedding.
  165. unsqueeze_dim (`int`, *optional*, defaults to 1):
  166. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  167. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  168. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  169. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  170. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  171. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  172. Returns:
  173. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  174. """
  175. cos = cos.unsqueeze(unsqueeze_dim)
  176. sin = sin.unsqueeze(unsqueeze_dim)
  177. # Interleave them instead of usual shape
  178. cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
  179. sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
  180. q_embed = (q * cos) + (rotate_half(q) * sin)
  181. k_embed = (k * cos) + (rotate_half(k) * sin)
  182. return q_embed, k_embed
  183. @use_kernelized_func(apply_rotary_pos_emb)
  184. class HeliumAttention(nn.Module):
  185. """Multi-headed attention from 'Attention Is All You Need' paper"""
  186. def __init__(self, config: HeliumConfig, layer_idx: int | None = None):
  187. super().__init__()
  188. self.config = config
  189. self.layer_idx = layer_idx
  190. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  191. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  192. self.scaling = 1 / math.sqrt(self.head_dim)
  193. self.attention_dropout = config.attention_dropout
  194. self.is_causal = True
  195. self.q_proj = nn.Linear(
  196. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  197. )
  198. self.k_proj = nn.Linear(
  199. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  200. )
  201. self.v_proj = nn.Linear(
  202. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  203. )
  204. self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
  205. def forward(
  206. self,
  207. hidden_states: torch.Tensor,
  208. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  209. attention_mask: torch.Tensor | None = None,
  210. past_key_values: Cache | None = None,
  211. **kwargs: Unpack[TransformersKwargs],
  212. ) -> tuple[torch.Tensor, torch.Tensor]:
  213. input_shape = hidden_states.shape[:-1]
  214. hidden_shape = (*input_shape, -1, self.head_dim)
  215. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  216. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  217. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  218. cos, sin = position_embeddings
  219. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  220. if past_key_values is not None:
  221. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  222. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  223. self.config._attn_implementation, eager_attention_forward
  224. )
  225. attn_output, attn_weights = attention_interface(
  226. self,
  227. query_states,
  228. key_states,
  229. value_states,
  230. attention_mask,
  231. dropout=0.0 if not self.training else self.attention_dropout,
  232. scaling=self.scaling,
  233. **kwargs,
  234. )
  235. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  236. attn_output = self.o_proj(attn_output)
  237. return attn_output, attn_weights
  238. class HeliumDecoderLayer(GradientCheckpointingLayer):
  239. def __init__(self, config: HeliumConfig, layer_idx: int | None = None):
  240. super().__init__()
  241. self.hidden_size = config.hidden_size
  242. self.self_attn = HeliumAttention(config=config, layer_idx=layer_idx)
  243. self.mlp = HeliumMLP(config)
  244. self.input_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  245. self.post_attention_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  246. def forward(
  247. self,
  248. hidden_states: torch.Tensor,
  249. attention_mask: torch.Tensor | None = None,
  250. position_ids: torch.LongTensor | None = None,
  251. past_key_values: Cache | None = None,
  252. use_cache: bool | None = False,
  253. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  254. **kwargs: Unpack[TransformersKwargs],
  255. ) -> torch.Tensor:
  256. residual = hidden_states
  257. hidden_states = self.input_layernorm(hidden_states)
  258. # Self Attention
  259. hidden_states, _ = self.self_attn(
  260. hidden_states=hidden_states,
  261. attention_mask=attention_mask,
  262. position_ids=position_ids,
  263. past_key_values=past_key_values,
  264. use_cache=use_cache,
  265. position_embeddings=position_embeddings,
  266. **kwargs,
  267. )
  268. hidden_states = residual + hidden_states
  269. # Fully Connected
  270. residual = hidden_states
  271. hidden_states = self.post_attention_layernorm(hidden_states)
  272. hidden_states = self.mlp(hidden_states)
  273. hidden_states = residual + hidden_states
  274. return hidden_states
  275. @auto_docstring
  276. class HeliumPreTrainedModel(PreTrainedModel):
  277. config: HeliumConfig
  278. base_model_prefix = "model"
  279. supports_gradient_checkpointing = True
  280. _no_split_modules = ["HeliumDecoderLayer"]
  281. _skip_keys_device_placement = ["past_key_values"]
  282. _supports_flash_attn = True
  283. _supports_sdpa = True
  284. _supports_flex_attn = True
  285. _can_compile_fullgraph = True
  286. _supports_attention_backend = True
  287. _can_record_outputs = {
  288. "hidden_states": HeliumDecoderLayer,
  289. "attentions": HeliumAttention,
  290. }
  291. @auto_docstring
  292. class HeliumModel(HeliumPreTrainedModel):
  293. def __init__(self, config: HeliumConfig):
  294. super().__init__(config)
  295. self.padding_idx = config.pad_token_id
  296. self.vocab_size = config.vocab_size
  297. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  298. self.layers = nn.ModuleList(
  299. [HeliumDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  300. )
  301. self.norm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  302. self.rotary_emb = HeliumRotaryEmbedding(config=config)
  303. self.gradient_checkpointing = False
  304. # Initialize weights and apply final processing
  305. self.post_init()
  306. @merge_with_config_defaults
  307. @capture_outputs
  308. @auto_docstring
  309. def forward(
  310. self,
  311. input_ids: torch.LongTensor | None = None,
  312. attention_mask: torch.Tensor | None = None,
  313. position_ids: torch.LongTensor | None = None,
  314. past_key_values: Cache | None = None,
  315. inputs_embeds: torch.FloatTensor | None = None,
  316. use_cache: bool | None = None,
  317. **kwargs: Unpack[TransformersKwargs],
  318. ) -> BaseModelOutputWithPast:
  319. if (input_ids is None) ^ (inputs_embeds is not None):
  320. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  321. if inputs_embeds is None:
  322. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  323. if use_cache and past_key_values is None:
  324. past_key_values = DynamicCache(config=self.config)
  325. if position_ids is None:
  326. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  327. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  328. position_ids = position_ids.unsqueeze(0)
  329. causal_mask = create_causal_mask(
  330. config=self.config,
  331. inputs_embeds=inputs_embeds,
  332. attention_mask=attention_mask,
  333. past_key_values=past_key_values,
  334. position_ids=position_ids,
  335. )
  336. hidden_states = inputs_embeds
  337. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  338. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  339. hidden_states = decoder_layer(
  340. hidden_states,
  341. attention_mask=causal_mask,
  342. position_embeddings=position_embeddings,
  343. position_ids=position_ids,
  344. past_key_values=past_key_values,
  345. use_cache=use_cache,
  346. **kwargs,
  347. )
  348. hidden_states = self.norm(hidden_states)
  349. return BaseModelOutputWithPast(
  350. last_hidden_state=hidden_states,
  351. past_key_values=past_key_values,
  352. )
  353. @auto_docstring
  354. class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
  355. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  356. _tp_plan = {"lm_head": "colwise_gather_output"}
  357. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  358. def __init__(self, config):
  359. super().__init__(config)
  360. self.model = HeliumModel(config)
  361. self.vocab_size = config.vocab_size
  362. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  363. # Initialize weights and apply final processing
  364. self.post_init()
  365. @can_return_tuple
  366. @auto_docstring
  367. def forward(
  368. self,
  369. input_ids: torch.LongTensor | None = None,
  370. attention_mask: torch.Tensor | None = None,
  371. position_ids: torch.LongTensor | None = None,
  372. past_key_values: Cache | None = None,
  373. inputs_embeds: torch.FloatTensor | None = None,
  374. labels: torch.LongTensor | None = None,
  375. use_cache: bool | None = None,
  376. logits_to_keep: int | torch.Tensor = 0,
  377. **kwargs: Unpack[TransformersKwargs],
  378. ) -> CausalLMOutputWithPast:
  379. r"""
  380. Example:
  381. ```python
  382. >>> from transformers import AutoTokenizer, HeliumForCausalLM
  383. >>> model = HeliumForCausalLM.from_pretrained("google/helium-7b")
  384. >>> tokenizer = AutoTokenizer.from_pretrained("google/helium-7b")
  385. >>> prompt = "What is your favorite condiment?"
  386. >>> inputs = tokenizer(prompt, return_tensors="pt")
  387. >>> # Generate
  388. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  389. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  390. "What is your favorite condiment?"
  391. ```"""
  392. outputs: BaseModelOutputWithPast = self.model(
  393. input_ids=input_ids,
  394. attention_mask=attention_mask,
  395. position_ids=position_ids,
  396. past_key_values=past_key_values,
  397. inputs_embeds=inputs_embeds,
  398. use_cache=use_cache,
  399. **kwargs,
  400. )
  401. hidden_states = outputs.last_hidden_state
  402. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  403. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  404. logits = self.lm_head(hidden_states[:, slice_indices, :])
  405. loss = None
  406. if labels is not None:
  407. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  408. return CausalLMOutputWithPast(
  409. loss=loss,
  410. logits=logits,
  411. past_key_values=outputs.past_key_values,
  412. hidden_states=outputs.hidden_states,
  413. attentions=outputs.attentions,
  414. )
  415. class HeliumForSequenceClassification(GenericForSequenceClassification, HeliumPreTrainedModel):
  416. pass
  417. class HeliumForTokenClassification(GenericForTokenClassification, HeliumPreTrainedModel):
  418. pass
  419. __all__ = [
  420. "HeliumPreTrainedModel",
  421. "HeliumModel",
  422. "HeliumForCausalLM",
  423. "HeliumForSequenceClassification",
  424. "HeliumForTokenClassification",
  425. ]