| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/diffllama/modular_diffllama.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_diffllama.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
- #
- # This code is based on Llama implementations in this library and Microsoft's
- # Differential Transformer implementations.
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from collections.abc import Callable
- from typing import Optional
- import torch
- from torch import nn
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache, StaticCache
- from ...generation import GenerationMixin
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
- from ...masking_utils import create_causal_mask
- from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
- from ...modeling_layers import (
- GenericForQuestionAnswering,
- GenericForSequenceClassification,
- GenericForTokenClassification,
- GradientCheckpointingLayer,
- )
- from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
- from ...utils.generic import maybe_autocast, merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from .configuration_diffllama import DiffLlamaConfig
- logger = logging.get_logger(__name__)
- class DiffLlamaMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_act]
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
- class DiffLlamaRotaryEmbedding(nn.Module):
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, config: DiffLlamaConfig, device=None):
- super().__init__()
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.rope_type = self.config.rope_parameters["rope_type"]
- rope_init_fn: Callable = self.compute_default_rope_parameters
- if self.rope_type != "default":
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
- @staticmethod
- def compute_default_rope_parameters(
- config: DiffLlamaConfig | None = None,
- device: Optional["torch.device"] = None,
- seq_len: int | None = None,
- ) -> tuple["torch.Tensor", float]:
- """
- Computes the inverse frequencies according to the original RoPE implementation
- Args:
- config ([`~transformers.PreTrainedConfig`]):
- The model configuration.
- device (`torch.device`):
- The device to use for initialization of the inverse frequencies.
- seq_len (`int`, *optional*):
- The current sequence length. Unused for this type of RoPE.
- Returns:
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
- """
- base = config.rope_parameters["rope_theta"]
- dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
- attention_factor = 1.0 # Unused in this type of RoPE
- # Compute the inverse frequencies
- inv_freq = 1.0 / (
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
- )
- return inv_freq, attention_factor
- @torch.no_grad()
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
- def forward(self, x, position_ids):
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
- position_ids_expanded = position_ids[:, None, :].float()
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * self.attention_scaling
- sin = emb.sin() * self.attention_scaling
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- @use_kernel_func_from_hub("rotary_pos_emb")
- def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
- def lambda_init_fn(layer_idx):
- return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
- class DiffLlamaAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: DiffLlamaConfig, layer_idx: int | None = None):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
- self.attention_dropout = config.attention_dropout
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
- self.num_key_value_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
- # under this are not used
- self.max_position_embeddings = config.max_position_embeddings
- self.is_causal = True
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
- self.lambda_init = lambda_init_fn(layer_idx)
- self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
- self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
- self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
- self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
- self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool = False,
- **kwargs,
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
- bsz, target_len, _ = hidden_states.size()
- q_len = target_len
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
- value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
- value_states = value_states.repeat(1, 2, 1, 1)
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
- lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
- query_states.dtype
- )
- lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
- query_states.dtype
- )
- lambda_full = lambda_1 - lambda_2 + self.lambda_init
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
- attn_output = attn_output1 - lambda_full * attn_output2
- attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(bsz, q_len, -1)
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class DiffLlamaFlashAttention2(DiffLlamaAttention):
- """
- DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool = False,
- ) -> tuple[torch.Tensor, None]:
- if isinstance(past_key_values, StaticCache):
- raise ValueError(
- "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
- "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
- )
- bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
- # Flash attention requires the input to have the shape
- # batch_size x seq_length x head_dim x hidden_dim
- # therefore we just need to keep the original shape
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
- # to be able to avoid many of these transpose/reshape/view.
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- dropout_rate = self.attention_dropout if self.training else 0.0
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (DiffLlamaRMSNorm handles it correctly)
- input_dtype = query_states.dtype
- device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled(device_type):
- target_dtype = torch.get_autocast_dtype(device_type)
- # Handle the case where the model is quantized
- elif hasattr(self.config, "_is_quantized"):
- target_dtype = self.config.dtype
- else:
- target_dtype = self.q_proj.weight.dtype
- logger.warning_once(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
- value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
- value_states1 = value_states1.repeat(1, 1, 2, 1)
- value_states2 = value_states2.repeat(1, 1, 2, 1)
- attn_output1 = _flash_attention_forward(
- query_states,
- key_states,
- value_states1,
- attention_mask,
- q_len,
- position_ids=position_ids,
- dropout=dropout_rate,
- sliding_window=getattr(self, "sliding_window", None),
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- is_causal=self.is_causal,
- )
- attn_output2 = _flash_attention_forward(
- query_states,
- key_states,
- value_states2,
- attention_mask,
- q_len,
- position_ids=position_ids,
- dropout=dropout_rate,
- sliding_window=getattr(self, "sliding_window", None),
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- is_causal=self.is_causal,
- )
- attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
- attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)
- lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
- query_states.dtype
- )
- lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
- query_states.dtype
- )
- lambda_full = lambda_1 - lambda_2 + self.lambda_init
- attn_output = attn_output1 - lambda_full * attn_output2
- attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
- attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, None
- class DiffLlamaSdpaAttention(DiffLlamaAttention):
- """
- DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
- `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
- SDPA API.
- """
- # Adapted from DiffLlamaAttention.forward
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool = False,
- **kwargs,
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
- bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
- value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
- value_states = value_states.repeat(1, 2, 1, 1)
- causal_mask = attention_mask
- if attention_mask is not None:
- causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- is_causal = causal_mask is None and q_len > 1
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query_states,
- key_states,
- value_states,
- attn_mask=causal_mask,
- dropout_p=self.attention_dropout if self.training else 0.0,
- is_causal=is_causal,
- )
- attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
- lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
- query_states.dtype
- )
- lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
- query_states.dtype
- )
- lambda_full = lambda_1 - lambda_2 + self.lambda_init
- attn_output = attn_output1 - lambda_full * attn_output2
- attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.view(bsz, q_len, -1)
- attn_output = self.o_proj(attn_output)
- return attn_output, None
- @use_kernel_forward_from_hub("RMSNorm")
- class DiffLlamaRMSNorm(nn.Module):
- def __init__(self, hidden_size, eps: float = 1e-6) -> None:
- """
- DiffLlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
- DIFFLLAMA_ATTENTION_CLASSES = {
- "eager": DiffLlamaAttention,
- "flash_attention_2": DiffLlamaFlashAttention2,
- "sdpa": DiffLlamaSdpaAttention,
- }
- class DiffLlamaDecoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: DiffLlamaConfig, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
- self.mlp = DiffLlamaMLP(config)
- self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool | None = False,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- # Self Attention
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = residual + hidden_states
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- @auto_docstring
- class DiffLlamaPreTrainedModel(PreTrainedModel):
- config: DiffLlamaConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["DiffLlamaDecoderLayer"]
- _skip_keys_device_placement = ["past_key_values"]
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = False
- _can_compile_fullgraph = True
- _supports_attention_backend = False
- _can_record_outputs = {
- "hidden_states": DiffLlamaDecoderLayer,
- "attentions": DiffLlamaAttention,
- }
- @torch.no_grad()
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, DiffLlamaAttention):
- init.normal_(module.lambda_q1, 0, self.config.lambda_std_dev)
- init.normal_(module.lambda_k1, 0, self.config.lambda_std_dev)
- init.normal_(module.lambda_q2, 0, self.config.lambda_std_dev)
- init.normal_(module.lambda_k2, 0, self.config.lambda_std_dev)
- @auto_docstring
- class DiffLlamaModel(DiffLlamaPreTrainedModel):
- def __init__(self, config: DiffLlamaConfig):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = nn.ModuleList(
- [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = DiffLlamaRotaryEmbedding(config=config)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPast:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- if position_ids is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
- position_ids = position_ids.unsqueeze(0)
- causal_mask = create_causal_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- position_ids=position_ids,
- )
- hidden_states = inputs_embeds
- position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
- hidden_states = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_embeddings=position_embeddings,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = self.norm(hidden_states)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- )
- @auto_docstring
- class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
- _tp_plan = {"lm_head": "colwise_gather_output"}
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
- def __init__(self, config):
- super().__init__(config)
- self.model = DiffLlamaModel(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
- r"""
- Example:
- ```python
- >>> from transformers import AutoTokenizer, DiffLlamaForCausalLM
- >>> model = DiffLlamaForCausalLM.from_pretrained("google/diffllama-7b")
- >>> tokenizer = AutoTokenizer.from_pretrained("google/diffllama-7b")
- >>> prompt = "What is your favorite condiment?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "What is your favorite condiment?"
- ```"""
- outputs: BaseModelOutputWithPast = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- class DiffLlamaForSequenceClassification(GenericForSequenceClassification, DiffLlamaPreTrainedModel):
- pass
- class DiffLlamaForQuestionAnswering(GenericForQuestionAnswering, DiffLlamaPreTrainedModel):
- base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
- class DiffLlamaForTokenClassification(GenericForTokenClassification, DiffLlamaPreTrainedModel):
- pass
- __all__ = [
- "DiffLlamaPreTrainedModel",
- "DiffLlamaModel",
- "DiffLlamaForCausalLM",
- "DiffLlamaForSequenceClassification",
- "DiffLlamaForQuestionAnswering",
- "DiffLlamaForTokenClassification",
- ]
|