modeling_diffllama.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/diffllama/modular_diffllama.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_diffllama.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on Llama implementations in this library and Microsoft's
  10. # Differential Transformer implementations.
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. import math
  23. from collections.abc import Callable
  24. from typing import Optional
  25. import torch
  26. from torch import nn
  27. from ... import initialization as init
  28. from ...activations import ACT2FN
  29. from ...cache_utils import Cache, DynamicCache, StaticCache
  30. from ...generation import GenerationMixin
  31. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
  32. from ...masking_utils import create_causal_mask
  33. from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
  34. from ...modeling_layers import (
  35. GenericForQuestionAnswering,
  36. GenericForSequenceClassification,
  37. GenericForTokenClassification,
  38. GradientCheckpointingLayer,
  39. )
  40. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  41. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  42. from ...modeling_utils import PreTrainedModel
  43. from ...processing_utils import Unpack
  44. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  45. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  46. from ...utils.output_capturing import capture_outputs
  47. from .configuration_diffllama import DiffLlamaConfig
  48. logger = logging.get_logger(__name__)
  49. class DiffLlamaMLP(nn.Module):
  50. def __init__(self, config):
  51. super().__init__()
  52. self.config = config
  53. self.hidden_size = config.hidden_size
  54. self.intermediate_size = config.intermediate_size
  55. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  56. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  57. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  58. self.act_fn = ACT2FN[config.hidden_act]
  59. def forward(self, x):
  60. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  61. return down_proj
  62. class DiffLlamaRotaryEmbedding(nn.Module):
  63. inv_freq: torch.Tensor # fix linting for `register_buffer`
  64. def __init__(self, config: DiffLlamaConfig, device=None):
  65. super().__init__()
  66. self.max_seq_len_cached = config.max_position_embeddings
  67. self.original_max_seq_len = config.max_position_embeddings
  68. self.config = config
  69. self.rope_type = self.config.rope_parameters["rope_type"]
  70. rope_init_fn: Callable = self.compute_default_rope_parameters
  71. if self.rope_type != "default":
  72. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  73. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  74. self.register_buffer("inv_freq", inv_freq, persistent=False)
  75. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  76. @staticmethod
  77. def compute_default_rope_parameters(
  78. config: DiffLlamaConfig | None = None,
  79. device: Optional["torch.device"] = None,
  80. seq_len: int | None = None,
  81. ) -> tuple["torch.Tensor", float]:
  82. """
  83. Computes the inverse frequencies according to the original RoPE implementation
  84. Args:
  85. config ([`~transformers.PreTrainedConfig`]):
  86. The model configuration.
  87. device (`torch.device`):
  88. The device to use for initialization of the inverse frequencies.
  89. seq_len (`int`, *optional*):
  90. The current sequence length. Unused for this type of RoPE.
  91. Returns:
  92. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  93. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  94. """
  95. base = config.rope_parameters["rope_theta"]
  96. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  97. attention_factor = 1.0 # Unused in this type of RoPE
  98. # Compute the inverse frequencies
  99. inv_freq = 1.0 / (
  100. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  101. )
  102. return inv_freq, attention_factor
  103. @torch.no_grad()
  104. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  105. def forward(self, x, position_ids):
  106. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  107. position_ids_expanded = position_ids[:, None, :].float()
  108. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  109. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  110. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  111. emb = torch.cat((freqs, freqs), dim=-1)
  112. cos = emb.cos() * self.attention_scaling
  113. sin = emb.sin() * self.attention_scaling
  114. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  115. def rotate_half(x):
  116. """Rotates half the hidden dims of the input."""
  117. x1 = x[..., : x.shape[-1] // 2]
  118. x2 = x[..., x.shape[-1] // 2 :]
  119. return torch.cat((-x2, x1), dim=-1)
  120. @use_kernel_func_from_hub("rotary_pos_emb")
  121. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  122. """Applies Rotary Position Embedding to the query and key tensors.
  123. Args:
  124. q (`torch.Tensor`): The query tensor.
  125. k (`torch.Tensor`): The key tensor.
  126. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  127. sin (`torch.Tensor`): The sine part of the rotary embedding.
  128. unsqueeze_dim (`int`, *optional*, defaults to 1):
  129. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  130. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  131. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  132. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  133. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  134. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  135. Returns:
  136. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  137. """
  138. cos = cos.unsqueeze(unsqueeze_dim)
  139. sin = sin.unsqueeze(unsqueeze_dim)
  140. q_embed = (q * cos) + (rotate_half(q) * sin)
  141. k_embed = (k * cos) + (rotate_half(k) * sin)
  142. return q_embed, k_embed
  143. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  144. """
  145. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  146. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  147. """
  148. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  149. if n_rep == 1:
  150. return hidden_states
  151. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  152. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  153. def lambda_init_fn(layer_idx):
  154. return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
  155. class DiffLlamaAttention(nn.Module):
  156. """Multi-headed attention from 'Attention Is All You Need' paper"""
  157. def __init__(self, config: DiffLlamaConfig, layer_idx: int | None = None):
  158. super().__init__()
  159. self.config = config
  160. self.layer_idx = layer_idx
  161. if layer_idx is None:
  162. logger.warning_once(
  163. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  164. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  165. "when creating this class."
  166. )
  167. self.attention_dropout = config.attention_dropout
  168. self.hidden_size = config.hidden_size
  169. self.num_heads = config.num_attention_heads
  170. self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
  171. self.num_key_value_heads = config.num_key_value_heads
  172. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  173. # under this are not used
  174. self.max_position_embeddings = config.max_position_embeddings
  175. self.is_causal = True
  176. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  177. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  178. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  179. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
  180. self.lambda_init = lambda_init_fn(layer_idx)
  181. self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  182. self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  183. self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  184. self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  185. self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
  186. def forward(
  187. self,
  188. hidden_states: torch.Tensor,
  189. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  190. attention_mask: torch.Tensor | None = None,
  191. position_ids: torch.LongTensor | None = None,
  192. past_key_values: Cache | None = None,
  193. use_cache: bool = False,
  194. **kwargs,
  195. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  196. bsz, target_len, _ = hidden_states.size()
  197. q_len = target_len
  198. query_states = self.q_proj(hidden_states)
  199. key_states = self.k_proj(hidden_states)
  200. value_states = self.v_proj(hidden_states)
  201. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  202. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  203. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  204. cos, sin = position_embeddings
  205. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  206. if past_key_values is not None:
  207. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  208. key_states = repeat_kv(key_states, self.num_key_value_groups)
  209. value_states = repeat_kv(value_states, self.num_key_value_groups)
  210. value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
  211. value_states = value_states.repeat(1, 2, 1, 1)
  212. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  213. if attention_mask is not None:
  214. attn_weights = attn_weights + attention_mask
  215. # upcast attention to fp32
  216. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  217. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  218. lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
  219. query_states.dtype
  220. )
  221. lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
  222. query_states.dtype
  223. )
  224. lambda_full = lambda_1 - lambda_2 + self.lambda_init
  225. attn_output = torch.matmul(attn_weights, value_states)
  226. attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
  227. attn_output = attn_output1 - lambda_full * attn_output2
  228. attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
  229. attn_output = attn_output.transpose(1, 2).contiguous()
  230. attn_output = attn_output.reshape(bsz, q_len, -1)
  231. attn_output = self.o_proj(attn_output)
  232. return attn_output, attn_weights
  233. class DiffLlamaFlashAttention2(DiffLlamaAttention):
  234. """
  235. DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
  236. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  237. flash attention and deal with padding tokens in case the input contains any of them.
  238. """
  239. def __init__(self, *args, **kwargs):
  240. super().__init__(*args, **kwargs)
  241. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  242. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  243. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  244. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  245. def forward(
  246. self,
  247. hidden_states: torch.Tensor,
  248. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  249. attention_mask: torch.LongTensor | None = None,
  250. position_ids: torch.LongTensor | None = None,
  251. past_key_values: Cache | None = None,
  252. use_cache: bool = False,
  253. ) -> tuple[torch.Tensor, None]:
  254. if isinstance(past_key_values, StaticCache):
  255. raise ValueError(
  256. "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
  257. "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
  258. )
  259. bsz, q_len, _ = hidden_states.size()
  260. query_states = self.q_proj(hidden_states)
  261. key_states = self.k_proj(hidden_states)
  262. value_states = self.v_proj(hidden_states)
  263. # Flash attention requires the input to have the shape
  264. # batch_size x seq_length x head_dim x hidden_dim
  265. # therefore we just need to keep the original shape
  266. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  267. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  268. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  269. cos, sin = position_embeddings
  270. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  271. if past_key_values is not None:
  272. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  273. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  274. # to be able to avoid many of these transpose/reshape/view.
  275. query_states = query_states.transpose(1, 2)
  276. key_states = key_states.transpose(1, 2)
  277. value_states = value_states.transpose(1, 2)
  278. dropout_rate = self.attention_dropout if self.training else 0.0
  279. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  280. # therefore the input hidden states gets silently casted in float32. Hence, we need
  281. # cast them back in the correct dtype just to be sure everything works as expected.
  282. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  283. # in fp32. (DiffLlamaRMSNorm handles it correctly)
  284. input_dtype = query_states.dtype
  285. device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
  286. if input_dtype == torch.float32:
  287. if torch.is_autocast_enabled(device_type):
  288. target_dtype = torch.get_autocast_dtype(device_type)
  289. # Handle the case where the model is quantized
  290. elif hasattr(self.config, "_is_quantized"):
  291. target_dtype = self.config.dtype
  292. else:
  293. target_dtype = self.q_proj.weight.dtype
  294. logger.warning_once(
  295. f"The input hidden states seems to be silently casted in float32, this might be related to"
  296. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  297. f" {target_dtype}."
  298. )
  299. query_states = query_states.to(target_dtype)
  300. key_states = key_states.to(target_dtype)
  301. value_states = value_states.to(target_dtype)
  302. value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
  303. value_states1 = value_states1.repeat(1, 1, 2, 1)
  304. value_states2 = value_states2.repeat(1, 1, 2, 1)
  305. attn_output1 = _flash_attention_forward(
  306. query_states,
  307. key_states,
  308. value_states1,
  309. attention_mask,
  310. q_len,
  311. position_ids=position_ids,
  312. dropout=dropout_rate,
  313. sliding_window=getattr(self, "sliding_window", None),
  314. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  315. is_causal=self.is_causal,
  316. )
  317. attn_output2 = _flash_attention_forward(
  318. query_states,
  319. key_states,
  320. value_states2,
  321. attention_mask,
  322. q_len,
  323. position_ids=position_ids,
  324. dropout=dropout_rate,
  325. sliding_window=getattr(self, "sliding_window", None),
  326. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  327. is_causal=self.is_causal,
  328. )
  329. attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
  330. attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)
  331. lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
  332. query_states.dtype
  333. )
  334. lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
  335. query_states.dtype
  336. )
  337. lambda_full = lambda_1 - lambda_2 + self.lambda_init
  338. attn_output = attn_output1 - lambda_full * attn_output2
  339. attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
  340. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  341. attn_output = self.o_proj(attn_output)
  342. return attn_output, None
  343. class DiffLlamaSdpaAttention(DiffLlamaAttention):
  344. """
  345. DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  346. `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  347. SDPA API.
  348. """
  349. # Adapted from DiffLlamaAttention.forward
  350. def forward(
  351. self,
  352. hidden_states: torch.Tensor,
  353. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  354. attention_mask: torch.Tensor | None = None,
  355. position_ids: torch.LongTensor | None = None,
  356. past_key_values: Cache | None = None,
  357. use_cache: bool = False,
  358. **kwargs,
  359. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  360. bsz, q_len, _ = hidden_states.size()
  361. query_states = self.q_proj(hidden_states)
  362. key_states = self.k_proj(hidden_states)
  363. value_states = self.v_proj(hidden_states)
  364. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  365. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  366. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  367. cos, sin = position_embeddings
  368. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  369. if past_key_values is not None:
  370. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  371. key_states = repeat_kv(key_states, self.num_key_value_groups)
  372. value_states = repeat_kv(value_states, self.num_key_value_groups)
  373. value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
  374. value_states = value_states.repeat(1, 2, 1, 1)
  375. causal_mask = attention_mask
  376. if attention_mask is not None:
  377. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  378. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  379. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  380. is_causal = causal_mask is None and q_len > 1
  381. attn_output = torch.nn.functional.scaled_dot_product_attention(
  382. query_states,
  383. key_states,
  384. value_states,
  385. attn_mask=causal_mask,
  386. dropout_p=self.attention_dropout if self.training else 0.0,
  387. is_causal=is_causal,
  388. )
  389. attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
  390. lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
  391. query_states.dtype
  392. )
  393. lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
  394. query_states.dtype
  395. )
  396. lambda_full = lambda_1 - lambda_2 + self.lambda_init
  397. attn_output = attn_output1 - lambda_full * attn_output2
  398. attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
  399. attn_output = attn_output.transpose(1, 2).contiguous()
  400. attn_output = attn_output.view(bsz, q_len, -1)
  401. attn_output = self.o_proj(attn_output)
  402. return attn_output, None
  403. @use_kernel_forward_from_hub("RMSNorm")
  404. class DiffLlamaRMSNorm(nn.Module):
  405. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  406. """
  407. DiffLlamaRMSNorm is equivalent to T5LayerNorm
  408. """
  409. super().__init__()
  410. self.weight = nn.Parameter(torch.ones(hidden_size))
  411. self.variance_epsilon = eps
  412. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  413. input_dtype = hidden_states.dtype
  414. hidden_states = hidden_states.to(torch.float32)
  415. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  416. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  417. return self.weight * hidden_states.to(input_dtype)
  418. def extra_repr(self):
  419. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  420. DIFFLLAMA_ATTENTION_CLASSES = {
  421. "eager": DiffLlamaAttention,
  422. "flash_attention_2": DiffLlamaFlashAttention2,
  423. "sdpa": DiffLlamaSdpaAttention,
  424. }
  425. class DiffLlamaDecoderLayer(GradientCheckpointingLayer):
  426. def __init__(self, config: DiffLlamaConfig, layer_idx: int):
  427. super().__init__()
  428. self.hidden_size = config.hidden_size
  429. self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
  430. self.mlp = DiffLlamaMLP(config)
  431. self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  432. self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  433. def forward(
  434. self,
  435. hidden_states: torch.Tensor,
  436. attention_mask: torch.Tensor | None = None,
  437. position_ids: torch.LongTensor | None = None,
  438. past_key_values: Cache | None = None,
  439. use_cache: bool | None = False,
  440. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  441. **kwargs: Unpack[TransformersKwargs],
  442. ) -> torch.Tensor:
  443. residual = hidden_states
  444. hidden_states = self.input_layernorm(hidden_states)
  445. # Self Attention
  446. hidden_states, _ = self.self_attn(
  447. hidden_states=hidden_states,
  448. attention_mask=attention_mask,
  449. position_ids=position_ids,
  450. past_key_values=past_key_values,
  451. use_cache=use_cache,
  452. position_embeddings=position_embeddings,
  453. **kwargs,
  454. )
  455. hidden_states = residual + hidden_states
  456. # Fully Connected
  457. residual = hidden_states
  458. hidden_states = self.post_attention_layernorm(hidden_states)
  459. hidden_states = self.mlp(hidden_states)
  460. hidden_states = residual + hidden_states
  461. return hidden_states
  462. @auto_docstring
  463. class DiffLlamaPreTrainedModel(PreTrainedModel):
  464. config: DiffLlamaConfig
  465. base_model_prefix = "model"
  466. supports_gradient_checkpointing = True
  467. _no_split_modules = ["DiffLlamaDecoderLayer"]
  468. _skip_keys_device_placement = ["past_key_values"]
  469. _supports_flash_attn = True
  470. _supports_sdpa = True
  471. _supports_flex_attn = False
  472. _can_compile_fullgraph = True
  473. _supports_attention_backend = False
  474. _can_record_outputs = {
  475. "hidden_states": DiffLlamaDecoderLayer,
  476. "attentions": DiffLlamaAttention,
  477. }
  478. @torch.no_grad()
  479. def _init_weights(self, module):
  480. super()._init_weights(module)
  481. if isinstance(module, DiffLlamaAttention):
  482. init.normal_(module.lambda_q1, 0, self.config.lambda_std_dev)
  483. init.normal_(module.lambda_k1, 0, self.config.lambda_std_dev)
  484. init.normal_(module.lambda_q2, 0, self.config.lambda_std_dev)
  485. init.normal_(module.lambda_k2, 0, self.config.lambda_std_dev)
  486. @auto_docstring
  487. class DiffLlamaModel(DiffLlamaPreTrainedModel):
  488. def __init__(self, config: DiffLlamaConfig):
  489. super().__init__(config)
  490. self.padding_idx = config.pad_token_id
  491. self.vocab_size = config.vocab_size
  492. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  493. self.layers = nn.ModuleList(
  494. [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  495. )
  496. self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  497. self.rotary_emb = DiffLlamaRotaryEmbedding(config=config)
  498. self.gradient_checkpointing = False
  499. # Initialize weights and apply final processing
  500. self.post_init()
  501. @merge_with_config_defaults
  502. @capture_outputs
  503. @auto_docstring
  504. def forward(
  505. self,
  506. input_ids: torch.LongTensor | None = None,
  507. attention_mask: torch.Tensor | None = None,
  508. position_ids: torch.LongTensor | None = None,
  509. past_key_values: Cache | None = None,
  510. inputs_embeds: torch.FloatTensor | None = None,
  511. use_cache: bool | None = None,
  512. **kwargs: Unpack[TransformersKwargs],
  513. ) -> BaseModelOutputWithPast:
  514. if (input_ids is None) ^ (inputs_embeds is not None):
  515. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  516. if inputs_embeds is None:
  517. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  518. if use_cache and past_key_values is None:
  519. past_key_values = DynamicCache(config=self.config)
  520. if position_ids is None:
  521. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  522. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  523. position_ids = position_ids.unsqueeze(0)
  524. causal_mask = create_causal_mask(
  525. config=self.config,
  526. inputs_embeds=inputs_embeds,
  527. attention_mask=attention_mask,
  528. past_key_values=past_key_values,
  529. position_ids=position_ids,
  530. )
  531. hidden_states = inputs_embeds
  532. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  533. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  534. hidden_states = decoder_layer(
  535. hidden_states,
  536. attention_mask=causal_mask,
  537. position_embeddings=position_embeddings,
  538. position_ids=position_ids,
  539. past_key_values=past_key_values,
  540. use_cache=use_cache,
  541. **kwargs,
  542. )
  543. hidden_states = self.norm(hidden_states)
  544. return BaseModelOutputWithPast(
  545. last_hidden_state=hidden_states,
  546. past_key_values=past_key_values,
  547. )
  548. @auto_docstring
  549. class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
  550. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  551. _tp_plan = {"lm_head": "colwise_gather_output"}
  552. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  553. def __init__(self, config):
  554. super().__init__(config)
  555. self.model = DiffLlamaModel(config)
  556. self.vocab_size = config.vocab_size
  557. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  558. # Initialize weights and apply final processing
  559. self.post_init()
  560. @can_return_tuple
  561. @auto_docstring
  562. def forward(
  563. self,
  564. input_ids: torch.LongTensor | None = None,
  565. attention_mask: torch.Tensor | None = None,
  566. position_ids: torch.LongTensor | None = None,
  567. past_key_values: Cache | None = None,
  568. inputs_embeds: torch.FloatTensor | None = None,
  569. labels: torch.LongTensor | None = None,
  570. use_cache: bool | None = None,
  571. logits_to_keep: int | torch.Tensor = 0,
  572. **kwargs: Unpack[TransformersKwargs],
  573. ) -> CausalLMOutputWithPast:
  574. r"""
  575. Example:
  576. ```python
  577. >>> from transformers import AutoTokenizer, DiffLlamaForCausalLM
  578. >>> model = DiffLlamaForCausalLM.from_pretrained("google/diffllama-7b")
  579. >>> tokenizer = AutoTokenizer.from_pretrained("google/diffllama-7b")
  580. >>> prompt = "What is your favorite condiment?"
  581. >>> inputs = tokenizer(prompt, return_tensors="pt")
  582. >>> # Generate
  583. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  584. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  585. "What is your favorite condiment?"
  586. ```"""
  587. outputs: BaseModelOutputWithPast = self.model(
  588. input_ids=input_ids,
  589. attention_mask=attention_mask,
  590. position_ids=position_ids,
  591. past_key_values=past_key_values,
  592. inputs_embeds=inputs_embeds,
  593. use_cache=use_cache,
  594. **kwargs,
  595. )
  596. hidden_states = outputs.last_hidden_state
  597. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  598. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  599. logits = self.lm_head(hidden_states[:, slice_indices, :])
  600. loss = None
  601. if labels is not None:
  602. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  603. return CausalLMOutputWithPast(
  604. loss=loss,
  605. logits=logits,
  606. past_key_values=outputs.past_key_values,
  607. hidden_states=outputs.hidden_states,
  608. attentions=outputs.attentions,
  609. )
  610. class DiffLlamaForSequenceClassification(GenericForSequenceClassification, DiffLlamaPreTrainedModel):
  611. pass
  612. class DiffLlamaForQuestionAnswering(GenericForQuestionAnswering, DiffLlamaPreTrainedModel):
  613. base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
  614. class DiffLlamaForTokenClassification(GenericForTokenClassification, DiffLlamaPreTrainedModel):
  615. pass
  616. __all__ = [
  617. "DiffLlamaPreTrainedModel",
  618. "DiffLlamaModel",
  619. "DiffLlamaForCausalLM",
  620. "DiffLlamaForSequenceClassification",
  621. "DiffLlamaForQuestionAnswering",
  622. "DiffLlamaForTokenClassification",
  623. ]