attention_pool.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. from typing import Optional, Type
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from .attention import maybe_add_mask
  6. from .config import use_fused_attn
  7. from .mlp import Mlp
  8. from .weight_init import trunc_normal_tf_
  9. class AttentionPoolLatent(nn.Module):
  10. """ Attention pooling w/ latent query
  11. Setting out_features=0 disables the output projection, norm, and MLP layers (pre_logits mode).
  12. """
  13. fused_attn: torch.jit.Final[bool]
  14. def __init__(
  15. self,
  16. in_features: int,
  17. out_features: int = None,
  18. embed_dim: int = None,
  19. num_heads: int = 8,
  20. feat_size: Optional[int] = None,
  21. mlp_ratio: float = 4.0,
  22. qkv_bias: bool = True,
  23. qk_norm: bool = False,
  24. latent_len: int = 1,
  25. latent_dim: int = None,
  26. pos_embed: str = '',
  27. pool_type: str = 'token',
  28. norm_layer: Optional[Type[nn.Module]] = None,
  29. act_layer: Optional[Type[nn.Module]] = nn.GELU,
  30. drop: float = 0.0,
  31. device = None,
  32. dtype = None
  33. ):
  34. dd = {'device': device, 'dtype': dtype}
  35. super().__init__()
  36. embed_dim = embed_dim or in_features
  37. if out_features is None:
  38. out_features = in_features
  39. assert embed_dim % num_heads == 0
  40. self.num_heads = num_heads
  41. self.head_dim = embed_dim // num_heads
  42. self.feat_size = feat_size
  43. self.scale = self.head_dim ** -0.5
  44. self.pool = pool_type
  45. self.fused_attn = use_fused_attn()
  46. if pos_embed == 'abs':
  47. assert feat_size is not None
  48. self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features, **dd))
  49. else:
  50. self.pos_embed = None
  51. self.latent_dim = latent_dim or embed_dim
  52. self.latent_len = latent_len
  53. self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim, **dd))
  54. self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias, **dd)
  55. self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias, **dd)
  56. if qk_norm:
  57. qk_norm_layer = norm_layer or nn.LayerNorm
  58. self.q_norm = qk_norm_layer(self.head_dim, **dd)
  59. self.k_norm = qk_norm_layer(self.head_dim, **dd)
  60. else:
  61. self.q_norm = nn.Identity()
  62. self.k_norm = nn.Identity()
  63. if out_features > 0:
  64. self.proj = nn.Linear(embed_dim, out_features, **dd)
  65. self.proj_drop = nn.Dropout(drop)
  66. self.norm = norm_layer(out_features, **dd) if norm_layer is not None else nn.Identity()
  67. self.mlp = Mlp(out_features, int(out_features * mlp_ratio), out_features=out_features, act_layer=act_layer, **dd)
  68. else:
  69. self.proj = nn.Identity()
  70. self.proj_drop = nn.Dropout(drop)
  71. self.norm = nn.Identity()
  72. self.mlp = None
  73. out_features = embed_dim
  74. self.out_features = out_features
  75. self.init_weights()
  76. def init_weights(self):
  77. if self.pos_embed is not None:
  78. trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
  79. trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5)
  80. def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
  81. B, N, C = x.shape
  82. if self.pos_embed is not None:
  83. # FIXME interpolate
  84. x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
  85. q_latent = self.latent.expand(B, -1, -1)
  86. q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
  87. kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  88. k, v = kv.unbind(0)
  89. q, k = self.q_norm(q), self.k_norm(k)
  90. if self.fused_attn:
  91. x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
  92. else:
  93. q = q * self.scale
  94. attn = q @ k.transpose(-2, -1)
  95. attn = maybe_add_mask(attn, attn_mask)
  96. attn = attn.softmax(dim=-1)
  97. x = attn @ v
  98. x = x.transpose(1, 2).reshape(B, self.latent_len, C)
  99. x = self.proj(x)
  100. x = self.proj_drop(x)
  101. if self.mlp is not None:
  102. x = x + self.mlp(self.norm(x))
  103. # optional pool if latent seq_len > 1 and pooled output is desired
  104. if self.pool == 'token':
  105. x = x[:, 0]
  106. elif self.pool == 'avg':
  107. x = x.mean(1)
  108. return x
  109. class AttentionPoolPrr(nn.Module):
  110. """ Patch Representation Refinement (PRR) attention pool.
  111. From "Locality-Attending Vision Transformer" (ICLR 2026).
  112. Parameter-free multi-head self-attention that refines all patch representations
  113. before pooling. No Q/K/V projections — input is reshaped directly into multi-head
  114. format for self-attention.
  115. """
  116. fused_attn: torch.jit.Final[bool]
  117. def __init__(
  118. self,
  119. dim: int,
  120. num_heads: int = 8,
  121. pool_type: str = 'token',
  122. pre_norm: bool = False,
  123. post_norm: bool = False,
  124. norm_layer: Optional[Type[nn.Module]] = None,
  125. device=None,
  126. dtype=None,
  127. ):
  128. dd = {'device': device, 'dtype': dtype}
  129. super().__init__()
  130. assert pool_type in ('token', 'avg'), f"pool_type must be 'token' or 'avg', got '{pool_type}'"
  131. assert dim % num_heads == 0, f"dim ({dim}) must be divisible by num_heads ({num_heads})"
  132. if norm_layer is None and (pre_norm or post_norm):
  133. norm_layer = nn.LayerNorm
  134. self.num_heads = num_heads
  135. self.head_dim = dim // num_heads
  136. self.scale = self.head_dim ** -0.5
  137. self.pool = pool_type
  138. self.fused_attn = use_fused_attn()
  139. self.out_features = dim
  140. self.pre_norm = norm_layer(dim, **dd) if pre_norm else nn.Identity()
  141. self.post_norm = norm_layer(dim, **dd) if post_norm else nn.Identity()
  142. def forward(self, x: torch.Tensor) -> torch.Tensor:
  143. B, N, C = x.shape
  144. x = self.pre_norm(x)
  145. # Parameter-free self-attention: reshape into multi-head format
  146. qkv = x.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, N, D)
  147. if self.fused_attn:
  148. x = F.scaled_dot_product_attention(qkv, qkv, qkv)
  149. else:
  150. attn = (qkv * self.scale) @ qkv.transpose(-2, -1)
  151. attn = attn.softmax(dim=-1)
  152. x = attn @ qkv
  153. x = x.transpose(1, 2).reshape(B, N, C)
  154. x = self.post_norm(x)
  155. # Pool
  156. if self.pool == 'token':
  157. x = x[:, 0]
  158. elif self.pool == 'avg':
  159. x = x.mean(1)
  160. return x