| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281 |
- """ EfficientViT (by MIT Song Han's Lab)
- Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition`
- - https://arxiv.org/abs/2205.14756
- Adapted from official impl at https://github.com/mit-han-lab/efficientvit
- """
- __all__ = ['EfficientVit', 'EfficientVitLarge']
- from typing import List, Optional, Tuple, Type, Union
- from functools import partial
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._features_fx import register_notrace_module
- from ._manipulate import checkpoint_seq
- from ._registry import register_model, generate_default_cfgs
- def val2list(x: list or tuple or any, repeat_time=1):
- if isinstance(x, (list, tuple)):
- return list(x)
- return [x for _ in range(repeat_time)]
- def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1):
- # repeat elements if necessary
- x = val2list(x)
- if len(x) > 0:
- x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
- return tuple(x)
- def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
- if isinstance(kernel_size, tuple):
- return tuple([get_same_padding(ks) for ks in kernel_size])
- else:
- assert kernel_size % 2 > 0, "kernel size should be odd number"
- return kernel_size // 2
- class ConvNormAct(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, Tuple[int, int]] = 3,
- stride: int = 1,
- dilation: int = 1,
- groups: int = 1,
- bias: bool = False,
- dropout: float = 0.,
- norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm2d,
- act_layer: Optional[Type[nn.Module]] = nn.ReLU,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.dropout = nn.Dropout(dropout, inplace=False)
- self.conv = create_conv2d(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation,
- groups=groups,
- bias=bias,
- **dd,
- )
- self.norm = norm_layer(num_features=out_channels, **dd) if norm_layer else nn.Identity()
- self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity()
- def forward(self, x):
- x = self.dropout(x)
- x = self.conv(x)
- x = self.norm(x)
- x = self.act(x)
- return x
- class DSConv(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int = 3,
- stride: int = 1,
- use_bias: Union[bool, Tuple[bool, bool]] = False,
- norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d,
- act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, None),
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- use_bias = val2tuple(use_bias, 2)
- norm_layer = val2tuple(norm_layer, 2)
- act_layer = val2tuple(act_layer, 2)
- self.depth_conv = ConvNormAct(
- in_channels,
- in_channels,
- kernel_size,
- stride,
- groups=in_channels,
- norm_layer=norm_layer[0],
- act_layer=act_layer[0],
- bias=use_bias[0],
- **dd,
- )
- self.point_conv = ConvNormAct(
- in_channels,
- out_channels,
- 1,
- norm_layer=norm_layer[1],
- act_layer=act_layer[1],
- bias=use_bias[1],
- **dd,
- )
- def forward(self, x):
- x = self.depth_conv(x)
- x = self.point_conv(x)
- return x
- class ConvBlock(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int = 3,
- stride: int = 1,
- mid_channels: Optional[int] = None,
- expand_ratio: float = 1,
- use_bias: Union[bool, Tuple[bool, bool]] = False,
- norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d,
- act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, None),
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- use_bias = val2tuple(use_bias, 2)
- norm_layer = val2tuple(norm_layer, 2)
- act_layer = val2tuple(act_layer, 2)
- mid_channels = mid_channels or round(in_channels * expand_ratio)
- self.conv1 = ConvNormAct(
- in_channels,
- mid_channels,
- kernel_size,
- stride,
- norm_layer=norm_layer[0],
- act_layer=act_layer[0],
- bias=use_bias[0],
- **dd,
- )
- self.conv2 = ConvNormAct(
- mid_channels,
- out_channels,
- kernel_size,
- 1,
- norm_layer=norm_layer[1],
- act_layer=act_layer[1],
- bias=use_bias[1],
- **dd,
- )
- def forward(self, x):
- x = self.conv1(x)
- x = self.conv2(x)
- return x
- class MBConv(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int = 3,
- stride: int = 1,
- mid_channels: Optional[int] = None,
- expand_ratio: float = 6,
- use_bias: Union[bool, Tuple[bool, ...]] = False,
- norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d,
- act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, nn.ReLU6, None),
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- use_bias = val2tuple(use_bias, 3)
- norm_layer = val2tuple(norm_layer, 3)
- act_layer = val2tuple(act_layer, 3)
- mid_channels = mid_channels or round(in_channels * expand_ratio)
- self.inverted_conv = ConvNormAct(
- in_channels,
- mid_channels,
- 1,
- stride=1,
- norm_layer=norm_layer[0],
- act_layer=act_layer[0],
- bias=use_bias[0],
- **dd,
- )
- self.depth_conv = ConvNormAct(
- mid_channels,
- mid_channels,
- kernel_size,
- stride=stride,
- groups=mid_channels,
- norm_layer=norm_layer[1],
- act_layer=act_layer[1],
- bias=use_bias[1],
- **dd,
- )
- self.point_conv = ConvNormAct(
- mid_channels,
- out_channels,
- 1,
- norm_layer=norm_layer[2],
- act_layer=act_layer[2],
- bias=use_bias[2],
- **dd,
- )
- def forward(self, x):
- x = self.inverted_conv(x)
- x = self.depth_conv(x)
- x = self.point_conv(x)
- return x
- class FusedMBConv(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int = 3,
- stride: int = 1,
- mid_channels: Optional[int] = None,
- expand_ratio: float = 6,
- groups: int = 1,
- use_bias: Union[bool, Tuple[bool, ...]] = False,
- norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d,
- act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, None),
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- use_bias = val2tuple(use_bias, 2)
- norm_layer = val2tuple(norm_layer, 2)
- act_layer = val2tuple(act_layer, 2)
- mid_channels = mid_channels or round(in_channels * expand_ratio)
- self.spatial_conv = ConvNormAct(
- in_channels,
- mid_channels,
- kernel_size,
- stride=stride,
- groups=groups,
- norm_layer=norm_layer[0],
- act_layer=act_layer[0],
- bias=use_bias[0],
- **dd,
- )
- self.point_conv = ConvNormAct(
- mid_channels,
- out_channels,
- 1,
- norm_layer=norm_layer[1],
- act_layer=act_layer[1],
- bias=use_bias[1],
- **dd,
- )
- def forward(self, x):
- x = self.spatial_conv(x)
- x = self.point_conv(x)
- return x
- class LiteMLA(nn.Module):
- """Lightweight multi-scale linear attention"""
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- heads: Optional[int] = None,
- heads_ratio: float = 1.0,
- dim: int = 8,
- use_bias: Union[bool, Tuple[bool, ...]] = False,
- norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (None, nn.BatchNorm2d),
- act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (None, None),
- kernel_func: Type[nn.Module] = nn.ReLU,
- scales: Tuple[int, ...] = (5,),
- eps: float = 1e-5,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.eps = eps
- heads = heads or int(in_channels // dim * heads_ratio)
- total_dim = heads * dim
- use_bias = val2tuple(use_bias, 2)
- norm_layer = val2tuple(norm_layer, 2)
- act_layer = val2tuple(act_layer, 2)
- self.dim = dim
- self.qkv = ConvNormAct(
- in_channels,
- 3 * total_dim,
- 1,
- bias=use_bias[0],
- norm_layer=norm_layer[0],
- act_layer=act_layer[0],
- **dd,
- )
- self.aggreg = nn.ModuleList([
- nn.Sequential(
- nn.Conv2d(
- 3 * total_dim,
- 3 * total_dim,
- scale,
- padding=get_same_padding(scale),
- groups=3 * total_dim,
- bias=use_bias[0],
- **dd,
- ),
- nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0], **dd),
- )
- for scale in scales
- ])
- self.kernel_func = kernel_func(inplace=False)
- self.proj = ConvNormAct(
- total_dim * (1 + len(scales)),
- out_channels,
- 1,
- bias=use_bias[1],
- norm_layer=norm_layer[1],
- act_layer=act_layer[1],
- **dd,
- )
- def _attn(self, q, k, v):
- dtype = v.dtype
- q, k, v = q.float(), k.float(), v.float()
- kv = k.transpose(-1, -2) @ v
- out = q @ kv
- out = out[..., :-1] / (out[..., -1:] + self.eps)
- return out.to(dtype)
- def forward(self, x):
- B, _, H, W = x.shape
- # generate multi-scale q, k, v
- qkv = self.qkv(x)
- multi_scale_qkv = [qkv]
- for op in self.aggreg:
- multi_scale_qkv.append(op(qkv))
- multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
- multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2)
- q, k, v = multi_scale_qkv.chunk(3, dim=-1)
- # lightweight global attention
- q = self.kernel_func(q)
- k = self.kernel_func(k)
- v = F.pad(v, (0, 1), mode="constant", value=1.)
- if not torch.jit.is_scripting():
- with torch.autocast(device_type=v.device.type, enabled=False):
- out = self._attn(q, k, v)
- else:
- out = self._attn(q, k, v)
- # final projection
- out = out.transpose(-1, -2).reshape(B, -1, H, W)
- out = self.proj(out)
- return out
- register_notrace_module(LiteMLA)
- class EfficientVitBlock(nn.Module):
- def __init__(
- self,
- in_channels: int,
- heads_ratio: float = 1.0,
- head_dim: int = 32,
- expand_ratio: float = 4,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- act_layer: Type[nn.Module] = nn.Hardswish,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.context_module = ResidualBlock(
- LiteMLA(
- in_channels=in_channels,
- out_channels=in_channels,
- heads_ratio=heads_ratio,
- dim=head_dim,
- norm_layer=(None, norm_layer),
- **dd,
- ),
- nn.Identity(),
- )
- self.local_module = ResidualBlock(
- MBConv(
- in_channels=in_channels,
- out_channels=in_channels,
- expand_ratio=expand_ratio,
- use_bias=(True, True, False),
- norm_layer=(None, None, norm_layer),
- act_layer=(act_layer, act_layer, None),
- **dd,
- ),
- nn.Identity(),
- )
- def forward(self, x):
- x = self.context_module(x)
- x = self.local_module(x)
- return x
- class ResidualBlock(nn.Module):
- def __init__(
- self,
- main: Optional[nn.Module],
- shortcut: Optional[nn.Module] = None,
- pre_norm: Optional[nn.Module] = None,
- ):
- super().__init__()
- self.pre_norm = pre_norm if pre_norm is not None else nn.Identity()
- self.main = main
- self.shortcut = shortcut
- def forward(self, x):
- res = self.main(self.pre_norm(x))
- if self.shortcut is not None:
- res = res + self.shortcut(x)
- return res
- def build_local_block(
- in_channels: int,
- out_channels: int,
- stride: int,
- expand_ratio: float,
- norm_layer: str,
- act_layer: str,
- fewer_norm: bool = False,
- block_type: str = "default",
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- assert block_type in ["default", "large", "fused"]
- if expand_ratio == 1:
- if block_type == "default":
- block = DSConv(
- in_channels=in_channels,
- out_channels=out_channels,
- stride=stride,
- use_bias=(True, False) if fewer_norm else False,
- norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
- act_layer=(act_layer, None),
- **dd,
- )
- else:
- block = ConvBlock(
- in_channels=in_channels,
- out_channels=out_channels,
- stride=stride,
- use_bias=(True, False) if fewer_norm else False,
- norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
- act_layer=(act_layer, None),
- **dd,
- )
- else:
- if block_type == "default":
- block = MBConv(
- in_channels=in_channels,
- out_channels=out_channels,
- stride=stride,
- expand_ratio=expand_ratio,
- use_bias=(True, True, False) if fewer_norm else False,
- norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer,
- act_layer=(act_layer, act_layer, None),
- **dd,
- )
- else:
- block = FusedMBConv(
- in_channels=in_channels,
- out_channels=out_channels,
- stride=stride,
- expand_ratio=expand_ratio,
- use_bias=(True, False) if fewer_norm else False,
- norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
- act_layer=(act_layer, None),
- **dd,
- )
- return block
- class Stem(nn.Sequential):
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- depth: int,
- norm_layer: Type[nn.Module],
- act_layer: Type[nn.Module],
- block_type: str = 'default',
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.stride = 2
- self.add_module(
- 'in_conv',
- ConvNormAct(
- in_chs,
- out_chs,
- kernel_size=3,
- stride=2,
- norm_layer=norm_layer,
- act_layer=act_layer,
- **dd,
- )
- )
- stem_block = 0
- for _ in range(depth):
- self.add_module(f'res{stem_block}', ResidualBlock(
- build_local_block(
- in_channels=out_chs,
- out_channels=out_chs,
- stride=1,
- expand_ratio=1,
- norm_layer=norm_layer,
- act_layer=act_layer,
- block_type=block_type,
- **dd,
- ),
- nn.Identity(),
- ))
- stem_block += 1
- class EfficientVitStage(nn.Module):
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- depth: int,
- norm_layer: Type[nn.Module],
- act_layer: Type[nn.Module],
- expand_ratio: float,
- head_dim: int,
- vit_stage: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- blocks = [ResidualBlock(
- build_local_block(
- in_channels=in_chs,
- out_channels=out_chs,
- stride=2,
- expand_ratio=expand_ratio,
- norm_layer=norm_layer,
- act_layer=act_layer,
- fewer_norm=vit_stage,
- **dd,
- ),
- None,
- )]
- in_chs = out_chs
- if vit_stage:
- # for stage 3, 4
- for _ in range(depth):
- blocks.append(
- EfficientVitBlock(
- in_channels=in_chs,
- head_dim=head_dim,
- expand_ratio=expand_ratio,
- norm_layer=norm_layer,
- act_layer=act_layer,
- **dd,
- )
- )
- else:
- # for stage 1, 2
- for i in range(1, depth):
- blocks.append(ResidualBlock(
- build_local_block(
- in_channels=in_chs,
- out_channels=out_chs,
- stride=1,
- expand_ratio=expand_ratio,
- norm_layer=norm_layer,
- act_layer=act_layer,
- **dd,
- ),
- nn.Identity(),
- ))
- self.blocks = nn.Sequential(*blocks)
- def forward(self, x):
- return self.blocks(x)
- class EfficientVitLargeStage(nn.Module):
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- depth: int,
- norm_layer: Type[nn.Module],
- act_layer: Type[nn.Module],
- head_dim: int,
- vit_stage: bool = False,
- fewer_norm: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- blocks = [ResidualBlock(
- build_local_block(
- in_channels=in_chs,
- out_channels=out_chs,
- stride=2,
- expand_ratio=24 if vit_stage else 16,
- norm_layer=norm_layer,
- act_layer=act_layer,
- fewer_norm=vit_stage or fewer_norm,
- block_type='default' if fewer_norm else 'fused',
- **dd,
- ),
- None,
- )]
- in_chs = out_chs
- if vit_stage:
- # for stage 4
- for _ in range(depth):
- blocks.append(
- EfficientVitBlock(
- in_channels=in_chs,
- head_dim=head_dim,
- expand_ratio=6,
- norm_layer=norm_layer,
- act_layer=act_layer,
- **dd,
- )
- )
- else:
- # for stage 1, 2, 3
- for i in range(depth):
- blocks.append(ResidualBlock(
- build_local_block(
- in_channels=in_chs,
- out_channels=out_chs,
- stride=1,
- expand_ratio=4,
- norm_layer=norm_layer,
- act_layer=act_layer,
- fewer_norm=fewer_norm,
- block_type='default' if fewer_norm else 'fused',
- **dd,
- ),
- nn.Identity(),
- ))
- self.blocks = nn.Sequential(*blocks)
- def forward(self, x):
- return self.blocks(x)
- class ClassifierHead(nn.Module):
- def __init__(
- self,
- in_channels: int,
- widths: List[int],
- num_classes: int = 1000,
- dropout: float = 0.,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- act_layer: Optional[Type[nn.Module]] = nn.Hardswish,
- pool_type: str = 'avg',
- norm_eps: float = 1e-5,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.widths = widths
- self.num_features = widths[-1]
- assert pool_type, 'Cannot disable pooling'
- self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer, **dd)
- self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True)
- self.classifier = nn.Sequential(
- nn.Linear(widths[0], widths[1], bias=False, **dd),
- nn.LayerNorm(widths[1], eps=norm_eps, **dd),
- act_layer(inplace=True) if act_layer is not None else nn.Identity(),
- nn.Dropout(dropout, inplace=False),
- nn.Linear(widths[1], num_classes, bias=True, **dd) if num_classes > 0 else nn.Identity(),
- )
- def reset(self, num_classes: int, pool_type: Optional[str] = None):
- if pool_type is not None:
- assert pool_type, 'Cannot disable pooling'
- self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True,)
- if num_classes > 0:
- self.classifier[-1] = nn.Linear(self.num_features, num_classes, bias=True)
- else:
- self.classifier[-1] = nn.Identity()
- def forward(self, x, pre_logits: bool = False):
- x = self.in_conv(x)
- x = self.global_pool(x)
- if pre_logits:
- # cannot slice or iterate with torchscript so, this
- x = self.classifier[0](x)
- x = self.classifier[1](x)
- x = self.classifier[2](x)
- x = self.classifier[3](x)
- else:
- x = self.classifier(x)
- return x
- class EfficientVit(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- widths: Tuple[int, ...] = (),
- depths: Tuple[int, ...] = (),
- head_dim: int = 32,
- expand_ratio: float = 4,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- act_layer: Type[nn.Module] = nn.Hardswish,
- global_pool: str = 'avg',
- head_widths: Tuple[int, ...] = (),
- drop_rate: float = 0.0,
- num_classes: int = 1000,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.grad_checkpointing = False
- self.global_pool = global_pool
- self.num_classes = num_classes
- self.in_chans = in_chans
- # input stem
- self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, **dd)
- stride = self.stem.stride
- # stages
- self.feature_info = []
- self.stages = nn.Sequential()
- in_channels = widths[0]
- for i, (w, d) in enumerate(zip(widths[1:], depths[1:])):
- self.stages.append(EfficientVitStage(
- in_channels,
- w,
- depth=d,
- norm_layer=norm_layer,
- act_layer=act_layer,
- expand_ratio=expand_ratio,
- head_dim=head_dim,
- vit_stage=i >= 2,
- **dd,
- ))
- stride *= 2
- in_channels = w
- self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')]
- self.num_features = in_channels
- self.head = ClassifierHead(
- self.num_features,
- widths=head_widths,
- num_classes=num_classes,
- dropout=drop_rate,
- pool_type=self.global_pool,
- **dd,
- )
- self.head_hidden_size = self.head.num_features
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- matcher = dict(
- stem=r'^stem',
- blocks=r'^stages\.(\d+)' if coarse else [
- (r'^stages\.(\d+).downsample', (0,)),
- (r'^stages\.(\d+)\.\w+\.(\d+)', None),
- ]
- )
- return matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- self.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head.classifier[-1]
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- self.head.reset(num_classes, global_pool)
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- norm: Apply norm layer to compatible intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- Returns:
- """
- assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
- intermediates = []
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- # forward pass
- x = self.stem(x)
- if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
- stages = self.stages
- else:
- stages = self.stages[:max_index + 1]
- for feat_idx, stage in enumerate(stages):
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(stages, x)
- else:
- x = stage(x)
- if feat_idx in take_indices:
- intermediates.append(x)
- if intermediates_only:
- return intermediates
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ):
- """ Prune layers not required for specified intermediates.
- """
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
- if prune_head:
- self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x):
- x = self.stem(x)
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.stages, x)
- else:
- x = self.stages(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- class EfficientVitLarge(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- widths: Tuple[int, ...] = (),
- depths: Tuple[int, ...] = (),
- head_dim: int = 32,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- act_layer: Type[nn.Module] = GELUTanh,
- global_pool: str = 'avg',
- head_widths: Tuple[int, ...] = (),
- drop_rate: float = 0.0,
- num_classes: int = 1000,
- norm_eps: float = 1e-7,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.grad_checkpointing = False
- self.global_pool = global_pool
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.norm_eps = norm_eps
- norm_layer = partial(norm_layer, eps=self.norm_eps)
- # input stem
- self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, block_type='large', **dd)
- stride = self.stem.stride
- # stages
- self.feature_info = []
- self.stages = nn.Sequential()
- in_channels = widths[0]
- for i, (w, d) in enumerate(zip(widths[1:], depths[1:])):
- self.stages.append(EfficientVitLargeStage(
- in_channels,
- w,
- depth=d,
- norm_layer=norm_layer,
- act_layer=act_layer,
- head_dim=head_dim,
- vit_stage=i >= 3,
- fewer_norm=i >= 2,
- **dd,
- ))
- stride *= 2
- in_channels = w
- self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')]
- self.num_features = in_channels
- self.head = ClassifierHead(
- self.num_features,
- widths=head_widths,
- num_classes=num_classes,
- dropout=drop_rate,
- pool_type=self.global_pool,
- act_layer=act_layer,
- norm_eps=self.norm_eps,
- **dd,
- )
- self.head_hidden_size = self.head.num_features
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- matcher = dict(
- stem=r'^stem',
- blocks=r'^stages\.(\d+)' if coarse else [
- (r'^stages\.(\d+).downsample', (0,)),
- (r'^stages\.(\d+)\.\w+\.(\d+)', None),
- ]
- )
- return matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- self.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head.classifier[-1]
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- self.head.reset(num_classes, global_pool)
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- norm: Apply norm layer to compatible intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- Returns:
- """
- assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
- intermediates = []
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- # forward pass
- x = self.stem(x)
- if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
- stages = self.stages
- else:
- stages = self.stages[:max_index + 1]
- for feat_idx, stage in enumerate(stages):
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(stages, x)
- else:
- x = stage(x)
- if feat_idx in take_indices:
- intermediates.append(x)
- if intermediates_only:
- return intermediates
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ):
- """ Prune layers not required for specified intermediates.
- """
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
- if prune_head:
- self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x):
- x = self.stem(x)
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.stages, x)
- else:
- x = self.stages(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000,
- 'mean': IMAGENET_DEFAULT_MEAN,
- 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'stem.in_conv.conv',
- 'classifier': 'head.classifier.4',
- 'crop_pct': 0.95,
- 'license': 'apache-2.0',
- 'input_size': (3, 224, 224),
- 'pool_size': (7, 7),
- **kwargs,
- }
- default_cfgs = generate_default_cfgs({
- 'efficientvit_b0.r224_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'efficientvit_b1.r224_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'efficientvit_b1.r256_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
- ),
- 'efficientvit_b1.r288_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
- ),
- 'efficientvit_b2.r224_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'efficientvit_b2.r256_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
- ),
- 'efficientvit_b2.r288_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
- ),
- 'efficientvit_b3.r224_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'efficientvit_b3.r256_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
- ),
- 'efficientvit_b3.r288_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
- ),
- 'efficientvit_l1.r224_in1k': _cfg(
- hf_hub_id='timm/',
- crop_pct=1.0,
- ),
- 'efficientvit_l2.r224_in1k': _cfg(
- hf_hub_id='timm/',
- crop_pct=1.0,
- ),
- 'efficientvit_l2.r256_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
- ),
- 'efficientvit_l2.r288_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
- ),
- 'efficientvit_l2.r384_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
- ),
- 'efficientvit_l3.r224_in1k': _cfg(
- hf_hub_id='timm/',
- crop_pct=1.0,
- ),
- 'efficientvit_l3.r256_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
- ),
- 'efficientvit_l3.r320_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0,
- ),
- 'efficientvit_l3.r384_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
- ),
- # 'efficientvit_l0_sam.sam': _cfg(
- # # hf_hub_id='timm/',
- # input_size=(3, 512, 512), crop_pct=1.0,
- # num_classes=0,
- # ),
- # 'efficientvit_l1_sam.sam': _cfg(
- # # hf_hub_id='timm/',
- # input_size=(3, 512, 512), crop_pct=1.0,
- # num_classes=0,
- # ),
- # 'efficientvit_l2_sam.sam': _cfg(
- # # hf_hub_id='timm/',f
- # input_size=(3, 512, 512), crop_pct=1.0,
- # num_classes=0,
- # ),
- })
- def _create_efficientvit(variant, pretrained=False, **kwargs):
- out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
- model = build_model_with_cfg(
- EfficientVit,
- variant,
- pretrained,
- feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
- **kwargs
- )
- return model
- def _create_efficientvit_large(variant, pretrained=False, **kwargs):
- out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
- model = build_model_with_cfg(
- EfficientVitLarge,
- variant,
- pretrained,
- feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
- **kwargs
- )
- return model
- @register_model
- def efficientvit_b0(pretrained=False, **kwargs):
- model_args = dict(
- widths=(8, 16, 32, 64, 128), depths=(1, 2, 2, 2, 2), head_dim=16, head_widths=(1024, 1280))
- return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def efficientvit_b1(pretrained=False, **kwargs):
- model_args = dict(
- widths=(16, 32, 64, 128, 256), depths=(1, 2, 3, 3, 4), head_dim=16, head_widths=(1536, 1600))
- return _create_efficientvit('efficientvit_b1', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def efficientvit_b2(pretrained=False, **kwargs):
- model_args = dict(
- widths=(24, 48, 96, 192, 384), depths=(1, 3, 4, 4, 6), head_dim=32, head_widths=(2304, 2560))
- return _create_efficientvit('efficientvit_b2', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def efficientvit_b3(pretrained=False, **kwargs):
- model_args = dict(
- widths=(32, 64, 128, 256, 512), depths=(1, 4, 6, 6, 9), head_dim=32, head_widths=(2304, 2560))
- return _create_efficientvit('efficientvit_b3', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def efficientvit_l1(pretrained=False, **kwargs):
- model_args = dict(
- widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, head_widths=(3072, 3200))
- return _create_efficientvit_large('efficientvit_l1', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def efficientvit_l2(pretrained=False, **kwargs):
- model_args = dict(
- widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, head_widths=(3072, 3200))
- return _create_efficientvit_large('efficientvit_l2', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def efficientvit_l3(pretrained=False, **kwargs):
- model_args = dict(
- widths=(64, 128, 256, 512, 1024), depths=(1, 2, 2, 8, 8), head_dim=32, head_widths=(6144, 6400))
- return _create_efficientvit_large('efficientvit_l3', pretrained=pretrained, **dict(model_args, **kwargs))
- # FIXME will wait for v2 SAM models which are pending
- # @register_model
- # def efficientvit_l0_sam(pretrained=False, **kwargs):
- # # only backbone for segment-anything-model weights
- # model_args = dict(
- # widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0, norm_eps=1e-6)
- # return _create_efficientvit_large('efficientvit_l0_sam', pretrained=pretrained, **dict(model_args, **kwargs))
- #
- #
- # @register_model
- # def efficientvit_l1_sam(pretrained=False, **kwargs):
- # # only backbone for segment-anything-model weights
- # model_args = dict(
- # widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0, norm_eps=1e-6)
- # return _create_efficientvit_large('efficientvit_l1_sam', pretrained=pretrained, **dict(model_args, **kwargs))
- #
- #
- # @register_model
- # def efficientvit_l2_sam(pretrained=False, **kwargs):
- # # only backbone for segment-anything-model weights
- # model_args = dict(
- # widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0, norm_eps=1e-6)
- # return _create_efficientvit_large('efficientvit_l2_sam', pretrained=pretrained, **dict(model_args, **kwargs))
|