halo_attn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. """ Halo Self Attention
  2. Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
  3. - https://arxiv.org/abs/2103.12731
  4. @misc{2103.12731,
  5. Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and
  6. Jonathon Shlens},
  7. Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones},
  8. Year = {2021},
  9. }
  10. Status:
  11. This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me.
  12. The attention mechanism works but it's slow as implemented.
  13. Hacked together by / Copyright 2021 Ross Wightman
  14. """
  15. from typing import List, Optional, Tuple, Union
  16. import torch
  17. from torch import nn
  18. import torch.nn.functional as F
  19. from .helpers import make_divisible
  20. from .weight_init import trunc_normal_
  21. from .trace_utils import _assert
  22. def rel_logits_1d(q, rel_k, permute_mask: List[int]):
  23. """ Compute relative logits along one dimension
  24. As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
  25. Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
  26. Args:
  27. q: (batch, height, width, dim)
  28. rel_k: (2 * window - 1, dim)
  29. permute_mask: permute output dim according to this
  30. """
  31. B, H, W, dim = q.shape
  32. rel_size = rel_k.shape[0]
  33. win_size = (rel_size + 1) // 2
  34. x = (q @ rel_k.transpose(-1, -2))
  35. x = x.reshape(-1, W, rel_size)
  36. # pad to shift from relative to absolute indexing
  37. x_pad = F.pad(x, [0, 1]).flatten(1)
  38. x_pad = F.pad(x_pad, [0, rel_size - W])
  39. # reshape and slice out the padded elements
  40. x_pad = x_pad.reshape(-1, W + 1, rel_size)
  41. x = x_pad[:, :W, win_size - 1:]
  42. # reshape and tile
  43. x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1)
  44. return x.permute(permute_mask)
  45. class PosEmbedRel(nn.Module):
  46. """ Relative Position Embedding
  47. As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
  48. Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
  49. """
  50. def __init__(
  51. self,
  52. block_size: int,
  53. win_size: int,
  54. dim_head: int,
  55. scale: float,
  56. device=None,
  57. dtype=None,
  58. ):
  59. """
  60. Args:
  61. block_size: block size
  62. win_size: neighbourhood window size
  63. dim_head: attention head dim
  64. scale: scale factor (for init)
  65. """
  66. dd = {'device': device, 'dtype': dtype}
  67. super().__init__()
  68. self.block_size = block_size
  69. self.dim_head = dim_head
  70. self.scale = scale
  71. self.height_rel = nn.Parameter(torch.empty(win_size * 2 - 1, dim_head, **dd))
  72. self.width_rel = nn.Parameter(torch.empty(win_size * 2 - 1, dim_head, **dd))
  73. self.reset_parameters()
  74. def reset_parameters(self):
  75. torch.nn.init.normal_(self.height_rel, std=self.scale)
  76. torch.nn.init.normal_(self.width_rel, std=self.scale)
  77. def forward(self, q):
  78. B, BB, HW, _ = q.shape
  79. # relative logits in width dimension.
  80. q = q.reshape(-1, self.block_size, self.block_size, self.dim_head)
  81. rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
  82. # relative logits in height dimension.
  83. q = q.transpose(1, 2)
  84. rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
  85. rel_logits = rel_logits_h + rel_logits_w
  86. rel_logits = rel_logits.reshape(B, BB, HW, -1)
  87. return rel_logits
  88. class HaloAttn(nn.Module):
  89. """ Halo Attention
  90. Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
  91. - https://arxiv.org/abs/2103.12731
  92. The internal dimensions of the attention module are controlled by the interaction of several arguments.
  93. * the output dimension of the module is specified by dim_out, which falls back to input dim if not set
  94. * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
  95. * the query and key (qk) dimensions are determined by
  96. * num_heads * dim_head if dim_head is not None
  97. * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None
  98. * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used
  99. Args:
  100. dim (int): input dimension to the module
  101. dim_out (int): output dimension of the module, same as dim if not set
  102. feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda)
  103. stride: output stride of the module, query downscaled if > 1 (default: 1).
  104. num_heads: parallel attention heads (default: 8).
  105. dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
  106. block_size (int): size of blocks. (default: 8)
  107. halo_size (int): size of halo overlap. (default: 3)
  108. qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
  109. qkv_bias (bool) : add bias to q, k, and v projections
  110. avg_down (bool): use average pool downsample instead of strided query blocks
  111. scale_pos_embed (bool): scale the position embedding as well as Q @ K
  112. """
  113. def __init__(
  114. self,
  115. dim: int,
  116. dim_out: Optional[int] = None,
  117. feat_size: Optional[Tuple[int, int]] = None,
  118. stride: int = 1,
  119. num_heads: int = 8,
  120. dim_head: Optional[int] = None,
  121. block_size: int = 8,
  122. halo_size: int = 3,
  123. qk_ratio: float = 1.0,
  124. qkv_bias: bool = False,
  125. avg_down: bool = False,
  126. scale_pos_embed: bool = False,
  127. device=None,
  128. dtype=None,
  129. ):
  130. dd = {'device': device, 'dtype': dtype}
  131. super().__init__()
  132. dim_out = dim_out or dim
  133. assert dim_out % num_heads == 0
  134. assert stride in (1, 2)
  135. self.num_heads = num_heads
  136. self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
  137. self.dim_head_v = dim_out // self.num_heads
  138. self.dim_out_qk = num_heads * self.dim_head_qk
  139. self.dim_out_v = num_heads * self.dim_head_v
  140. self.scale = self.dim_head_qk ** -0.5
  141. self.scale_pos_embed = scale_pos_embed
  142. self.block_size = self.block_size_ds = block_size
  143. self.halo_size = halo_size
  144. self.win_size = block_size + halo_size * 2 # neighbourhood window size
  145. self.block_stride = 1
  146. use_avg_pool = False
  147. if stride > 1:
  148. use_avg_pool = avg_down or block_size % stride != 0
  149. self.block_stride = 1 if use_avg_pool else stride
  150. self.block_size_ds = self.block_size // self.block_stride
  151. # FIXME not clear if this stride behaviour is what the paper intended
  152. # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
  153. # data in unfolded block form. I haven't wrapped my head around how that'd look.
  154. self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias, **dd)
  155. self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias, **dd)
  156. self.pos_embed = PosEmbedRel(
  157. block_size=self.block_size_ds,
  158. win_size=self.win_size,
  159. dim_head=self.dim_head_qk,
  160. scale=self.scale,
  161. **dd,
  162. )
  163. self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity()
  164. self.reset_parameters()
  165. def reset_parameters(self):
  166. std = self.q.weight.shape[1] ** -0.5 # fan-in
  167. trunc_normal_(self.q.weight, std=std)
  168. trunc_normal_(self.kv.weight, std=std)
  169. trunc_normal_(self.pos_embed.height_rel, std=self.scale)
  170. trunc_normal_(self.pos_embed.width_rel, std=self.scale)
  171. def forward(self, x):
  172. B, C, H, W = x.shape
  173. _assert(H % self.block_size == 0, '')
  174. _assert(W % self.block_size == 0, '')
  175. num_h_blocks = H // self.block_size
  176. num_w_blocks = W // self.block_size
  177. num_blocks = num_h_blocks * num_w_blocks
  178. q = self.q(x)
  179. # unfold
  180. q = q.reshape(
  181. -1, self.dim_head_qk,
  182. num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4)
  183. # B, num_heads * dim_head * block_size ** 2, num_blocks
  184. q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3)
  185. # B * num_heads, num_blocks, block_size ** 2, dim_head
  186. kv = self.kv(x)
  187. # Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not
  188. # lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach.
  189. # FIXME figure out how to switch impl between this and conv2d if XLA being used.
  190. kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])
  191. kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(
  192. B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1)
  193. k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)
  194. # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v
  195. if self.scale_pos_embed:
  196. attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale
  197. else:
  198. attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q)
  199. # B * num_heads, num_blocks, block_size ** 2, win_size ** 2
  200. attn = attn.softmax(dim=-1)
  201. out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks
  202. # fold
  203. out = out.reshape(-1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks)
  204. out = out.permute(0, 3, 1, 4, 2).contiguous().view(
  205. B, self.dim_out_v, H // self.block_stride, W // self.block_stride)
  206. # B, dim_out, H // block_stride, W // block_stride
  207. out = self.pool(out)
  208. return out
  209. """ Three alternatives for overlapping windows.
  210. `.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold()
  211. if is_xla:
  212. # This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is
  213. # EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment.
  214. WW = self.win_size ** 2
  215. pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size)
  216. kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size)
  217. elif self.stride_tricks:
  218. kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
  219. kv = kv.as_strided((
  220. B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
  221. stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
  222. else:
  223. kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
  224. kv = kv.reshape(
  225. B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3)
  226. """