| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282 |
- from typing import Final, Optional, Type
- import torch
- from torch import nn as nn
- from torch.nn import functional as F
- from ._fx import register_notrace_function
- from .config import use_fused_attn
- from .pos_embed_sincos import apply_rot_embed_cat
- __all__ = ['Attention', 'AttentionRope', 'maybe_add_mask', 'resolve_self_attn_mask']
- @torch.fx.wrap
- @register_notrace_function
- def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
- return scores if attn_mask is None else scores + attn_mask
- @torch.fx.wrap
- @register_notrace_function
- def resolve_self_attn_mask(
- seq_len: int,
- attn: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- is_causal: bool = False,
- ) -> Optional[torch.Tensor]:
- # Build additive bias matching SDPA semantics for self-attention
- # is_causal and attn_mask are mutually exclusive (is_causal takes precedence)
- if is_causal:
- attn_bias = attn.new_full((seq_len, seq_len), float('-inf')).triu_(1)
- elif attn_mask is None:
- attn_bias = None
- elif attn_mask.dtype == torch.bool:
- attn_bias = torch.zeros_like(attn_mask, dtype=attn.dtype)
- attn_bias.masked_fill_(~attn_mask, float('-inf'))
- else:
- attn_bias = attn_mask
- return attn_bias
- class Attention(nn.Module):
- """Standard Multi-head Self Attention module with QKV projection.
- This module implements the standard multi-head attention mechanism used in transformers.
- It supports both the fused attention implementation (scaled_dot_product_attention) for
- efficiency when available, and a manual implementation otherwise. The module includes
- options for QK normalization, attention dropout, and projection dropout.
- """
- fused_attn: Final[bool]
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- attn_head_dim: Optional[int] = None,
- dim_out: Optional[int] = None,
- qkv_bias: bool = False,
- qk_norm: bool = False,
- scale_norm: bool = False,
- proj_bias: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: Optional[Type[nn.Module]] = None,
- device=None,
- dtype=None,
- ) -> None:
- """Initialize the Attention module.
- Args:
- dim: Input dimension of the token embeddings.
- num_heads: Number of attention heads.
- attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads.
- dim_out: Output dimension. If None, same as dim.
- qkv_bias: Whether to use bias in the query, key, value projections.
- qk_norm: Whether to apply normalization to query and key vectors.
- scale_norm: Whether to apply normalization to attention output before projection.
- proj_bias: Whether to use bias in the output projection.
- attn_drop: Dropout rate applied to the attention weights.
- proj_drop: Dropout rate applied after the output projection.
- norm_layer: Normalization layer constructor for QK normalization if enabled.
- """
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- dim_out = dim_out or dim
- head_dim = attn_head_dim
- if head_dim is None:
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
- head_dim = dim // num_heads
- if qk_norm or scale_norm:
- assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
- self.num_heads = num_heads
- self.head_dim = head_dim
- self.attn_dim = num_heads * head_dim
- self.scale = head_dim ** -0.5
- self.fused_attn = use_fused_attn()
- self.qkv = nn.Linear(dim, self.attn_dim * 3, bias=qkv_bias, **dd)
- self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity()
- self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd)
- self.proj_drop = nn.Dropout(proj_drop)
- def forward(
- self,
- x: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- is_causal: bool = False,
- ) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(0)
- q, k = self.q_norm(q), self.k_norm(k)
- if self.fused_attn:
- x = F.scaled_dot_product_attention(
- q, k, v,
- attn_mask=attn_mask,
- dropout_p=self.attn_drop.p if self.training else 0.,
- is_causal=is_causal,
- )
- else:
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
- attn_bias = resolve_self_attn_mask(N, attn, attn_mask, is_causal)
- attn = maybe_add_mask(attn, attn_bias)
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = attn @ v
- x = x.transpose(1, 2).reshape(B, N, self.attn_dim)
- x = self.norm(x)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
- class AttentionRope(nn.Module):
- """ A Self Attention module with ROPE support.
- Includes options for:
- * QK normalization option
- * Attention output (scale) normalization
- * Fused or unfused QKV projection support
- """
- fused_attn: torch.jit.Final[bool]
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- dim_out: Optional[int] = None,
- qkv_bias: bool = True,
- qkv_fused: bool = True,
- num_prefix_tokens: int = 1,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- attn_head_dim: Optional[int] = None,
- norm_layer: Type[nn.Module] = None,
- qk_norm: bool = False,
- scale_norm: bool = False,
- proj_bias: bool = True,
- rotate_half: bool = False,
- device=None,
- dtype=None,
- ):
- """Initialize the Attention module.
- Args:
- dim: Input dimension of the token embeddings
- num_heads: Number of attention heads
- dim_out: Output dimension. If None, same as dim.
- qkv_bias: Whether to add a bias term to the query, key, and value projections
- qkv_fused: Whether to use fused QKV projection (single linear) or separate projections
- num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
- should not have position embeddings applied
- attn_drop: Dropout rate for attention weights
- proj_drop: Dropout rate for the output projection
- attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads.
- norm_layer: Normalization layer constructor to use for QK and scale normalization
- qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
- scale_norm: Enable normalization (scaling) of attention output with norm_layer
- proj_bias: Whether to use bias in the output projection
- rotate_half: Use 'half' ROPE layout instead of default 'interleaved'
- """
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- dim_out = dim_out or dim
- head_dim = attn_head_dim
- if head_dim is None:
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
- head_dim = dim // num_heads
- if scale_norm or qk_norm:
- assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
- self.num_heads = num_heads
- self.head_dim = head_dim
- self.attn_dim = head_dim * num_heads
- self.scale = head_dim ** -0.5
- self.num_prefix_tokens = num_prefix_tokens
- self.fused_attn = use_fused_attn()
- self.rotate_half = rotate_half
- if qkv_fused:
- self.qkv = nn.Linear(dim, self.attn_dim * 3, bias=qkv_bias, **dd)
- self.q_proj = self.k_proj = self.v_proj = None
- else:
- self.qkv = None
- self.q_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd)
- self.k_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd)
- self.v_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd)
- self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity()
- self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd)
- self.proj_drop = nn.Dropout(proj_drop)
- def forward(
- self,
- x,
- rope: Optional[torch.Tensor] = None,
- attn_mask: Optional[torch.Tensor] = None,
- is_causal: bool = False,
- ):
- """Forward pass for the attention module.
- Args:
- x: Input tensor of shape (batch_size, sequence_length, embedding_dim)
- rope: Rotary position embeddings tensor for position-aware attention
- attn_mask: Optional attention mask to apply during attention computation
- is_causal: If True, use causal (autoregressive) masking
- Returns:
- Tensor of shape (batch_size, sequence_length, dim_out)
- """
- B, N, C = x.shape
- if self.qkv is not None:
- qkv = self.qkv(x)
- qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
- else:
- q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
- k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
- v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
- q, k = self.q_norm(q), self.k_norm(k)
- if rope is not None:
- npt = self.num_prefix_tokens
- half = getattr(self, 'rotate_half', False)
- q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope, half=half)], dim=2).type_as(v)
- k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope, half=half)], dim=2).type_as(v)
- if self.fused_attn:
- x = F.scaled_dot_product_attention(
- q, k, v,
- attn_mask=attn_mask,
- dropout_p=self.attn_drop.p if self.training else 0.,
- is_causal=is_causal,
- )
- else:
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
- attn_bias = resolve_self_attn_mask(N, attn, attn_mask, is_causal)
- attn = maybe_add_mask(attn, attn_bias)
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = attn @ v
- x = x.transpose(1, 2).reshape(B, N, self.attn_dim)
- x = self.norm(x)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
|