other_pool.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. """ Non-Local Attention Pooling Layers
  2. A collection of global pooling layers that go beyond simple avg/max pooling.
  3. LSEPool - LogSumExp pooling, a smooth approximation between avg and max pooling
  4. SimPool - Attention-based pooling from 'Keep It SimPool' (ICCV 2023)
  5. Based on implementations from:
  6. * LSE Pooling: custom implementation by Bill Psomas
  7. * SimPool: https://arxiv.org/abs/2309.06891 - 'Keep It SimPool: Who Said Supervised Transformers
  8. Suffer from Attention Deficit?' by Bill Psomas et al.
  9. Hacked together by / Copyright 2024 Ross Wightman, original code by Bill Psomas
  10. """
  11. from typing import Optional, Type, Union
  12. import torch
  13. import torch.nn as nn
  14. import torch.nn.functional as F
  15. from .config import use_fused_attn
  16. class LsePlus2d(nn.Module):
  17. """LogSumExp (LSE) Pooling for 2D inputs.
  18. A smooth approximation to max pooling that provides a learnable interpolation between
  19. average and max pooling. When r is large, LSE approaches max pooling; when r is small,
  20. it approaches average pooling.
  21. Implements: (1/r) * log((1/n) * sum(exp(r * (x - x_max)))) + x_max
  22. The x_max subtraction provides numerical stability.
  23. """
  24. def __init__(
  25. self,
  26. r: float = 10.0,
  27. r_learnable: bool = True,
  28. flatten: bool = True,
  29. device=None,
  30. dtype=None,
  31. ):
  32. """
  33. Args:
  34. r: Initial value of the pooling parameter. Higher = closer to max pooling.
  35. r_learnable: If True, r is a learnable parameter.
  36. flatten: If True, flatten spatial dims in output.
  37. """
  38. super().__init__()
  39. if r_learnable:
  40. self.r = nn.Parameter(torch.tensor(r, device=device, dtype=dtype))
  41. else:
  42. self.register_buffer('r', torch.tensor(r, device=device, dtype=dtype))
  43. self.flatten = flatten
  44. def forward(self, x: torch.Tensor) -> torch.Tensor:
  45. x_max = F.adaptive_max_pool2d(x, 1)
  46. exp_x = torch.exp(self.r * (x - x_max))
  47. sum_exp = exp_x.mean(dim=(2, 3), keepdim=True)
  48. out = x_max + (1.0 / self.r) * torch.log(sum_exp)
  49. if self.flatten:
  50. out = out.flatten(1)
  51. return out
  52. class LsePlus1d(nn.Module):
  53. """LogSumExp (LSE) Pooling for sequence (NLC) inputs.
  54. A smooth approximation to max pooling that provides a learnable interpolation between
  55. average and max pooling. When r is large, LSE approaches max pooling; when r is small,
  56. it approaches average pooling.
  57. """
  58. def __init__(
  59. self,
  60. r: float = 10.0,
  61. r_learnable: bool = True,
  62. device=None,
  63. dtype=None,
  64. ):
  65. """
  66. Args:
  67. r: Initial value of the pooling parameter. Higher = closer to max pooling.
  68. r_learnable: If True, r is a learnable parameter.
  69. """
  70. super().__init__()
  71. if r_learnable:
  72. self.r = nn.Parameter(torch.tensor(r, device=device, dtype=dtype))
  73. else:
  74. self.register_buffer('r', torch.tensor(r, device=device, dtype=dtype))
  75. def forward(self, x: torch.Tensor) -> torch.Tensor:
  76. # x: (B, N, C)
  77. x_max = x.max(dim=1, keepdim=True).values
  78. exp_x = torch.exp(self.r * (x - x_max))
  79. sum_exp = exp_x.mean(dim=1, keepdim=True)
  80. out = x_max + (1.0 / self.r) * torch.log(sum_exp)
  81. return out.squeeze(1) # (B, C)
  82. class SimPool2d(nn.Module):
  83. """SimPool: Simple Attention-Based Pooling for 2D (NCHW) inputs.
  84. From 'Keep It SimPool: Who Said Supervised Transformers Suffer from Attention Deficit?'
  85. https://arxiv.org/abs/2309.06891
  86. Uses GAP as query initialization and applies cross-attention between the GAP query
  87. and spatial features to produce a weighted pooled representation.
  88. """
  89. fused_attn: torch.jit.Final[bool]
  90. def __init__(
  91. self,
  92. dim: int,
  93. num_heads: int = 1,
  94. qkv_bias: bool = False,
  95. qk_norm: bool = False,
  96. gamma: Optional[float] = None,
  97. norm_layer: Optional[Type[nn.Module]] = None,
  98. device=None,
  99. dtype=None,
  100. ):
  101. """
  102. Args:
  103. dim: Input feature dimension (number of channels).
  104. num_heads: Number of attention heads.
  105. qkv_bias: If True, add bias to query and key projections.
  106. qk_norm: If True, apply normalization to queries and keys.
  107. gamma: If provided, apply power normalization to values with this exponent.
  108. norm_layer: Normalization layer for patches and optionally qk_norm.
  109. flatten: If True, flatten output to (B, C).
  110. """
  111. super().__init__()
  112. dd = {'device': device, 'dtype': dtype}
  113. assert dim % num_heads == 0, 'dim must be divisible by num_heads'
  114. self.num_heads = num_heads
  115. self.head_dim = dim // num_heads
  116. self.scale = self.head_dim ** -0.5
  117. self.gamma = gamma
  118. self.fused_attn = use_fused_attn()
  119. norm_layer = norm_layer or nn.LayerNorm
  120. self.norm = norm_layer(dim, **dd)
  121. self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  122. self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  123. if qk_norm:
  124. self.q_norm = norm_layer(self.head_dim, **dd)
  125. self.k_norm = norm_layer(self.head_dim, **dd)
  126. else:
  127. self.q_norm = nn.Identity()
  128. self.k_norm = nn.Identity()
  129. def forward(self, x: torch.Tensor) -> torch.Tensor:
  130. B, C, H, W = x.shape
  131. N = H * W
  132. # Reshape to (B, N, C) for attention
  133. x = x.flatten(2).transpose(1, 2) # (B, N, C)
  134. # GAP as query initialization
  135. q = x.mean(dim=1, keepdim=True) # (B, 1, C)
  136. # Normalize patches for keys and values
  137. x_norm = self.norm(x)
  138. # Project query and keys
  139. q = self.q(q).reshape(B, 1, self.num_heads, self.head_dim).transpose(1, 2)
  140. k = self.k(x_norm).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
  141. v = x_norm.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
  142. q, k = self.q_norm(q), self.k_norm(k)
  143. if self.gamma is not None:
  144. # Power normalization on values
  145. v_min = v.amin(dim=-2, keepdim=True)
  146. v_shifted = v - v_min + 1e-6
  147. if self.fused_attn:
  148. attn_out = F.scaled_dot_product_attention(q, k, v_shifted.pow(self.gamma))
  149. else:
  150. attn = (q * self.scale) @ k.transpose(-2, -1)
  151. attn = attn.softmax(dim=-1)
  152. attn_out = attn @ v_shifted.pow(self.gamma)
  153. out = attn_out.pow(1.0 / self.gamma)
  154. else:
  155. if self.fused_attn:
  156. out = F.scaled_dot_product_attention(q, k, v)
  157. else:
  158. attn = (q * self.scale) @ k.transpose(-2, -1)
  159. attn = attn.softmax(dim=-1)
  160. out = attn @ v
  161. # (B, num_heads, 1, head_dim) -> (B, C) or (B, C)
  162. out = out.transpose(1, 2).reshape(B, C)
  163. return out
  164. class SimPool1d(nn.Module):
  165. """SimPool: Simple Attention-Based Pooling for sequence (NLC) inputs.
  166. From 'Keep It SimPool: Who Said Supervised Transformers Suffer from Attention Deficit?'
  167. https://arxiv.org/abs/2309.06891
  168. Uses GAP as query initialization and applies cross-attention between the GAP query
  169. and sequence tokens to produce a weighted pooled representation.
  170. """
  171. fused_attn: torch.jit.Final[bool]
  172. def __init__(
  173. self,
  174. dim: int,
  175. num_heads: int = 1,
  176. qkv_bias: bool = False,
  177. qk_norm: bool = False,
  178. gamma: Optional[float] = None,
  179. norm_layer: Optional[Type[nn.Module]] = None,
  180. device=None,
  181. dtype=None,
  182. ):
  183. """
  184. Args:
  185. dim: Input feature dimension.
  186. num_heads: Number of attention heads.
  187. qkv_bias: If True, add bias to query and key projections.
  188. qk_norm: If True, apply normalization to queries and keys.
  189. gamma: If provided, apply power normalization to values with this exponent.
  190. norm_layer: Normalization layer for tokens and optionally qk_norm.
  191. """
  192. super().__init__()
  193. dd = {'device': device, 'dtype': dtype}
  194. assert dim % num_heads == 0, 'dim must be divisible by num_heads'
  195. self.num_heads = num_heads
  196. self.head_dim = dim // num_heads
  197. self.scale = self.head_dim ** -0.5
  198. self.gamma = gamma
  199. self.fused_attn = use_fused_attn()
  200. norm_layer = norm_layer or nn.LayerNorm
  201. self.norm = norm_layer(dim, **dd)
  202. self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  203. self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  204. if qk_norm:
  205. self.q_norm = norm_layer(self.head_dim, **dd)
  206. self.k_norm = norm_layer(self.head_dim, **dd)
  207. else:
  208. self.q_norm = nn.Identity()
  209. self.k_norm = nn.Identity()
  210. def forward(self, x: torch.Tensor) -> torch.Tensor:
  211. B, N, C = x.shape
  212. # GAP as query initialization
  213. q = x.mean(dim=1, keepdim=True) # (B, 1, C)
  214. # Normalize tokens for keys and values
  215. x_norm = self.norm(x)
  216. # Project query and keys
  217. q = self.q(q).reshape(B, 1, self.num_heads, self.head_dim).transpose(1, 2)
  218. k = self.k(x_norm).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
  219. v = x_norm.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
  220. q, k = self.q_norm(q), self.k_norm(k)
  221. if self.gamma is not None:
  222. # Power normalization on values
  223. v_min = v.amin(dim=-2, keepdim=True)
  224. v_shifted = v - v_min + 1e-6
  225. if self.fused_attn:
  226. attn_out = F.scaled_dot_product_attention(q, k, v_shifted.pow(self.gamma))
  227. else:
  228. attn = (q * self.scale) @ k.transpose(-2, -1)
  229. attn = attn.softmax(dim=-1)
  230. attn_out = attn @ v_shifted.pow(self.gamma)
  231. out = attn_out.pow(1.0 / self.gamma)
  232. else:
  233. if self.fused_attn:
  234. out = F.scaled_dot_product_attention(q, k, v)
  235. else:
  236. attn = (q * self.scale) @ k.transpose(-2, -1)
  237. attn = attn.softmax(dim=-1)
  238. out = attn @ v
  239. # (B, num_heads, 1, head_dim) -> (B, C)
  240. out = out.transpose(1, 2).reshape(B, C)
  241. return out