from typing import Final, Optional, Type import torch from torch import nn as nn from torch.nn import functional as F from ._fx import register_notrace_function from .config import use_fused_attn from .pos_embed_sincos import apply_rot_embed_cat __all__ = ['Attention', 'AttentionRope', 'maybe_add_mask', 'resolve_self_attn_mask'] @torch.fx.wrap @register_notrace_function def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): return scores if attn_mask is None else scores + attn_mask @torch.fx.wrap @register_notrace_function def resolve_self_attn_mask( seq_len: int, attn: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, ) -> Optional[torch.Tensor]: # Build additive bias matching SDPA semantics for self-attention # is_causal and attn_mask are mutually exclusive (is_causal takes precedence) if is_causal: attn_bias = attn.new_full((seq_len, seq_len), float('-inf')).triu_(1) elif attn_mask is None: attn_bias = None elif attn_mask.dtype == torch.bool: attn_bias = torch.zeros_like(attn_mask, dtype=attn.dtype) attn_bias.masked_fill_(~attn_mask, float('-inf')) else: attn_bias = attn_mask return attn_bias class Attention(nn.Module): """Standard Multi-head Self Attention module with QKV projection. This module implements the standard multi-head attention mechanism used in transformers. It supports both the fused attention implementation (scaled_dot_product_attention) for efficiency when available, and a manual implementation otherwise. The module includes options for QK normalization, attention dropout, and projection dropout. """ fused_attn: Final[bool] def __init__( self, dim: int, num_heads: int = 8, attn_head_dim: Optional[int] = None, dim_out: Optional[int] = None, qkv_bias: bool = False, qk_norm: bool = False, scale_norm: bool = False, proj_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0., norm_layer: Optional[Type[nn.Module]] = None, device=None, dtype=None, ) -> None: """Initialize the Attention module. Args: dim: Input dimension of the token embeddings. num_heads: Number of attention heads. attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads. dim_out: Output dimension. If None, same as dim. qkv_bias: Whether to use bias in the query, key, value projections. qk_norm: Whether to apply normalization to query and key vectors. scale_norm: Whether to apply normalization to attention output before projection. proj_bias: Whether to use bias in the output projection. attn_drop: Dropout rate applied to the attention weights. proj_drop: Dropout rate applied after the output projection. norm_layer: Normalization layer constructor for QK normalization if enabled. """ super().__init__() dd = {'device': device, 'dtype': dtype} dim_out = dim_out or dim head_dim = attn_head_dim if head_dim is None: assert dim % num_heads == 0, 'dim should be divisible by num_heads' head_dim = dim // num_heads if qk_norm or scale_norm: assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True' self.num_heads = num_heads self.head_dim = head_dim self.attn_dim = num_heads * head_dim self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, self.attn_dim * 3, bias=qkv_bias, **dd) self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity() self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity() self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, ) -> torch.Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.attn_drop.p if self.training else 0., is_causal=is_causal, ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) attn_bias = resolve_self_attn_mask(N, attn, attn_mask, is_causal) attn = maybe_add_mask(attn, attn_bias) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2).reshape(B, N, self.attn_dim) x = self.norm(x) x = self.proj(x) x = self.proj_drop(x) return x class AttentionRope(nn.Module): """ A Self Attention module with ROPE support. Includes options for: * QK normalization option * Attention output (scale) normalization * Fused or unfused QKV projection support """ fused_attn: torch.jit.Final[bool] def __init__( self, dim: int, num_heads: int = 8, dim_out: Optional[int] = None, qkv_bias: bool = True, qkv_fused: bool = True, num_prefix_tokens: int = 1, attn_drop: float = 0., proj_drop: float = 0., attn_head_dim: Optional[int] = None, norm_layer: Type[nn.Module] = None, qk_norm: bool = False, scale_norm: bool = False, proj_bias: bool = True, rotate_half: bool = False, device=None, dtype=None, ): """Initialize the Attention module. Args: dim: Input dimension of the token embeddings num_heads: Number of attention heads dim_out: Output dimension. If None, same as dim. qkv_bias: Whether to add a bias term to the query, key, and value projections qkv_fused: Whether to use fused QKV projection (single linear) or separate projections num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that should not have position embeddings applied attn_drop: Dropout rate for attention weights proj_drop: Dropout rate for the output projection attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads. norm_layer: Normalization layer constructor to use for QK and scale normalization qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer scale_norm: Enable normalization (scaling) of attention output with norm_layer proj_bias: Whether to use bias in the output projection rotate_half: Use 'half' ROPE layout instead of default 'interleaved' """ super().__init__() dd = {'device': device, 'dtype': dtype} dim_out = dim_out or dim head_dim = attn_head_dim if head_dim is None: assert dim % num_heads == 0, 'dim should be divisible by num_heads' head_dim = dim // num_heads if scale_norm or qk_norm: assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True' self.num_heads = num_heads self.head_dim = head_dim self.attn_dim = head_dim * num_heads self.scale = head_dim ** -0.5 self.num_prefix_tokens = num_prefix_tokens self.fused_attn = use_fused_attn() self.rotate_half = rotate_half if qkv_fused: self.qkv = nn.Linear(dim, self.attn_dim * 3, bias=qkv_bias, **dd) self.q_proj = self.k_proj = self.v_proj = None else: self.qkv = None self.q_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd) self.k_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd) self.v_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd) self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity() self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity() self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward( self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, ): """Forward pass for the attention module. Args: x: Input tensor of shape (batch_size, sequence_length, embedding_dim) rope: Rotary position embeddings tensor for position-aware attention attn_mask: Optional attention mask to apply during attention computation is_causal: If True, use causal (autoregressive) masking Returns: Tensor of shape (batch_size, sequence_length, dim_out) """ B, N, C = x.shape if self.qkv is not None: qkv = self.qkv(x) qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim else: q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) q, k = self.q_norm(q), self.k_norm(k) if rope is not None: npt = self.num_prefix_tokens half = getattr(self, 'rotate_half', False) q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope, half=half)], dim=2).type_as(v) k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope, half=half)], dim=2).type_as(v) if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.attn_drop.p if self.training else 0., is_causal=is_causal, ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) attn_bias = resolve_self_attn_mask(N, attn, attn_mask, is_causal) attn = maybe_add_mask(attn, attn_bias) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2).reshape(B, N, self.attn_dim) x = self.norm(x) x = self.proj(x) x = self.proj_drop(x) return x