| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807 |
- # FastViT for PyTorch
- #
- # Original implementation and weights from https://github.com/apple/ml-fastvit
- #
- # For licensing see accompanying LICENSE file at https://github.com/apple/ml-fastvit/tree/main
- # Original work is copyright (C) 2023 Apple Inc. All Rights Reserved.
- #
- import os
- from functools import partial
- from typing import List, Optional, Tuple, Type, Union
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
- from timm.layers import (
- DropPath,
- calculate_drop_path_rates,
- trunc_normal_,
- create_conv2d,
- ConvNormAct,
- SqueezeExcite,
- use_fused_attn,
- ClassifierHead,
- LayerNorm2d,
- )
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._manipulate import checkpoint_seq
- from ._registry import register_model, generate_default_cfgs
- __all__ = ['FastVit']
- def num_groups(group_size, channels):
- if not group_size: # 0 or None
- return 1 # normal conv with 1 group
- else:
- # NOTE group_size == 1 -> depthwise conv
- assert channels % group_size == 0
- return channels // group_size
- class MobileOneBlock(nn.Module):
- """MobileOne building block.
- This block has a multi-branched architecture at train-time
- and plain-CNN style architecture at inference time
- For more details, please refer to our paper:
- `An Improved One millisecond Mobile Backbone` -
- https://arxiv.org/pdf/2206.04040.pdf
- """
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- kernel_size: int,
- stride: int = 1,
- dilation: int = 1,
- group_size: int = 0,
- inference_mode: bool = False,
- use_se: bool = False,
- use_act: bool = True,
- use_scale_branch: bool = True,
- num_conv_branches: int = 1,
- act_layer: Type[nn.Module] = nn.GELU,
- device=None,
- dtype=None,
- ) -> None:
- """Construct a MobileOneBlock module.
- Args:
- in_chs: Number of channels in the input.
- out_chs: Number of channels produced by the block.
- kernel_size: Size of the convolution kernel.
- stride: Stride size.
- dilation: Kernel dilation factor.
- group_size: Convolution group size.
- inference_mode: If True, instantiates model in inference mode.
- use_se: Whether to use SE-ReLU activations.
- use_act: Whether to use activation. Default: ``True``
- use_scale_branch: Whether to use scale branch. Default: ``True``
- num_conv_branches: Number of linear conv branches.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.inference_mode = inference_mode
- self.groups = num_groups(group_size, in_chs)
- self.stride = stride
- self.dilation = dilation
- self.kernel_size = kernel_size
- self.in_chs = in_chs
- self.out_chs = out_chs
- self.num_conv_branches = num_conv_branches
- # Check if SE-ReLU is requested
- self.se = SqueezeExcite(out_chs, rd_divisor=1, **dd) if use_se else nn.Identity()
- if inference_mode:
- self.reparam_conv = create_conv2d(
- in_chs,
- out_chs,
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation,
- groups=self.groups,
- bias=True,
- **dd,
- )
- else:
- # Re-parameterizable skip connection
- self.reparam_conv = None
- self.identity = (
- nn.BatchNorm2d(num_features=in_chs, **dd)
- if out_chs == in_chs and stride == 1
- else None
- )
- # Re-parameterizable conv branches
- if num_conv_branches > 0:
- self.conv_kxk = nn.ModuleList([
- ConvNormAct(
- self.in_chs,
- self.out_chs,
- kernel_size=kernel_size,
- stride=self.stride,
- groups=self.groups,
- apply_act=False,
- **dd,
- ) for _ in range(self.num_conv_branches)
- ])
- else:
- self.conv_kxk = None
- # Re-parameterizable scale branch
- self.conv_scale = None
- if kernel_size > 1 and use_scale_branch:
- self.conv_scale = ConvNormAct(
- self.in_chs,
- self.out_chs,
- kernel_size=1,
- stride=self.stride,
- groups=self.groups,
- apply_act=False,
- **dd,
- )
- self.act = act_layer() if use_act else nn.Identity()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Apply forward pass."""
- # Inference mode forward pass.
- if self.reparam_conv is not None:
- return self.act(self.se(self.reparam_conv(x)))
- # Multi-branched train-time forward pass.
- # Identity branch output
- identity_out = 0
- if self.identity is not None:
- identity_out = self.identity(x)
- # Scale branch output
- scale_out = 0
- if self.conv_scale is not None:
- scale_out = self.conv_scale(x)
- # Other kxk conv branches
- out = scale_out + identity_out
- if self.conv_kxk is not None:
- for rc in self.conv_kxk:
- out += rc(x)
- return self.act(self.se(out))
- def reparameterize(self):
- """Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
- https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
- architecture used at training time to obtain a plain CNN-like structure
- for inference.
- """
- if self.reparam_conv is not None:
- return
- kernel, bias = self._get_kernel_bias()
- self.reparam_conv = create_conv2d(
- in_channels=self.in_chs,
- out_channels=self.out_chs,
- kernel_size=self.kernel_size,
- stride=self.stride,
- dilation=self.dilation,
- groups=self.groups,
- bias=True,
- )
- self.reparam_conv.weight.data = kernel
- self.reparam_conv.bias.data = bias
- # Delete un-used branches
- for name, para in self.named_parameters():
- if 'reparam_conv' in name:
- continue
- para.detach_()
- self.__delattr__("conv_kxk")
- self.__delattr__("conv_scale")
- if hasattr(self, "identity"):
- self.__delattr__("identity")
- self.inference_mode = True
- def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
- """Method to obtain re-parameterized kernel and bias.
- Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
- Returns:
- Tuple of (kernel, bias) after fusing branches.
- """
- # get weights and bias of scale branch
- kernel_scale = 0
- bias_scale = 0
- if self.conv_scale is not None:
- kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale)
- # Pad scale branch kernel to match conv branch kernel size.
- pad = self.kernel_size // 2
- kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
- # get weights and bias of skip branch
- kernel_identity = 0
- bias_identity = 0
- if self.identity is not None:
- kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
- # get weights and bias of conv branches
- kernel_conv = 0
- bias_conv = 0
- if self.conv_kxk is not None:
- for ix in range(self.num_conv_branches):
- _kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix])
- kernel_conv += _kernel
- bias_conv += _bias
- kernel_final = kernel_conv + kernel_scale + kernel_identity
- bias_final = bias_conv + bias_scale + bias_identity
- return kernel_final, bias_final
- def _fuse_bn_tensor(
- self,
- branch: Union[nn.Sequential, nn.BatchNorm2d]
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Method to fuse batchnorm layer with preceding conv layer.
- Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
- Args:
- branch: Sequence of ops to be fused.
- Returns:
- Tuple of (kernel, bias) after fusing batchnorm.
- """
- if isinstance(branch, ConvNormAct):
- kernel = branch.conv.weight
- running_mean = branch.bn.running_mean
- running_var = branch.bn.running_var
- gamma = branch.bn.weight
- beta = branch.bn.bias
- eps = branch.bn.eps
- else:
- assert isinstance(branch, nn.BatchNorm2d)
- if not hasattr(self, "id_tensor"):
- input_dim = self.in_chs // self.groups
- kernel_value = torch.zeros(
- (self.in_chs, input_dim, self.kernel_size, self.kernel_size),
- dtype=branch.weight.dtype,
- device=branch.weight.device,
- )
- for i in range(self.in_chs):
- kernel_value[
- i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2
- ] = 1
- self.id_tensor = kernel_value
- kernel = self.id_tensor
- running_mean = branch.running_mean
- running_var = branch.running_var
- gamma = branch.weight
- beta = branch.bias
- eps = branch.eps
- std = (running_var + eps).sqrt()
- t = (gamma / std).reshape(-1, 1, 1, 1)
- return kernel * t, beta - running_mean * gamma / std
- class ReparamLargeKernelConv(nn.Module):
- """Building Block of RepLKNet
- This class defines overparameterized large kernel conv block
- introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_
- Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
- """
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- kernel_size: int,
- stride: int,
- group_size: int,
- small_kernel: Optional[int] = None,
- use_se: bool = False,
- act_layer: Optional[nn.Module] = None,
- inference_mode: bool = False,
- device=None,
- dtype=None,
- ) -> None:
- """Construct a ReparamLargeKernelConv module.
- Args:
- in_chs: Number of input channels.
- out_chs: Number of output channels.
- kernel_size: Kernel size of the large kernel conv branch.
- stride: Stride size. Default: 1
- group_size: Group size. Default: 1
- small_kernel: Kernel size of small kernel conv branch.
- act_layer: Activation module. Default: ``nn.GELU``
- inference_mode: If True, instantiates model in inference mode. Default: ``False``
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.stride = stride
- self.groups = num_groups(group_size, in_chs)
- self.in_chs = in_chs
- self.out_chs = out_chs
- self.kernel_size = kernel_size
- self.small_kernel = small_kernel
- if inference_mode:
- self.reparam_conv = create_conv2d(
- in_chs,
- out_chs,
- kernel_size=kernel_size,
- stride=stride,
- dilation=1,
- groups=self.groups,
- bias=True,
- **dd,
- )
- else:
- self.reparam_conv = None
- self.large_conv = ConvNormAct(
- in_chs,
- out_chs,
- kernel_size=kernel_size,
- stride=self.stride,
- groups=self.groups,
- apply_act=False,
- **dd,
- )
- if small_kernel is not None:
- assert (
- small_kernel <= kernel_size
- ), "The kernel size for re-param cannot be larger than the large kernel!"
- self.small_conv = ConvNormAct(
- in_chs,
- out_chs,
- kernel_size=small_kernel,
- stride=self.stride,
- groups=self.groups,
- apply_act=False,
- **dd,
- )
- self.se = SqueezeExcite(out_chs, rd_ratio=0.25, **dd) if use_se else nn.Identity()
- # FIXME output of this act was not used in original impl, likely due to bug
- self.act = act_layer() if act_layer is not None else nn.Identity()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self.reparam_conv is not None:
- out = self.reparam_conv(x)
- else:
- out = self.large_conv(x)
- if self.small_conv is not None:
- out = out + self.small_conv(x)
- out = self.se(out)
- out = self.act(out)
- return out
- def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
- """Method to obtain re-parameterized kernel and bias.
- Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
- Returns:
- Tuple of (kernel, bias) after fusing branches.
- """
- eq_k, eq_b = self._fuse_bn(self.large_conv.conv, self.large_conv.bn)
- if hasattr(self, "small_conv"):
- small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
- eq_b += small_b
- eq_k += nn.functional.pad(
- small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
- )
- return eq_k, eq_b
- def reparameterize(self) -> None:
- """
- Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
- https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
- architecture used at training time to obtain a plain CNN-like structure
- for inference.
- """
- eq_k, eq_b = self.get_kernel_bias()
- self.reparam_conv = create_conv2d(
- self.in_chs,
- self.out_chs,
- kernel_size=self.kernel_size,
- stride=self.stride,
- groups=self.groups,
- bias=True,
- )
- self.reparam_conv.weight.data = eq_k
- self.reparam_conv.bias.data = eq_b
- self.__delattr__("large_conv")
- if hasattr(self, "small_conv"):
- self.__delattr__("small_conv")
- @staticmethod
- def _fuse_bn(
- conv: nn.Conv2d,
- bn: nn.BatchNorm2d
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Method to fuse batchnorm layer with conv layer.
- Args:
- conv: Convolutional kernel weights.
- bn: Batchnorm 2d layer.
- Returns:
- Tuple of (kernel, bias) after fusing batchnorm.
- """
- kernel = conv.weight
- running_mean = bn.running_mean
- running_var = bn.running_var
- gamma = bn.weight
- beta = bn.bias
- eps = bn.eps
- std = (running_var + eps).sqrt()
- t = (gamma / std).reshape(-1, 1, 1, 1)
- return kernel * t, beta - running_mean * gamma / std
- def convolutional_stem(
- in_chs: int,
- out_chs: int,
- act_layer: Type[nn.Module] = nn.GELU,
- inference_mode: bool = False,
- use_scale_branch: bool = True,
- device=None,
- dtype=None,
- ) -> nn.Sequential:
- """Build convolutional stem with MobileOne blocks.
- Args:
- in_chs: Number of input channels.
- out_chs: Number of output channels.
- inference_mode: Flag to instantiate model in inference mode. Default: ``False``
- Returns:
- nn.Sequential object with stem elements.
- """
- dd = {'device': device, 'dtype': dtype}
- return nn.Sequential(
- MobileOneBlock(
- in_chs=in_chs,
- out_chs=out_chs,
- kernel_size=3,
- stride=2,
- act_layer=act_layer,
- inference_mode=inference_mode,
- use_scale_branch=use_scale_branch,
- **dd,
- ),
- MobileOneBlock(
- in_chs=out_chs,
- out_chs=out_chs,
- kernel_size=3,
- stride=2,
- group_size=1,
- act_layer=act_layer,
- inference_mode=inference_mode,
- use_scale_branch=use_scale_branch,
- **dd,
- ),
- MobileOneBlock(
- in_chs=out_chs,
- out_chs=out_chs,
- kernel_size=1,
- stride=1,
- act_layer=act_layer,
- inference_mode=inference_mode,
- use_scale_branch=use_scale_branch,
- **dd,
- ),
- )
- class Attention(nn.Module):
- """Multi-headed Self Attention module.
- Source modified from:
- https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
- """
- fused_attn: torch.jit.Final[bool]
- def __init__(
- self,
- dim: int,
- head_dim: int = 32,
- qkv_bias: bool = False,
- attn_drop: float = 0.0,
- proj_drop: float = 0.0,
- device=None,
- dtype=None,
- ) -> None:
- """Build MHSA module that can handle 3D or 4D input tensors.
- Args:
- dim: Number of embedding dimensions.
- head_dim: Number of hidden dimensions per head. Default: ``32``
- qkv_bias: Use bias or not. Default: ``False``
- attn_drop: Dropout rate for attention tensor.
- proj_drop: Dropout rate for projection tensor.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- assert dim % head_dim == 0, "dim should be divisible by head_dim"
- self.head_dim = head_dim
- self.num_heads = dim // head_dim
- self.scale = head_dim ** -0.5
- self.fused_attn = use_fused_attn()
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim, **dd)
- self.proj_drop = nn.Dropout(proj_drop)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- B, C, H, W = x.shape
- N = H * W
- x = x.flatten(2).transpose(-2, -1) # (B, N, C)
- qkv = (
- self.qkv(x)
- .reshape(B, N, 3, self.num_heads, self.head_dim)
- .permute(2, 0, 3, 1, 4)
- )
- q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
- if self.fused_attn:
- x = torch.nn.functional.scaled_dot_product_attention(
- q, k, v,
- dropout_p=self.attn_drop.p if self.training else 0.,
- )
- else:
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = attn @ v
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- x = x.transpose(-2, -1).reshape(B, C, H, W)
- return x
- class PatchEmbed(nn.Module):
- """Convolutional patch embedding layer."""
- def __init__(
- self,
- patch_size: int,
- stride: int,
- in_chs: int,
- embed_dim: int,
- act_layer: Type[nn.Module] = nn.GELU,
- lkc_use_act: bool = False,
- use_se: bool = False,
- inference_mode: bool = False,
- device=None,
- dtype=None,
- ) -> None:
- """Build patch embedding layer.
- Args:
- patch_size: Patch size for embedding computation.
- stride: Stride for convolutional embedding layer.
- in_chs: Number of channels of input tensor.
- embed_dim: Number of embedding dimensions.
- inference_mode: Flag to instantiate model in inference mode. Default: ``False``
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.proj = nn.Sequential(
- ReparamLargeKernelConv(
- in_chs=in_chs,
- out_chs=embed_dim,
- kernel_size=patch_size,
- stride=stride,
- group_size=1,
- small_kernel=3,
- use_se=use_se,
- act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act
- inference_mode=inference_mode,
- **dd,
- ),
- MobileOneBlock(
- in_chs=embed_dim,
- out_chs=embed_dim,
- kernel_size=1,
- stride=1,
- use_se=False,
- act_layer=act_layer,
- inference_mode=inference_mode,
- **dd,
- )
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.proj(x)
- return x
- class LayerScale2d(nn.Module):
- def __init__(
- self,
- dim: int,
- init_values: float = 1e-5,
- inplace: bool = False,
- device=None,
- dtype=None,
- ):
- super().__init__()
- self.inplace = inplace
- self.gamma = nn.Parameter(init_values * torch.ones(dim, 1, 1, device=device, dtype=dtype))
- def forward(self, x):
- return x.mul_(self.gamma) if self.inplace else x * self.gamma
- class RepMixer(nn.Module):
- """Reparameterizable token mixer.
- For more details, please refer to our paper:
- `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
- """
- def __init__(
- self,
- dim: int,
- kernel_size: int = 3,
- layer_scale_init_value: Optional[float] = 1e-5,
- inference_mode: bool = False,
- device=None,
- dtype=None,
- ):
- """Build RepMixer Module.
- Args:
- dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
- kernel_size: Kernel size for spatial mixing. Default: 3
- layer_scale_init_value: Initial value for layer scale. Default: 1e-5
- inference_mode: If True, instantiates model in inference mode. Default: ``False``
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.dim = dim
- self.kernel_size = kernel_size
- self.inference_mode = inference_mode
- if inference_mode:
- self.reparam_conv = nn.Conv2d(
- self.dim,
- self.dim,
- kernel_size=self.kernel_size,
- stride=1,
- padding=self.kernel_size // 2,
- groups=self.dim,
- bias=True,
- **dd,
- )
- else:
- self.reparam_conv = None
- self.norm = MobileOneBlock(
- dim,
- dim,
- kernel_size,
- group_size=1,
- use_act=False,
- use_scale_branch=False,
- num_conv_branches=0,
- **dd,
- )
- self.mixer = MobileOneBlock(
- dim,
- dim,
- kernel_size,
- group_size=1,
- use_act=False,
- **dd,
- )
- if layer_scale_init_value is not None:
- self.layer_scale = LayerScale2d(dim, layer_scale_init_value, **dd)
- else:
- self.layer_scale = nn.Identity()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self.reparam_conv is not None:
- x = self.reparam_conv(x)
- else:
- x = x + self.layer_scale(self.mixer(x) - self.norm(x))
- return x
- def reparameterize(self) -> None:
- """Reparameterize mixer and norm into a single
- convolutional layer for efficient inference.
- """
- if self.inference_mode:
- return
- self.mixer.reparameterize()
- self.norm.reparameterize()
- if isinstance(self.layer_scale, LayerScale2d):
- w = self.mixer.id_tensor + self.layer_scale.gamma.unsqueeze(-1) * (
- self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
- )
- b = torch.squeeze(self.layer_scale.gamma) * (
- self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
- )
- else:
- w = (
- self.mixer.id_tensor
- + self.mixer.reparam_conv.weight
- - self.norm.reparam_conv.weight
- )
- b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
- self.reparam_conv = create_conv2d(
- self.dim,
- self.dim,
- kernel_size=self.kernel_size,
- stride=1,
- groups=self.dim,
- bias=True,
- )
- self.reparam_conv.weight.data = w
- self.reparam_conv.bias.data = b
- for name, para in self.named_parameters():
- if 'reparam_conv' in name:
- continue
- para.detach_()
- self.__delattr__("mixer")
- self.__delattr__("norm")
- self.__delattr__("layer_scale")
- class ConvMlp(nn.Module):
- """Convolutional FFN Module."""
- def __init__(
- self,
- in_chs: int,
- hidden_channels: Optional[int] = None,
- out_chs: Optional[int] = None,
- act_layer: Type[nn.Module] = nn.GELU,
- drop: float = 0.0,
- device=None,
- dtype=None,
- ) -> None:
- """Build convolutional FFN module.
- Args:
- in_chs: Number of input channels.
- hidden_channels: Number of channels after expansion. Default: None
- out_chs: Number of output channels. Default: None
- act_layer: Activation layer. Default: ``GELU``
- drop: Dropout rate. Default: ``0.0``.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- out_chs = out_chs or in_chs
- hidden_channels = hidden_channels or in_chs
- self.conv = ConvNormAct(
- in_chs,
- out_chs,
- kernel_size=7,
- groups=in_chs,
- apply_act=False,
- **dd,
- )
- self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1, **dd)
- self.act = act_layer()
- self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1, **dd)
- self.drop = nn.Dropout(drop)
- self.apply(self._init_weights)
- def _init_weights(self, m: nn.Module) -> None:
- if isinstance(m, nn.Conv2d):
- trunc_normal_(m.weight, std=0.02)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.conv(x)
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
- class RepConditionalPosEnc(nn.Module):
- """Implementation of conditional positional encoding.
- For more details refer to paper:
- `Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
- In our implementation, we can reparameterize this module to eliminate a skip connection.
- """
- def __init__(
- self,
- dim: int,
- dim_out: Optional[int] = None,
- spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
- inference_mode: bool = False,
- device=None,
- dtype=None,
- ) -> None:
- """Build reparameterizable conditional positional encoding
- Args:
- dim: Number of input channels.
- dim_out: Number of embedding dimensions. Default: 768
- spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
- inference_mode: Flag to instantiate block in inference mode. Default: ``False``
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- if isinstance(spatial_shape, int):
- spatial_shape = tuple([spatial_shape] * 2)
- assert isinstance(spatial_shape, Tuple), (
- f'"spatial_shape" must by a sequence or int, '
- f"get {type(spatial_shape)} instead."
- )
- assert len(spatial_shape) == 2, (
- f'Length of "spatial_shape" should be 2, '
- f"got {len(spatial_shape)} instead."
- )
- self.spatial_shape = spatial_shape
- self.dim = dim
- self.dim_out = dim_out or dim
- self.groups = dim
- if inference_mode:
- self.reparam_conv = nn.Conv2d(
- self.dim,
- self.dim_out,
- kernel_size=self.spatial_shape,
- stride=1,
- padding=spatial_shape[0] // 2,
- groups=self.groups,
- bias=True,
- **dd,
- )
- else:
- self.reparam_conv = None
- self.pos_enc = nn.Conv2d(
- self.dim,
- self.dim_out,
- spatial_shape,
- 1,
- int(spatial_shape[0] // 2),
- groups=self.groups,
- bias=True,
- **dd,
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self.reparam_conv is not None:
- x = self.reparam_conv(x)
- else:
- x = self.pos_enc(x) + x
- return x
- def reparameterize(self) -> None:
- # Build equivalent Id tensor
- input_dim = self.dim // self.groups
- kernel_value = torch.zeros(
- (
- self.dim,
- input_dim,
- self.spatial_shape[0],
- self.spatial_shape[1],
- ),
- dtype=self.pos_enc.weight.dtype,
- device=self.pos_enc.weight.device,
- )
- for i in range(self.dim):
- kernel_value[
- i,
- i % input_dim,
- self.spatial_shape[0] // 2,
- self.spatial_shape[1] // 2,
- ] = 1
- id_tensor = kernel_value
- # Reparameterize Id tensor and conv
- w_final = id_tensor + self.pos_enc.weight
- b_final = self.pos_enc.bias
- # Introduce reparam conv
- self.reparam_conv = nn.Conv2d(
- self.dim,
- self.dim_out,
- kernel_size=self.spatial_shape,
- stride=1,
- padding=int(self.spatial_shape[0] // 2),
- groups=self.groups,
- bias=True,
- )
- self.reparam_conv.weight.data = w_final
- self.reparam_conv.bias.data = b_final
- for name, para in self.named_parameters():
- if 'reparam_conv' in name:
- continue
- para.detach_()
- self.__delattr__("pos_enc")
- class RepMixerBlock(nn.Module):
- """Implementation of Metaformer block with RepMixer as token mixer.
- For more details on Metaformer structure, please refer to:
- `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
- """
- def __init__(
- self,
- dim: int,
- kernel_size: int = 3,
- mlp_ratio: float = 4.0,
- act_layer: Type[nn.Module] = nn.GELU,
- proj_drop: float = 0.0,
- drop_path: float = 0.0,
- layer_scale_init_value: float = 1e-5,
- inference_mode: bool = False,
- device=None,
- dtype=None,
- ):
- """Build RepMixer Block.
- Args:
- dim: Number of embedding dimensions.
- kernel_size: Kernel size for repmixer. Default: 3
- mlp_ratio: MLP expansion ratio. Default: 4.0
- act_layer: Activation layer. Default: ``nn.GELU``
- proj_drop: Dropout rate. Default: 0.0
- drop_path: Drop path rate. Default: 0.0
- layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
- inference_mode: Flag to instantiate block in inference mode. Default: ``False``
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.token_mixer = RepMixer(
- dim,
- kernel_size=kernel_size,
- layer_scale_init_value=layer_scale_init_value,
- inference_mode=inference_mode,
- **dd,
- )
- self.mlp = ConvMlp(
- in_chs=dim,
- hidden_channels=int(dim * mlp_ratio),
- act_layer=act_layer,
- drop=proj_drop,
- **dd,
- )
- if layer_scale_init_value is not None:
- self.layer_scale = LayerScale2d(dim, layer_scale_init_value, **dd)
- else:
- self.layer_scale = nn.Identity()
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- def forward(self, x):
- x = self.token_mixer(x)
- x = x + self.drop_path(self.layer_scale(self.mlp(x)))
- return x
- class AttentionBlock(nn.Module):
- """Implementation of metaformer block with MHSA as token mixer.
- For more details on Metaformer structure, please refer to:
- `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
- """
- def __init__(
- self,
- dim: int,
- mlp_ratio: float = 4.0,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- proj_drop: float = 0.0,
- drop_path: float = 0.0,
- layer_scale_init_value: float = 1e-5,
- device=None,
- dtype=None,
- ):
- """Build Attention Block.
- Args:
- dim: Number of embedding dimensions.
- mlp_ratio: MLP expansion ratio. Default: 4.0
- act_layer: Activation layer. Default: ``nn.GELU``
- norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
- proj_drop: Dropout rate. Default: 0.0
- drop_path: Drop path rate. Default: 0.0
- layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.norm = norm_layer(dim, **dd)
- self.token_mixer = Attention(dim=dim, **dd)
- if layer_scale_init_value is not None:
- self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value, **dd)
- else:
- self.layer_scale_1 = nn.Identity()
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- self.mlp = ConvMlp(
- in_chs=dim,
- hidden_channels=int(dim * mlp_ratio),
- act_layer=act_layer,
- drop=proj_drop,
- **dd,
- )
- if layer_scale_init_value is not None:
- self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value, **dd)
- else:
- self.layer_scale_2 = nn.Identity()
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- def forward(self, x):
- x = x + self.drop_path1(self.layer_scale_1(self.token_mixer(self.norm(x))))
- x = x + self.drop_path2(self.layer_scale_2(self.mlp(x)))
- return x
- class FastVitStage(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: int,
- depth: int,
- token_mixer_type: str,
- downsample: bool = True,
- se_downsample: bool = False,
- down_patch_size: int = 7,
- down_stride: int = 2,
- pos_emb_layer: Optional[nn.Module] = None,
- kernel_size: int = 3,
- mlp_ratio: float = 4.0,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- proj_drop_rate: float = 0.0,
- drop_path_rate: Union[List[float], float] = 0.0,
- layer_scale_init_value: Optional[float] = 1e-5,
- lkc_use_act: bool = False,
- inference_mode: bool = False,
- device=None,
- dtype=None,
- ):
- """FastViT stage.
- Args:
- dim: Number of embedding dimensions.
- depth: Number of blocks in stage
- token_mixer_type: Token mixer type.
- kernel_size: Kernel size for repmixer.
- mlp_ratio: MLP expansion ratio.
- act_layer: Activation layer.
- norm_layer: Normalization layer.
- proj_drop_rate: Dropout rate.
- drop_path_rate: Drop path rate.
- layer_scale_init_value: Layer scale value at initialization.
- inference_mode: Flag to instantiate block in inference mode.
- """
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.grad_checkpointing = False
- if downsample:
- self.downsample = PatchEmbed(
- patch_size=down_patch_size,
- stride=down_stride,
- in_chs=dim,
- embed_dim=dim_out,
- use_se=se_downsample,
- act_layer=act_layer,
- lkc_use_act=lkc_use_act,
- inference_mode=inference_mode,
- **dd,
- )
- else:
- assert dim == dim_out
- self.downsample = nn.Identity()
- if pos_emb_layer is not None:
- self.pos_emb = pos_emb_layer(dim_out, inference_mode=inference_mode, **dd)
- else:
- self.pos_emb = nn.Identity()
- blocks = []
- for block_idx in range(depth):
- if token_mixer_type == "repmixer":
- blocks.append(RepMixerBlock(
- dim_out,
- kernel_size=kernel_size,
- mlp_ratio=mlp_ratio,
- act_layer=act_layer,
- proj_drop=proj_drop_rate,
- drop_path=drop_path_rate[block_idx],
- layer_scale_init_value=layer_scale_init_value,
- inference_mode=inference_mode,
- **dd,
- ))
- elif token_mixer_type == "attention":
- blocks.append(AttentionBlock(
- dim_out,
- mlp_ratio=mlp_ratio,
- act_layer=act_layer,
- norm_layer=norm_layer,
- proj_drop=proj_drop_rate,
- drop_path=drop_path_rate[block_idx],
- layer_scale_init_value=layer_scale_init_value,
- **dd,
- ))
- else:
- raise ValueError(
- "Token mixer type: {} not supported".format(token_mixer_type)
- )
- self.blocks = nn.Sequential(*blocks)
- def forward(self, x):
- x = self.downsample(x)
- x = self.pos_emb(x)
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.blocks, x)
- else:
- x = self.blocks(x)
- return x
- class FastVit(nn.Module):
- fork_feat: torch.jit.Final[bool]
- """
- This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
- """
- def __init__(
- self,
- in_chans: int = 3,
- layers: Tuple[int, ...] = (2, 2, 6, 2),
- token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
- embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
- mlp_ratios: Tuple[float, ...] = (4,) * 4,
- downsamples: Tuple[bool, ...] = (False, True, True, True),
- se_downsamples: Tuple[bool, ...] = (False, False, False, False),
- repmixer_kernel_size: int = 3,
- num_classes: int = 1000,
- pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
- down_patch_size: int = 7,
- down_stride: int = 2,
- drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- layer_scale_init_value: float = 1e-5,
- lkc_use_act: bool = False,
- stem_use_scale_branch: bool = True,
- fork_feat: bool = False,
- cls_ratio: float = 2.0,
- global_pool: str = 'avg',
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- act_layer: Type[nn.Module] = nn.GELU,
- inference_mode: bool = False,
- device=None,
- dtype=None,
- ) -> None:
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.num_classes = 0 if fork_feat else num_classes
- self.fork_feat = fork_feat
- self.global_pool = global_pool
- self.feature_info = []
- # Convolutional stem
- self.stem = convolutional_stem(
- in_chans,
- embed_dims[0],
- act_layer,
- inference_mode,
- use_scale_branch=stem_use_scale_branch,
- **dd,
- )
- # Build the main stages of the network architecture
- prev_dim = embed_dims[0]
- scale = 1
- dpr = calculate_drop_path_rates(drop_path_rate, layers, stagewise=True)
- stages = []
- for i in range(len(layers)):
- downsample = downsamples[i] or prev_dim != embed_dims[i]
- stage = FastVitStage(
- dim=prev_dim,
- dim_out=embed_dims[i],
- depth=layers[i],
- downsample=downsample,
- se_downsample=se_downsamples[i],
- down_patch_size=down_patch_size,
- down_stride=down_stride,
- pos_emb_layer=pos_embs[i],
- token_mixer_type=token_mixers[i],
- kernel_size=repmixer_kernel_size,
- mlp_ratio=mlp_ratios[i],
- act_layer=act_layer,
- norm_layer=norm_layer,
- proj_drop_rate=proj_drop_rate,
- drop_path_rate=dpr[i],
- layer_scale_init_value=layer_scale_init_value,
- lkc_use_act=lkc_use_act,
- inference_mode=inference_mode,
- **dd,
- )
- stages.append(stage)
- prev_dim = embed_dims[i]
- if downsample:
- scale *= 2
- self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')]
- self.stages = nn.Sequential(*stages)
- self.num_stages = len(self.stages)
- self.num_features = self.head_hidden_size = prev_dim
- # For segmentation and detection, extract intermediate output
- if self.fork_feat:
- # Add a norm layer for each output. self.stages is slightly different than self.network
- # in the original code, the PatchEmbed layer is part of self.stages in this code where
- # it was part of self.network in the original code. So we do not need to skip out indices.
- self.out_indices = [0, 1, 2, 3]
- for i_emb, i_layer in enumerate(self.out_indices):
- if i_emb == 0 and os.environ.get("FORK_LAST3", None):
- """For RetinaNet, `start_level=1`. The first norm layer will not used.
- cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...`
- """
- layer = nn.Identity()
- else:
- layer = norm_layer(embed_dims[i_emb], **dd)
- layer_name = f"norm{i_layer}"
- self.add_module(layer_name, layer)
- else:
- # Classifier head
- self.num_features = self.head_hidden_size = final_features = int(embed_dims[-1] * cls_ratio)
- self.final_conv = MobileOneBlock(
- in_chs=embed_dims[-1],
- out_chs=final_features,
- kernel_size=3,
- stride=1,
- group_size=1,
- inference_mode=inference_mode,
- use_se=True,
- act_layer=act_layer,
- num_conv_branches=1,
- **dd,
- )
- self.head = ClassifierHead(
- final_features,
- num_classes,
- pool_type=global_pool,
- drop_rate=drop_rate,
- **dd,
- )
- self.apply(self._init_weights)
- def _init_weights(self, m: nn.Module) -> None:
- """Init. for classification"""
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=0.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- @torch.jit.ignore
- def no_weight_decay(self):
- return set()
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- return dict(
- stem=r'^stem', # stem and embed
- blocks=r'^stages\.(\d+)' if coarse else [
- (r'^stages\.(\d+).downsample', (0,)),
- (r'^stages\.(\d+).pos_emb', (0,)),
- (r'^stages\.(\d+)\.\w+\.(\d+)', None),
- ]
- )
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- for s in self.stages:
- s.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head.fc
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- self.head.reset(num_classes, global_pool)
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- norm: Apply norm layer to compatible intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- Returns:
- """
- assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
- intermediates = []
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- # forward pass
- x = self.stem(x)
- last_idx = self.num_stages - 1
- if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
- stages = self.stages
- else:
- stages = self.stages[:max_index + 1]
- feat_idx = 0
- for feat_idx, stage in enumerate(stages):
- x = stage(x)
- if feat_idx in take_indices:
- intermediates.append(x)
- if intermediates_only:
- return intermediates
- if feat_idx == last_idx:
- x = self.final_conv(x)
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ):
- """ Prune layers not required for specified intermediates.
- """
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
- if prune_head:
- self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x: torch.Tensor) -> torch.Tensor:
- # input embedding
- x = self.stem(x)
- outs = []
- for idx, block in enumerate(self.stages):
- x = block(x)
- if self.fork_feat:
- if idx in self.out_indices:
- norm_layer = getattr(self, f"norm{idx}")
- x_out = norm_layer(x)
- outs.append(x_out)
- if self.fork_feat:
- # output the features of four stages for dense prediction
- return outs
- x = self.final_conv(x)
- return x
- def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
- return self.head(x, pre_logits=True) if pre_logits else self.head(x)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.forward_features(x)
- if self.fork_feat:
- return x
- x = self.forward_head(x)
- return x
- def _cfg(url="", **kwargs):
- return {
- "url": url,
- "num_classes": 1000,
- "input_size": (3, 256, 256),
- "pool_size": (8, 8),
- "crop_pct": 0.9,
- "interpolation": "bicubic",
- "mean": IMAGENET_DEFAULT_MEAN,
- "license": "fastvit-license",
- "std": IMAGENET_DEFAULT_STD,
- 'first_conv': ('stem.0.conv_kxk.0.conv', 'stem.0.conv_scale.conv'),
- "classifier": "head.fc",
- **kwargs,
- }
- default_cfgs = generate_default_cfgs({
- "fastvit_t8.apple_in1k": _cfg(
- hf_hub_id='timm/'),
- "fastvit_t12.apple_in1k": _cfg(
- hf_hub_id='timm/'),
- "fastvit_s12.apple_in1k": _cfg(
- hf_hub_id='timm/'),
- "fastvit_sa12.apple_in1k": _cfg(
- hf_hub_id='timm/'),
- "fastvit_sa24.apple_in1k": _cfg(
- hf_hub_id='timm/'),
- "fastvit_sa36.apple_in1k": _cfg(
- hf_hub_id='timm/'),
- "fastvit_ma36.apple_in1k": _cfg(
- hf_hub_id='timm/',
- crop_pct=0.95),
- "fastvit_t8.apple_dist_in1k": _cfg(
- hf_hub_id='timm/'),
- "fastvit_t12.apple_dist_in1k": _cfg(
- hf_hub_id='timm/'),
- "fastvit_s12.apple_dist_in1k": _cfg(
- hf_hub_id='timm/',),
- "fastvit_sa12.apple_dist_in1k": _cfg(
- hf_hub_id='timm/',),
- "fastvit_sa24.apple_dist_in1k": _cfg(
- hf_hub_id='timm/',),
- "fastvit_sa36.apple_dist_in1k": _cfg(
- hf_hub_id='timm/',),
- "fastvit_ma36.apple_dist_in1k": _cfg(
- hf_hub_id='timm/',
- crop_pct=0.95
- ),
- "fastvit_mci0.apple_mclip": _cfg(
- hf_hub_id='apple/mobileclip_s0_timm',
- url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt',
- crop_pct=0.95,
- num_classes=512, # CLIP proj dim
- mean=(0., 0., 0.), std=(1., 1., 1.), license='apple-amlr'
- ),
- "fastvit_mci1.apple_mclip": _cfg(
- hf_hub_id='apple/mobileclip_s1_timm',
- url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt',
- crop_pct=0.95,
- num_classes=512, # CLIP proj dim
- mean=(0., 0., 0.), std=(1., 1., 1.), license='apple-amlr'
- ),
- "fastvit_mci2.apple_mclip": _cfg(
- hf_hub_id='apple/mobileclip_s2_timm',
- url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt',
- crop_pct=0.95,
- num_classes=512, # CLIP proj dim
- mean=(0., 0., 0.), std=(1., 1., 1.), license='apple-amlr'
- ),
- "fastvit_mci0.apple_mclip2_dfndr2b": _cfg(
- hf_hub_id='timm/',
- crop_pct=1.0,
- num_classes=512, # CLIP proj dim
- mean=(0., 0., 0.), std=(1., 1., 1.),
- license='apple-amlr'
- ),
- "fastvit_mci2.apple_mclip2_dfndr2b": _cfg(
- hf_hub_id='timm/',
- crop_pct=0.95,
- num_classes=512, # CLIP proj dim
- mean=(0., 0., 0.), std=(1., 1., 1.),
- license='apple-amlr'
- ),
- "fastvit_mci3.apple_mclip2_dfndr2b": _cfg(
- hf_hub_id='timm/',
- crop_pct=0.95,
- num_classes=768, # CLIP proj dim
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- pool_size=(4, 4),
- first_conv='stem.0.conv_kxk.0.conv',
- license='apple-amlr'
- ),
- "fastvit_mci4.apple_mclip2_dfndr2b": _cfg(
- hf_hub_id='timm/',
- crop_pct=0.95,
- num_classes=768, # CLIP proj dim
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- pool_size=(4, 4),
- first_conv='stem.0.conv_kxk.0.conv',
- license='apple-amlr'
- ),
- })
- def checkpoint_filter_fn(state_dict, model):
- """ Remap original checkpoints -> timm """
- if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
- return state_dict # non-original checkpoint, no remapping needed
- if 'module.visual.trunk.stem.0.conv_kxk.0.conv.weight' in state_dict:
- return {k.replace('module.visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('module.visual.trunk')}
- state_dict = state_dict.get('state_dict', state_dict)
- if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
- # remap MobileCLIP checkpoints
- prefix = 'image_encoder.model.'
- else:
- prefix = ''
- import re
- import bisect
- # find stage ends by locating downsample layers
- stage_ends = []
- for k, v in state_dict.items():
- match = re.match(r'^(.*?)network\.(\d+)\.proj.*', k)
- if match:
- stage_ends.append(int(match.group(2)))
- stage_ends = list(sorted(set(stage_ends)))
- out_dict = {}
- for k, v in state_dict.items():
- if prefix:
- if prefix not in k:
- continue
- k = k.replace(prefix, '')
- # remap renamed layers
- k = k.replace('patch_embed', 'stem')
- k = k.replace('rbr_conv', 'conv_kxk')
- k = k.replace('rbr_scale', 'conv_scale')
- k = k.replace('rbr_skip', 'identity')
- k = k.replace('conv_exp', 'final_conv') # to match byobnet, regnet, nfnet
- k = k.replace('lkb_origin', 'large_conv')
- k = k.replace('convffn', 'mlp')
- k = k.replace('se.reduce', 'se.fc1')
- k = k.replace('se.expand', 'se.fc2')
- k = re.sub(r'layer_scale_([0-9])', r'layer_scale_\1.gamma', k)
- if k.endswith('layer_scale'):
- k = k.replace('layer_scale', 'layer_scale.gamma')
- k = k.replace('dist_head', 'head_dist')
- if k.startswith('head.'):
- if k == 'head.proj' and hasattr(model.head, 'fc') and isinstance(model.head.fc, nn.Linear):
- # if CLIP projection, map to head.fc w/ bias = zeros
- k = k.replace('head.proj', 'head.fc.weight')
- v = v.T
- out_dict['head.fc.bias'] = torch.zeros(v.shape[0])
- else:
- k = k.replace('head.', 'head.fc.')
- # remap flat sequential network to stages
- match = re.match(r'^network\.(\d+)', k)
- stage_idx, net_idx = None, None
- if match:
- net_idx = int(match.group(1))
- stage_idx = bisect.bisect_right(stage_ends, net_idx)
- if stage_idx is not None:
- net_prefix = f'network.{net_idx}'
- stage_prefix = f'stages.{stage_idx}'
- if net_prefix + '.proj' in k:
- k = k.replace(net_prefix + '.proj', stage_prefix + '.downsample.proj')
- elif net_prefix + '.pe' in k:
- k = k.replace(net_prefix + '.pe', stage_prefix + '.pos_emb.pos_enc')
- else:
- k = k.replace(net_prefix, stage_prefix + '.blocks')
- out_dict[k] = v
- return out_dict
- def _create_fastvit(variant, pretrained=False, **kwargs):
- out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
- model = build_model_with_cfg(
- FastVit,
- variant,
- pretrained,
- pretrained_filter_fn=checkpoint_filter_fn,
- feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
- **kwargs
- )
- return model
- @register_model
- def fastvit_t8(pretrained=False, **kwargs):
- """Instantiate FastViT-T8 model variant."""
- model_args = dict(
- layers=(2, 2, 4, 2),
- embed_dims=(48, 96, 192, 384),
- mlp_ratios=(3, 3, 3, 3),
- token_mixers=("repmixer", "repmixer", "repmixer", "repmixer")
- )
- return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def fastvit_t12(pretrained=False, **kwargs):
- """Instantiate FastViT-T12 model variant."""
- model_args = dict(
- layers=(2, 2, 6, 2),
- embed_dims=(64, 128, 256, 512),
- mlp_ratios=(3, 3, 3, 3),
- token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
- )
- return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def fastvit_s12(pretrained=False, **kwargs):
- """Instantiate FastViT-S12 model variant."""
- model_args = dict(
- layers=(2, 2, 6, 2),
- embed_dims=(64, 128, 256, 512),
- mlp_ratios=(4, 4, 4, 4),
- token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
- )
- return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def fastvit_sa12(pretrained=False, **kwargs):
- """Instantiate FastViT-SA12 model variant."""
- model_args = dict(
- layers=(2, 2, 6, 2),
- embed_dims=(64, 128, 256, 512),
- mlp_ratios=(4, 4, 4, 4),
- pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
- token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
- )
- return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def fastvit_sa24(pretrained=False, **kwargs):
- """Instantiate FastViT-SA24 model variant."""
- model_args = dict(
- layers=(4, 4, 12, 4),
- embed_dims=(64, 128, 256, 512),
- mlp_ratios=(4, 4, 4, 4),
- pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
- token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
- )
- return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def fastvit_sa36(pretrained=False, **kwargs):
- """Instantiate FastViT-SA36 model variant."""
- model_args = dict(
- layers=(6, 6, 18, 6),
- embed_dims=(64, 128, 256, 512),
- mlp_ratios=(4, 4, 4, 4),
- pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
- token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
- )
- return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def fastvit_ma36(pretrained=False, **kwargs):
- """Instantiate FastViT-MA36 model variant."""
- model_args = dict(
- layers=(6, 6, 18, 6),
- embed_dims=(76, 152, 304, 608),
- mlp_ratios=(4, 4, 4, 4),
- pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
- token_mixers=("repmixer", "repmixer", "repmixer", "attention")
- )
- return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def fastvit_mci0(pretrained=False, **kwargs):
- """Instantiate MCi0 model variant."""
- model_args = dict(
- layers=(2, 6, 10, 2),
- embed_dims=(64, 128, 256, 512),
- mlp_ratios=(3, 3, 3, 3),
- se_downsamples=(False, False, True, True),
- pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
- token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
- lkc_use_act=True,
- )
- return _create_fastvit('fastvit_mci0', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def fastvit_mci1(pretrained=False, **kwargs):
- """Instantiate MCi1 model variant."""
- model_args = dict(
- layers=(4, 12, 20, 4),
- embed_dims=(64, 128, 256, 512),
- mlp_ratios=(3, 3, 3, 3),
- se_downsamples=(False, False, True, True),
- pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
- token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
- lkc_use_act=True,
- )
- return _create_fastvit('fastvit_mci1', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def fastvit_mci2(pretrained=False, **kwargs):
- """Instantiate MCi2 model variant."""
- model_args = dict(
- layers=(4, 12, 24, 4),
- embed_dims=(80, 160, 320, 640),
- mlp_ratios=(3, 3, 3, 3),
- se_downsamples=(False, False, True, True),
- pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
- token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
- lkc_use_act=True,
- )
- return _create_fastvit('fastvit_mci2', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def fastvit_mci3(pretrained=False, **kwargs):
- """Instantiate L model variant."""
- model_args = dict(
- layers=(2, 12, 24, 4, 2),
- embed_dims=(96, 192, 384, 768, 1536),
- mlp_ratios=(4, 4, 4, 4, 4),
- se_downsamples=(False, False, False, False, False),
- downsamples=(False, True, True, True, True),
- pos_embs=(
- None,
- None,
- None,
- partial(RepConditionalPosEnc, spatial_shape=(7, 7)),
- partial(RepConditionalPosEnc, spatial_shape=(7, 7))
- ),
- token_mixers=("repmixer", "repmixer", "repmixer", "attention", "attention"),
- lkc_use_act=True,
- norm_layer=partial(LayerNorm2d, eps=1e-5),
- stem_use_scale_branch=False,
- )
- model = _create_fastvit('fastvit_mci3', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def fastvit_mci4(pretrained=False, **kwargs):
- """Instantiate XL model variant."""
- model_args = dict(
- layers=(2, 12, 24, 4, 4),
- embed_dims=(128, 256, 512, 1024, 2048),
- mlp_ratios=(4, 4, 4, 4, 4),
- se_downsamples=(False, False, False, False, False),
- downsamples=(False, True, True, True, True),
- pos_embs=(
- None,
- None,
- None,
- partial(RepConditionalPosEnc, spatial_shape=(7, 7)),
- partial(RepConditionalPosEnc, spatial_shape=(7, 7))
- ),
- token_mixers=("repmixer", "repmixer", "repmixer", "attention", "attention"),
- lkc_use_act=True,
- norm_layer=partial(LayerNorm2d, eps=1e-5),
- stem_use_scale_branch=False,
- )
- model = _create_fastvit('fastvit_mci4', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
|