bottleneck_attn.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. """ Bottleneck Self Attention (Bottleneck Transformers)
  2. Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605
  3. @misc{2101.11605,
  4. Author = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani},
  5. Title = {Bottleneck Transformers for Visual Recognition},
  6. Year = {2021},
  7. }
  8. Based on ref gist at: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
  9. This impl is a WIP but given that it is based on the ref gist likely not too far off.
  10. Hacked together by / Copyright 2021 Ross Wightman
  11. """
  12. from typing import List, Optional, Tuple
  13. import torch
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. from .helpers import to_2tuple, make_divisible
  17. from .weight_init import trunc_normal_
  18. from .trace_utils import _assert
  19. def rel_logits_1d(q, rel_k, permute_mask: List[int]):
  20. """ Compute relative logits along one dimension
  21. As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
  22. Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
  23. Args:
  24. q: (batch, heads, height, width, dim)
  25. rel_k: (2 * width - 1, dim)
  26. permute_mask: permute output dim according to this
  27. """
  28. B, H, W, dim = q.shape
  29. x = (q @ rel_k.transpose(-1, -2))
  30. x = x.reshape(-1, W, 2 * W -1)
  31. # pad to shift from relative to absolute indexing
  32. x_pad = F.pad(x, [0, 1]).flatten(1)
  33. x_pad = F.pad(x_pad, [0, W - 1])
  34. # reshape and slice out the padded elements
  35. x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1)
  36. x = x_pad[:, :W, W - 1:]
  37. # reshape and tile
  38. x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1)
  39. return x.permute(permute_mask)
  40. class PosEmbedRel(nn.Module):
  41. """ Relative Position Embedding
  42. As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
  43. Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
  44. """
  45. def __init__(
  46. self,
  47. feat_size: Tuple[int, int],
  48. dim_head: int,
  49. scale: float,
  50. device=None,
  51. dtype=None,
  52. ):
  53. dd = {'device': device, 'dtype': dtype}
  54. super().__init__()
  55. self.height, self.width = to_2tuple(feat_size)
  56. self.dim_head = dim_head
  57. self.scale = scale
  58. self.height_rel = nn.Parameter(torch.empty(self.height * 2 - 1, dim_head, **dd))
  59. self.width_rel = nn.Parameter(torch.empty(self.width * 2 - 1, dim_head, **dd))
  60. self.reset_parameters()
  61. def reset_parameters(self):
  62. torch.nn.init.normal_(self.height_rel, std=self.scale)
  63. torch.nn.init.normal_(self.width_rel, std=self.scale)
  64. def forward(self, q):
  65. B, HW, _ = q.shape
  66. # relative logits in width dimension.
  67. q = q.reshape(B, self.height, self.width, -1)
  68. rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
  69. # relative logits in height dimension.
  70. q = q.transpose(1, 2)
  71. rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
  72. rel_logits = rel_logits_h + rel_logits_w
  73. rel_logits = rel_logits.reshape(B, HW, HW)
  74. return rel_logits
  75. class BottleneckAttn(nn.Module):
  76. """ Bottleneck Attention
  77. Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605
  78. The internal dimensions of the attention module are controlled by the interaction of several arguments.
  79. * the output dimension of the module is specified by dim_out, which falls back to input dim if not set
  80. * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
  81. * the query and key (qk) dimensions are determined by
  82. * num_heads * dim_head if dim_head is not None
  83. * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None
  84. * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used
  85. Args:
  86. dim (int): input dimension to the module
  87. dim_out (int): output dimension of the module, same as dim if not set
  88. stride (int): output stride of the module, avg pool used if stride == 2 (default: 1).
  89. num_heads (int): parallel attention heads (default: 4)
  90. dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
  91. qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
  92. qkv_bias (bool): add bias to q, k, and v projections
  93. scale_pos_embed (bool): scale the position embedding as well as Q @ K
  94. """
  95. def __init__(
  96. self,
  97. dim: int,
  98. dim_out: Optional[int] = None,
  99. feat_size: Optional[Tuple[int, int]] = None,
  100. stride: int = 1,
  101. num_heads: int = 4,
  102. dim_head: Optional[int] = None,
  103. qk_ratio: float = 1.0,
  104. qkv_bias: bool = False,
  105. scale_pos_embed: bool = False,
  106. device=None,
  107. dtype=None,
  108. ):
  109. dd = {'device': device, 'dtype': dtype}
  110. super().__init__()
  111. assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required'
  112. dim_out = dim_out or dim
  113. assert dim_out % num_heads == 0
  114. self.num_heads = num_heads
  115. self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
  116. self.dim_head_v = dim_out // self.num_heads
  117. self.dim_out_qk = num_heads * self.dim_head_qk
  118. self.dim_out_v = num_heads * self.dim_head_v
  119. self.scale = self.dim_head_qk ** -0.5
  120. self.scale_pos_embed = scale_pos_embed
  121. self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias, **dd)
  122. # NOTE I'm only supporting relative pos embedding for now
  123. self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale, **dd)
  124. self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
  125. self.reset_parameters()
  126. def reset_parameters(self):
  127. trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
  128. trunc_normal_(self.pos_embed.height_rel, std=self.scale)
  129. trunc_normal_(self.pos_embed.width_rel, std=self.scale)
  130. def forward(self, x):
  131. B, C, H, W = x.shape
  132. _assert(H == self.pos_embed.height, '')
  133. _assert(W == self.pos_embed.width, '')
  134. x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
  135. # NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v
  136. # So, this is more verbose than if heads were before qkv splits, but throughput is not impacted.
  137. q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
  138. q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
  139. k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
  140. v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
  141. if self.scale_pos_embed:
  142. attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
  143. else:
  144. attn = (q @ k) * self.scale + self.pos_embed(q)
  145. attn = attn.softmax(dim=-1)
  146. out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
  147. out = self.pool(out)
  148. return out