| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953 |
- """ DaViT: Dual Attention Vision Transformers
- As described in https://arxiv.org/abs/2204.03645
- Input size invariant transformer architecture that combines channel and spacial
- attention in each block. The attention mechanisms used are linear in complexity.
- DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below
- """
- # Copyright (c) 2022 Mingyu Ding
- # All rights reserved.
- # This source code is licensed under the MIT license
- from functools import partial
- from typing import List, Optional, Tuple, Type, Union
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch import Tensor
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
- from timm.layers import NormMlpClassifierHead, ClassifierHead
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._features_fx import register_notrace_function
- from ._manipulate import checkpoint, checkpoint_seq
- from ._registry import generate_default_cfgs, register_model
- __all__ = ['DaVit']
- class ConvPosEnc(nn.Module):
- def __init__(
- self,
- dim: int,
- k: int = 3,
- act: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.proj = nn.Conv2d(
- dim,
- dim,
- kernel_size=k,
- stride=1,
- padding=k // 2,
- groups=dim,
- **dd,
- )
- self.act = nn.GELU() if act else nn.Identity()
- def forward(self, x: Tensor):
- feat = self.proj(x)
- x = x + self.act(feat)
- return x
- class Stem(nn.Module):
- """ Size-agnostic implementation of 2D image to patch embedding,
- allowing input size to be adjusted during model forward operation
- """
- def __init__(
- self,
- in_chs: int = 3,
- out_chs: int = 96,
- stride: int = 4,
- norm_layer: Type[nn.Module] = LayerNorm2d,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- stride = to_2tuple(stride)
- self.stride = stride
- self.in_chs = in_chs
- self.out_chs = out_chs
- assert stride[0] == 4 # only setup for stride==4
- self.conv = nn.Conv2d(
- in_chs,
- out_chs,
- kernel_size=7,
- stride=stride,
- padding=3,
- **dd,
- )
- self.norm = norm_layer(out_chs, **dd)
- def forward(self, x: Tensor):
- B, C, H, W = x.shape
- pad_r = (self.stride[1] - W % self.stride[1]) % self.stride[1]
- pad_b = (self.stride[0] - H % self.stride[0]) % self.stride[0]
- x = F.pad(x, (0, pad_r, 0, pad_b))
- x = self.conv(x)
- x = self.norm(x)
- return x
- class Downsample(nn.Module):
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- kernel_size: int = 3,
- norm_layer: Type[nn.Module] = LayerNorm2d,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.in_chs = in_chs
- self.out_chs = out_chs
- self.norm = norm_layer(in_chs, **dd)
- self.even_k = kernel_size % 2 == 0
- self.conv = nn.Conv2d(
- in_chs,
- out_chs,
- kernel_size=kernel_size,
- stride=2,
- padding=0 if self.even_k else kernel_size // 2,
- **dd,
- )
- def forward(self, x: Tensor):
- B, C, H, W = x.shape
- x = self.norm(x)
- if self.even_k:
- k_h, k_w = self.conv.kernel_size
- pad_r = (k_w - W % k_w) % k_w
- pad_b = (k_h - H % k_h) % k_h
- x = F.pad(x, (0, pad_r , 0, pad_b))
- x = self.conv(x)
- return x
- class ChannelAttentionV2(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = True,
- dynamic_scale: bool = True,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.groups = num_heads
- self.head_dim = dim // num_heads
- self.dynamic_scale = dynamic_scale
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
- self.proj = nn.Linear(dim, dim, **dd)
- def forward(self, x):
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(0)
- if self.dynamic_scale:
- q = q * N ** -0.5
- else:
- q = q * self.head_dim ** -0.5
- attn = q.transpose(-1, -2) @ k
- attn = attn.softmax(dim=-1)
- x = (attn @ v.transpose(-1, -2)).transpose(-1, -2)
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- return x
- class ChannelAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = head_dim ** -0.5
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
- self.proj = nn.Linear(dim, dim, **dd)
- def forward(self, x: Tensor):
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(0)
- k = k * self.scale
- attn = k.transpose(-1, -2) @ v
- attn = attn.softmax(dim=-1)
- x = (attn @ q.transpose(-1, -2)).transpose(-1, -2)
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- return x
- class ChannelBlock(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int,
- mlp_ratio: float = 4.,
- qkv_bias: bool = False,
- drop_path: float = 0.,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Type[nn.Module] = nn.LayerNorm,
- ffn: bool = True,
- cpe_act: bool = False,
- v2: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd)
- self.ffn = ffn
- self.norm1 = norm_layer(dim, **dd)
- attn_layer = ChannelAttentionV2 if v2 else ChannelAttention
- self.attn = attn_layer(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- **dd,
- )
- self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd)
- if self.ffn:
- self.norm2 = norm_layer(dim, **dd)
- self.mlp = Mlp(
- in_features=dim,
- hidden_features=int(dim * mlp_ratio),
- act_layer=act_layer,
- **dd,
- )
- self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- else:
- self.norm2 = None
- self.mlp = None
- self.drop_path2 = None
- def forward(self, x: Tensor):
- B, C, H, W = x.shape
- x = self.cpe1(x).flatten(2).transpose(1, 2)
- cur = self.norm1(x)
- cur = self.attn(cur)
- x = x + self.drop_path1(cur)
- x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
- if self.mlp is not None:
- x = x.flatten(2).transpose(1, 2)
- x = x + self.drop_path2(self.mlp(self.norm2(x)))
- x = x.transpose(1, 2).view(B, C, H, W)
- return x
- def window_partition(x: Tensor, window_size: Tuple[int, int]):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
- return windows
- @register_notrace_function # reason: int argument is a Proxy
- def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
- Returns:
- x: (B, H, W, C)
- """
- C = windows.shape[-1]
- x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
- return x
- class WindowAttention(nn.Module):
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- """
- fused_attn: torch.jit.Final[bool]
- def __init__(
- self,
- dim: int,
- window_size: Tuple[int, int],
- num_heads: int,
- qkv_bias: bool = True,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.dim = dim
- self.window_size = window_size
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = head_dim ** -0.5
- self.fused_attn = use_fused_attn()
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
- self.proj = nn.Linear(dim, dim, **dd)
- self.softmax = nn.Softmax(dim=-1)
- def forward(self, x: Tensor):
- B_, N, C = x.shape
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(0)
- if self.fused_attn:
- x = F.scaled_dot_product_attention(q, k, v)
- else:
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
- attn = self.softmax(attn)
- x = attn @ v
- x = x.transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- return x
- class SpatialBlock(nn.Module):
- r""" Windows Block.
- Args:
- dim (int): Number of input channels.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
- def __init__(
- self,
- dim: int,
- num_heads: int,
- window_size: int = 7,
- mlp_ratio: float = 4.,
- qkv_bias: bool = True,
- drop_path: float = 0.,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Type[nn.Module] = nn.LayerNorm,
- ffn: bool = True,
- cpe_act: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.dim = dim
- self.ffn = ffn
- self.num_heads = num_heads
- self.window_size = to_2tuple(window_size)
- self.mlp_ratio = mlp_ratio
- self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd)
- self.norm1 = norm_layer(dim, **dd)
- self.attn = WindowAttention(
- dim,
- self.window_size,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- **dd,
- )
- self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd)
- if self.ffn:
- self.norm2 = norm_layer(dim, **dd)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(
- in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- **dd,
- )
- self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- else:
- self.norm2 = None
- self.mlp = None
- self.drop_path1 = None
- def forward(self, x: Tensor):
- B, C, H, W = x.shape
- shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
- x = self.norm1(shortcut)
- x = x.view(B, H, W, C)
- pad_l = pad_t = 0
- pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
- pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
- x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
- _, Hp, Wp, _ = x.shape
- x_windows = window_partition(x, self.window_size)
- x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
- # W-MSA/SW-MSA
- attn_windows = self.attn(x_windows)
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
- x = window_reverse(attn_windows, self.window_size, Hp, Wp)
- # if pad_r > 0 or pad_b > 0:
- x = x[:, :H, :W, :].contiguous()
- x = x.view(B, H * W, C)
- x = shortcut + self.drop_path1(x)
- x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
- if self.mlp is not None:
- x = x.flatten(2).transpose(1, 2)
- x = x + self.drop_path2(self.mlp(self.norm2(x)))
- x = x.transpose(1, 2).view(B, C, H, W)
- return x
- class DaVitStage(nn.Module):
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- depth:int = 1,
- downsample: bool = True,
- attn_types: Tuple[str, ...] = ('spatial', 'channel'),
- num_heads: int = 3,
- window_size: int = 7,
- mlp_ratio: float = 4.,
- qkv_bias: bool = True,
- drop_path_rates: Tuple[float, ...] = (0, 0),
- norm_layer: Type[nn.Module] = LayerNorm2d,
- norm_layer_cl: Type[nn.Module] = nn.LayerNorm,
- ffn: bool = True,
- cpe_act: bool = False,
- down_kernel_size: int = 2,
- named_blocks: bool = False,
- channel_attn_v2: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.grad_checkpointing = False
- # downsample embedding layer at the beginning of each stage
- if downsample:
- self.downsample = Downsample(in_chs, out_chs, kernel_size=down_kernel_size, norm_layer=norm_layer, **dd)
- else:
- self.downsample = nn.Identity()
- '''
- repeating alternating attention blocks in each stage
- default: (spatial -> channel) x depth
- potential opportunity to integrate with a more general version of ByobNet/ByoaNet
- since the logic is similar
- '''
- stage_blocks = []
- for block_idx in range(depth):
- from collections import OrderedDict
- dual_attention_block = []
- for attn_idx, attn_type in enumerate(attn_types):
- if attn_type == 'spatial':
- dual_attention_block.append(('spatial_block', SpatialBlock(
- dim=out_chs,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop_path=drop_path_rates[block_idx],
- norm_layer=norm_layer_cl,
- ffn=ffn,
- cpe_act=cpe_act,
- window_size=window_size,
- **dd,
- )))
- elif attn_type == 'channel':
- dual_attention_block.append(('channel_block', ChannelBlock(
- dim=out_chs,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop_path=drop_path_rates[block_idx],
- norm_layer=norm_layer_cl,
- ffn=ffn,
- cpe_act=cpe_act,
- v2=channel_attn_v2,
- **dd,
- )))
- if named_blocks:
- stage_blocks.append(nn.Sequential(OrderedDict(dual_attention_block)))
- else:
- stage_blocks.append(nn.Sequential(*[b[1] for b in dual_attention_block]))
- self.blocks = nn.Sequential(*stage_blocks)
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- self.grad_checkpointing = enable
- def forward(self, x: Tensor):
- x = self.downsample(x)
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.blocks, x)
- else:
- x = self.blocks(x)
- return x
- class DaVit(nn.Module):
- r""" DaViT
- A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645
- Supports arbitrary input sizes and pyramid feature extraction
- Args:
- in_chans (int): Number of input image channels. Default: 3
- num_classes (int): Number of classes for classification head. Default: 1000
- depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1)
- embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768)
- num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24)
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- """
- def __init__(
- self,
- in_chans: int = 3,
- depths: Tuple[int, ...] = (1, 1, 3, 1),
- embed_dims: Tuple[int, ...] = (96, 192, 384, 768),
- num_heads: Tuple[int, ...] = (3, 6, 12, 24),
- window_size: int = 7,
- mlp_ratio: float = 4,
- qkv_bias: bool = True,
- norm_layer: str = 'layernorm2d',
- norm_layer_cl: str = 'layernorm',
- norm_eps: float = 1e-5,
- attn_types: Tuple[str, ...] = ('spatial', 'channel'),
- ffn: bool = True,
- cpe_act: bool = False,
- down_kernel_size: int = 2,
- channel_attn_v2: bool = False,
- named_blocks: bool = False,
- drop_rate: float = 0.,
- drop_path_rate: float = 0.,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- head_norm_first: bool = False,
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- num_stages = len(embed_dims)
- assert num_stages == len(num_heads) == len(depths)
- norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
- norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.num_features = self.head_hidden_size = embed_dims[-1]
- self.drop_rate = drop_rate
- self.grad_checkpointing = False
- self.feature_info = []
- self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer, **dd)
- in_chs = embed_dims[0]
- dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
- stages = []
- for i in range(num_stages):
- out_chs = embed_dims[i]
- stage = DaVitStage(
- in_chs,
- out_chs,
- depth=depths[i],
- downsample=i > 0,
- attn_types=attn_types,
- num_heads=num_heads[i],
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop_path_rates=dpr[i],
- norm_layer=norm_layer,
- norm_layer_cl=norm_layer_cl,
- ffn=ffn,
- cpe_act=cpe_act,
- down_kernel_size=down_kernel_size,
- channel_attn_v2=channel_attn_v2,
- named_blocks=named_blocks,
- **dd,
- )
- in_chs = out_chs
- stages.append(stage)
- self.feature_info += [dict(num_chs=out_chs, reduction=2**(i+2), module=f'stages.{i}')]
- self.stages = nn.Sequential(*stages)
- # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
- # otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt
- # FIXME generalize this structure to ClassifierHead
- if head_norm_first:
- self.norm_pre = norm_layer(self.num_features, **dd)
- self.head = ClassifierHead(
- self.num_features,
- num_classes,
- pool_type=global_pool,
- drop_rate=self.drop_rate,
- **dd,
- )
- else:
- self.norm_pre = nn.Identity()
- self.head = NormMlpClassifierHead(
- self.num_features,
- num_classes,
- pool_type=global_pool,
- drop_rate=self.drop_rate,
- norm_layer=norm_layer,
- **dd,
- )
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- return dict(
- stem=r'^stem', # stem and embed
- blocks=r'^stages\.(\d+)' if coarse else [
- (r'^stages\.(\d+).downsample', (0,)),
- (r'^stages\.(\d+)\.blocks\.(\d+)', None),
- (r'^norm_pre', (99999,)),
- ]
- )
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- self.grad_checkpointing = enable
- for stage in self.stages:
- stage.set_grad_checkpointing(enable=enable)
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head.fc
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- self.head.reset(num_classes, global_pool)
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- norm: Apply norm layer to compatible intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- Returns:
- """
- assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
- intermediates = []
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- # forward pass
- x = self.stem(x)
- last_idx = len(self.stages) - 1
- if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
- stages = self.stages
- else:
- stages = self.stages[:max_index + 1]
- for feat_idx, stage in enumerate(stages):
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint(stage, x)
- else:
- x = stage(x)
- if feat_idx in take_indices:
- if norm and feat_idx == last_idx:
- x_inter = self.norm_pre(x) # applying final norm to last intermediate
- else:
- x_inter = x
- intermediates.append(x_inter)
- if intermediates_only:
- return intermediates
- if feat_idx == last_idx:
- x = self.norm_pre(x)
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ):
- """ Prune layers not required for specified intermediates.
- """
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
- if prune_norm:
- self.norm_pre = nn.Identity()
- if prune_head:
- self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x):
- x = self.stem(x)
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.stages, x)
- else:
- x = self.stages(x)
- x = self.norm_pre(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- return self.head(x, pre_logits=True) if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def _convert_florence2(state_dict, model, prefix='vision_tower.'):
- import re
- out_dict = {}
- for k, v in state_dict.items():
- if k.startswith(prefix):
- k = k.replace(prefix, '')
- else:
- continue
- k = re.sub(r'convs.([0-9]+)', r'stages.\1.downsample', k)
- k = re.sub(r'blocks.([0-9]+)', r'stages.\1.blocks', k)
- k = k.replace('downsample.proj', 'downsample.conv')
- k = k.replace('stages.0.downsample', 'stem')
- #k = k.replace('head.', 'head.fc.')
- #k = k.replace('norms.', 'head.norm.')
- k = k.replace('window_attn.norm.', 'norm1.')
- k = k.replace('window_attn.fn.', 'attn.')
- k = k.replace('channel_attn.norm.', 'norm1.')
- k = k.replace('channel_attn.fn.', 'attn.')
- k = k.replace('ffn.norm.', 'norm2.')
- k = k.replace('ffn.fn.net.', 'mlp.')
- k = k.replace('conv1.fn.dw', 'cpe1.proj')
- k = k.replace('conv2.fn.dw', 'cpe2.proj')
- out_dict[k] = v
- return out_dict
- def checkpoint_filter_fn(state_dict, model):
- """ Remap MSFT checkpoints -> timm """
- if 'head.fc.weight' in state_dict:
- return state_dict # non-MSFT checkpoint
- if 'state_dict' in state_dict:
- state_dict = state_dict['state_dict']
- if 'vision_tower.convs.0.proj.weight' in state_dict:
- return _convert_florence2(state_dict, model)
- import re
- out_dict = {}
- for k, v in state_dict.items():
- k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.downsample', k)
- k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
- k = k.replace('downsample.proj', 'downsample.conv')
- k = k.replace('stages.0.downsample', 'stem')
- k = k.replace('head.', 'head.fc.')
- k = k.replace('norms.', 'head.norm.')
- k = k.replace('cpe.0', 'cpe1')
- k = k.replace('cpe.1', 'cpe2')
- out_dict[k] = v
- return out_dict
- def _create_davit(variant, pretrained=False, **kwargs):
- default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
- out_indices = kwargs.pop('out_indices', default_out_indices)
- strict = kwargs.pop('pretrained_strict', True)
- if variant.endswith('_fl'):
- # FIXME cleaner approach to missing head norm?
- strict = False
- model = build_model_with_cfg(
- DaVit,
- variant,
- pretrained,
- pretrained_filter_fn=checkpoint_filter_fn,
- feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
- pretrained_strict=strict,
- **kwargs)
- return model
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
- 'crop_pct': 0.95, 'interpolation': 'bicubic',
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'stem.conv', 'classifier': 'head.fc',
- 'license': 'apache-2.0',
- **kwargs
- }
- # TODO contact authors to get larger pretrained models
- default_cfgs = generate_default_cfgs({
- # official microsoft weights from https://github.com/dingmyu/davit
- 'davit_tiny.msft_in1k': _cfg(
- hf_hub_id='timm/'),
- 'davit_small.msft_in1k': _cfg(
- hf_hub_id='timm/'),
- 'davit_base.msft_in1k': _cfg(
- hf_hub_id='timm/'),
- 'davit_large': _cfg(),
- 'davit_huge': _cfg(),
- 'davit_giant': _cfg(),
- 'davit_base_fl.msft_florence2': _cfg(
- hf_hub_id='microsoft/Florence-2-base',
- num_classes=0, input_size=(3, 768, 768)),
- 'davit_huge_fl.msft_florence2': _cfg(
- hf_hub_id='microsoft/Florence-2-large',
- num_classes=0, input_size=(3, 768, 768)),
- })
- @register_model
- def davit_tiny(pretrained=False, **kwargs) -> DaVit:
- model_args = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24))
- return _create_davit('davit_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def davit_small(pretrained=False, **kwargs) -> DaVit:
- model_args = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24))
- return _create_davit('davit_small', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def davit_base(pretrained=False, **kwargs) -> DaVit:
- model_args = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32))
- return _create_davit('davit_base', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def davit_large(pretrained=False, **kwargs) -> DaVit:
- model_args = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48))
- return _create_davit('davit_large', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def davit_huge(pretrained=False, **kwargs) -> DaVit:
- model_args = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64))
- return _create_davit('davit_huge', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def davit_giant(pretrained=False, **kwargs) -> DaVit:
- model_args = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96))
- return _create_davit('davit_giant', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def davit_base_fl(pretrained=False, **kwargs) -> DaVit:
- model_args = dict(
- depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32),
- window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
- )
- return _create_davit('davit_base_fl', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def davit_huge_fl(pretrained=False, **kwargs) -> DaVit:
- # NOTE: huge image tower used in 'large' Florence2 model
- model_args = dict(
- depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64),
- window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
- )
- return _create_davit('davit_huge_fl', pretrained=pretrained, **dict(model_args, **kwargs))
|