| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- """ Lambda Layer
- Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
- - https://arxiv.org/abs/2102.08602
- @misc{2102.08602,
- Author = {Irwan Bello},
- Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention},
- Year = {2021},
- }
- Status:
- This impl is a WIP. Code snippets in the paper were used as reference but
- good chance some details are missing/wrong.
- I've only implemented local lambda conv based pos embeddings.
- For a PyTorch impl that includes other embedding options checkout
- https://github.com/lucidrains/lambda-networks
- Hacked together by / Copyright 2021 Ross Wightman
- """
- from typing import Optional, Tuple
- import torch
- from torch import nn
- import torch.nn.functional as F
- from .grid import ndgrid
- from .helpers import to_2tuple, make_divisible
- from .weight_init import trunc_normal_
- def rel_pos_indices(size, device=None):
- size = to_2tuple(size)
- pos = torch.stack(ndgrid(
- torch.arange(size[0], device=device, dtype=torch.long),
- torch.arange(size[1], device=device, dtype=torch.long),
- )).flatten(1)
- rel_pos = pos[:, None, :] - pos[:, :, None]
- rel_pos[0] += size[0] - 1
- rel_pos[1] += size[1] - 1
- return rel_pos # 2, H * W, H * W
- class LambdaLayer(nn.Module):
- """Lambda Layer
- Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
- - https://arxiv.org/abs/2102.08602
- NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add.
- The internal dimensions of the lambda module are controlled via the interaction of several arguments.
- * the output dimension of the module is specified by dim_out, which falls back to input dim if not set
- * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
- * the query (q) and key (k) dimension are determined by
- * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None
- * q = num_heads * dim_head, k = dim_head
- * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set
- Args:
- dim: input dimension to the module
- dim_out: output dimension of the module, same as dim if not set
- feat_size: size of input feature_map for relative pos variant H, W
- stride: output stride of the module, avg pool used if stride == 2
- num_heads: parallel attention heads.
- dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
- r: local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9)
- qk_ratio: ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
- qkv_bias: add bias to q, k, and v projections
- """
- def __init__(
- self,
- dim: int,
- dim_out: Optional[int] = None,
- feat_size: Optional[Tuple[int, int]] = None,
- stride: int = 1,
- num_heads: int = 4,
- dim_head: int = 16,
- r: int = 9,
- qk_ratio: float = 1.0,
- qkv_bias: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- dim_out = dim_out or dim
- assert dim_out % num_heads == 0, ' should be divided by num_heads'
- self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
- self.num_heads = num_heads
- self.dim_v = dim_out // num_heads
- self.qkv = nn.Conv2d(
- dim,
- num_heads * self.dim_qk + self.dim_qk + self.dim_v,
- kernel_size=1,
- bias=qkv_bias,
- **dd,
- )
- self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk, **dd)
- self.norm_v = nn.BatchNorm2d(self.dim_v, **dd)
- if r is not None:
- # local lambda convolution for pos
- self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0), **dd)
- self.pos_emb = None
- self.rel_pos_indices = None
- self.feat_size = None
- else:
- # relative pos embedding
- assert feat_size is not None
- feat_size = to_2tuple(feat_size)
- self.feat_size = feat_size
- rel_size = [2 * s - 1 for s in feat_size]
- M = feat_size[0] * feat_size[1]
- self.conv_lambda = None
- self.pos_emb = nn.Parameter(torch.empty(rel_size[0], rel_size[1], self.dim_qk, **dd))
- self.register_buffer(
- 'rel_pos_indices',
- torch.empty((2, M, M), device=device, dtype=torch.long),
- persistent=False,
- )
- self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
- # TODO: skip init when on meta device when safe to do so
- self.reset_parameters()
- def reset_parameters(self) -> None:
- """Initialize parameters and buffers."""
- trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
- if self.conv_lambda is not None:
- trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5)
- if self.pos_emb is not None:
- trunc_normal_(self.pos_emb, std=.02)
- self._init_buffers()
- def _init_buffers(self) -> None:
- """Compute and fill non-persistent buffer values."""
- if self.rel_pos_indices is not None:
- self.rel_pos_indices.copy_(
- rel_pos_indices(self.feat_size, device=self.rel_pos_indices.device)
- )
- def forward(self, x):
- B, C, H, W = x.shape
- M = H * W
- qkv = self.qkv(x)
- q, k, v = torch.split(qkv, [
- self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1)
- q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K
- v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
- k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
- content_lam = k @ v # B, K, V
- content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
- if self.pos_emb is None:
- position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
- position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
- else:
- # FIXME relative pos embedding path not fully verified
- pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1)
- position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V
- position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V
- out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W
- out = self.pool(out)
- return out
- def init_non_persistent_buffers(self) -> None:
- """Initialize non-persistent buffers."""
- self._init_buffers()
|