attention2d.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. from typing import List, Optional, Type, Union
  2. import torch
  3. from torch import nn as nn
  4. from torch.nn import functional as F
  5. from .config import use_fused_attn
  6. from .create_conv2d import create_conv2d
  7. from .helpers import to_2tuple
  8. from .pool2d_same import create_pool2d
  9. class MultiQueryAttentionV2(nn.Module):
  10. """Multi Query Attention.
  11. Fast Transformer Decoding: One Write-Head is All You Need
  12. https://arxiv.org/pdf/1911.02150.pdf
  13. This is an acceletor optimized version - removing multiple unnecessary
  14. tensor transpose by re-arranging indices according to the following rules: 1)
  15. contracted indices are at the end, 2) other indices have the same order in the
  16. input and output tensores.
  17. Compared to V1, this gives 3x speed up.
  18. """
  19. def __init__(
  20. self,
  21. dim: int,
  22. dim_out: Optional[int] = None,
  23. num_heads: int = 8,
  24. key_dim: int = 64,
  25. value_dim: int = 64,
  26. attn_drop: float = 0.,
  27. proj_drop: float = 0.,
  28. device=None,
  29. dtype=None,
  30. ):
  31. """Initializer."""
  32. dd = {'device': device, 'dtype': dtype}
  33. super().__init__()
  34. dim_out = dim_out or dim
  35. self.num_heads = num_heads
  36. self.key_dim = key_dim
  37. self.value_dim = value_dim
  38. self.scale = key_dim ** -0.5
  39. self.query_proj = nn.Parameter(torch.empty((self.num_heads, self.key_dim, dim), **dd))
  40. self.key_proj = nn.Parameter(torch.empty((dim, self.key_dim), **dd))
  41. self.value_proj = nn.Parameter(torch.empty((dim, self.value_dim), **dd))
  42. self.attn_drop = nn.Dropout(attn_drop)
  43. self.out_proj = nn.Parameter(torch.empty((dim_out, self.num_heads, self.value_dim), **dd))
  44. self.proj_drop = nn.Dropout(proj_drop)
  45. self.reset_parameters()
  46. def reset_parameters(self):
  47. scale = self.key_proj.shape[0] ** -0.5
  48. nn.init.normal_(self.query_proj, std=scale)
  49. nn.init.normal_(self.key_proj, std=scale)
  50. nn.init.normal_(self.value_proj, std=scale)
  51. nn.init.normal_(self.out_proj, std=self.out_proj.shape[0] ** -0.5)
  52. def _reshape_input(self, t):
  53. """Reshapes a tensor to three dimensions, keeping the first and last."""
  54. s = t.shape
  55. # Propagate the shape statically where possible.
  56. #num = t.shape[1:-1].numel()
  57. #return t.reshape(s[0], num, s[-1])
  58. return t.reshape(s[0], s[1], -1).transpose(1, 2)
  59. def forward(self, x, m: Optional[torch.Tensor] = None):
  60. """Run layer computation."""
  61. b, _, h, w = x.shape
  62. m = m if m is not None else x
  63. reshaped_x = self._reshape_input(x)
  64. reshaped_m = self._reshape_input(m)
  65. q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj)
  66. k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj)
  67. attn = torch.einsum('bnhk,bmk->bnhm', q, k) * self.scale
  68. attn = attn.softmax(dim=-1)
  69. attn = self.attn_drop(attn)
  70. v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj)
  71. o = torch.einsum('bnhm,bmv->bnhv', attn, v)
  72. result = torch.einsum('bnhv,dhv->bdn', o, self.out_proj)
  73. result = self.proj_drop(result)
  74. return result.reshape(b, -1, h, w)
  75. class MultiQueryAttention2d(nn.Module):
  76. """Multi Query Attention with spatial downsampling.
  77. 3 parameters are introduced for the spatial downsampling:
  78. 1. kv_stride: downsampling factor on Key and Values only.
  79. 2. query_strides: horizontal & vertical strides on Query only.
  80. This is an optimized version.
  81. 1. Projections in Attention is explicit written out as 1x1 Conv2D.
  82. 2. Additional reshapes are introduced to bring a up to 3x speed up.
  83. """
  84. fused_attn: torch.jit.Final[bool]
  85. def __init__(
  86. self,
  87. dim: int,
  88. dim_out: Optional[int] = None,
  89. num_heads: int = 8,
  90. key_dim: Optional[int] = None,
  91. value_dim: Optional[int] = None,
  92. query_strides: int = 1,
  93. kv_stride: int = 1,
  94. dw_kernel_size: int = 3,
  95. dilation: int = 1,
  96. padding: Union[str, int, List[int]] = '',
  97. attn_drop: float = 0.,
  98. proj_drop: float = 0.,
  99. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  100. use_bias: bool = False,
  101. device=None,
  102. dtype=None,
  103. ):
  104. """Initializer.
  105. Args:
  106. num_heads: Number of attention heads.
  107. key_dim: Size of the attention key dimension.
  108. value_dim: Size of the attention value dimension.
  109. query_strides: Vertical stride size for query only.
  110. kv_stride: Key and value stride size.
  111. dw_kernel_size: Spatial dimension of the depthwise kernel.
  112. """
  113. dd = {'device': device, 'dtype': dtype}
  114. super().__init__()
  115. dim_out = dim_out or dim
  116. self.num_heads = num_heads
  117. self.key_dim = key_dim or dim // num_heads
  118. self.value_dim = value_dim or dim // num_heads
  119. self.query_strides = to_2tuple(query_strides)
  120. self.kv_stride = kv_stride
  121. self.has_query_strides = any([s > 1 for s in self.query_strides])
  122. self.scale = self.key_dim ** -0.5
  123. self.fused_attn = use_fused_attn()
  124. self.drop = attn_drop
  125. self.query = nn.Sequential()
  126. if self.has_query_strides:
  127. # FIXME dilation
  128. if padding == 'same':
  129. self.query.add_module('down_pool', create_pool2d(
  130. 'avg',
  131. kernel_size=self.query_strides,
  132. padding='same',
  133. ))
  134. else:
  135. # no pad if not 'same' as kern=stride=even
  136. self.query.add_module('down_pool', nn.AvgPool2d(kernel_size=query_strides))
  137. self.query.add_module('norm', norm_layer(dim, **dd))
  138. self.query.add_module('proj', create_conv2d(
  139. dim,
  140. self.num_heads * self.key_dim,
  141. kernel_size=1,
  142. bias=use_bias,
  143. **dd,
  144. ))
  145. self.key = nn.Sequential()
  146. if kv_stride > 1:
  147. self.key.add_module('down_conv', create_conv2d(
  148. dim,
  149. dim,
  150. kernel_size=dw_kernel_size,
  151. stride=kv_stride,
  152. dilation=dilation,
  153. padding=padding,
  154. depthwise=True,
  155. **dd,
  156. ))
  157. self.key.add_module('norm', norm_layer(dim, **dd))
  158. self.key.add_module('proj', create_conv2d(
  159. dim,
  160. self.key_dim,
  161. kernel_size=1,
  162. padding=padding,
  163. bias=use_bias,
  164. **dd,
  165. ))
  166. self.value = nn.Sequential()
  167. if kv_stride > 1:
  168. self.value.add_module('down_conv', create_conv2d(
  169. dim,
  170. dim,
  171. kernel_size=dw_kernel_size,
  172. stride=kv_stride,
  173. dilation=dilation,
  174. padding=padding,
  175. depthwise=True,
  176. **dd,
  177. ))
  178. self.value.add_module('norm', norm_layer(dim, **dd))
  179. self.value.add_module('proj', create_conv2d(
  180. dim,
  181. self.value_dim,
  182. kernel_size=1,
  183. bias=use_bias,
  184. **dd,
  185. ))
  186. self.attn_drop = nn.Dropout(attn_drop)
  187. self.output = nn.Sequential()
  188. if self.has_query_strides:
  189. self.output.add_module('upsample', nn.Upsample(
  190. scale_factor=self.query_strides,
  191. mode='bilinear',
  192. align_corners=False
  193. ))
  194. self.output.add_module('proj', create_conv2d(
  195. self.value_dim * self.num_heads,
  196. dim_out,
  197. kernel_size=1,
  198. bias=use_bias,
  199. **dd,
  200. ))
  201. self.output.add_module('drop', nn.Dropout(proj_drop))
  202. self.einsum = False
  203. self.init_weights()
  204. def init_weights(self):
  205. # using xavier appeared to improve stability for mobilenetv4 hybrid w/ this layer
  206. nn.init.xavier_uniform_(self.query.proj.weight)
  207. nn.init.xavier_uniform_(self.key.proj.weight)
  208. nn.init.xavier_uniform_(self.value.proj.weight)
  209. if self.kv_stride > 1:
  210. nn.init.xavier_uniform_(self.key.down_conv.weight)
  211. nn.init.xavier_uniform_(self.value.down_conv.weight)
  212. nn.init.xavier_uniform_(self.output.proj.weight)
  213. def _reshape_input(self, t: torch.Tensor):
  214. """Reshapes a tensor to three dimensions, keeping the batch and channels."""
  215. s = t.shape
  216. t = t.reshape(s[0], s[1], -1).transpose(1, 2)
  217. if self.einsum:
  218. return t
  219. else:
  220. return t.unsqueeze(1).contiguous()
  221. def _reshape_projected_query(self, t: torch.Tensor, num_heads: int, key_dim: int):
  222. """Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k]."""
  223. s = t.shape
  224. t = t.reshape(s[0], num_heads, key_dim, -1)
  225. if self.einsum:
  226. return t.permute(0, 3, 1, 2).contiguous()
  227. else:
  228. return t.transpose(-1, -2).contiguous()
  229. def _reshape_output(self, t: torch.Tensor, num_heads: int, h_px: int, w_px: int):
  230. """Reshape output:[b, n x n x h, k] -> [b, n, n, hk]."""
  231. s = t.shape
  232. feat_dim = s[-1] * num_heads
  233. if not self.einsum:
  234. t = t.transpose(1, 2)
  235. return t.reshape(s[0], h_px, w_px, feat_dim).permute(0, 3, 1, 2).contiguous()
  236. def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
  237. """Run layer computation."""
  238. B, C, H, W = s = x.shape
  239. q = self.query(x)
  240. # desired q shape: [b, h, k, n x n] - [b, l, h, k]
  241. q = self._reshape_projected_query(q, self.num_heads, self.key_dim)
  242. k = self.key(x)
  243. # output shape of k: [b, k, p], p = m x m
  244. k = self._reshape_input(k)
  245. v = self.value(x)
  246. # output shape of v: [ b, p, k], p = m x m
  247. v = self._reshape_input(v)
  248. # desired q shape: [b, n x n, h, k]
  249. # desired k shape: [b, m x m, k]
  250. # desired logits shape: [b, n x n, h, m x m]
  251. if self.einsum:
  252. attn = torch.einsum('blhk,bpk->blhp', q, k) * self.scale
  253. if attn_mask is not None:
  254. # NOTE: assumes mask is float and in correct shape
  255. attn = attn + attn_mask
  256. attn = attn.softmax(dim=-1)
  257. attn = self.attn_drop(attn)
  258. o = torch.einsum('blhp,bpk->blhk', attn, v)
  259. else:
  260. if self.fused_attn:
  261. o = F.scaled_dot_product_attention(
  262. q, k, v,
  263. attn_mask=attn_mask,
  264. dropout_p=self.attn_drop.p if self.training else 0.
  265. )
  266. else:
  267. q = q * self.scale
  268. attn = q @ k.transpose(-1, -2)
  269. if attn_mask is not None:
  270. # NOTE: assumes mask is float and in correct shape
  271. attn = attn + attn_mask
  272. attn = attn.softmax(dim=-1)
  273. attn = self.attn_drop(attn)
  274. o = attn @ v
  275. # reshape o into [b, hk, n, n,]
  276. o = self._reshape_output(o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1])
  277. x = self.output(o)
  278. return x
  279. class Attention2d(nn.Module):
  280. fused_attn: torch.jit.Final[bool]
  281. """ multi-head attention for 2D NCHW tensors"""
  282. def __init__(
  283. self,
  284. dim: int,
  285. dim_out: Optional[int] = None,
  286. num_heads: int = 32,
  287. bias: bool = True,
  288. expand_first: bool = False,
  289. head_first: bool = False,
  290. attn_drop: float = 0.,
  291. proj_drop: float = 0.,
  292. device=None,
  293. dtype=None,
  294. ):
  295. dd = {'device': device, 'dtype': dtype}
  296. super().__init__()
  297. dim_out = dim_out or dim
  298. dim_attn = dim_out if expand_first else dim
  299. self.num_heads = num_heads
  300. self.dim_head = dim_attn // num_heads
  301. self.head_first = head_first
  302. self.fused_attn = use_fused_attn()
  303. self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias, **dd)
  304. self.attn_drop = nn.Dropout(attn_drop)
  305. self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias, **dd)
  306. self.proj_drop = nn.Dropout(proj_drop)
  307. def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
  308. B, C, H, W = x.shape
  309. if self.head_first:
  310. q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2)
  311. else:
  312. q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
  313. if self.fused_attn:
  314. x = torch.nn.functional.scaled_dot_product_attention(
  315. q.transpose(-1, -2).contiguous(),
  316. k.transpose(-1, -2).contiguous(),
  317. v.transpose(-1, -2).contiguous(),
  318. attn_mask=attn_mask,
  319. dropout_p=self.attn_drop.p if self.training else 0.,
  320. ).transpose(-1, -2).reshape(B, -1, H, W)
  321. else:
  322. q = q.transpose(-1, -2)
  323. v = v.transpose(-1, -2)
  324. attn = q @ k * q.size(-1) ** -0.5
  325. if attn_mask is not None:
  326. # NOTE: assumes mask is float and in correct shape
  327. attn = attn + attn_mask
  328. attn = attn.softmax(dim=-1)
  329. attn = self.attn_drop(attn)
  330. x = (attn @ v).transpose(-1, -2).reshape(B, -1, H, W)
  331. x = self.proj(x)
  332. x = self.proj_drop(x)
  333. return x