attention_pool2d.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. """ Attention Pool 2D
  2. Implementations of 2D spatial feature pooling using multi-head attention instead of average pool.
  3. Based on idea in CLIP by OpenAI, licensed Apache 2.0
  4. https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
  5. Hacked together by / Copyright 2021 Ross Wightman
  6. """
  7. from typing import Optional, Union, Tuple
  8. import torch
  9. import torch.nn as nn
  10. from .config import use_fused_attn
  11. from .helpers import to_2tuple
  12. from .pos_embed import resample_abs_pos_embed
  13. from .pos_embed_sincos import apply_rot_embed_cat, create_rope_embed
  14. from .weight_init import trunc_normal_
  15. class RotAttentionPool2d(nn.Module):
  16. """ Attention based 2D feature pooling w/ rotary (relative) pos embedding.
  17. This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
  18. Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed.
  19. https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
  20. NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
  21. train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
  22. Setting out_features=0 disables the output projection (pre_logits mode).
  23. """
  24. fused_attn: torch.jit.Final[bool]
  25. def __init__(
  26. self,
  27. in_features: int,
  28. out_features: Optional[int] = None,
  29. ref_feat_size: Union[int, Tuple[int, int]] = 7,
  30. embed_dim: Optional[int] = None,
  31. head_dim: Optional[int] = 64,
  32. num_heads: Optional[int] = None,
  33. qkv_bias: bool = True,
  34. qkv_separate: bool = False,
  35. pool_type: str = 'token',
  36. class_token: bool = False,
  37. drop_rate: float = 0.,
  38. rope_type: str = 'cat',
  39. device=None,
  40. dtype=None,
  41. ):
  42. dd = {'device': device, 'dtype': dtype}
  43. super().__init__()
  44. assert pool_type in ('', 'token')
  45. self.embed_dim = embed_dim = embed_dim or in_features
  46. self.in_features = in_features
  47. if out_features is None:
  48. self.out_features = in_features
  49. elif out_features > 0:
  50. self.out_features = out_features
  51. else:
  52. self.out_features = embed_dim # out_features=0 disables projection
  53. ref_feat_size = to_2tuple(ref_feat_size)
  54. if num_heads is not None:
  55. assert embed_dim % num_heads == 0
  56. head_dim = embed_dim // num_heads
  57. else:
  58. assert embed_dim % head_dim == 0
  59. num_heads = embed_dim // head_dim
  60. self.num_heads = num_heads
  61. self.head_dim = head_dim
  62. self.pool_type = pool_type.lower()
  63. self.scale = self.head_dim ** -0.5
  64. self.fused_attn = use_fused_attn()
  65. self.rope_type = rope_type
  66. if class_token:
  67. self.cls_token = nn.Parameter(torch.zeros(1, embed_dim, **dd))
  68. else:
  69. self.cls_token = None
  70. if qkv_separate:
  71. self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
  72. self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
  73. self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
  74. self.qkv = None
  75. else:
  76. self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias, **dd)
  77. self.drop = nn.Dropout(drop_rate)
  78. self.proj = nn.Linear(embed_dim, self.out_features, **dd) if out_features != 0 else nn.Identity()
  79. self.pos_embed = create_rope_embed(
  80. rope_type=rope_type,
  81. dim=embed_dim,
  82. num_heads=num_heads,
  83. in_pixels=False,
  84. ref_feat_shape=ref_feat_size,
  85. rotate_half=False,
  86. **dd,
  87. )
  88. def init_weights(self, zero_init_last: bool = False):
  89. if self.qkv is None:
  90. in_features = self.q.in_features
  91. trunc_normal_(self.q.weight, std=in_features ** -0.5)
  92. nn.init.zeros_(self.q.bias)
  93. trunc_normal_(self.k.weight, std=in_features ** -0.5)
  94. nn.init.zeros_(self.k.bias)
  95. trunc_normal_(self.v.weight, std=in_features ** -0.5)
  96. nn.init.zeros_(self.v.bias)
  97. else:
  98. in_features = self.qkv.in_features
  99. trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
  100. nn.init.zeros_(self.qkv.bias)
  101. def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None):
  102. # NOTE: this module is being used as a head, so need compatible reset()
  103. if pool_type is not None:
  104. assert pool_type in ('', 'token')
  105. self.pool_type = pool_type
  106. if num_classes is not None:
  107. self.proj = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  108. self.out_features = num_classes if num_classes > 0 else self.embed_dim
  109. def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
  110. if self.pool_type == 'token':
  111. x = x[:, 0]
  112. else:
  113. # if not pooled, return spatial output without token
  114. x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
  115. return x
  116. def forward(self, x, pre_logits: bool = False):
  117. B, _, H, W = x.shape
  118. N = H * W
  119. x = x.flatten(2).transpose(1, 2)
  120. if self.cls_token is None:
  121. x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
  122. else:
  123. x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
  124. if self.qkv is None:
  125. q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
  126. k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
  127. v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
  128. else:
  129. x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  130. q, k, v = x.unbind(0)
  131. rope = self.pos_embed.get_embed((H, W))
  132. if isinstance(rope, tuple):
  133. # RotaryEmbedding returns (sin, cos) tuple - concatenate for apply_rot_embed_cat
  134. rope = torch.cat(rope, dim=-1)
  135. q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], dim=2).type_as(v)
  136. k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], dim=2).type_as(v)
  137. if self.fused_attn:
  138. x = nn.functional.scaled_dot_product_attention(q, k, v)
  139. else:
  140. q = q * self.scale
  141. attn = q @ k.transpose(-2, -1)
  142. attn = attn.softmax(dim=-1)
  143. x = attn @ v
  144. x = x.transpose(1, 2).reshape(B, N + 1, -1)
  145. x = self.drop(x)
  146. if pre_logits:
  147. x = self._pool(x, H, W)
  148. return x
  149. x = self.proj(x)
  150. x = self._pool(x, H, W)
  151. return x
  152. class AttentionPool2d(nn.Module):
  153. """ Attention based 2D feature pooling w/ learned (absolute) pos embedding.
  154. This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
  155. It was based on impl in CLIP by OpenAI
  156. https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
  157. NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
  158. Setting out_features=0 disables the output projection (pre_logits mode).
  159. """
  160. fused_attn: torch.jit.Final[bool]
  161. def __init__(
  162. self,
  163. in_features: int,
  164. feat_size: Union[int, Tuple[int, int]] = 7,
  165. out_features: Optional[int] = None,
  166. embed_dim: Optional[int] = None,
  167. head_dim: Optional[int] = 64,
  168. num_heads: Optional[int] = None,
  169. qkv_bias: bool = True,
  170. qkv_separate: bool = False,
  171. pool_type: str = 'token',
  172. class_token: bool = False,
  173. drop_rate: float = 0.,
  174. device=None,
  175. dtype=None,
  176. ):
  177. dd = {'device': device, 'dtype': dtype}
  178. super().__init__()
  179. assert pool_type in ('', 'token')
  180. self.embed_dim = embed_dim = embed_dim or in_features
  181. self.in_features = in_features
  182. if out_features is None:
  183. self.out_features = in_features
  184. elif out_features > 0:
  185. self.out_features = out_features
  186. else:
  187. self.out_features = embed_dim # out_features=0 disables projection
  188. if num_heads is not None:
  189. assert embed_dim % num_heads == 0
  190. head_dim = embed_dim // num_heads
  191. else:
  192. assert embed_dim % head_dim == 0
  193. num_heads = embed_dim // head_dim
  194. self.feat_size = to_2tuple(feat_size)
  195. self.seq_len = self.feat_size[0] * self.feat_size[1]
  196. self.num_heads = num_heads
  197. self.head_dim = head_dim
  198. self.pool_type = pool_type
  199. self.scale = self.head_dim ** -0.5
  200. self.fused_attn = use_fused_attn()
  201. if class_token:
  202. self.cls_token = nn.Parameter(torch.zeros(1, embed_dim, **dd))
  203. else:
  204. self.cls_token = None
  205. if qkv_separate:
  206. self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
  207. self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
  208. self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
  209. self.qkv = None
  210. else:
  211. self.q = self.k = self.v = None
  212. self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias, **dd)
  213. self.drop = nn.Dropout(drop_rate)
  214. self.proj = nn.Linear(embed_dim, self.out_features, **dd) if out_features != 0 else nn.Identity()
  215. self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features, **dd))
  216. self.init_weights()
  217. def init_weights(self, zero_init_last: bool = False):
  218. if self.qkv is None:
  219. in_features = self.q.in_features
  220. trunc_normal_(self.q.weight, std=in_features ** -0.5)
  221. nn.init.zeros_(self.q.bias)
  222. trunc_normal_(self.k.weight, std=in_features ** -0.5)
  223. nn.init.zeros_(self.k.bias)
  224. trunc_normal_(self.v.weight, std=in_features ** -0.5)
  225. nn.init.zeros_(self.v.bias)
  226. else:
  227. in_features = self.qkv.in_features
  228. trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
  229. nn.init.zeros_(self.qkv.bias)
  230. trunc_normal_(self.pos_embed, std=in_features ** -0.5)
  231. def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None):
  232. # NOTE: this module is being used as a head, so need compatible reset()
  233. if pool_type is not None:
  234. assert pool_type in ('', 'token')
  235. self.pool_type = pool_type
  236. if num_classes is not None:
  237. self.proj = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  238. self.out_features = num_classes if num_classes > 0 else self.embed_dim
  239. def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
  240. if self.pool_type == 'token':
  241. x = x[:, 0]
  242. else:
  243. # if not pooled, return spatial output without token
  244. x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
  245. return x
  246. def forward(self, x, pre_logits: bool = False):
  247. B, _, H, W = x.shape
  248. N = H * W
  249. x = x.flatten(2).transpose(1, 2)
  250. if self.cls_token is None:
  251. x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
  252. else:
  253. x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
  254. pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
  255. x = x + pos_embed
  256. if self.qkv is None:
  257. q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
  258. k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
  259. v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
  260. else:
  261. x = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  262. q, k, v = x.unbind(0)
  263. if self.fused_attn:
  264. x = nn.functional.scaled_dot_product_attention(q, k, v)
  265. else:
  266. q = q * self.scale
  267. attn = q @ k.transpose(-2, -1)
  268. attn = attn.softmax(dim=-1)
  269. x = attn @ v
  270. x = x.transpose(1, 2).reshape(B, N + 1, -1)
  271. x = self.drop(x)
  272. if pre_logits:
  273. x = self._pool(x, H, W)
  274. return x
  275. x = self.proj(x)
  276. x = self._pool(x, H, W)
  277. return x