| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826 |
- """ EfficientViT (by MSRA)
- Paper: `EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention`
- - https://arxiv.org/abs/2305.07027
- Adapted from official impl at https://github.com/microsoft/Cream/tree/main/EfficientViT
- """
- __all__ = ['EfficientVitMsra']
- import itertools
- from collections import OrderedDict
- from functools import partial
- from typing import Dict, List, Optional, Tuple, Type, Union
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._manipulate import checkpoint, checkpoint_seq
- from ._registry import register_model, generate_default_cfgs
- class ConvNorm(torch.nn.Sequential):
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- ks: int = 1,
- stride: int = 1,
- pad: int = 0,
- dilation: int = 1,
- groups: int = 1,
- bn_weight_init: float = 1,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False, **dd)
- self.bn = nn.BatchNorm2d(out_chs, **dd)
- torch.nn.init.constant_(self.bn.weight, bn_weight_init)
- @torch.no_grad()
- def fuse(self):
- c, bn = self.conv, self.bn
- w = bn.weight / (bn.running_var + bn.eps)**0.5
- w = c.weight * w[:, None, None, None]
- b = bn.bias - bn.running_mean * bn.weight / \
- (bn.running_var + bn.eps)**0.5
- m = torch.nn.Conv2d(
- w.size(1) * self.conv.groups, w.size(0), w.shape[2:],
- stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups)
- m.weight.data.copy_(w)
- m.bias.data.copy_(b)
- return m
- class NormLinear(torch.nn.Sequential):
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias: bool = True,
- std: float = 0.02,
- drop: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.bn = nn.BatchNorm1d(in_features, **dd)
- self.drop = nn.Dropout(drop)
- self.linear = nn.Linear(in_features, out_features, bias=bias, **dd)
- trunc_normal_(self.linear.weight, std=std)
- if self.linear.bias is not None:
- nn.init.constant_(self.linear.bias, 0)
- @torch.no_grad()
- def fuse(self):
- bn, linear = self.bn, self.linear
- w = bn.weight / (bn.running_var + bn.eps)**0.5
- b = bn.bias - self.bn.running_mean * \
- self.bn.weight / (bn.running_var + bn.eps)**0.5
- w = linear.weight * w[None, :]
- if linear.bias is None:
- b = b @ self.linear.weight.T
- else:
- b = (linear.weight @ b[:, None]).view(-1) + self.linear.bias
- m = torch.nn.Linear(w.size(1), w.size(0))
- m.weight.data.copy_(w)
- m.bias.data.copy_(b)
- return m
- class PatchMerging(torch.nn.Module):
- def __init__(
- self,
- dim: int,
- out_dim: int,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- hid_dim = int(dim * 4)
- self.conv1 = ConvNorm(dim, hid_dim, 1, 1, 0, **dd)
- self.act = torch.nn.ReLU()
- self.conv2 = ConvNorm(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim, **dd)
- self.se = SqueezeExcite(hid_dim, .25, **dd)
- self.conv3 = ConvNorm(hid_dim, out_dim, 1, 1, 0, **dd)
- def forward(self, x):
- x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
- return x
- class ResidualDrop(torch.nn.Module):
- def __init__(self, m: nn.Module, drop: float = 0.):
- super().__init__()
- self.m = m
- self.drop = drop
- def forward(self, x):
- if self.training and self.drop > 0:
- return x + self.m(x) * torch.rand(
- x.size(0), 1, 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
- else:
- return x + self.m(x)
- class ConvMlp(torch.nn.Module):
- def __init__(
- self,
- ed: int,
- h: int,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.pw1 = ConvNorm(ed, h, **dd)
- self.act = torch.nn.ReLU()
- self.pw2 = ConvNorm(h, ed, bn_weight_init=0, **dd)
- def forward(self, x):
- x = self.pw2(self.act(self.pw1(x)))
- return x
- class CascadedGroupAttention(torch.nn.Module):
- attention_bias_cache: Dict[str, torch.Tensor]
- r""" Cascaded Group Attention.
- Args:
- dim (int): Number of input channels.
- key_dim (int): The dimension for query and key.
- num_heads (int): Number of attention heads.
- attn_ratio (int): Multiplier for the query dim for value dimension.
- resolution (int): Input resolution, correspond to the window size.
- kernels (List[int]): The kernel size of the dw conv on query.
- """
- def __init__(
- self,
- dim: int,
- key_dim: int,
- num_heads: int = 8,
- attn_ratio: int = 4,
- resolution: int = 14,
- kernels: Tuple[int, ...] = (5, 5, 5, 5),
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.num_heads = num_heads
- self.scale = key_dim ** -0.5
- self.key_dim = key_dim
- self.val_dim = int(attn_ratio * key_dim)
- self.attn_ratio = attn_ratio
- qkvs = []
- dws = []
- for i in range(num_heads):
- qkvs.append(ConvNorm(dim // num_heads, self.key_dim * 2 + self.val_dim, **dd))
- dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim, **dd))
- self.qkvs = torch.nn.ModuleList(qkvs)
- self.dws = torch.nn.ModuleList(dws)
- self.proj = torch.nn.Sequential(
- torch.nn.ReLU(),
- ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0, **dd)
- )
- self.resolution = resolution
- N = resolution * resolution
- # Number of unique offsets: abs differences range from 0 to resolution-1 for each dim
- num_offsets = resolution * resolution
- self.attention_biases = torch.nn.Parameter(torch.empty(num_heads, num_offsets, **dd))
- self.register_buffer(
- 'attention_bias_idxs',
- torch.empty((N, N), device=device, dtype=torch.long),
- persistent=False,
- )
- self.attention_bias_cache = {}
- # TODO: skip init when on meta device when safe to do so
- self.reset_parameters()
- def reset_parameters(self) -> None:
- """Initialize parameters and buffers."""
- torch.nn.init.zeros_(self.attention_biases)
- self._init_buffers()
- def _init_buffers(self) -> None:
- """Compute and fill non-persistent buffer values."""
- points = list(itertools.product(range(self.resolution), range(self.resolution)))
- attention_offsets = {}
- idxs = []
- for p1 in points:
- for p2 in points:
- offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
- if offset not in attention_offsets:
- attention_offsets[offset] = len(attention_offsets)
- idxs.append(attention_offsets[offset])
- self.attention_bias_idxs.copy_(torch.tensor(idxs, dtype=torch.long).view(len(points), len(points)))
- def init_non_persistent_buffers(self) -> None:
- """Initialize non-persistent buffers."""
- self._init_buffers()
- @torch.no_grad()
- def train(self, mode=True):
- super().train(mode)
- if mode and self.attention_bias_cache:
- self.attention_bias_cache = {} # clear ab cache
- def get_attention_biases(self, device: torch.device) -> torch.Tensor:
- if torch.jit.is_tracing() or self.training:
- return self.attention_biases[:, self.attention_bias_idxs]
- else:
- device_key = str(device)
- if device_key not in self.attention_bias_cache:
- self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
- return self.attention_bias_cache[device_key]
- def forward(self, x):
- B, C, H, W = x.shape
- feats_in = x.chunk(len(self.qkvs), dim=1)
- feats_out = []
- feat = feats_in[0]
- attn_bias = self.get_attention_biases(x.device)
- for head_idx, (qkv, dws) in enumerate(zip(self.qkvs, self.dws)):
- if head_idx > 0:
- feat = feat + feats_in[head_idx]
- feat = qkv(feat)
- q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.val_dim], dim=1)
- q = dws(q)
- q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
- q = q * self.scale
- attn = q.transpose(-2, -1) @ k
- attn = attn + attn_bias[head_idx]
- attn = attn.softmax(dim=-1)
- feat = v @ attn.transpose(-2, -1)
- feat = feat.view(B, self.val_dim, H, W)
- feats_out.append(feat)
- x = self.proj(torch.cat(feats_out, 1))
- return x
- class LocalWindowAttention(torch.nn.Module):
- r""" Local Window Attention.
- Args:
- dim (int): Number of input channels.
- key_dim (int): The dimension for query and key.
- num_heads (int): Number of attention heads.
- attn_ratio (int): Multiplier for the query dim for value dimension.
- resolution (int): Input resolution.
- window_resolution (int): Local window resolution.
- kernels (List[int]): The kernel size of the dw conv on query.
- """
- def __init__(
- self,
- dim: int,
- key_dim: int,
- num_heads: int = 8,
- attn_ratio: int = 4,
- resolution: int = 14,
- window_resolution: int = 7,
- kernels: Tuple[int, ...] = (5, 5, 5, 5),
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.dim = dim
- self.num_heads = num_heads
- self.resolution = resolution
- assert window_resolution > 0, 'window_size must be greater than 0'
- self.window_resolution = window_resolution
- window_resolution = min(window_resolution, resolution)
- self.attn = CascadedGroupAttention(
- dim, key_dim, num_heads,
- attn_ratio=attn_ratio,
- resolution=window_resolution,
- kernels=kernels,
- **dd,
- )
- def forward(self, x):
- H = W = self.resolution
- B, C, H_, W_ = x.shape
- # Only check this for classification models
- _assert(H == H_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}')
- _assert(W == W_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}')
- if H <= self.window_resolution and W <= self.window_resolution:
- x = self.attn(x)
- else:
- x = x.permute(0, 2, 3, 1)
- pad_b = (self.window_resolution - H % self.window_resolution) % self.window_resolution
- pad_r = (self.window_resolution - W % self.window_resolution) % self.window_resolution
- x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b))
- pH, pW = H + pad_b, W + pad_r
- nH = pH // self.window_resolution
- nW = pW // self.window_resolution
- # window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw
- x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3)
- x = x.reshape(B * nH * nW, self.window_resolution, self.window_resolution, C).permute(0, 3, 1, 2)
- x = self.attn(x)
- # window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC
- x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution, C)
- x = x.transpose(2, 3).reshape(B, pH, pW, C)
- x = x[:, :H, :W].contiguous()
- x = x.permute(0, 3, 1, 2)
- return x
- class EfficientVitBlock(torch.nn.Module):
- """ A basic EfficientVit building block.
- Args:
- dim (int): Number of input channels.
- key_dim (int): Dimension for query and key in the token mixer.
- num_heads (int): Number of attention heads.
- attn_ratio (int): Multiplier for the query dim for value dimension.
- resolution (int): Input resolution.
- window_resolution (int): Local window resolution.
- kernels (List[int]): The kernel size of the dw conv on query.
- """
- def __init__(
- self,
- dim: int,
- key_dim: int,
- num_heads: int = 8,
- attn_ratio: int = 4,
- resolution: int = 14,
- window_resolution: int = 7,
- kernels: List[int] = [5, 5, 5, 5],
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.dw0 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0., **dd))
- self.ffn0 = ResidualDrop(ConvMlp(dim, int(dim * 2), **dd))
- self.mixer = ResidualDrop(
- LocalWindowAttention(
- dim, key_dim, num_heads,
- attn_ratio=attn_ratio,
- resolution=resolution,
- window_resolution=window_resolution,
- kernels=kernels,
- **dd,
- ),
- )
- self.dw1 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0., **dd))
- self.ffn1 = ResidualDrop(ConvMlp(dim, int(dim * 2), **dd))
- def forward(self, x):
- return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
- class EfficientVitStage(torch.nn.Module):
- def __init__(
- self,
- in_dim: int,
- out_dim: int,
- key_dim: int,
- downsample: Tuple[str, int] = ('', 1),
- num_heads: int = 8,
- attn_ratio: int = 4,
- resolution: int = 14,
- window_resolution: int = 7,
- kernels: List[int] = [5, 5, 5, 5],
- depth: int = 1,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- if downsample[0] == 'subsample':
- self.resolution = (resolution - 1) // downsample[1] + 1
- down_blocks = []
- down_blocks.append((
- 'res1',
- torch.nn.Sequential(
- ResidualDrop(ConvNorm(in_dim, in_dim, 3, 1, 1, groups=in_dim, **dd)),
- ResidualDrop(ConvMlp(in_dim, int(in_dim * 2), **dd)),
- )
- ))
- down_blocks.append(('patchmerge', PatchMerging(in_dim, out_dim, **dd)))
- down_blocks.append((
- 'res2',
- torch.nn.Sequential(
- ResidualDrop(ConvNorm(out_dim, out_dim, 3, 1, 1, groups=out_dim, **dd)),
- ResidualDrop(ConvMlp(out_dim, int(out_dim * 2), **dd)),
- )
- ))
- self.downsample = nn.Sequential(OrderedDict(down_blocks))
- else:
- assert in_dim == out_dim
- self.downsample = nn.Identity()
- self.resolution = resolution
- blocks = []
- for d in range(depth):
- blocks.append(EfficientVitBlock(
- out_dim,
- key_dim,
- num_heads,
- attn_ratio,
- self.resolution,
- window_resolution,
- kernels,
- **dd,
- ))
- self.blocks = nn.Sequential(*blocks)
- def forward(self, x):
- x = self.downsample(x)
- x = self.blocks(x)
- return x
- class PatchEmbedding(torch.nn.Sequential):
- def __init__(
- self,
- in_chans: int,
- dim: int,
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.add_module('conv1', ConvNorm(in_chans, dim // 8, 3, 2, 1, **dd))
- self.add_module('relu1', torch.nn.ReLU())
- self.add_module('conv2', ConvNorm(dim // 8, dim // 4, 3, 2, 1, **dd))
- self.add_module('relu2', torch.nn.ReLU())
- self.add_module('conv3', ConvNorm(dim // 4, dim // 2, 3, 2, 1, **dd))
- self.add_module('relu3', torch.nn.ReLU())
- self.add_module('conv4', ConvNorm(dim // 2, dim, 3, 2, 1, **dd))
- self.patch_size = 16
- class EfficientVitMsra(nn.Module):
- def __init__(
- self,
- img_size: int = 224,
- in_chans: int = 3,
- num_classes: int = 1000,
- embed_dim: Tuple[int, ...] = (64, 128, 192),
- key_dim: Tuple[int, ...] = (16, 16, 16),
- depth: Tuple[int, ...] = (1, 2, 3),
- num_heads: Tuple[int, ...] = (4, 4, 4),
- window_size: Tuple[int, ...] = (7, 7, 7),
- kernels: Tuple[int, ...] = (5, 5, 5, 5),
- down_ops: Tuple[Tuple[str, int], ...] = (('', 1), ('subsample', 2), ('subsample', 2)),
- global_pool: str = 'avg',
- drop_rate: float = 0.,
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.grad_checkpointing = False
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.drop_rate = drop_rate
- # Patch embedding
- self.patch_embed = PatchEmbedding(in_chans, embed_dim[0], **dd)
- stride = self.patch_embed.patch_size
- resolution = img_size // self.patch_embed.patch_size
- attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
- # Build EfficientVit blocks
- self.feature_info = []
- stages = []
- pre_ed = embed_dim[0]
- for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate(
- zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
- stage = EfficientVitStage(
- in_dim=pre_ed,
- out_dim=ed,
- key_dim=kd,
- downsample=do,
- num_heads=nh,
- attn_ratio=ar,
- resolution=resolution,
- window_resolution=wd,
- kernels=kernels,
- depth=dpth,
- **dd,
- )
- pre_ed = ed
- if do[0] == 'subsample' and i != 0:
- stride *= do[1]
- resolution = stage.resolution
- stages.append(stage)
- self.feature_info += [dict(num_chs=ed, reduction=stride, module=f'stages.{i}')]
- self.stages = nn.Sequential(*stages)
- if global_pool == 'avg':
- self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
- else:
- assert num_classes == 0
- self.global_pool = nn.Identity()
- self.num_features = self.head_hidden_size = embed_dim[-1]
- self.head = NormLinear(
- self.num_features, num_classes, drop=self.drop_rate, **dd) if num_classes > 0 else torch.nn.Identity()
- # TODO: skip init when on meta device when safe to do so
- self.init_weights(needs_reset=False)
- def init_weights(self, needs_reset: bool = True):
- self.apply(partial(self._init_weights, needs_reset=needs_reset))
- def _init_weights(self, m: nn.Module, needs_reset: bool = True) -> None:
- if needs_reset and hasattr(m, 'reset_parameters'):
- m.reset_parameters()
- @torch.jit.ignore
- def no_weight_decay(self):
- return {x for x in self.state_dict().keys() if 'attention_biases' in x}
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- matcher = dict(
- stem=r'^patch_embed',
- blocks=r'^stages\.(\d+)' if coarse else [
- (r'^stages\.(\d+).downsample', (0,)),
- (r'^stages\.(\d+)\.\w+\.(\d+)', None),
- ]
- )
- return matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- self.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head.linear
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- if global_pool is not None:
- if global_pool == 'avg':
- self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
- else:
- assert num_classes == 0
- self.global_pool = nn.Identity()
- self.head = NormLinear(
- self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity()
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- norm: Apply norm layer to compatible intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- Returns:
- """
- assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
- intermediates = []
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- # forward pass
- x = self.patch_embed(x)
- if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
- stages = self.stages
- else:
- stages = self.stages[:max_index + 1]
- for feat_idx, stage in enumerate(stages):
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint(stage, x)
- else:
- x = stage(x)
- if feat_idx in take_indices:
- intermediates.append(x)
- if intermediates_only:
- return intermediates
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ):
- """ Prune layers not required for specified intermediates.
- """
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
- if prune_head:
- self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x):
- x = self.patch_embed(x)
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.stages, x)
- else:
- x = self.stages(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- x = self.global_pool(x)
- return x if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- # def checkpoint_filter_fn(state_dict, model):
- # if 'model' in state_dict.keys():
- # state_dict = state_dict['model']
- # tmp_dict = {}
- # out_dict = {}
- # target_keys = model.state_dict().keys()
- # target_keys = [k for k in target_keys if k.startswith('stages.')]
- #
- # for k, v in state_dict.items():
- # if 'attention_bias_idxs' in k:
- # continue
- # k = k.split('.')
- # if k[-2] == 'c':
- # k[-2] = 'conv'
- # if k[-2] == 'l':
- # k[-2] = 'linear'
- # k = '.'.join(k)
- # tmp_dict[k] = v
- #
- # for k, v in tmp_dict.items():
- # if k.startswith('patch_embed'):
- # k = k.split('.')
- # k[1] = 'conv' + str(int(k[1]) // 2 + 1)
- # k = '.'.join(k)
- # elif k.startswith('blocks'):
- # kw = '.'.join(k.split('.')[2:])
- # find_kw = [a for a in list(sorted(tmp_dict.keys())) if kw in a]
- # idx = find_kw.index(k)
- # k = [a for a in target_keys if kw in a][idx]
- # out_dict[k] = v
- #
- # return out_dict
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000,
- 'mean': IMAGENET_DEFAULT_MEAN,
- 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'patch_embed.conv1.conv',
- 'classifier': 'head.linear',
- 'fixed_input_size': True,
- 'pool_size': (4, 4),
- 'license': 'mit',
- **kwargs,
- }
- default_cfgs = generate_default_cfgs({
- 'efficientvit_m0.r224_in1k': _cfg(
- hf_hub_id='timm/',
- #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth'
- ),
- 'efficientvit_m1.r224_in1k': _cfg(
- hf_hub_id='timm/',
- #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth'
- ),
- 'efficientvit_m2.r224_in1k': _cfg(
- hf_hub_id='timm/',
- #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth'
- ),
- 'efficientvit_m3.r224_in1k': _cfg(
- hf_hub_id='timm/',
- #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth'
- ),
- 'efficientvit_m4.r224_in1k': _cfg(
- hf_hub_id='timm/',
- #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth'
- ),
- 'efficientvit_m5.r224_in1k': _cfg(
- hf_hub_id='timm/',
- #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth'
- ),
- })
- def _create_efficientvit_msra(variant, pretrained=False, **kwargs):
- out_indices = kwargs.pop('out_indices', (0, 1, 2))
- model = build_model_with_cfg(
- EfficientVitMsra,
- variant,
- pretrained,
- feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
- **kwargs
- )
- return model
- @register_model
- def efficientvit_m0(pretrained=False, **kwargs):
- model_args = dict(
- img_size=224,
- embed_dim=[64, 128, 192],
- depth=[1, 2, 3],
- num_heads=[4, 4, 4],
- window_size=[7, 7, 7],
- kernels=[5, 5, 5, 5]
- )
- return _create_efficientvit_msra('efficientvit_m0', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def efficientvit_m1(pretrained=False, **kwargs):
- model_args = dict(
- img_size=224,
- embed_dim=[128, 144, 192],
- depth=[1, 2, 3],
- num_heads=[2, 3, 3],
- window_size=[7, 7, 7],
- kernels=[7, 5, 3, 3]
- )
- return _create_efficientvit_msra('efficientvit_m1', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def efficientvit_m2(pretrained=False, **kwargs):
- model_args = dict(
- img_size=224,
- embed_dim=[128, 192, 224],
- depth=[1, 2, 3],
- num_heads=[4, 3, 2],
- window_size=[7, 7, 7],
- kernels=[7, 5, 3, 3]
- )
- return _create_efficientvit_msra('efficientvit_m2', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def efficientvit_m3(pretrained=False, **kwargs):
- model_args = dict(
- img_size=224,
- embed_dim=[128, 240, 320],
- depth=[1, 2, 3],
- num_heads=[4, 3, 4],
- window_size=[7, 7, 7],
- kernels=[5, 5, 5, 5]
- )
- return _create_efficientvit_msra('efficientvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def efficientvit_m4(pretrained=False, **kwargs):
- model_args = dict(
- img_size=224,
- embed_dim=[128, 256, 384],
- depth=[1, 2, 3],
- num_heads=[4, 4, 4],
- window_size=[7, 7, 7],
- kernels=[7, 5, 3, 3]
- )
- return _create_efficientvit_msra('efficientvit_m4', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def efficientvit_m5(pretrained=False, **kwargs):
- model_args = dict(
- img_size=224,
- embed_dim=[192, 288, 384],
- depth=[1, 3, 4],
- num_heads=[3, 3, 4],
- window_size=[7, 7, 7],
- kernels=[7, 5, 3, 3]
- )
- return _create_efficientvit_msra('efficientvit_m5', pretrained=pretrained, **dict(model_args, **kwargs))
|