lambda_layer.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. """ Lambda Layer
  2. Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
  3. - https://arxiv.org/abs/2102.08602
  4. @misc{2102.08602,
  5. Author = {Irwan Bello},
  6. Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention},
  7. Year = {2021},
  8. }
  9. Status:
  10. This impl is a WIP. Code snippets in the paper were used as reference but
  11. good chance some details are missing/wrong.
  12. I've only implemented local lambda conv based pos embeddings.
  13. For a PyTorch impl that includes other embedding options checkout
  14. https://github.com/lucidrains/lambda-networks
  15. Hacked together by / Copyright 2021 Ross Wightman
  16. """
  17. from typing import Optional, Tuple
  18. import torch
  19. from torch import nn
  20. import torch.nn.functional as F
  21. from .grid import ndgrid
  22. from .helpers import to_2tuple, make_divisible
  23. from .weight_init import trunc_normal_
  24. def rel_pos_indices(size, device=None):
  25. size = to_2tuple(size)
  26. pos = torch.stack(ndgrid(
  27. torch.arange(size[0], device=device, dtype=torch.long),
  28. torch.arange(size[1], device=device, dtype=torch.long),
  29. )).flatten(1)
  30. rel_pos = pos[:, None, :] - pos[:, :, None]
  31. rel_pos[0] += size[0] - 1
  32. rel_pos[1] += size[1] - 1
  33. return rel_pos # 2, H * W, H * W
  34. class LambdaLayer(nn.Module):
  35. """Lambda Layer
  36. Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
  37. - https://arxiv.org/abs/2102.08602
  38. NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add.
  39. The internal dimensions of the lambda module are controlled via the interaction of several arguments.
  40. * the output dimension of the module is specified by dim_out, which falls back to input dim if not set
  41. * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
  42. * the query (q) and key (k) dimension are determined by
  43. * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None
  44. * q = num_heads * dim_head, k = dim_head
  45. * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set
  46. Args:
  47. dim: input dimension to the module
  48. dim_out: output dimension of the module, same as dim if not set
  49. feat_size: size of input feature_map for relative pos variant H, W
  50. stride: output stride of the module, avg pool used if stride == 2
  51. num_heads: parallel attention heads.
  52. dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
  53. r: local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9)
  54. qk_ratio: ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
  55. qkv_bias: add bias to q, k, and v projections
  56. """
  57. def __init__(
  58. self,
  59. dim: int,
  60. dim_out: Optional[int] = None,
  61. feat_size: Optional[Tuple[int, int]] = None,
  62. stride: int = 1,
  63. num_heads: int = 4,
  64. dim_head: int = 16,
  65. r: int = 9,
  66. qk_ratio: float = 1.0,
  67. qkv_bias: bool = False,
  68. device=None,
  69. dtype=None,
  70. ):
  71. dd = {'device': device, 'dtype': dtype}
  72. super().__init__()
  73. dim_out = dim_out or dim
  74. assert dim_out % num_heads == 0, ' should be divided by num_heads'
  75. self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
  76. self.num_heads = num_heads
  77. self.dim_v = dim_out // num_heads
  78. self.qkv = nn.Conv2d(
  79. dim,
  80. num_heads * self.dim_qk + self.dim_qk + self.dim_v,
  81. kernel_size=1,
  82. bias=qkv_bias,
  83. **dd,
  84. )
  85. self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk, **dd)
  86. self.norm_v = nn.BatchNorm2d(self.dim_v, **dd)
  87. if r is not None:
  88. # local lambda convolution for pos
  89. self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0), **dd)
  90. self.pos_emb = None
  91. self.rel_pos_indices = None
  92. self.feat_size = None
  93. else:
  94. # relative pos embedding
  95. assert feat_size is not None
  96. feat_size = to_2tuple(feat_size)
  97. self.feat_size = feat_size
  98. rel_size = [2 * s - 1 for s in feat_size]
  99. M = feat_size[0] * feat_size[1]
  100. self.conv_lambda = None
  101. self.pos_emb = nn.Parameter(torch.empty(rel_size[0], rel_size[1], self.dim_qk, **dd))
  102. self.register_buffer(
  103. 'rel_pos_indices',
  104. torch.empty((2, M, M), device=device, dtype=torch.long),
  105. persistent=False,
  106. )
  107. self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
  108. # TODO: skip init when on meta device when safe to do so
  109. self.reset_parameters()
  110. def reset_parameters(self) -> None:
  111. """Initialize parameters and buffers."""
  112. trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
  113. if self.conv_lambda is not None:
  114. trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5)
  115. if self.pos_emb is not None:
  116. trunc_normal_(self.pos_emb, std=.02)
  117. self._init_buffers()
  118. def _init_buffers(self) -> None:
  119. """Compute and fill non-persistent buffer values."""
  120. if self.rel_pos_indices is not None:
  121. self.rel_pos_indices.copy_(
  122. rel_pos_indices(self.feat_size, device=self.rel_pos_indices.device)
  123. )
  124. def forward(self, x):
  125. B, C, H, W = x.shape
  126. M = H * W
  127. qkv = self.qkv(x)
  128. q, k, v = torch.split(qkv, [
  129. self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1)
  130. q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K
  131. v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
  132. k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
  133. content_lam = k @ v # B, K, V
  134. content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
  135. if self.pos_emb is None:
  136. position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
  137. position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
  138. else:
  139. # FIXME relative pos embedding path not fully verified
  140. pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1)
  141. position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V
  142. position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V
  143. out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W
  144. out = self.pool(out)
  145. return out
  146. def init_non_persistent_buffers(self) -> None:
  147. """Initialize non-persistent buffers."""
  148. self._init_buffers()