| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156 |
- """ Multi-Scale Vision Transformer v2
- @inproceedings{li2021improved,
- title={MViTv2: Improved multiscale vision transformers for classification and detection},
- author={Li, Yanghao and Wu, Chao-Yuan and Fan, Haoqi and Mangalam, Karttikeya and Xiong, Bo and Malik, Jitendra and Feichtenhofer, Christoph},
- booktitle={CVPR},
- year={2022}
- }
- Code adapted from original Apache 2.0 licensed impl at https://github.com/facebookresearch/mvit
- Original copyright below.
- Modifications and timm support by / Copyright 2022, Ross Wightman
- """
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. All Rights Reserved.
- import operator
- from collections import OrderedDict
- from dataclasses import dataclass
- from functools import partial, reduce
- from typing import Union, List, Tuple, Optional, Any, Type
- import torch
- from torch import nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import Mlp, DropPath, calculate_drop_path_rates, trunc_normal_tf_, get_norm_layer, to_2tuple
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._features_fx import register_notrace_function
- from ._manipulate import checkpoint
- from ._registry import register_model, generate_default_cfgs
- __all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this
- @dataclass
- class MultiScaleVitCfg:
- depths: Tuple[int, ...] = (2, 3, 16, 3)
- embed_dim: Union[int, Tuple[int, ...]] = 96
- num_heads: Union[int, Tuple[int, ...]] = 1
- mlp_ratio: float = 4.
- pool_first: bool = False
- expand_attn: bool = True
- qkv_bias: bool = True
- use_cls_token: bool = False
- use_abs_pos: bool = False
- residual_pooling: bool = True
- mode: str = 'conv'
- kernel_qkv: Tuple[int, int] = (3, 3)
- stride_q: Optional[Tuple[Tuple[int, int]]] = ((1, 1), (2, 2), (2, 2), (2, 2))
- stride_kv: Optional[Tuple[Tuple[int, int]]] = None
- stride_kv_adaptive: Optional[Tuple[int, int]] = (4, 4)
- patch_kernel: Tuple[int, int] = (7, 7)
- patch_stride: Tuple[int, int] = (4, 4)
- patch_padding: Tuple[int, int] = (3, 3)
- pool_type: str = 'max'
- rel_pos_type: str = 'spatial'
- act_layer: Union[str, Tuple[str, str]] = 'gelu'
- norm_layer: Union[str, Tuple[str, str]] = 'layernorm'
- norm_eps: float = 1e-6
- def __post_init__(self):
- num_stages = len(self.depths)
- if not isinstance(self.embed_dim, (tuple, list)):
- self.embed_dim = tuple(self.embed_dim * 2 ** i for i in range(num_stages))
- assert len(self.embed_dim) == num_stages
- if not isinstance(self.num_heads, (tuple, list)):
- self.num_heads = tuple(self.num_heads * 2 ** i for i in range(num_stages))
- assert len(self.num_heads) == num_stages
- if self.stride_kv_adaptive is not None and self.stride_kv is None:
- _stride_kv = self.stride_kv_adaptive
- pool_kv_stride = []
- for i in range(num_stages):
- if min(self.stride_q[i]) > 1:
- _stride_kv = [
- max(_stride_kv[d] // self.stride_q[i][d], 1)
- for d in range(len(_stride_kv))
- ]
- pool_kv_stride.append(tuple(_stride_kv))
- self.stride_kv = tuple(pool_kv_stride)
- def prod(iterable):
- return reduce(operator.mul, iterable, 1)
- class PatchEmbed(nn.Module):
- """
- PatchEmbed.
- """
- def __init__(
- self,
- dim_in: int = 3,
- dim_out: int = 768,
- kernel: Tuple[int, int] = (7, 7),
- stride: Tuple[int, int] = (4, 4),
- padding: Tuple[int, int] = (3, 3),
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.proj = nn.Conv2d(
- dim_in,
- dim_out,
- kernel_size=kernel,
- stride=stride,
- padding=padding,
- **dd,
- )
- def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
- x = self.proj(x)
- # B C H W -> B HW C
- return x.flatten(2).transpose(1, 2), x.shape[-2:]
- @register_notrace_function
- def reshape_pre_pool(
- x,
- feat_size: List[int],
- has_cls_token: bool = True
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
- H, W = feat_size
- if has_cls_token:
- cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :]
- else:
- cls_tok = None
- x = x.reshape(-1, H, W, x.shape[-1]).permute(0, 3, 1, 2).contiguous()
- return x, cls_tok
- @register_notrace_function
- def reshape_post_pool(
- x,
- num_heads: int,
- cls_tok: Optional[torch.Tensor] = None
- ) -> Tuple[torch.Tensor, List[int]]:
- feat_size = [x.shape[2], x.shape[3]]
- L_pooled = x.shape[2] * x.shape[3]
- x = x.reshape(-1, num_heads, x.shape[1], L_pooled).transpose(2, 3)
- if cls_tok is not None:
- x = torch.cat((cls_tok, x), dim=2)
- return x, feat_size
- @register_notrace_function
- def cal_rel_pos_type(
- attn: torch.Tensor,
- q: torch.Tensor,
- has_cls_token: bool,
- q_size: List[int],
- k_size: List[int],
- rel_pos_h: torch.Tensor,
- rel_pos_w: torch.Tensor,
- ):
- """
- Spatial Relative Positional Embeddings.
- """
- sp_idx = 1 if has_cls_token else 0
- q_h, q_w = q_size
- k_h, k_w = k_size
- # Scale up rel pos if shapes for q and k are different.
- q_h_ratio = max(k_h / q_h, 1.0)
- k_h_ratio = max(q_h / k_h, 1.0)
- dist_h = (
- torch.arange(q_h, device=q.device, dtype=torch.long).unsqueeze(-1) * q_h_ratio -
- torch.arange(k_h, device=q.device, dtype=torch.long).unsqueeze(0) * k_h_ratio
- )
- dist_h += (k_h - 1) * k_h_ratio
- q_w_ratio = max(k_w / q_w, 1.0)
- k_w_ratio = max(q_w / k_w, 1.0)
- dist_w = (
- torch.arange(q_w, device=q.device, dtype=torch.long).unsqueeze(-1) * q_w_ratio -
- torch.arange(k_w, device=q.device, dtype=torch.long).unsqueeze(0) * k_w_ratio
- )
- dist_w += (k_w - 1) * k_w_ratio
- rel_h = rel_pos_h[dist_h.long()]
- rel_w = rel_pos_w[dist_w.long()]
- B, n_head, q_N, dim = q.shape
- r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim)
- rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, rel_h)
- rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, rel_w)
- attn[:, :, sp_idx:, sp_idx:] = (
- attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w)
- + rel_h.unsqueeze(-1)
- + rel_w.unsqueeze(-2)
- ).view(B, -1, q_h * q_w, k_h * k_w)
- return attn
- class MultiScaleAttentionPoolFirst(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: int,
- feat_size: Tuple[int, int],
- num_heads: int = 8,
- qkv_bias: bool = True,
- mode: str = "conv",
- kernel_q: Tuple[int, int] = (1, 1),
- kernel_kv: Tuple[int, int] = (1, 1),
- stride_q: Tuple[int, int] = (1, 1),
- stride_kv: Tuple[int, int] = (1, 1),
- has_cls_token: bool = True,
- rel_pos_type: str = 'spatial',
- residual_pooling: bool = True,
- norm_layer: Type[nn.Module] = nn.LayerNorm,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.num_heads = num_heads
- self.dim_out = dim_out
- self.head_dim = dim_out // num_heads
- self.scale = self.head_dim ** -0.5
- self.has_cls_token = has_cls_token
- padding_q = tuple([int(q // 2) for q in kernel_q])
- padding_kv = tuple([int(kv // 2) for kv in kernel_kv])
- self.q = nn.Linear(dim, dim_out, bias=qkv_bias, **dd)
- self.k = nn.Linear(dim, dim_out, bias=qkv_bias, **dd)
- self.v = nn.Linear(dim, dim_out, bias=qkv_bias, **dd)
- self.proj = nn.Linear(dim_out, dim_out, **dd)
- # Skip pooling with kernel and stride size of (1, 1, 1).
- if prod(kernel_q) == 1 and prod(stride_q) == 1:
- kernel_q = None
- if prod(kernel_kv) == 1 and prod(stride_kv) == 1:
- kernel_kv = None
- self.mode = mode
- self.unshared = mode == 'conv_unshared'
- self.pool_q, self.pool_k, self.pool_v = None, None, None
- self.norm_q, self.norm_k, self.norm_v = None, None, None
- if mode in ("avg", "max"):
- pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d
- if kernel_q:
- self.pool_q = pool_op(kernel_q, stride_q, padding_q)
- if kernel_kv:
- self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv)
- self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv)
- elif mode == "conv" or mode == "conv_unshared":
- dim_conv = dim // num_heads if mode == "conv" else dim
- if kernel_q:
- self.pool_q = nn.Conv2d(
- dim_conv,
- dim_conv,
- kernel_q,
- stride=stride_q,
- padding=padding_q,
- groups=dim_conv,
- bias=False,
- **dd,
- )
- self.norm_q = norm_layer(dim_conv, **dd)
- if kernel_kv:
- self.pool_k = nn.Conv2d(
- dim_conv,
- dim_conv,
- kernel_kv,
- stride=stride_kv,
- padding=padding_kv,
- groups=dim_conv,
- bias=False,
- **dd,
- )
- self.norm_k = norm_layer(dim_conv, **dd)
- self.pool_v = nn.Conv2d(
- dim_conv,
- dim_conv,
- kernel_kv,
- stride=stride_kv,
- padding=padding_kv,
- groups=dim_conv,
- bias=False,
- **dd,
- )
- self.norm_v = norm_layer(dim_conv, **dd)
- else:
- raise NotImplementedError(f"Unsupported model {mode}")
- # relative pos embedding
- self.rel_pos_type = rel_pos_type
- if self.rel_pos_type == 'spatial':
- assert feat_size[0] == feat_size[1]
- size = feat_size[0]
- q_size = size // stride_q[1] if len(stride_q) > 0 else size
- kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
- rel_sp_dim = 2 * max(q_size, kv_size) - 1
- self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd))
- self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd))
- trunc_normal_tf_(self.rel_pos_h, std=0.02)
- trunc_normal_tf_(self.rel_pos_w, std=0.02)
- self.residual_pooling = residual_pooling
- def forward(self, x, feat_size: List[int]):
- B, N, _ = x.shape
- fold_dim = 1 if self.unshared else self.num_heads
- x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
- q = k = v = x
- if self.pool_q is not None:
- q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token)
- q = self.pool_q(q)
- q, q_size = reshape_post_pool(q, self.num_heads, q_tok)
- else:
- q_size = feat_size
- if self.norm_q is not None:
- q = self.norm_q(q)
- if self.pool_k is not None:
- k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token)
- k = self.pool_k(k)
- k, k_size = reshape_post_pool(k, self.num_heads, k_tok)
- else:
- k_size = feat_size
- if self.norm_k is not None:
- k = self.norm_k(k)
- if self.pool_v is not None:
- v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
- v = self.pool_v(v)
- v, v_size = reshape_post_pool(v, self.num_heads, v_tok)
- else:
- v_size = feat_size
- if self.norm_v is not None:
- v = self.norm_v(v)
- q_N = q_size[0] * q_size[1] + int(self.has_cls_token)
- q = q.transpose(1, 2).reshape(B, q_N, -1)
- q = self.q(q).reshape(B, q_N, self.num_heads, -1).transpose(1, 2)
- k_N = k_size[0] * k_size[1] + int(self.has_cls_token)
- k = k.transpose(1, 2).reshape(B, k_N, -1)
- k = self.k(k).reshape(B, k_N, self.num_heads, -1)
- v_N = v_size[0] * v_size[1] + int(self.has_cls_token)
- v = v.transpose(1, 2).reshape(B, v_N, -1)
- v = self.v(v).reshape(B, v_N, self.num_heads, -1).transpose(1, 2)
- attn = (q * self.scale) @ k
- if self.rel_pos_type == 'spatial':
- attn = cal_rel_pos_type(
- attn,
- q,
- self.has_cls_token,
- q_size,
- k_size,
- self.rel_pos_h,
- self.rel_pos_w,
- )
- attn = attn.softmax(dim=-1)
- x = attn @ v
- if self.residual_pooling:
- x = x + q
- x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
- x = self.proj(x)
- return x, q_size
- class MultiScaleAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: int,
- feat_size: Tuple[int, int],
- num_heads: int = 8,
- qkv_bias: bool = True,
- mode: str = "conv",
- kernel_q: Tuple[int, int] = (1, 1),
- kernel_kv: Tuple[int, int] = (1, 1),
- stride_q: Tuple[int, int] = (1, 1),
- stride_kv: Tuple[int, int] = (1, 1),
- has_cls_token: bool = True,
- rel_pos_type: str = 'spatial',
- residual_pooling: bool = True,
- norm_layer: Type[nn.Module] = nn.LayerNorm,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.num_heads = num_heads
- self.dim_out = dim_out
- self.head_dim = dim_out // num_heads
- self.scale = self.head_dim ** -0.5
- self.has_cls_token = has_cls_token
- padding_q = tuple([int(q // 2) for q in kernel_q])
- padding_kv = tuple([int(kv // 2) for kv in kernel_kv])
- self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias, **dd)
- self.proj = nn.Linear(dim_out, dim_out, **dd)
- # Skip pooling with kernel and stride size of (1, 1, 1).
- if prod(kernel_q) == 1 and prod(stride_q) == 1:
- kernel_q = None
- if prod(kernel_kv) == 1 and prod(stride_kv) == 1:
- kernel_kv = None
- self.mode = mode
- self.unshared = mode == 'conv_unshared'
- self.norm_q, self.norm_k, self.norm_v = None, None, None
- self.pool_q, self.pool_k, self.pool_v = None, None, None
- if mode in ("avg", "max"):
- pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d
- if kernel_q:
- self.pool_q = pool_op(kernel_q, stride_q, padding_q)
- if kernel_kv:
- self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv)
- self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv)
- elif mode == "conv" or mode == "conv_unshared":
- dim_conv = dim_out // num_heads if mode == "conv" else dim_out
- if kernel_q:
- self.pool_q = nn.Conv2d(
- dim_conv,
- dim_conv,
- kernel_q,
- stride=stride_q,
- padding=padding_q,
- groups=dim_conv,
- bias=False,
- **dd,
- )
- self.norm_q = norm_layer(dim_conv, **dd)
- if kernel_kv:
- self.pool_k = nn.Conv2d(
- dim_conv,
- dim_conv,
- kernel_kv,
- stride=stride_kv,
- padding=padding_kv,
- groups=dim_conv,
- bias=False,
- **dd,
- )
- self.norm_k = norm_layer(dim_conv, **dd)
- self.pool_v = nn.Conv2d(
- dim_conv,
- dim_conv,
- kernel_kv,
- stride=stride_kv,
- padding=padding_kv,
- groups=dim_conv,
- bias=False,
- **dd,
- )
- self.norm_v = norm_layer(dim_conv, **dd)
- else:
- raise NotImplementedError(f"Unsupported model {mode}")
- # relative pos embedding
- self.rel_pos_type = rel_pos_type
- if self.rel_pos_type == 'spatial':
- assert feat_size[0] == feat_size[1]
- size = feat_size[0]
- q_size = size // stride_q[1] if len(stride_q) > 0 else size
- kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
- rel_sp_dim = 2 * max(q_size, kv_size) - 1
- self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd))
- self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd))
- trunc_normal_tf_(self.rel_pos_h, std=0.02)
- trunc_normal_tf_(self.rel_pos_w, std=0.02)
- self.residual_pooling = residual_pooling
- def forward(self, x, feat_size: List[int]):
- B, N, _ = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(dim=0)
- if self.pool_q is not None:
- q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token)
- q = self.pool_q(q)
- q, q_size = reshape_post_pool(q, self.num_heads, q_tok)
- else:
- q_size = feat_size
- if self.norm_q is not None:
- q = self.norm_q(q)
- if self.pool_k is not None:
- k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token)
- k = self.pool_k(k)
- k, k_size = reshape_post_pool(k, self.num_heads, k_tok)
- else:
- k_size = feat_size
- if self.norm_k is not None:
- k = self.norm_k(k)
- if self.pool_v is not None:
- v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
- v = self.pool_v(v)
- v, _ = reshape_post_pool(v, self.num_heads, v_tok)
- if self.norm_v is not None:
- v = self.norm_v(v)
- attn = (q * self.scale) @ k.transpose(-2, -1)
- if self.rel_pos_type == 'spatial':
- attn = cal_rel_pos_type(
- attn,
- q,
- self.has_cls_token,
- q_size,
- k_size,
- self.rel_pos_h,
- self.rel_pos_w,
- )
- attn = attn.softmax(dim=-1)
- x = attn @ v
- if self.residual_pooling:
- x = x + q
- x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
- x = self.proj(x)
- return x, q_size
- class MultiScaleBlock(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: int,
- num_heads: int,
- feat_size: Tuple[int, int],
- mlp_ratio: float = 4.0,
- qkv_bias: bool = True,
- drop_path: float = 0.0,
- norm_layer: Type[nn.Module] = nn.LayerNorm,
- kernel_q: Tuple[int, int] = (1, 1),
- kernel_kv: Tuple[int, int] = (1, 1),
- stride_q: Tuple[int, int] = (1, 1),
- stride_kv: Tuple[int, int] = (1, 1),
- mode: str = "conv",
- has_cls_token: bool = True,
- expand_attn: bool = False,
- pool_first: bool = False,
- rel_pos_type: str = 'spatial',
- residual_pooling: bool = True,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- proj_needed = dim != dim_out
- self.dim = dim
- self.dim_out = dim_out
- self.has_cls_token = has_cls_token
- self.norm1 = norm_layer(dim, **dd)
- self.shortcut_proj_attn = nn.Linear(dim, dim_out, **dd) if proj_needed and expand_attn else None
- if stride_q and prod(stride_q) > 1:
- kernel_skip = [s + 1 if s > 1 else s for s in stride_q]
- stride_skip = stride_q
- padding_skip = [int(skip // 2) for skip in kernel_skip]
- self.shortcut_pool_attn = nn.MaxPool2d(kernel_skip, stride_skip, padding_skip)
- else:
- self.shortcut_pool_attn = None
- att_dim = dim_out if expand_attn else dim
- attn_layer = MultiScaleAttentionPoolFirst if pool_first else MultiScaleAttention
- self.attn = attn_layer(
- dim,
- att_dim,
- num_heads=num_heads,
- feat_size=feat_size,
- qkv_bias=qkv_bias,
- kernel_q=kernel_q,
- kernel_kv=kernel_kv,
- stride_q=stride_q,
- stride_kv=stride_kv,
- norm_layer=norm_layer,
- has_cls_token=has_cls_token,
- mode=mode,
- rel_pos_type=rel_pos_type,
- residual_pooling=residual_pooling,
- **dd,
- )
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- self.norm2 = norm_layer(att_dim, **dd)
- mlp_dim_out = dim_out
- self.shortcut_proj_mlp = nn.Linear(dim, dim_out, **dd) if proj_needed and not expand_attn else None
- self.mlp = Mlp(
- in_features=att_dim,
- hidden_features=int(att_dim * mlp_ratio),
- out_features=mlp_dim_out,
- **dd,
- )
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- def _shortcut_pool(self, x, feat_size: List[int]):
- if self.shortcut_pool_attn is None:
- return x
- if self.has_cls_token:
- cls_tok, x = x[:, :1, :], x[:, 1:, :]
- else:
- cls_tok = None
- B, L, C = x.shape
- H, W = feat_size
- x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
- x = self.shortcut_pool_attn(x)
- x = x.reshape(B, C, -1).transpose(1, 2)
- if cls_tok is not None:
- x = torch.cat((cls_tok, x), dim=1)
- return x
- def forward(self, x, feat_size: List[int]):
- x_norm = self.norm1(x)
- # NOTE as per the original impl, this seems odd, but shortcut uses un-normalized input if no proj
- x_shortcut = x if self.shortcut_proj_attn is None else self.shortcut_proj_attn(x_norm)
- x_shortcut = self._shortcut_pool(x_shortcut, feat_size)
- x, feat_size_new = self.attn(x_norm, feat_size)
- x = x_shortcut + self.drop_path1(x)
- x_norm = self.norm2(x)
- x_shortcut = x if self.shortcut_proj_mlp is None else self.shortcut_proj_mlp(x_norm)
- x = x_shortcut + self.drop_path2(self.mlp(x_norm))
- return x, feat_size_new
- class MultiScaleVitStage(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: int,
- depth: int,
- num_heads: int,
- feat_size: Tuple[int, int],
- mlp_ratio: float = 4.0,
- qkv_bias: bool = True,
- kernel_q: Tuple[int, int] = (1, 1),
- kernel_kv: Tuple[int, int] = (1, 1),
- stride_q: Tuple[int, int] = (1, 1),
- stride_kv: Tuple[int, int] = (1, 1),
- mode: str = "conv",
- has_cls_token: bool = True,
- expand_attn: bool = False,
- pool_first: bool = False,
- rel_pos_type: str = 'spatial',
- residual_pooling: bool = True,
- norm_layer: Type[nn.Module] = nn.LayerNorm,
- drop_path: Union[float, List[float]] = 0.0,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.grad_checkpointing = False
- self.blocks = nn.ModuleList()
- if expand_attn:
- out_dims = (dim_out,) * depth
- else:
- out_dims = (dim,) * (depth - 1) + (dim_out,)
- for i in range(depth):
- attention_block = MultiScaleBlock(
- dim=dim,
- dim_out=out_dims[i],
- num_heads=num_heads,
- feat_size=feat_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- kernel_q=kernel_q,
- kernel_kv=kernel_kv,
- stride_q=stride_q if i == 0 else (1, 1),
- stride_kv=stride_kv,
- mode=mode,
- has_cls_token=has_cls_token,
- pool_first=pool_first,
- rel_pos_type=rel_pos_type,
- residual_pooling=residual_pooling,
- expand_attn=expand_attn,
- norm_layer=norm_layer,
- drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path,
- **dd,
- )
- dim = out_dims[i]
- self.blocks.append(attention_block)
- if i == 0:
- feat_size = tuple([size // stride for size, stride in zip(feat_size, stride_q)])
- self.feat_size = feat_size
- def forward(self, x, feat_size: List[int]):
- for blk in self.blocks:
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x, feat_size = checkpoint(blk, x, feat_size)
- else:
- x, feat_size = blk(x, feat_size)
- return x, feat_size
- class MultiScaleVit(nn.Module):
- """
- Improved Multiscale Vision Transformers for Classification and Detection
- Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik,
- Christoph Feichtenhofer*
- https://arxiv.org/abs/2112.01526
- Multiscale Vision Transformers
- Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik,
- Christoph Feichtenhofer*
- https://arxiv.org/abs/2104.11227
- """
- def __init__(
- self,
- cfg: MultiScaleVitCfg,
- img_size: Tuple[int, int] = (224, 224),
- in_chans: int = 3,
- global_pool: Optional[str] = None,
- num_classes: int = 1000,
- drop_path_rate: float = 0.,
- drop_rate: float = 0.,
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- img_size = to_2tuple(img_size)
- norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.drop_rate = drop_rate
- if global_pool is None:
- global_pool = 'token' if cfg.use_cls_token else 'avg'
- self.global_pool = global_pool
- self.depths = tuple(cfg.depths)
- self.expand_attn = cfg.expand_attn
- embed_dim = cfg.embed_dim[0]
- self.patch_embed = PatchEmbed(
- dim_in=in_chans,
- dim_out=embed_dim,
- kernel=cfg.patch_kernel,
- stride=cfg.patch_stride,
- padding=cfg.patch_padding,
- **dd,
- )
- patch_dims = (img_size[0] // cfg.patch_stride[0], img_size[1] // cfg.patch_stride[1])
- num_patches = prod(patch_dims)
- if cfg.use_cls_token:
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
- self.num_prefix_tokens = 1
- pos_embed_dim = num_patches + 1
- else:
- self.num_prefix_tokens = 0
- self.cls_token = None
- pos_embed_dim = num_patches
- if cfg.use_abs_pos:
- self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_dim, embed_dim, **dd))
- else:
- self.pos_embed = None
- num_stages = len(cfg.embed_dim)
- feat_size = patch_dims
- curr_stride = max(cfg.patch_stride)
- dpr = calculate_drop_path_rates(drop_path_rate, cfg.depths, stagewise=True)
- self.stages = nn.ModuleList()
- self.feature_info = []
- for i in range(num_stages):
- if cfg.expand_attn:
- dim_out = cfg.embed_dim[i]
- else:
- dim_out = cfg.embed_dim[min(i + 1, num_stages - 1)]
- stage = MultiScaleVitStage(
- dim=embed_dim,
- dim_out=dim_out,
- depth=cfg.depths[i],
- num_heads=cfg.num_heads[i],
- feat_size=feat_size,
- mlp_ratio=cfg.mlp_ratio,
- qkv_bias=cfg.qkv_bias,
- mode=cfg.mode,
- pool_first=cfg.pool_first,
- expand_attn=cfg.expand_attn,
- kernel_q=cfg.kernel_qkv,
- kernel_kv=cfg.kernel_qkv,
- stride_q=cfg.stride_q[i],
- stride_kv=cfg.stride_kv[i],
- has_cls_token=cfg.use_cls_token,
- rel_pos_type=cfg.rel_pos_type,
- residual_pooling=cfg.residual_pooling,
- norm_layer=norm_layer,
- drop_path=dpr[i],
- **dd,
- )
- curr_stride *= max(cfg.stride_q[i])
- self.feature_info += [dict(module=f'block.{i}', num_chs=dim_out, reduction=curr_stride)]
- embed_dim = dim_out
- feat_size = stage.feat_size
- self.stages.append(stage)
- self.num_features = self.head_hidden_size = embed_dim
- self.norm = norm_layer(embed_dim, **dd)
- self.head = nn.Sequential(OrderedDict([
- ('drop', nn.Dropout(self.drop_rate)),
- ('fc', nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity())
- ]))
- if self.pos_embed is not None:
- trunc_normal_tf_(self.pos_embed, std=0.02)
- if self.cls_token is not None:
- trunc_normal_tf_(self.cls_token, std=0.02)
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_tf_(m.weight, std=0.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0.0)
- @torch.jit.ignore
- def no_weight_decay(self):
- return {k for k, _ in self.named_parameters()
- if any(n in k for n in ["pos_embed", "rel_pos_h", "rel_pos_w", "cls_token"])}
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- matcher = dict(
- stem=r'^patch_embed', # stem and embed
- blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
- )
- return matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- for s in self.stages:
- s.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head.fc
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- if global_pool is not None:
- self.global_pool = global_pool
- device = self.head.fc.weight.device if hasattr(self.head.fc, 'weight') else None
- dtype = self.head.fc.weight.dtype if hasattr(self.head.fc, 'weight') else None
- self.head = nn.Sequential(OrderedDict([
- ('drop', nn.Dropout(self.drop_rate)),
- ('fc', nn.Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity())
- ]))
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- norm: Apply norm layer to all intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- Returns:
- """
- assert output_fmt in ('NCHW', 'NLC'), 'Output shape must be NCHW or NLC.'
- reshape = output_fmt == 'NCHW'
- intermediates = []
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- # FIXME slice block/pos_block if < max
- # forward pass
- x, feat_size = self.patch_embed(x)
- B = x.shape[0]
- if self.cls_token is not None:
- cls_tokens = self.cls_token.expand(B, -1, -1)
- x = torch.cat((cls_tokens, x), dim=1)
- if self.pos_embed is not None:
- x = x + self.pos_embed
- last_idx = len(self.stages) - 1
- for feat_idx, stage in enumerate(self.stages):
- x, feat_size = stage(x, feat_size)
- if feat_idx in take_indices:
- if norm and feat_idx == last_idx:
- x_inter = self.norm(x) # applying final norm last intermediate
- else:
- x_inter = x
- if reshape:
- if self.cls_token is not None:
- # possible to allow return of class tokens, TBD
- x_inter = x_inter[:, 1:]
- x_inter = x_inter.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2)
- intermediates.append(x_inter)
- if intermediates_only:
- return intermediates
- if feat_idx == last_idx:
- x = self.norm(x)
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ):
- """ Prune layers not required for specified intermediates.
- """
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- # FIXME add stage pruning
- # self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
- if prune_norm:
- self.norm = nn.Identity()
- if prune_head:
- self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x):
- x, feat_size = self.patch_embed(x)
- B, N, C = x.shape
- if self.cls_token is not None:
- cls_tokens = self.cls_token.expand(B, -1, -1)
- x = torch.cat((cls_tokens, x), dim=1)
- if self.pos_embed is not None:
- x = x + self.pos_embed
- for stage in self.stages:
- x, feat_size = stage(x, feat_size)
- x = self.norm(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- if self.global_pool:
- if self.global_pool == 'avg':
- x = x[:, self.num_prefix_tokens:].mean(1)
- else:
- x = x[:, 0]
- return x if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def checkpoint_filter_fn(state_dict, model):
- if 'stages.0.blocks.0.norm1.weight' in state_dict:
- # native checkpoint, look for rel_pos interpolations
- for k in state_dict.keys():
- if 'rel_pos' in k:
- rel_pos = state_dict[k]
- dest_rel_pos_shape = model.state_dict()[k].shape
- if rel_pos.shape[0] != dest_rel_pos_shape[0]:
- rel_pos_resized = torch.nn.functional.interpolate(
- rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
- size=dest_rel_pos_shape[0],
- mode="linear",
- )
- state_dict[k] = rel_pos_resized.reshape(-1, dest_rel_pos_shape[0]).permute(1, 0)
- return state_dict
- import re
- if 'model_state' in state_dict:
- state_dict = state_dict['model_state']
- depths = getattr(model, 'depths', None)
- expand_attn = getattr(model, 'expand_attn', True)
- assert depths is not None, 'model requires depth attribute to remap checkpoints'
- depth_map = {}
- block_idx = 0
- for stage_idx, d in enumerate(depths):
- depth_map.update({i: (stage_idx, i - block_idx) for i in range(block_idx, block_idx + d)})
- block_idx += d
- out_dict = {}
- for k, v in state_dict.items():
- k = re.sub(
- r'blocks\.(\d+)',
- lambda x: f'stages.{depth_map[int(x.group(1))][0]}.blocks.{depth_map[int(x.group(1))][1]}',
- k)
- if expand_attn:
- k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_attn', k)
- else:
- k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_mlp', k)
- if 'head' in k:
- k = k.replace('head.projection', 'head.fc')
- out_dict[k] = v
- return out_dict
- model_cfgs = dict(
- mvitv2_tiny=MultiScaleVitCfg(
- depths=(1, 2, 5, 2),
- ),
- mvitv2_small=MultiScaleVitCfg(
- depths=(1, 2, 11, 2),
- ),
- mvitv2_base=MultiScaleVitCfg(
- depths=(2, 3, 16, 3),
- ),
- mvitv2_large=MultiScaleVitCfg(
- depths=(2, 6, 36, 4),
- embed_dim=144,
- num_heads=2,
- expand_attn=False,
- ),
- mvitv2_small_cls=MultiScaleVitCfg(
- depths=(1, 2, 11, 2),
- use_cls_token=True,
- ),
- mvitv2_base_cls=MultiScaleVitCfg(
- depths=(2, 3, 16, 3),
- use_cls_token=True,
- ),
- mvitv2_large_cls=MultiScaleVitCfg(
- depths=(2, 6, 36, 4),
- embed_dim=144,
- num_heads=2,
- use_cls_token=True,
- expand_attn=True,
- ),
- mvitv2_huge_cls=MultiScaleVitCfg(
- depths=(4, 8, 60, 8),
- embed_dim=192,
- num_heads=3,
- use_cls_token=True,
- expand_attn=True,
- ),
- )
- def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs):
- out_indices = kwargs.pop('out_indices', 4)
- return build_model_with_cfg(
- MultiScaleVit,
- variant,
- pretrained,
- model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
- pretrained_filter_fn=checkpoint_filter_fn,
- feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
- **kwargs,
- )
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
- 'crop_pct': .9, 'interpolation': 'bicubic',
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
- 'fixed_input_size': True,
- 'license': 'apache-2.0',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- 'mvitv2_tiny.fb_in1k': _cfg(
- url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth',
- hf_hub_id='timm/'),
- 'mvitv2_small.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth',
- hf_hub_id='timm/'),
- 'mvitv2_base.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth',
- hf_hub_id='timm/'),
- 'mvitv2_large.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth',
- hf_hub_id='timm/'),
- 'mvitv2_small_cls': _cfg(url=''),
- 'mvitv2_base_cls.fb_inw21k': _cfg(
- url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth',
- hf_hub_id='timm/',
- num_classes=19168),
- 'mvitv2_large_cls.fb_inw21k': _cfg(
- url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth',
- hf_hub_id='timm/',
- num_classes=19168),
- 'mvitv2_huge_cls.fb_inw21k': _cfg(
- url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth',
- hf_hub_id='timm/',
- num_classes=19168),
- })
- @register_model
- def mvitv2_tiny(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_tiny', pretrained=pretrained, **kwargs)
- @register_model
- def mvitv2_small(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_small', pretrained=pretrained, **kwargs)
- @register_model
- def mvitv2_base(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_base', pretrained=pretrained, **kwargs)
- @register_model
- def mvitv2_large(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_large', pretrained=pretrained, **kwargs)
- @register_model
- def mvitv2_small_cls(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs)
- @register_model
- def mvitv2_base_cls(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_base_cls', pretrained=pretrained, **kwargs)
- @register_model
- def mvitv2_large_cls(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_large_cls', pretrained=pretrained, **kwargs)
- @register_model
- def mvitv2_huge_cls(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_huge_cls', pretrained=pretrained, **kwargs)
|