| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380 |
- from typing import List, Optional, Type, Union
- import torch
- from torch import nn as nn
- from torch.nn import functional as F
- from .config import use_fused_attn
- from .create_conv2d import create_conv2d
- from .helpers import to_2tuple
- from .pool2d_same import create_pool2d
- class MultiQueryAttentionV2(nn.Module):
- """Multi Query Attention.
- Fast Transformer Decoding: One Write-Head is All You Need
- https://arxiv.org/pdf/1911.02150.pdf
- This is an acceletor optimized version - removing multiple unnecessary
- tensor transpose by re-arranging indices according to the following rules: 1)
- contracted indices are at the end, 2) other indices have the same order in the
- input and output tensores.
- Compared to V1, this gives 3x speed up.
- """
- def __init__(
- self,
- dim: int,
- dim_out: Optional[int] = None,
- num_heads: int = 8,
- key_dim: int = 64,
- value_dim: int = 64,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- device=None,
- dtype=None,
- ):
- """Initializer."""
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- dim_out = dim_out or dim
- self.num_heads = num_heads
- self.key_dim = key_dim
- self.value_dim = value_dim
- self.scale = key_dim ** -0.5
- self.query_proj = nn.Parameter(torch.empty((self.num_heads, self.key_dim, dim), **dd))
- self.key_proj = nn.Parameter(torch.empty((dim, self.key_dim), **dd))
- self.value_proj = nn.Parameter(torch.empty((dim, self.value_dim), **dd))
- self.attn_drop = nn.Dropout(attn_drop)
- self.out_proj = nn.Parameter(torch.empty((dim_out, self.num_heads, self.value_dim), **dd))
- self.proj_drop = nn.Dropout(proj_drop)
- self.reset_parameters()
- def reset_parameters(self):
- scale = self.key_proj.shape[0] ** -0.5
- nn.init.normal_(self.query_proj, std=scale)
- nn.init.normal_(self.key_proj, std=scale)
- nn.init.normal_(self.value_proj, std=scale)
- nn.init.normal_(self.out_proj, std=self.out_proj.shape[0] ** -0.5)
- def _reshape_input(self, t):
- """Reshapes a tensor to three dimensions, keeping the first and last."""
- s = t.shape
- # Propagate the shape statically where possible.
- #num = t.shape[1:-1].numel()
- #return t.reshape(s[0], num, s[-1])
- return t.reshape(s[0], s[1], -1).transpose(1, 2)
- def forward(self, x, m: Optional[torch.Tensor] = None):
- """Run layer computation."""
- b, _, h, w = x.shape
- m = m if m is not None else x
- reshaped_x = self._reshape_input(x)
- reshaped_m = self._reshape_input(m)
- q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj)
- k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj)
- attn = torch.einsum('bnhk,bmk->bnhm', q, k) * self.scale
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj)
- o = torch.einsum('bnhm,bmv->bnhv', attn, v)
- result = torch.einsum('bnhv,dhv->bdn', o, self.out_proj)
- result = self.proj_drop(result)
- return result.reshape(b, -1, h, w)
- class MultiQueryAttention2d(nn.Module):
- """Multi Query Attention with spatial downsampling.
- 3 parameters are introduced for the spatial downsampling:
- 1. kv_stride: downsampling factor on Key and Values only.
- 2. query_strides: horizontal & vertical strides on Query only.
- This is an optimized version.
- 1. Projections in Attention is explicit written out as 1x1 Conv2D.
- 2. Additional reshapes are introduced to bring a up to 3x speed up.
- """
- fused_attn: torch.jit.Final[bool]
- def __init__(
- self,
- dim: int,
- dim_out: Optional[int] = None,
- num_heads: int = 8,
- key_dim: Optional[int] = None,
- value_dim: Optional[int] = None,
- query_strides: int = 1,
- kv_stride: int = 1,
- dw_kernel_size: int = 3,
- dilation: int = 1,
- padding: Union[str, int, List[int]] = '',
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- use_bias: bool = False,
- device=None,
- dtype=None,
- ):
- """Initializer.
- Args:
- num_heads: Number of attention heads.
- key_dim: Size of the attention key dimension.
- value_dim: Size of the attention value dimension.
- query_strides: Vertical stride size for query only.
- kv_stride: Key and value stride size.
- dw_kernel_size: Spatial dimension of the depthwise kernel.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- dim_out = dim_out or dim
- self.num_heads = num_heads
- self.key_dim = key_dim or dim // num_heads
- self.value_dim = value_dim or dim // num_heads
- self.query_strides = to_2tuple(query_strides)
- self.kv_stride = kv_stride
- self.has_query_strides = any([s > 1 for s in self.query_strides])
- self.scale = self.key_dim ** -0.5
- self.fused_attn = use_fused_attn()
- self.drop = attn_drop
- self.query = nn.Sequential()
- if self.has_query_strides:
- # FIXME dilation
- if padding == 'same':
- self.query.add_module('down_pool', create_pool2d(
- 'avg',
- kernel_size=self.query_strides,
- padding='same',
- ))
- else:
- # no pad if not 'same' as kern=stride=even
- self.query.add_module('down_pool', nn.AvgPool2d(kernel_size=query_strides))
- self.query.add_module('norm', norm_layer(dim, **dd))
- self.query.add_module('proj', create_conv2d(
- dim,
- self.num_heads * self.key_dim,
- kernel_size=1,
- bias=use_bias,
- **dd,
- ))
- self.key = nn.Sequential()
- if kv_stride > 1:
- self.key.add_module('down_conv', create_conv2d(
- dim,
- dim,
- kernel_size=dw_kernel_size,
- stride=kv_stride,
- dilation=dilation,
- padding=padding,
- depthwise=True,
- **dd,
- ))
- self.key.add_module('norm', norm_layer(dim, **dd))
- self.key.add_module('proj', create_conv2d(
- dim,
- self.key_dim,
- kernel_size=1,
- padding=padding,
- bias=use_bias,
- **dd,
- ))
- self.value = nn.Sequential()
- if kv_stride > 1:
- self.value.add_module('down_conv', create_conv2d(
- dim,
- dim,
- kernel_size=dw_kernel_size,
- stride=kv_stride,
- dilation=dilation,
- padding=padding,
- depthwise=True,
- **dd,
- ))
- self.value.add_module('norm', norm_layer(dim, **dd))
- self.value.add_module('proj', create_conv2d(
- dim,
- self.value_dim,
- kernel_size=1,
- bias=use_bias,
- **dd,
- ))
- self.attn_drop = nn.Dropout(attn_drop)
- self.output = nn.Sequential()
- if self.has_query_strides:
- self.output.add_module('upsample', nn.Upsample(
- scale_factor=self.query_strides,
- mode='bilinear',
- align_corners=False
- ))
- self.output.add_module('proj', create_conv2d(
- self.value_dim * self.num_heads,
- dim_out,
- kernel_size=1,
- bias=use_bias,
- **dd,
- ))
- self.output.add_module('drop', nn.Dropout(proj_drop))
- self.einsum = False
- self.init_weights()
- def init_weights(self):
- # using xavier appeared to improve stability for mobilenetv4 hybrid w/ this layer
- nn.init.xavier_uniform_(self.query.proj.weight)
- nn.init.xavier_uniform_(self.key.proj.weight)
- nn.init.xavier_uniform_(self.value.proj.weight)
- if self.kv_stride > 1:
- nn.init.xavier_uniform_(self.key.down_conv.weight)
- nn.init.xavier_uniform_(self.value.down_conv.weight)
- nn.init.xavier_uniform_(self.output.proj.weight)
- def _reshape_input(self, t: torch.Tensor):
- """Reshapes a tensor to three dimensions, keeping the batch and channels."""
- s = t.shape
- t = t.reshape(s[0], s[1], -1).transpose(1, 2)
- if self.einsum:
- return t
- else:
- return t.unsqueeze(1).contiguous()
- def _reshape_projected_query(self, t: torch.Tensor, num_heads: int, key_dim: int):
- """Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k]."""
- s = t.shape
- t = t.reshape(s[0], num_heads, key_dim, -1)
- if self.einsum:
- return t.permute(0, 3, 1, 2).contiguous()
- else:
- return t.transpose(-1, -2).contiguous()
- def _reshape_output(self, t: torch.Tensor, num_heads: int, h_px: int, w_px: int):
- """Reshape output:[b, n x n x h, k] -> [b, n, n, hk]."""
- s = t.shape
- feat_dim = s[-1] * num_heads
- if not self.einsum:
- t = t.transpose(1, 2)
- return t.reshape(s[0], h_px, w_px, feat_dim).permute(0, 3, 1, 2).contiguous()
- def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
- """Run layer computation."""
- B, C, H, W = s = x.shape
- q = self.query(x)
- # desired q shape: [b, h, k, n x n] - [b, l, h, k]
- q = self._reshape_projected_query(q, self.num_heads, self.key_dim)
- k = self.key(x)
- # output shape of k: [b, k, p], p = m x m
- k = self._reshape_input(k)
- v = self.value(x)
- # output shape of v: [ b, p, k], p = m x m
- v = self._reshape_input(v)
- # desired q shape: [b, n x n, h, k]
- # desired k shape: [b, m x m, k]
- # desired logits shape: [b, n x n, h, m x m]
- if self.einsum:
- attn = torch.einsum('blhk,bpk->blhp', q, k) * self.scale
- if attn_mask is not None:
- # NOTE: assumes mask is float and in correct shape
- attn = attn + attn_mask
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- o = torch.einsum('blhp,bpk->blhk', attn, v)
- else:
- if self.fused_attn:
- o = F.scaled_dot_product_attention(
- q, k, v,
- attn_mask=attn_mask,
- dropout_p=self.attn_drop.p if self.training else 0.
- )
- else:
- q = q * self.scale
- attn = q @ k.transpose(-1, -2)
- if attn_mask is not None:
- # NOTE: assumes mask is float and in correct shape
- attn = attn + attn_mask
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- o = attn @ v
- # reshape o into [b, hk, n, n,]
- o = self._reshape_output(o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1])
- x = self.output(o)
- return x
- class Attention2d(nn.Module):
- fused_attn: torch.jit.Final[bool]
- """ multi-head attention for 2D NCHW tensors"""
- def __init__(
- self,
- dim: int,
- dim_out: Optional[int] = None,
- num_heads: int = 32,
- bias: bool = True,
- expand_first: bool = False,
- head_first: bool = False,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- dim_out = dim_out or dim
- dim_attn = dim_out if expand_first else dim
- self.num_heads = num_heads
- self.dim_head = dim_attn // num_heads
- self.head_first = head_first
- self.fused_attn = use_fused_attn()
- self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias, **dd)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias, **dd)
- self.proj_drop = nn.Dropout(proj_drop)
- def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
- B, C, H, W = x.shape
- if self.head_first:
- q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2)
- else:
- q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
- if self.fused_attn:
- x = torch.nn.functional.scaled_dot_product_attention(
- q.transpose(-1, -2).contiguous(),
- k.transpose(-1, -2).contiguous(),
- v.transpose(-1, -2).contiguous(),
- attn_mask=attn_mask,
- dropout_p=self.attn_drop.p if self.training else 0.,
- ).transpose(-1, -2).reshape(B, -1, H, W)
- else:
- q = q.transpose(-1, -2)
- v = v.transpose(-1, -2)
- attn = q @ k * q.size(-1) ** -0.5
- if attn_mask is not None:
- # NOTE: assumes mask is float and in correct shape
- attn = attn + attn_mask
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = (attn @ v).transpose(-1, -2).reshape(B, -1, H, W)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
|