| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587 |
- """ Relative position embedding modules and functions
- Hacked together by / Copyright 2022 Ross Wightman
- """
- import math
- import os
- from typing import Optional, Tuple
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from .grid import ndgrid
- from .interpolate import RegularGridInterpolator
- from .mlp import Mlp
- from .weight_init import trunc_normal_
- _USE_SCIPY = int(os.environ.get('TIMM_USE_SCIPY_INTERP', 0)) > 0
- def gen_relative_position_index(
- q_size: Tuple[int, int],
- k_size: Optional[Tuple[int, int]] = None,
- class_token: bool = False,
- device=None,
- ) -> torch.Tensor:
- # Adapted with significant modifications from Swin / BeiT codebases
- # get pair-wise relative position index for each token inside the window
- assert k_size is None, 'Different q & k sizes not currently supported' # FIXME
- coords = torch.stack(ndgrid(
- torch.arange(q_size[0], device=device),
- torch.arange(q_size[1], device=device),
- )).flatten(1) # 2, Wh, Ww
- relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
- relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += q_size[1] - 1
- relative_coords[:, :, 0] *= 2 * q_size[1] - 1
- num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1)
- # else:
- # # FIXME different q vs k sizes is a WIP, need to better offset the two grids?
- # q_coords = torch.stack(
- # ndgrid(
- # torch.arange(q_size[0]),
- # torch.arange(q_size[1])
- # )
- # ).flatten(1) # 2, Wh, Ww
- # k_coords = torch.stack(
- # ndgrid(
- # torch.arange(k_size[0]),
- # torch.arange(k_size[1])
- # )
- # ).flatten(1)
- # relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
- # relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
- # relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0
- # relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1
- # relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1
- # relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw
- # num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + k_size[1] - 1) + 3
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- if class_token:
- # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
- # NOTE not intended or tested with MLP log-coords
- relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
- relative_position_index[0, 0:] = num_relative_distance
- relative_position_index[0:, 0] = num_relative_distance + 1
- relative_position_index[0, 0] = num_relative_distance + 2
- return relative_position_index.contiguous()
- def resize_rel_pos_bias_table_simple(
- rel_pos_bias,
- new_window_size: Tuple[int, int],
- new_bias_shape: Tuple[int, ...],
- ):
- dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1)
- if rel_pos_bias.ndim == 3:
- # TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported
- _, dst_h, dst_w = new_bias_shape
- num_attn_heads, src_h, src_w = rel_pos_bias.shape
- assert dst_h == dst_size[0] and dst_w == dst_size[1]
- if src_h != dst_h or src_w != dst_w:
- rel_pos_bias = torch.nn.functional.interpolate(
- rel_pos_bias.unsqueeze(0),
- size=dst_size,
- mode="bicubic",
- align_corners=False,
- ).squeeze(0)
- else:
- assert rel_pos_bias.ndim == 2
- # (num_pos, num_heads) (aka flat) bias shape
- dst_num_pos, _ = new_bias_shape
- src_num_pos, num_attn_heads = rel_pos_bias.shape
- num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1])
- src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
- src_size = (src_size, src_size) # FIXME could support non-equal src if argument passed
- if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]:
- if num_extra_tokens:
- extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
- rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
- else:
- extra_tokens = None
- rel_pos_bias = torch.nn.functional.interpolate(
- rel_pos_bias.transpose(1, 0).reshape((1, -1, src_size[0], src_size[1])),
- size=dst_size,
- mode="bicubic",
- align_corners=False,
- ).view(-1, dst_num_pos - num_extra_tokens).transpose(0, 1)
- if extra_tokens is not None:
- rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
- return rel_pos_bias
- def resize_rel_pos_bias_table_levit(
- position_bias_table,
- new_size,
- interpolation: str = 'bicubic',
- antialias: bool = True,
- ):
- """
- Resample relative position bias table suggested in LeVit
- Adapted from: https://github.com/microsoft/Cream/blob/main/TinyViT/utils.py
- """
- L1, nH1 = position_bias_table.size()
- L2, nH2 = new_size
- assert nH1 == nH2
- if L1 != L2:
- orig_dtype = position_bias_table.dtype
- position_bias_table = position_bias_table.float()
- # bicubic interpolate relative_position_bias_table if not match
- S1 = int(L1 ** 0.5)
- S2 = int(L2 ** 0.5)
- relative_position_bias_table_resized = F.interpolate(
- position_bias_table.permute(1, 0).view(1, nH1, S1, S1),
- size=(S2, S2),
- mode=interpolation,
- antialias=antialias,
- )
- relative_position_bias_table_resized = relative_position_bias_table_resized.view(nH2, L2).permute(1, 0)
- relative_position_bias_table_resized.to(orig_dtype)
- return relative_position_bias_table_resized
- else:
- return position_bias_table
- def resize_rel_pos_bias_table(
- rel_pos_bias,
- new_window_size: Tuple[int, int],
- new_bias_shape: Tuple[int, ...],
- ):
- """ Resize relative position bias table using more advanced interpolation.
- Modified from code in Microsoft Unilm (https://github.com/microsoft/unilm) repo (BeiT, BeiT-v2, etc).
- https://github.com/microsoft/unilm/blob/5255d52de86dad642810f5849dd357769346c1d7/beit/run_class_finetuning.py#L351
- Args:
- rel_pos_bias:
- new_window_size:
- new_bias_shape:
- Returns:
- """
- if _USE_SCIPY:
- from scipy import interpolate
- dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1)
- if rel_pos_bias.ndim == 3:
- # TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported
- num_extra_tokens = 0
- _, dst_h, dst_w = new_bias_shape
- assert dst_h == dst_size[0] and dst_w == dst_size[1]
- num_attn_heads, src_h, src_w = rel_pos_bias.shape
- src_size = (src_h, src_w)
- has_flat_shape = False
- else:
- assert rel_pos_bias.ndim == 2
- # (num_pos, num_heads) (aka flat) bias shape
- dst_num_pos, _ = new_bias_shape
- src_num_pos, num_attn_heads = rel_pos_bias.shape
- num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1])
- src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
- src_size = (src_size, src_size)
- has_flat_shape = True
- if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]:
- # print("Interpolating position from %dx%d to %dx%d" % (src_size[0], src_size[1], dst_size[0], dst_size[1]))
- if num_extra_tokens:
- extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
- rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
- else:
- extra_tokens = None
- def geometric_progression(a, r, n):
- return a * (1.0 - r ** n) / (1.0 - r)
- def _calc(src, dst):
- left, right = 1.01, 1.5
- while right - left > 1e-6:
- q = (left + right) / 2.0
- gp = geometric_progression(1, q, src // 2)
- if gp > dst // 2:
- right = q
- else:
- left = q
- dis = []
- cur = 1
- for i in range(src // 2):
- dis.append(cur)
- cur += q ** (i + 1)
- r_ids = [-_ for _ in reversed(dis)]
- return r_ids + [0] + dis
- y = _calc(src_size[0], dst_size[0])
- x = _calc(src_size[1], dst_size[1])
- yx = [torch.tensor(y), torch.tensor(x)]
- # print("Original positions = %s" % str(x))
- ty = dst_size[0] // 2.0
- tx = dst_size[1] // 2.0
- dy = torch.arange(-ty, ty + 0.1, 1.0)
- dx = torch.arange(-tx, tx + 0.1, 1.0)
- dyx = ndgrid(dy, dx)
- # print("Target positions = %s" % str(dx))
- all_rel_pos_bias = []
- for i in range(num_attn_heads):
- if has_flat_shape:
- z = rel_pos_bias[:, i].view(src_size[0], src_size[1]).float()
- else:
- z = rel_pos_bias[i, :, :].float()
- if _USE_SCIPY:
- # Original beit code uses scipy w/ cubic interpolation
- f = interpolate.interp2d(x, y, z.numpy(), kind='cubic')
- r = torch.Tensor(f(dx, dy)).contiguous().to(rel_pos_bias.device)
- else:
- # Without scipy dependency, I've found a reasonably simple impl
- # that supports uneven spaced interpolation pts with 'linear' interp.
- # Results are comparable to scipy for model accuracy in most cases.
- f = RegularGridInterpolator(yx, z)
- r = f(dyx).contiguous().to(rel_pos_bias.device)
- if has_flat_shape:
- r = r.view(-1, 1)
- all_rel_pos_bias.append(r)
- if has_flat_shape:
- rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
- else:
- rel_pos_bias = torch.cat(all_rel_pos_bias, dim=0)
- if extra_tokens is not None:
- assert has_flat_shape
- rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
- return rel_pos_bias
- class RelPosBias(nn.Module):
- """ Relative Position Bias
- Adapted from Swin-V1 relative position bias impl, modularized.
- """
- def __init__(
- self,
- window_size: Tuple[int, int],
- num_heads: int,
- prefix_tokens: int = 0,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- assert prefix_tokens <= 1
- self.window_size = window_size
- self.window_area = window_size[0] * window_size[1]
- self.prefix_tokens = prefix_tokens
- self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
- num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
- self.relative_position_bias_table = nn.Parameter(torch.empty(num_relative_distance, num_heads, **dd))
- index_size = (self.window_area + prefix_tokens) ** 2
- self.register_buffer(
- "relative_position_index",
- torch.empty(index_size, device=device, dtype=torch.long),
- persistent=False,
- )
- # TODO: skip init when on meta device when safe to do so
- self.reset_parameters()
- def reset_parameters(self) -> None:
- """Initialize parameters and buffers."""
- trunc_normal_(self.relative_position_bias_table, std=.02)
- self._init_buffers()
- def _init_buffers(self) -> None:
- """Compute and fill non-persistent buffer values."""
- self.relative_position_index.copy_(
- gen_relative_position_index(
- self.window_size,
- class_token=self.prefix_tokens > 0,
- device=self.relative_position_index.device,
- ).view(-1)
- )
- def get_bias(self) -> torch.Tensor:
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index]
- # win_h * win_w, win_h * win_w, num_heads
- relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
- return relative_position_bias.unsqueeze(0).contiguous()
- def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
- return attn + self.get_bias()
- def init_non_persistent_buffers(self) -> None:
- """Initialize non-persistent buffers."""
- self._init_buffers()
- def gen_relative_log_coords(
- win_size: Tuple[int, int],
- pretrained_win_size: Tuple[int, int] = (0, 0),
- mode='swin',
- device=None,
- dtype=None,
- ):
- assert mode in ('swin', 'cr')
- # as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
- relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], device=device).to(torch.float32)
- relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], device=device).to(torch.float32)
- relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w))
- relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
- if mode == 'swin':
- if pretrained_win_size[0] > 0:
- relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
- relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
- else:
- relative_coords_table[:, :, 0] /= (win_size[0] - 1)
- relative_coords_table[:, :, 1] /= (win_size[1] - 1)
- relative_coords_table *= 8 # normalize to -8, 8
- relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
- 1.0 + relative_coords_table.abs()) / math.log2(8)
- else:
- # mode == 'cr'
- relative_coords_table = torch.sign(relative_coords_table) * torch.log(
- 1.0 + relative_coords_table.abs())
- return relative_coords_table.to(dtype)
- class RelPosMlp(nn.Module):
- """ Log-Coordinate Relative Position MLP
- Based on ideas presented in Swin-V2 paper (https://arxiv.org/abs/2111.09883)
- This impl covers the 'swin' implementation as well as two timm specific modes ('cr', and 'rw')
- """
- def __init__(
- self,
- window_size: Tuple[int, int],
- num_heads: int = 8,
- hidden_dim: int = 128,
- prefix_tokens: int = 0,
- mode: str = 'cr',
- pretrained_window_size: Tuple[int, int] = (0, 0),
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.window_size = window_size
- self.window_area = self.window_size[0] * self.window_size[1]
- self.prefix_tokens = prefix_tokens
- self.num_heads = num_heads
- self.bias_shape = (self.window_area,) * 2 + (num_heads,)
- self.mode = mode
- self.pretrained_window_size = pretrained_window_size
- if mode == 'swin':
- self.bias_act = nn.Sigmoid()
- self.bias_gain = 16
- mlp_bias = (True, False)
- else:
- self.bias_act = nn.Identity()
- self.bias_gain = None
- mlp_bias = True
- self.mlp = Mlp(
- 2, # x, y
- hidden_features=hidden_dim,
- out_features=num_heads,
- act_layer=nn.ReLU,
- bias=mlp_bias,
- drop=(0.125, 0.),
- **dd,
- )
- index_size = self.window_area ** 2
- rel_coords_shape = (2 * window_size[0] - 1, 2 * window_size[1] - 1, 2)
- self.register_buffer(
- "relative_position_index",
- torch.empty(index_size, device=device, dtype=torch.long),
- persistent=False,
- )
- self.register_buffer(
- "rel_coords_log",
- torch.empty(rel_coords_shape, **dd),
- persistent=False,
- )
- # TODO: skip init when on meta device when safe to do so
- self.reset_parameters()
- def get_bias(self) -> torch.Tensor:
- relative_position_bias = self.mlp(self.rel_coords_log)
- if self.relative_position_index is not None:
- relative_position_bias = relative_position_bias.view(-1, self.num_heads)[self.relative_position_index]
- relative_position_bias = relative_position_bias.view(self.bias_shape)
- relative_position_bias = relative_position_bias.permute(2, 0, 1)
- relative_position_bias = self.bias_act(relative_position_bias)
- if self.bias_gain is not None:
- relative_position_bias = self.bias_gain * relative_position_bias
- if self.prefix_tokens:
- relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
- return relative_position_bias.unsqueeze(0).contiguous()
- def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
- return attn + self.get_bias()
- def reset_parameters(self) -> None:
- """Initialize parameters and buffers."""
- self._init_buffers()
- def _init_buffers(self) -> None:
- """Compute and fill non-persistent buffer values."""
- device = self.relative_position_index.device
- dtype = self.rel_coords_log.dtype
- self.relative_position_index.copy_(
- gen_relative_position_index(self.window_size, device=device).view(-1)
- )
- self.rel_coords_log.copy_(
- gen_relative_log_coords(
- self.window_size,
- self.pretrained_window_size,
- mode=self.mode,
- device=device,
- dtype=dtype,
- )
- )
- def init_non_persistent_buffers(self) -> None:
- """Initialize non-persistent buffers."""
- self._init_buffers()
- def generate_lookup_tensor(
- length: int,
- max_relative_position: Optional[int] = None,
- device=None,
- dtype=None,
- ):
- """Generate a one_hot lookup tensor to reindex embeddings along one dimension.
- Args:
- length: the length to reindex to.
- max_relative_position: the maximum relative position to consider.
- Relative position embeddings for distances above this threshold
- are zeroed out.
- Returns:
- a lookup Tensor of size [length, length, vocab_size] that satisfies
- ret[n,m,v] = 1{m - n + max_relative_position = v}.
- """
- if max_relative_position is None:
- max_relative_position = length - 1
- # Return the cached lookup tensor, otherwise compute it and cache it.
- vocab_size = 2 * max_relative_position + 1
- ret = torch.zeros(length, length, vocab_size, device=device, dtype=dtype)
- for i in range(length):
- for x in range(length):
- v = x - i + max_relative_position
- if abs(x - i) > max_relative_position:
- continue
- ret[i, x, v] = 1
- return ret
- def reindex_2d_einsum_lookup(
- relative_position_tensor,
- height: int,
- width: int,
- height_lookup: torch.Tensor,
- width_lookup: torch.Tensor,
- ) -> torch.Tensor:
- """Reindex 2d relative position bias with 2 independent einsum lookups.
- Adapted from:
- https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
- Args:
- relative_position_tensor: tensor of shape
- [..., vocab_height, vocab_width, ...].
- height: height to reindex to.
- width: width to reindex to.
- height_lookup: one-hot height lookup
- width_lookup: one-hot width lookup
- Returns:
- reindexed_tensor: a Tensor of shape
- [..., height * width, height * width, ...]
- """
- reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup)
- reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup)
- area = height * width
- return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area)
- class RelPosBiasTf(nn.Module):
- """ Relative Position Bias Impl (Compatible with Tensorflow MaxViT models)
- Adapted from:
- https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
- """
- def __init__(
- self,
- window_size: Tuple[int, int],
- num_heads: int,
- prefix_tokens: int = 0,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- assert prefix_tokens <= 1
- self.window_size = window_size
- self.window_area = window_size[0] * window_size[1]
- self.num_heads = num_heads
- vocab_height = 2 * window_size[0] - 1
- vocab_width = 2 * window_size[1] - 1
- self.bias_shape = (self.num_heads, vocab_height, vocab_width)
- self.relative_position_bias_table = nn.Parameter(torch.empty(self.bias_shape, **dd))
- height_lookup_shape = (window_size[0], window_size[0], vocab_height)
- width_lookup_shape = (window_size[1], window_size[1], vocab_width)
- self.register_buffer('height_lookup', torch.empty(height_lookup_shape, **dd), persistent=False)
- self.register_buffer('width_lookup', torch.empty(width_lookup_shape, **dd), persistent=False)
- # TODO: skip init when on meta device when safe to do so
- self.reset_parameters()
- def reset_parameters(self) -> None:
- """Initialize parameters and buffers."""
- nn.init.normal_(self.relative_position_bias_table, std=.02)
- self._init_buffers()
- def _init_buffers(self) -> None:
- """Compute and fill non-persistent buffer values."""
- device = self.height_lookup.device
- dtype = self.height_lookup.dtype
- self.height_lookup.copy_(generate_lookup_tensor(self.window_size[0], device=device, dtype=dtype))
- self.width_lookup.copy_(generate_lookup_tensor(self.window_size[1], device=device, dtype=dtype))
- def get_bias(self) -> torch.Tensor:
- # FIXME change to not use one-hot/einsum?
- return reindex_2d_einsum_lookup(
- self.relative_position_bias_table,
- self.window_size[0],
- self.window_size[1],
- self.height_lookup,
- self.width_lookup
- )
- def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
- return attn + self.get_bias()
- def init_non_persistent_buffers(self) -> None:
- """Initialize non-persistent buffers."""
- self._init_buffers()
|