| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711 |
- """ MaxVit and CoAtNet Vision Transformer - CNN Hybrids in PyTorch
- This is a from-scratch implementation of both CoAtNet and MaxVit in PyTorch.
- 99% of the implementation was done from papers, however last minute some adjustments were made
- based on the (as yet unfinished?) public code release https://github.com/google-research/maxvit
- There are multiple sets of models defined for both architectures. Typically, names with a
- `_rw` suffix are my own original configs prior to referencing https://github.com/google-research/maxvit.
- These configs work well and appear to be a bit faster / lower resource than the paper.
- The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to
- match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match.
- Papers:
- MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697
- @article{tu2022maxvit,
- title={MaxViT: Multi-Axis Vision Transformer},
- author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao},
- journal={ECCV},
- year={2022},
- }
- CoAtNet: Marrying Convolution and Attention for All Data Sizes - https://arxiv.org/abs/2106.04803
- @article{DBLP:journals/corr/abs-2106-04803,
- author = {Zihang Dai and Hanxiao Liu and Quoc V. Le and Mingxing Tan},
- title = {CoAtNet: Marrying Convolution and Attention for All Data Sizes},
- journal = {CoRR},
- volume = {abs/2106.04803},
- year = {2021}
- }
- Hacked together by / Copyright 2022, Ross Wightman
- """
- import math
- from collections import OrderedDict
- from dataclasses import dataclass, replace, field
- from functools import partial
- from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
- import torch
- from torch import nn
- from torch.jit import Final
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import (
- Mlp,
- ConvMlp,
- DropPath,
- calculate_drop_path_rates,
- LayerNorm,
- LayerScale,
- LayerScale2d,
- ClassifierHead,
- NormMlpClassifierHead,
- create_attn,
- get_act_layer,
- get_norm_layer,
- get_norm_act_layer,
- create_conv2d,
- create_pool2d,
- trunc_normal_tf_,
- to_2tuple,
- extend_tuple,
- make_divisible,
- _assert,
- RelPosMlp,
- RelPosBias,
- RelPosBiasTf,
- use_fused_attn,
- resize_rel_pos_bias_table,
- )
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._features_fx import register_notrace_function
- from ._manipulate import named_apply, checkpoint_seq
- from ._registry import generate_default_cfgs, register_model
- __all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']
- @dataclass
- class MaxxVitTransformerCfg:
- """Configuration for MaxxVit transformer blocks."""
- dim_head: int = 32
- head_first: bool = True # head ordering in qkv channel dim
- expand_ratio: float = 4.0
- expand_first: bool = True
- shortcut_bias: bool = True
- attn_bias: bool = True
- attn_drop: float = 0.
- proj_drop: float = 0.
- pool_type: str = 'avg2'
- rel_pos_type: str = 'bias'
- rel_pos_dim: int = 512 # for relative position types w/ MLP
- partition_ratio: int = 32
- window_size: Optional[Tuple[int, int]] = None
- grid_size: Optional[Tuple[int, int]] = None
- no_block_attn: bool = False # disable window block attention for maxvit (ie only grid)
- use_nchw_attn: bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order
- init_values: Optional[float] = None
- act_layer: str = 'gelu'
- norm_layer: str = 'layernorm2d'
- norm_layer_cl: str = 'layernorm'
- norm_eps: float = 1e-6
- def __post_init__(self):
- if self.grid_size is not None:
- self.grid_size = to_2tuple(self.grid_size)
- if self.window_size is not None:
- self.window_size = to_2tuple(self.window_size)
- if self.grid_size is None:
- self.grid_size = self.window_size
- @dataclass
- class MaxxVitConvCfg:
- """Configuration for MaxxVit convolution blocks."""
- block_type: str = 'mbconv'
- expand_ratio: float = 4.0
- expand_output: bool = True # calculate expansion channels from output (vs input chs)
- kernel_size: int = 3
- group_size: int = 1 # 1 == depthwise
- pre_norm_act: bool = False # activation after pre-norm
- output_bias: bool = True # bias for shortcut + final 1x1 projection conv
- stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
- pool_type: str = 'avg2'
- downsample_pool_type: str = 'avg2'
- padding: str = ''
- attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2
- attn_layer: str = 'se'
- attn_act_layer: str = 'silu'
- attn_ratio: float = 0.25
- init_values: Optional[float] = 1e-6 # for ConvNeXt block, ignored by MBConv
- act_layer: str = 'gelu'
- norm_layer: str = ''
- norm_layer_cl: str = ''
- norm_eps: Optional[float] = None
- def __post_init__(self):
- # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args
- assert self.block_type in ('mbconv', 'convnext')
- use_mbconv = self.block_type == 'mbconv'
- if not self.norm_layer:
- self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d'
- if not self.norm_layer_cl and not use_mbconv:
- self.norm_layer_cl = 'layernorm'
- if self.norm_eps is None:
- self.norm_eps = 1e-5 if use_mbconv else 1e-6
- self.downsample_pool_type = self.downsample_pool_type or self.pool_type
- @dataclass
- class MaxxVitCfg:
- """Configuration for MaxxVit models."""
- embed_dim: Tuple[int, ...] = (96, 192, 384, 768)
- depths: Tuple[int, ...] = (2, 3, 5, 2)
- block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T')
- stem_width: Union[int, Tuple[int, int]] = 64
- stem_bias: bool = False
- conv_cfg: MaxxVitConvCfg = field(default_factory=MaxxVitConvCfg)
- transformer_cfg: MaxxVitTransformerCfg = field(default_factory=MaxxVitTransformerCfg)
- head_hidden_size: Optional[int] = None
- weight_init: str = 'vit_eff'
- class Attention2d(nn.Module):
- """Multi-head attention for 2D NCHW tensors."""
- fused_attn: Final[bool]
- def __init__(
- self,
- dim: int,
- dim_out: Optional[int] = None,
- dim_head: int = 32,
- bias: bool = True,
- expand_first: bool = True,
- head_first: bool = True,
- rel_pos_cls: Optional[Callable] = None,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- device=None,
- dtype=None,
- ):
- """
- Args:
- dim: Input dimension.
- dim_out: Output dimension (defaults to input dimension).
- dim_head: Dimension per attention head.
- bias: Whether to use bias in qkv and projection.
- expand_first: Whether to expand channels before or after qkv.
- head_first: Whether heads are first in tensor layout.
- rel_pos_cls: Relative position class to use.
- attn_drop: Attention dropout rate.
- proj_drop: Projection dropout rate.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- dim_out = dim_out or dim
- dim_attn = dim_out if expand_first else dim
- self.num_heads = dim_attn // dim_head
- self.dim_head = dim_head
- self.head_first = head_first
- self.scale = dim_head ** -0.5
- self.fused_attn = use_fused_attn()
- self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias, **dd)
- self.rel_pos = rel_pos_cls(num_heads=self.num_heads, **dd) if rel_pos_cls else None
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias, **dd)
- self.proj_drop = nn.Dropout(proj_drop)
- def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
- B, C, H, W = x.shape
- if self.head_first:
- q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2)
- else:
- q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
- if self.fused_attn:
- attn_bias = None
- if self.rel_pos is not None:
- attn_bias = self.rel_pos.get_bias()
- elif shared_rel_pos is not None:
- attn_bias = shared_rel_pos
- x = torch.nn.functional.scaled_dot_product_attention(
- q.transpose(-1, -2).contiguous(),
- k.transpose(-1, -2).contiguous(),
- v.transpose(-1, -2).contiguous(),
- attn_mask=attn_bias,
- dropout_p=self.attn_drop.p if self.training else 0.,
- ).transpose(-1, -2).reshape(B, -1, H, W)
- else:
- q = q * self.scale
- attn = q.transpose(-2, -1) @ k
- if self.rel_pos is not None:
- attn = self.rel_pos(attn)
- elif shared_rel_pos is not None:
- attn = attn + shared_rel_pos
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
- class AttentionCl(nn.Module):
- """Channels-last multi-head attention (B, ..., C)."""
- fused_attn: Final[bool]
- def __init__(
- self,
- dim: int,
- dim_out: Optional[int] = None,
- dim_head: int = 32,
- bias: bool = True,
- expand_first: bool = True,
- head_first: bool = True,
- rel_pos_cls: Optional[Callable] = None,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- device=None,
- dtype=None,
- ):
- """
- Args:
- dim: Input dimension.
- dim_out: Output dimension (defaults to input dimension).
- dim_head: Dimension per attention head.
- bias: Whether to use bias in qkv and projection.
- expand_first: Whether to expand channels before or after qkv.
- head_first: Whether heads are first in tensor layout.
- rel_pos_cls: Relative position class to use.
- attn_drop: Attention dropout rate.
- proj_drop: Projection dropout rate.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- dim_out = dim_out or dim
- dim_attn = dim_out if expand_first and dim_out > dim else dim
- assert dim_attn % dim_head == 0, 'attn dim should be divisible by head_dim'
- self.num_heads = dim_attn // dim_head
- self.dim_head = dim_head
- self.head_first = head_first
- self.scale = dim_head ** -0.5
- self.fused_attn = use_fused_attn()
- self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias, **dd)
- self.rel_pos = rel_pos_cls(num_heads=self.num_heads, **dd) if rel_pos_cls else None
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim_attn, dim_out, bias=bias, **dd)
- self.proj_drop = nn.Dropout(proj_drop)
- def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
- B = x.shape[0]
- restore_shape = x.shape[:-1]
- if self.head_first:
- q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3)
- else:
- q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2)
- if self.fused_attn:
- attn_bias = None
- if self.rel_pos is not None:
- attn_bias = self.rel_pos.get_bias()
- elif shared_rel_pos is not None:
- attn_bias = shared_rel_pos
- x = torch.nn.functional.scaled_dot_product_attention(
- q, k, v,
- attn_mask=attn_bias,
- dropout_p=self.attn_drop.p if self.training else 0.,
- )
- else:
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
- if self.rel_pos is not None:
- attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
- elif shared_rel_pos is not None:
- attn = attn + shared_rel_pos
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = attn @ v
- x = x.transpose(1, 2).reshape(restore_shape + (-1,))
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
- class Downsample2d(nn.Module):
- """A downsample pooling module supporting several maxpool and avgpool modes.
- * 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1
- * 'max2' - MaxPool2d w/ kernel_size = stride = 2
- * 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1
- * 'avg2' - AvgPool2d w/ kernel_size = stride = 2
- """
- def __init__(
- self,
- dim: int,
- dim_out: int,
- pool_type: str = 'avg2',
- padding: str = '',
- bias: bool = True,
- device=None,
- dtype=None,
- ):
- """
- Args:
- dim: Input dimension.
- dim_out: Output dimension.
- pool_type: Type of pooling operation.
- padding: Padding mode.
- bias: Whether to use bias in expansion conv.
- """
- super().__init__()
- assert pool_type in ('max', 'max2', 'avg', 'avg2')
- if pool_type == 'max':
- self.pool = create_pool2d('max', kernel_size=3, stride=2, padding=padding or 1)
- elif pool_type == 'max2':
- self.pool = create_pool2d('max', 2, padding=padding or 0) # kernel_size == stride == 2
- elif pool_type == 'avg':
- self.pool = create_pool2d(
- 'avg', kernel_size=3, stride=2, count_include_pad=False, padding=padding or 1)
- else:
- self.pool = create_pool2d('avg', 2, padding=padding or 0)
- if dim != dim_out:
- self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias, device=device, dtype=dtype)
- else:
- self.expand = nn.Identity()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.pool(x) # spatial downsample
- x = self.expand(x) # expand chs
- return x
- def _init_transformer(module: nn.Module, name: str, scheme: str = '') -> None:
- """Initialize transformer module weights."""
- if isinstance(module, (nn.Conv2d, nn.Linear)):
- if scheme == 'normal':
- nn.init.normal_(module.weight, std=.02)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif scheme == 'trunc_normal':
- trunc_normal_tf_(module.weight, std=.02)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif scheme == 'xavier_normal':
- nn.init.xavier_normal_(module.weight)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- else:
- # vit like
- nn.init.xavier_uniform_(module.weight)
- if module.bias is not None:
- if 'mlp' in name:
- nn.init.normal_(module.bias, std=1e-6)
- else:
- nn.init.zeros_(module.bias)
- class TransformerBlock2d(nn.Module):
- """Transformer block with 2D downsampling.
- '2D' NCHW tensor layout
- Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW
- for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs.
- This impl was faster on TPU w/ PT XLA than the 1D experiment.
- """
- def __init__(
- self,
- dim: int,
- dim_out: int,
- stride: int = 1,
- rel_pos_cls: Optional[Callable] = None,
- cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- """
- Args:
- dim: Input dimension.
- dim_out: Output dimension.
- stride: Stride for downsampling.
- rel_pos_cls: Relative position class.
- cfg: Transformer block configuration.
- drop_path: Drop path rate.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
- act_layer = get_act_layer(cfg.act_layer)
- if stride == 2:
- self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias, **dd)
- self.norm1 = nn.Sequential(OrderedDict([
- ('norm', norm_layer(dim, **dd)),
- ('down', Downsample2d(dim, dim, pool_type=cfg.pool_type, **dd)),
- ]))
- else:
- assert dim == dim_out
- self.shortcut = nn.Identity()
- self.norm1 = norm_layer(dim, **dd)
- self.attn = Attention2d(
- dim,
- dim_out,
- dim_head=cfg.dim_head,
- expand_first=cfg.expand_first,
- bias=cfg.attn_bias,
- rel_pos_cls=rel_pos_cls,
- attn_drop=cfg.attn_drop,
- proj_drop=cfg.proj_drop,
- **dd,
- )
- self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
- self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim_out, **dd)
- self.mlp = ConvMlp(
- in_features=dim_out,
- hidden_features=int(dim_out * cfg.expand_ratio),
- act_layer=act_layer,
- drop=cfg.proj_drop,
- **dd,
- )
- self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
- self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- def init_weights(self, scheme: str = '') -> None:
- named_apply(partial(_init_transformer, scheme=scheme), self)
- def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
- x = self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos)))
- x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
- return x
- def _init_conv(module: nn.Module, name: str, scheme: str = '') -> None:
- """Initialize convolution module weights."""
- if isinstance(module, nn.Conv2d):
- if scheme == 'normal':
- nn.init.normal_(module.weight, std=.02)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif scheme == 'trunc_normal':
- trunc_normal_tf_(module.weight, std=.02)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif scheme == 'xavier_normal':
- nn.init.xavier_normal_(module.weight)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- else:
- # efficientnet like
- fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
- fan_out //= module.groups
- nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- def num_groups(group_size: Optional[int], channels: int) -> int:
- """Calculate number of groups for grouped convolution."""
- if not group_size: # 0 or None
- return 1 # normal conv with 1 group
- else:
- # NOTE group_size == 1 -> depthwise conv
- assert channels % group_size == 0
- return channels // group_size
- class MbConvBlock(nn.Module):
- """Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)."""
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- stride: int = 1,
- dilation: Tuple[int, int] = (1, 1),
- cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- """
- Args:
- in_chs: Input channels.
- out_chs: Output channels.
- stride: Stride for conv.
- dilation: Dilation for conv.
- cfg: Convolution block configuration.
- drop_path: Drop path rate.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps)
- mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio)
- groups = num_groups(cfg.group_size, mid_chs)
- if stride == 2:
- self.shortcut = Downsample2d(
- in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias, padding=cfg.padding, **dd)
- else:
- self.shortcut = nn.Identity()
- assert cfg.stride_mode in ('pool', '1x1', 'dw')
- stride_pool, stride_1, stride_2 = 1, 1, 1
- if cfg.stride_mode == 'pool':
- # NOTE this is not described in paper, experiment to find faster option that doesn't stride in 1x1
- stride_pool, dilation_2 = stride, dilation[1]
- # FIXME handle dilation of avg pool
- elif cfg.stride_mode == '1x1':
- # NOTE I don't like this option described in paper, 1x1 w/ stride throws info away
- stride_1, dilation_2 = stride, dilation[1]
- else:
- stride_2, dilation_2 = stride, dilation[0]
- self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act, **dd)
- if stride_pool > 1:
- self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type, padding=cfg.padding, **dd)
- else:
- self.down = nn.Identity()
- self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1, **dd)
- self.norm1 = norm_act_layer(mid_chs, **dd)
- self.conv2_kxk = create_conv2d(
- mid_chs,
- mid_chs,
- cfg.kernel_size,
- stride=stride_2,
- dilation=dilation_2,
- groups=groups,
- padding=cfg.padding,
- **dd,
- )
- attn_kwargs = {}
- if isinstance(cfg.attn_layer, str):
- if cfg.attn_layer == 'se' or cfg.attn_layer == 'eca':
- attn_kwargs['act_layer'] = cfg.attn_act_layer
- attn_kwargs['rd_channels'] = int(cfg.attn_ratio * (out_chs if cfg.expand_output else mid_chs))
- # two different orderings for SE and norm2 (due to some weights and trials using SE before norm2)
- if cfg.attn_early:
- self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs, **dd)
- self.norm2 = norm_act_layer(mid_chs, **dd)
- self.se = None
- else:
- self.se_early = None
- self.norm2 = norm_act_layer(mid_chs, **dd)
- self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs, **dd)
- self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias, **dd)
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- def init_weights(self, scheme: str = '') -> None:
- named_apply(partial(_init_conv, scheme=scheme), self)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- shortcut = self.shortcut(x)
- x = self.pre_norm(x)
- x = self.down(x)
- # 1x1 expansion conv & norm-act
- x = self.conv1_1x1(x)
- x = self.norm1(x)
- # depthwise / grouped 3x3 conv w/ SE (or other) channel attention & norm-act
- x = self.conv2_kxk(x)
- if self.se_early is not None:
- x = self.se_early(x)
- x = self.norm2(x)
- if self.se is not None:
- x = self.se(x)
- # 1x1 linear projection to output width
- x = self.conv3_1x1(x)
- x = self.drop_path(x) + shortcut
- return x
- class ConvNeXtBlock(nn.Module):
- """ConvNeXt Block."""
- def __init__(
- self,
- in_chs: int,
- out_chs: Optional[int] = None,
- kernel_size: int = 7,
- stride: int = 1,
- dilation: Tuple[int, int] = (1, 1),
- cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
- conv_mlp: bool = True,
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- """
- Args:
- in_chs: Input channels.
- out_chs: Output channels.
- kernel_size: Kernel size for depthwise conv.
- stride: Stride for conv.
- dilation: Dilation for conv.
- cfg: Convolution block configuration.
- conv_mlp: Whether to use convolutional MLP.
- drop_path: Drop path rate.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- out_chs = out_chs or in_chs
- act_layer = get_act_layer(cfg.act_layer)
- if conv_mlp:
- norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
- mlp_layer = ConvMlp
- else:
- assert 'layernorm' in cfg.norm_layer
- norm_layer = LayerNorm
- mlp_layer = Mlp
- self.use_conv_mlp = conv_mlp
- if stride == 2:
- self.shortcut = Downsample2d(in_chs, out_chs, **dd)
- elif in_chs != out_chs:
- self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias, **dd)
- else:
- self.shortcut = nn.Identity()
- assert cfg.stride_mode in ('pool', 'dw')
- stride_pool, stride_dw = 1, 1
- # FIXME handle dilation?
- if cfg.stride_mode == 'pool':
- stride_pool = stride
- else:
- stride_dw = stride
- if stride_pool == 2:
- self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type, **dd)
- else:
- self.down = nn.Identity()
- self.conv_dw = create_conv2d(
- in_chs,
- out_chs,
- kernel_size=kernel_size,
- stride=stride_dw,
- dilation=dilation[1],
- depthwise=True,
- bias=cfg.output_bias,
- **dd,
- )
- self.norm = norm_layer(out_chs, **dd)
- self.mlp = mlp_layer(
- out_chs,
- int(cfg.expand_ratio * out_chs),
- bias=cfg.output_bias,
- act_layer=act_layer,
- **dd,
- )
- if conv_mlp:
- self.ls = LayerScale2d(out_chs, cfg.init_values, **dd) if cfg.init_values else nn.Identity()
- else:
- self.ls = LayerScale(out_chs, cfg.init_values, **dd) if cfg.init_values else nn.Identity()
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- shortcut = self.shortcut(x)
- x = self.down(x)
- x = self.conv_dw(x)
- if self.use_conv_mlp:
- x = self.norm(x)
- x = self.mlp(x)
- x = self.ls(x)
- else:
- x = x.permute(0, 2, 3, 1)
- x = self.norm(x)
- x = self.mlp(x)
- x = self.ls(x)
- x = x.permute(0, 3, 1, 2)
- x = self.drop_path(x) + shortcut
- return x
- def window_partition(x: torch.Tensor, window_size: List[int]) -> torch.Tensor:
- """Partition into non-overlapping windows."""
- B, H, W, C = x.shape
- _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
- _assert(W % window_size[1] == 0, f'width ({W}) must be divisible by window ({window_size[1]})')
- x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
- return windows
- @register_notrace_function # reason: int argument is a Proxy
- def window_reverse(windows: torch.Tensor, window_size: List[int], img_size: List[int]) -> torch.Tensor:
- """Reverse window partition."""
- H, W = img_size
- C = windows.shape[-1]
- x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
- return x
- def grid_partition(x: torch.Tensor, grid_size: List[int]) -> torch.Tensor:
- """Partition into overlapping windows with grid striding."""
- B, H, W, C = x.shape
- _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
- _assert(W % grid_size[1] == 0, f'width {W} must be divisible by grid {grid_size[1]}')
- x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C)
- windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C)
- return windows
- @register_notrace_function # reason: int argument is a Proxy
- def grid_reverse(windows: torch.Tensor, grid_size: List[int], img_size: List[int]) -> torch.Tensor:
- """Reverse grid partition."""
- H, W = img_size
- C = windows.shape[-1]
- x = windows.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C)
- x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C)
- return x
- def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size: Tuple[int, int]) -> Optional[Callable]:
- """Get relative position class based on config."""
- rel_pos_cls = None
- if cfg.rel_pos_type == 'mlp':
- rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim)
- elif cfg.rel_pos_type == 'bias':
- rel_pos_cls = partial(RelPosBias, window_size=window_size)
- elif cfg.rel_pos_type == 'bias_tf':
- rel_pos_cls = partial(RelPosBiasTf, window_size=window_size)
- return rel_pos_cls
- class PartitionAttentionCl(nn.Module):
- """Grid or Block partition + Attn + FFN.
- NxC 'channels last' tensor layout.
- """
- def __init__(
- self,
- dim: int,
- partition_type: str = 'block',
- cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last
- act_layer = get_act_layer(cfg.act_layer)
- self.partition_block = partition_type == 'block'
- self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size)
- rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
- self.norm1 = norm_layer(dim, **dd)
- self.attn = AttentionCl(
- dim,
- dim,
- dim_head=cfg.dim_head,
- bias=cfg.attn_bias,
- head_first=cfg.head_first,
- rel_pos_cls=rel_pos_cls,
- attn_drop=cfg.attn_drop,
- proj_drop=cfg.proj_drop,
- **dd,
- )
- self.ls1 = LayerScale(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
- self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim, **dd)
- self.mlp = Mlp(
- in_features=dim,
- hidden_features=int(dim * cfg.expand_ratio),
- act_layer=act_layer,
- drop=cfg.proj_drop,
- **dd,
- )
- self.ls2 = LayerScale(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
- self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- def _partition_attn(self, x):
- img_size = x.shape[1:3]
- if self.partition_block:
- partitioned = window_partition(x, self.partition_size)
- else:
- partitioned = grid_partition(x, self.partition_size)
- partitioned = self.attn(partitioned)
- if self.partition_block:
- x = window_reverse(partitioned, self.partition_size, img_size)
- else:
- x = grid_reverse(partitioned, self.partition_size, img_size)
- return x
- def forward(self, x):
- x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
- x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
- return x
- class ParallelPartitionAttention(nn.Module):
- """Experimental. Grid and Block partition + single FFN.
- NxC tensor layout.
- """
- def __init__(
- self,
- dim: int,
- cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- """
- Args:
- dim: Input dimension.
- cfg: Transformer block configuration.
- drop_path: Drop path rate.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- assert dim % 2 == 0
- norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last
- act_layer = get_act_layer(cfg.act_layer)
- assert cfg.window_size == cfg.grid_size
- self.partition_size = to_2tuple(cfg.window_size)
- rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
- self.norm1 = norm_layer(dim, **dd)
- self.attn_block = AttentionCl(
- dim,
- dim // 2,
- dim_head=cfg.dim_head,
- bias=cfg.attn_bias,
- head_first=cfg.head_first,
- rel_pos_cls=rel_pos_cls,
- attn_drop=cfg.attn_drop,
- proj_drop=cfg.proj_drop,
- **dd,
- )
- self.attn_grid = AttentionCl(
- dim,
- dim // 2,
- dim_head=cfg.dim_head,
- bias=cfg.attn_bias,
- head_first=cfg.head_first,
- rel_pos_cls=rel_pos_cls,
- attn_drop=cfg.attn_drop,
- proj_drop=cfg.proj_drop,
- **dd,
- )
- self.ls1 = LayerScale(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
- self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim, **dd)
- self.mlp = Mlp(
- in_features=dim,
- hidden_features=int(dim * cfg.expand_ratio),
- out_features=dim,
- act_layer=act_layer,
- drop=cfg.proj_drop,
- **dd,
- )
- self.ls2 = LayerScale(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
- self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- def _partition_attn(self, x: torch.Tensor) -> torch.Tensor:
- img_size = x.shape[1:3]
- partitioned_block = window_partition(x, self.partition_size)
- partitioned_block = self.attn_block(partitioned_block)
- x_window = window_reverse(partitioned_block, self.partition_size, img_size)
- partitioned_grid = grid_partition(x, self.partition_size)
- partitioned_grid = self.attn_grid(partitioned_grid)
- x_grid = grid_reverse(partitioned_grid, self.partition_size, img_size)
- return torch.cat([x_window, x_grid], dim=-1)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
- x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
- return x
- def window_partition_nchw(x: torch.Tensor, window_size: List[int]) -> torch.Tensor:
- """Partition windows for NCHW tensors."""
- B, C, H, W = x.shape
- _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
- _assert(W % window_size[1] == 0, f'width ({W}) must be divisible by window ({window_size[1]})')
- x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1])
- windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1])
- return windows
- @register_notrace_function # reason: int argument is a Proxy
- def window_reverse_nchw(windows: torch.Tensor, window_size: List[int], img_size: List[int]) -> torch.Tensor:
- """Reverse window partition for NCHW tensors."""
- H, W = img_size
- C = windows.shape[1]
- x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1])
- x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W)
- return x
- def grid_partition_nchw(x: torch.Tensor, grid_size: List[int]) -> torch.Tensor:
- """Grid partition for NCHW tensors."""
- B, C, H, W = x.shape
- _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
- _assert(W % grid_size[1] == 0, f'width {W} must be divisible by grid {grid_size[1]}')
- x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1])
- windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1])
- return windows
- @register_notrace_function # reason: int argument is a Proxy
- def grid_reverse_nchw(windows: torch.Tensor, grid_size: List[int], img_size: List[int]) -> torch.Tensor:
- """Reverse grid partition for NCHW tensors."""
- H, W = img_size
- C = windows.shape[1]
- x = windows.view(-1, H // grid_size[0], W // grid_size[1], C, grid_size[0], grid_size[1])
- x = x.permute(0, 3, 4, 1, 5, 2).contiguous().view(-1, C, H, W)
- return x
- class PartitionAttention2d(nn.Module):
- """Grid or Block partition + Attn + FFN.
- '2D' NCHW tensor layout.
- """
- def __init__(
- self,
- dim: int,
- partition_type: str = 'block',
- cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- """
- Args:
- dim: Input dimension.
- partition_type: Partition type ('block' or 'grid').
- cfg: Transformer block configuration.
- drop_path: Drop path rate.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) # NOTE this block is channels-last
- act_layer = get_act_layer(cfg.act_layer)
- self.partition_block = partition_type == 'block'
- self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size)
- rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
- self.norm1 = norm_layer(dim, **dd)
- self.attn = Attention2d(
- dim,
- dim,
- dim_head=cfg.dim_head,
- bias=cfg.attn_bias,
- head_first=cfg.head_first,
- rel_pos_cls=rel_pos_cls,
- attn_drop=cfg.attn_drop,
- proj_drop=cfg.proj_drop,
- **dd,
- )
- self.ls1 = LayerScale2d(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
- self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim, **dd)
- self.mlp = ConvMlp(
- in_features=dim,
- hidden_features=int(dim * cfg.expand_ratio),
- act_layer=act_layer,
- drop=cfg.proj_drop,
- **dd,
- )
- self.ls2 = LayerScale2d(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
- self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- def _partition_attn(self, x: torch.Tensor) -> torch.Tensor:
- img_size = x.shape[-2:]
- if self.partition_block:
- partitioned = window_partition_nchw(x, self.partition_size)
- else:
- partitioned = grid_partition_nchw(x, self.partition_size)
- partitioned = self.attn(partitioned)
- if self.partition_block:
- x = window_reverse_nchw(partitioned, self.partition_size, img_size)
- else:
- x = grid_reverse_nchw(partitioned, self.partition_size, img_size)
- return x
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
- x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
- return x
- class MaxxVitBlock(nn.Module):
- """MaxVit conv, window partition + FFN , grid partition + FFN."""
- def __init__(
- self,
- dim: int,
- dim_out: int,
- stride: int = 1,
- conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
- transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- """Initialize MaxxVitBlock.
- Args:
- dim: Input channel dimension.
- dim_out: Output channel dimension.
- stride: Stride for downsampling.
- conv_cfg: Configuration for convolutional blocks.
- transformer_cfg: Configuration for transformer blocks.
- drop_path: Drop path rate.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.nchw_attn = transformer_cfg.use_nchw_attn
- conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
- self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path, **dd)
- attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path, **dd)
- partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl
- self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs)
- self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
- def init_weights(self, scheme=''):
- if self.attn_block is not None:
- named_apply(partial(_init_transformer, scheme=scheme), self.attn_block)
- named_apply(partial(_init_transformer, scheme=scheme), self.attn_grid)
- named_apply(partial(_init_conv, scheme=scheme), self.conv)
- def forward(self, x):
- # NCHW format
- x = self.conv(x)
- if not self.nchw_attn:
- x = x.permute(0, 2, 3, 1) # to NHWC (channels-last)
- if self.attn_block is not None:
- x = self.attn_block(x)
- x = self.attn_grid(x)
- if not self.nchw_attn:
- x = x.permute(0, 3, 1, 2) # back to NCHW
- return x
- class ParallelMaxxVitBlock(nn.Module):
- """MaxVit block with parallel cat(window + grid), one FF.
- Experimental timm block.
- """
- def __init__(
- self,
- dim: int,
- dim_out: int,
- stride: int = 1,
- num_conv: int = 2,
- conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
- transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- """
- Args:
- dim: Input dimension.
- dim_out: Output dimension.
- stride: Stride for first conv block.
- num_conv: Number of convolution blocks.
- conv_cfg: Convolution block configuration.
- transformer_cfg: Transformer block configuration.
- drop_path: Drop path rate.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
- if num_conv > 1:
- convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path, **dd)]
- convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path, **dd)] * (num_conv - 1)
- self.conv = nn.Sequential(*convs)
- else:
- self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path, **dd)
- self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path, **dd)
- def init_weights(self, scheme: str = '') -> None:
- named_apply(partial(_init_transformer, scheme=scheme), self.attn)
- named_apply(partial(_init_conv, scheme=scheme), self.conv)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.conv(x)
- x = x.permute(0, 2, 3, 1)
- x = self.attn(x)
- x = x.permute(0, 3, 1, 2)
- return x
- class MaxxVitStage(nn.Module):
- """MaxxVit stage consisting of mixed convolution and transformer blocks."""
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- stride: int = 2,
- depth: int = 4,
- feat_size: Tuple[int, int] = (14, 14),
- block_types: Union[str, Tuple[str]] = 'C',
- transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
- conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
- drop_path: Union[float, List[float]] = 0.,
- device=None,
- dtype=None,
- ):
- """
- Args:
- in_chs: Input channels.
- out_chs: Output channels.
- stride: Stride for first block.
- depth: Number of blocks in stage.
- feat_size: Feature map size.
- block_types: Block types ('C' for conv, 'T' for transformer, etc).
- transformer_cfg: Transformer block configuration.
- conv_cfg: Convolution block configuration.
- drop_path: Drop path rate(s).
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.grad_checkpointing = False
- block_types = extend_tuple(block_types, depth)
- blocks = []
- for i, t in enumerate(block_types):
- block_stride = stride if i == 0 else 1
- assert t in ('C', 'T', 'M', 'PM')
- if t == 'C':
- conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
- blocks += [conv_cls(
- in_chs,
- out_chs,
- stride=block_stride,
- cfg=conv_cfg,
- drop_path=drop_path[i],
- **dd,
- )]
- elif t == 'T':
- rel_pos_cls = get_rel_pos_cls(transformer_cfg, feat_size)
- blocks += [TransformerBlock2d(
- in_chs,
- out_chs,
- stride=block_stride,
- rel_pos_cls=rel_pos_cls,
- cfg=transformer_cfg,
- drop_path=drop_path[i],
- **dd,
- )]
- elif t == 'M':
- blocks += [MaxxVitBlock(
- in_chs,
- out_chs,
- stride=block_stride,
- conv_cfg=conv_cfg,
- transformer_cfg=transformer_cfg,
- drop_path=drop_path[i],
- **dd,
- )]
- elif t == 'PM':
- blocks += [ParallelMaxxVitBlock(
- in_chs,
- out_chs,
- stride=block_stride,
- conv_cfg=conv_cfg,
- transformer_cfg=transformer_cfg,
- drop_path=drop_path[i],
- **dd,
- )]
- in_chs = out_chs
- self.blocks = nn.Sequential(*blocks)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.blocks, x)
- else:
- x = self.blocks(x)
- return x
- class Stem(nn.Module):
- """Stem layer for feature extraction."""
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- kernel_size: int = 3,
- padding: str = '',
- bias: bool = False,
- act_layer: str = 'gelu',
- norm_layer: str = 'batchnorm2d',
- norm_eps: float = 1e-5,
- device=None,
- dtype=None,
- ):
- """
- Args:
- in_chs: Input channels.
- out_chs: Output channels.
- kernel_size: Kernel size for convolutions.
- padding: Padding mode.
- bias: Whether to use bias.
- act_layer: Activation layer.
- norm_layer: Normalization layer.
- norm_eps: Normalization epsilon.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- if not isinstance(out_chs, (list, tuple)):
- out_chs = to_2tuple(out_chs)
- norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
- self.out_chs = out_chs[-1]
- self.stride = 2
- self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2, padding=padding, bias=bias, **dd)
- self.norm1 = norm_act_layer(out_chs[0], **dd)
- self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1, padding=padding, bias=bias, **dd)
- def init_weights(self, scheme: str = '') -> None:
- named_apply(partial(_init_conv, scheme=scheme), self)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.conv1(x)
- x = self.norm1(x)
- x = self.conv2(x)
- return x
- def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]) -> MaxxVitTransformerCfg:
- """Configure window size based on image size and partition ratio."""
- if cfg.window_size is not None:
- assert cfg.grid_size
- return cfg
- partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio
- cfg = replace(cfg, window_size=partition_size, grid_size=partition_size)
- return cfg
- def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs: Any) -> MaxxVitCfg:
- """Overlay keyword arguments onto configuration."""
- transformer_kwargs = {}
- conv_kwargs = {}
- base_kwargs = {}
- for k, v in kwargs.items():
- if k.startswith('transformer_'):
- transformer_kwargs[k.replace('transformer_', '')] = v
- elif k.startswith('conv_'):
- conv_kwargs[k.replace('conv_', '')] = v
- else:
- base_kwargs[k] = v
- cfg = replace(
- cfg,
- transformer_cfg=replace(cfg.transformer_cfg, **transformer_kwargs),
- conv_cfg=replace(cfg.conv_cfg, **conv_kwargs),
- **base_kwargs
- )
- return cfg
- class MaxxVit(nn.Module):
- """CoaTNet + MaxVit base model.
- Highly configurable for different block compositions, tensor layouts, pooling types.
- """
- def __init__(
- self,
- cfg: MaxxVitCfg,
- img_size: Union[int, Tuple[int, int]] = 224,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- drop_rate: float = 0.,
- drop_path_rate: float = 0.,
- device=None,
- dtype=None,
- **kwargs: Any,
- ):
- """
- Args:
- cfg: Model configuration.
- img_size: Input image size.
- in_chans: Number of input channels.
- num_classes: Number of classification classes.
- global_pool: Global pooling type.
- drop_rate: Dropout rate.
- drop_path_rate: Drop path rate.
- **kwargs: Additional keyword arguments to overlay on config.
- """
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- img_size = to_2tuple(img_size)
- if kwargs:
- cfg = _overlay_kwargs(cfg, **kwargs)
- transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.global_pool = global_pool
- self.num_features = self.embed_dim = cfg.embed_dim[-1]
- self.drop_rate = drop_rate
- self.grad_checkpointing = False
- self.feature_info = []
- self.stem = Stem(
- in_chs=in_chans,
- out_chs=cfg.stem_width,
- padding=cfg.conv_cfg.padding,
- bias=cfg.stem_bias,
- act_layer=cfg.conv_cfg.act_layer,
- norm_layer=cfg.conv_cfg.norm_layer,
- norm_eps=cfg.conv_cfg.norm_eps,
- **dd,
- )
- stride = self.stem.stride
- self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')]
- feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))])
- num_stages = len(cfg.embed_dim)
- assert len(cfg.depths) == num_stages
- dpr = calculate_drop_path_rates(drop_path_rate, cfg.depths, stagewise=True)
- in_chs = self.stem.out_chs
- stages = []
- for i in range(num_stages):
- stage_stride = 2
- out_chs = cfg.embed_dim[i]
- feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size])
- stages += [MaxxVitStage(
- in_chs,
- out_chs,
- depth=cfg.depths[i],
- block_types=cfg.block_type[i],
- conv_cfg=cfg.conv_cfg,
- transformer_cfg=transformer_cfg,
- feat_size=feat_size,
- drop_path=dpr[i],
- **dd,
- )]
- stride *= stage_stride
- in_chs = out_chs
- self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')]
- self.stages = nn.Sequential(*stages)
- final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps)
- if cfg.head_hidden_size:
- self.norm = nn.Identity()
- self.head_hidden_size = cfg.head_hidden_size
- self.head = NormMlpClassifierHead(
- self.num_features,
- num_classes,
- hidden_size=self.head_hidden_size,
- pool_type=global_pool,
- drop_rate=drop_rate,
- norm_layer=final_norm_layer,
- **dd,
- )
- else:
- # standard classifier head w/ norm, pooling, fc classifier
- self.head_hidden_size = self.num_features
- self.norm = final_norm_layer(self.num_features, **dd)
- self.head = ClassifierHead(
- self.num_features,
- num_classes,
- pool_type=global_pool,
- drop_rate=drop_rate,
- **dd,
- )
- # Weight init (default PyTorch init works well for AdamW if scheme not set)
- assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff')
- if cfg.weight_init:
- named_apply(partial(self._init_weights, scheme=cfg.weight_init), self)
- def _init_weights(self, module: nn.Module, name: str, scheme: str = '') -> None:
- if hasattr(module, 'init_weights'):
- try:
- module.init_weights(scheme=scheme)
- except TypeError:
- module.init_weights()
- @torch.jit.ignore
- def no_weight_decay(self) -> Set[str]:
- return {
- k for k, _ in self.named_parameters()
- if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
- @torch.jit.ignore
- def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
- matcher = dict(
- stem=r'^stem', # stem and embed
- blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
- )
- return matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable: bool = True) -> None:
- for s in self.stages:
- s.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head.fc
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
- self.num_classes = num_classes
- self.head.reset(num_classes, global_pool)
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- norm: Apply norm layer to compatible intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- Returns:
- """
- assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
- intermediates = []
- take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
- # forward pass
- feat_idx = 0 # stem is index 0
- x = self.stem(x)
- if feat_idx in take_indices:
- intermediates.append(x)
- last_idx = len(self.stages)
- if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
- stages = self.stages
- else:
- stages = self.stages[:max_index]
- for stage in stages:
- feat_idx += 1
- x = stage(x)
- if feat_idx in take_indices:
- if norm and feat_idx == last_idx:
- x_inter = self.norm(x) # applying final norm to last intermediate
- else:
- x_inter = x
- intermediates.append(x_inter)
- if intermediates_only:
- return intermediates
- if feat_idx == last_idx:
- x = self.norm(x)
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ) -> Tuple[int, ...]:
- """Prune layers not required for specified intermediates."""
- take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
- self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
- if prune_norm:
- self.norm = nn.Identity()
- if prune_head:
- self.head = self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x: torch.Tensor) -> torch.Tensor:
- x = self.stem(x)
- x = self.stages(x)
- x = self.norm(x)
- return x
- def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
- return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def _rw_coat_cfg(
- stride_mode: str = 'pool',
- pool_type: str = 'avg2',
- conv_output_bias: bool = False,
- conv_attn_early: bool = False,
- conv_attn_act_layer: str = 'relu',
- conv_norm_layer: str = '',
- transformer_shortcut_bias: bool = True,
- transformer_norm_layer: str = 'layernorm2d',
- transformer_norm_layer_cl: str = 'layernorm',
- init_values: Optional[float] = None,
- rel_pos_type: str = 'bias',
- rel_pos_dim: int = 512,
- ) -> Dict[str, Any]:
- """RW variant configuration for CoAtNet models.
- These models were created and trained before seeing https://github.com/google-research/maxvit
- Common differences for initial timm models:
- - pre-norm layer in MZBConv included an activation after norm
- - mbconv expansion calculated from input instead of output chs
- - mbconv shortcut and final 1x1 conv did not have a bias
- - SE act layer was relu, not silu
- - mbconv uses silu in timm, not gelu
- - expansion in attention block done via output proj, not input proj
- Variable differences (evolved over training initial models):
- - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat)
- - SE attention was between conv2 and norm/act
- - default to avg pool for mbconv downsample instead of 1x1 or dw conv
- - transformer block shortcut has no bias
- """
- return dict(
- conv_cfg=MaxxVitConvCfg(
- stride_mode=stride_mode,
- pool_type=pool_type,
- pre_norm_act=True,
- expand_output=False,
- output_bias=conv_output_bias,
- attn_early=conv_attn_early,
- attn_act_layer=conv_attn_act_layer,
- act_layer='silu',
- norm_layer=conv_norm_layer,
- ),
- transformer_cfg=MaxxVitTransformerCfg(
- expand_first=False,
- shortcut_bias=transformer_shortcut_bias,
- pool_type=pool_type,
- init_values=init_values,
- norm_layer=transformer_norm_layer,
- norm_layer_cl=transformer_norm_layer_cl,
- rel_pos_type=rel_pos_type,
- rel_pos_dim=rel_pos_dim,
- ),
- )
- def _rw_max_cfg(
- stride_mode: str = 'dw',
- pool_type: str = 'avg2',
- conv_output_bias: bool = False,
- conv_attn_ratio: float = 1 / 16,
- conv_norm_layer: str = '',
- transformer_norm_layer: str = 'layernorm2d',
- transformer_norm_layer_cl: str = 'layernorm',
- window_size: Optional[Tuple[int, int]] = None,
- dim_head: int = 32,
- init_values: Optional[float] = None,
- rel_pos_type: str = 'bias',
- rel_pos_dim: int = 512,
- ) -> Dict[str, Any]:
- """RW variant configuration for MaxViT models.
- These models were created and trained before seeing https://github.com/google-research/maxvit
- Differences of initial timm models:
- - mbconv expansion calculated from input instead of output chs
- - mbconv shortcut and final 1x1 conv did not have a bias
- - mbconv uses silu in timm, not gelu
- - expansion in attention block done via output proj, not input proj
- """
- return dict(
- conv_cfg=MaxxVitConvCfg(
- stride_mode=stride_mode,
- pool_type=pool_type,
- expand_output=False,
- output_bias=conv_output_bias,
- attn_ratio=conv_attn_ratio,
- act_layer='silu',
- norm_layer=conv_norm_layer,
- ),
- transformer_cfg=MaxxVitTransformerCfg(
- expand_first=False,
- pool_type=pool_type,
- dim_head=dim_head,
- window_size=window_size,
- init_values=init_values,
- norm_layer=transformer_norm_layer,
- norm_layer_cl=transformer_norm_layer_cl,
- rel_pos_type=rel_pos_type,
- rel_pos_dim=rel_pos_dim,
- ),
- )
- def _next_cfg(
- stride_mode: str = 'dw',
- pool_type: str = 'avg2',
- conv_norm_layer: str = 'layernorm2d',
- conv_norm_layer_cl: str = 'layernorm',
- transformer_norm_layer: str = 'layernorm2d',
- transformer_norm_layer_cl: str = 'layernorm',
- window_size: Optional[Tuple[int, int]] = None,
- no_block_attn: bool = False,
- init_values: Union[float, Tuple[float, float]] = 1e-6,
- rel_pos_type: str = 'mlp', # MLP by default for maxxvit
- rel_pos_dim: int = 512,
- ) -> Dict[str, Any]:
- """Configuration for experimental ConvNeXt-based MaxxViT models."""
- init_values = to_2tuple(init_values)
- return dict(
- conv_cfg=MaxxVitConvCfg(
- block_type='convnext',
- stride_mode=stride_mode,
- pool_type=pool_type,
- expand_output=False,
- init_values=init_values[0],
- norm_layer=conv_norm_layer,
- norm_layer_cl=conv_norm_layer_cl,
- ),
- transformer_cfg=MaxxVitTransformerCfg(
- expand_first=False,
- pool_type=pool_type,
- window_size=window_size,
- no_block_attn=no_block_attn, # enabled for MaxxViT-V2
- init_values=init_values[1],
- norm_layer=transformer_norm_layer,
- norm_layer_cl=transformer_norm_layer_cl,
- rel_pos_type=rel_pos_type,
- rel_pos_dim=rel_pos_dim,
- ),
- )
- def _tf_cfg() -> Dict[str, Any]:
- """Configuration matching TensorFlow MaxViT models."""
- return dict(
- conv_cfg=MaxxVitConvCfg(
- norm_eps=1e-3,
- act_layer='gelu_tanh',
- padding='same',
- ),
- transformer_cfg=MaxxVitTransformerCfg(
- norm_eps=1e-5,
- act_layer='gelu_tanh',
- head_first=False, # heads are interleaved (q_nh, q_hdim, k_nh, q_hdim, ....)
- rel_pos_type='bias_tf',
- ),
- )
- model_cfgs = dict(
- # timm specific CoAtNet configs
- coatnet_pico_rw=MaxxVitCfg(
- embed_dim=(64, 128, 256, 512),
- depths=(2, 3, 5, 2),
- stem_width=(32, 64),
- **_rw_max_cfg( # using newer max defaults here
- conv_output_bias=True,
- conv_attn_ratio=0.25,
- ),
- ),
- coatnet_nano_rw=MaxxVitCfg(
- embed_dim=(64, 128, 256, 512),
- depths=(3, 4, 6, 3),
- stem_width=(32, 64),
- **_rw_max_cfg( # using newer max defaults here
- stride_mode='pool',
- conv_output_bias=True,
- conv_attn_ratio=0.25,
- ),
- ),
- coatnet_0_rw=MaxxVitCfg(
- embed_dim=(96, 192, 384, 768),
- depths=(2, 3, 7, 2), # deeper than paper '0' model
- stem_width=(32, 64),
- **_rw_coat_cfg(
- conv_attn_early=True,
- transformer_shortcut_bias=False,
- ),
- ),
- coatnet_1_rw=MaxxVitCfg(
- embed_dim=(96, 192, 384, 768),
- depths=(2, 6, 14, 2),
- stem_width=(32, 64),
- **_rw_coat_cfg(
- stride_mode='dw',
- conv_attn_early=True,
- transformer_shortcut_bias=False,
- )
- ),
- coatnet_2_rw=MaxxVitCfg(
- embed_dim=(128, 256, 512, 1024),
- depths=(2, 6, 14, 2),
- stem_width=(64, 128),
- **_rw_coat_cfg(
- stride_mode='dw',
- conv_attn_act_layer='silu',
- #init_values=1e-6,
- ),
- ),
- coatnet_3_rw=MaxxVitCfg(
- embed_dim=(192, 384, 768, 1536),
- depths=(2, 6, 14, 2),
- stem_width=(96, 192),
- **_rw_coat_cfg(
- stride_mode='dw',
- conv_attn_act_layer='silu',
- init_values=1e-6,
- ),
- ),
- # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
- coatnet_bn_0_rw=MaxxVitCfg(
- embed_dim=(96, 192, 384, 768),
- depths=(2, 3, 7, 2), # deeper than paper '0' model
- stem_width=(32, 64),
- **_rw_coat_cfg(
- stride_mode='dw',
- conv_attn_early=True,
- transformer_shortcut_bias=False,
- transformer_norm_layer='batchnorm2d',
- )
- ),
- coatnet_rmlp_nano_rw=MaxxVitCfg(
- embed_dim=(64, 128, 256, 512),
- depths=(3, 4, 6, 3),
- stem_width=(32, 64),
- **_rw_max_cfg(
- conv_output_bias=True,
- conv_attn_ratio=0.25,
- rel_pos_type='mlp',
- rel_pos_dim=384,
- ),
- ),
- coatnet_rmlp_0_rw=MaxxVitCfg(
- embed_dim=(96, 192, 384, 768),
- depths=(2, 3, 7, 2), # deeper than paper '0' model
- stem_width=(32, 64),
- **_rw_coat_cfg(
- stride_mode='dw',
- rel_pos_type='mlp',
- ),
- ),
- coatnet_rmlp_1_rw=MaxxVitCfg(
- embed_dim=(96, 192, 384, 768),
- depths=(2, 6, 14, 2),
- stem_width=(32, 64),
- **_rw_coat_cfg(
- pool_type='max',
- conv_attn_early=True,
- transformer_shortcut_bias=False,
- rel_pos_type='mlp',
- rel_pos_dim=384, # was supposed to be 512, woops
- ),
- ),
- coatnet_rmlp_1_rw2=MaxxVitCfg(
- embed_dim=(96, 192, 384, 768),
- depths=(2, 6, 14, 2),
- stem_width=(32, 64),
- **_rw_coat_cfg(
- stride_mode='dw',
- rel_pos_type='mlp',
- rel_pos_dim=512, # was supposed to be 512, woops
- ),
- ),
- coatnet_rmlp_2_rw=MaxxVitCfg(
- embed_dim=(128, 256, 512, 1024),
- depths=(2, 6, 14, 2),
- stem_width=(64, 128),
- **_rw_coat_cfg(
- stride_mode='dw',
- conv_attn_act_layer='silu',
- init_values=1e-6,
- rel_pos_type='mlp'
- ),
- ),
- coatnet_rmlp_3_rw=MaxxVitCfg(
- embed_dim=(192, 384, 768, 1536),
- depths=(2, 6, 14, 2),
- stem_width=(96, 192),
- **_rw_coat_cfg(
- stride_mode='dw',
- conv_attn_act_layer='silu',
- init_values=1e-6,
- rel_pos_type='mlp'
- ),
- ),
- coatnet_nano_cc=MaxxVitCfg(
- embed_dim=(64, 128, 256, 512),
- depths=(3, 4, 6, 3),
- stem_width=(32, 64),
- block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
- **_rw_coat_cfg(),
- ),
- coatnext_nano_rw=MaxxVitCfg(
- embed_dim=(64, 128, 256, 512),
- depths=(3, 4, 6, 3),
- stem_width=(32, 64),
- weight_init='normal',
- **_next_cfg(
- rel_pos_type='bias',
- init_values=(1e-5, None)
- ),
- ),
- # Trying to be like the CoAtNet paper configs
- coatnet_0=MaxxVitCfg(
- embed_dim=(96, 192, 384, 768),
- depths=(2, 3, 5, 2),
- stem_width=64,
- head_hidden_size=768,
- ),
- coatnet_1=MaxxVitCfg(
- embed_dim=(96, 192, 384, 768),
- depths=(2, 6, 14, 2),
- stem_width=64,
- head_hidden_size=768,
- ),
- coatnet_2=MaxxVitCfg(
- embed_dim=(128, 256, 512, 1024),
- depths=(2, 6, 14, 2),
- stem_width=128,
- head_hidden_size=1024,
- ),
- coatnet_3=MaxxVitCfg(
- embed_dim=(192, 384, 768, 1536),
- depths=(2, 6, 14, 2),
- stem_width=192,
- head_hidden_size=1536,
- ),
- coatnet_4=MaxxVitCfg(
- embed_dim=(192, 384, 768, 1536),
- depths=(2, 12, 28, 2),
- stem_width=192,
- head_hidden_size=1536,
- ),
- coatnet_5=MaxxVitCfg(
- embed_dim=(256, 512, 1280, 2048),
- depths=(2, 12, 28, 2),
- stem_width=192,
- head_hidden_size=2048,
- ),
- # Experimental MaxVit configs
- maxvit_pico_rw=MaxxVitCfg(
- embed_dim=(32, 64, 128, 256),
- depths=(2, 2, 5, 2),
- block_type=('M',) * 4,
- stem_width=(24, 32),
- **_rw_max_cfg(),
- ),
- maxvit_nano_rw=MaxxVitCfg(
- embed_dim=(64, 128, 256, 512),
- depths=(1, 2, 3, 1),
- block_type=('M',) * 4,
- stem_width=(32, 64),
- **_rw_max_cfg(),
- ),
- maxvit_tiny_rw=MaxxVitCfg(
- embed_dim=(64, 128, 256, 512),
- depths=(2, 2, 5, 2),
- block_type=('M',) * 4,
- stem_width=(32, 64),
- **_rw_max_cfg(),
- ),
- maxvit_tiny_pm=MaxxVitCfg(
- embed_dim=(64, 128, 256, 512),
- depths=(2, 2, 5, 2),
- block_type=('PM',) * 4,
- stem_width=(32, 64),
- **_rw_max_cfg(),
- ),
- maxvit_rmlp_pico_rw=MaxxVitCfg(
- embed_dim=(32, 64, 128, 256),
- depths=(2, 2, 5, 2),
- block_type=('M',) * 4,
- stem_width=(24, 32),
- **_rw_max_cfg(rel_pos_type='mlp'),
- ),
- maxvit_rmlp_nano_rw=MaxxVitCfg(
- embed_dim=(64, 128, 256, 512),
- depths=(1, 2, 3, 1),
- block_type=('M',) * 4,
- stem_width=(32, 64),
- **_rw_max_cfg(rel_pos_type='mlp'),
- ),
- maxvit_rmlp_tiny_rw=MaxxVitCfg(
- embed_dim=(64, 128, 256, 512),
- depths=(2, 2, 5, 2),
- block_type=('M',) * 4,
- stem_width=(32, 64),
- **_rw_max_cfg(rel_pos_type='mlp'),
- ),
- maxvit_rmlp_small_rw=MaxxVitCfg(
- embed_dim=(96, 192, 384, 768),
- depths=(2, 2, 5, 2),
- block_type=('M',) * 4,
- stem_width=(32, 64),
- **_rw_max_cfg(
- rel_pos_type='mlp',
- init_values=1e-6,
- ),
- ),
- maxvit_rmlp_base_rw=MaxxVitCfg(
- embed_dim=(96, 192, 384, 768),
- depths=(2, 6, 14, 2),
- block_type=('M',) * 4,
- stem_width=(32, 64),
- head_hidden_size=768,
- **_rw_max_cfg(
- rel_pos_type='mlp',
- ),
- ),
- maxxvit_rmlp_nano_rw=MaxxVitCfg(
- embed_dim=(64, 128, 256, 512),
- depths=(1, 2, 3, 1),
- block_type=('M',) * 4,
- stem_width=(32, 64),
- weight_init='normal',
- **_next_cfg(),
- ),
- maxxvit_rmlp_tiny_rw=MaxxVitCfg(
- embed_dim=(64, 128, 256, 512),
- depths=(2, 2, 5, 2),
- block_type=('M',) * 4,
- stem_width=(32, 64),
- **_next_cfg(),
- ),
- maxxvit_rmlp_small_rw=MaxxVitCfg(
- embed_dim=(96, 192, 384, 768),
- depths=(2, 2, 5, 2),
- block_type=('M',) * 4,
- stem_width=(48, 96),
- **_next_cfg(),
- ),
- maxxvitv2_nano_rw=MaxxVitCfg(
- embed_dim=(96, 192, 384, 768),
- depths=(1, 2, 3, 1),
- block_type=('M',) * 4,
- stem_width=(48, 96),
- weight_init='normal',
- **_next_cfg(
- no_block_attn=True,
- rel_pos_type='bias',
- ),
- ),
- maxxvitv2_rmlp_base_rw=MaxxVitCfg(
- embed_dim=(128, 256, 512, 1024),
- depths=(2, 6, 12, 2),
- block_type=('M',) * 4,
- stem_width=(64, 128),
- **_next_cfg(
- no_block_attn=True,
- ),
- ),
- maxxvitv2_rmlp_large_rw=MaxxVitCfg(
- embed_dim=(160, 320, 640, 1280),
- depths=(2, 6, 16, 2),
- block_type=('M',) * 4,
- stem_width=(80, 160),
- head_hidden_size=1280,
- **_next_cfg(
- no_block_attn=True,
- ),
- ),
- # Trying to be like the MaxViT paper configs
- maxvit_tiny_tf=MaxxVitCfg(
- embed_dim=(64, 128, 256, 512),
- depths=(2, 2, 5, 2),
- block_type=('M',) * 4,
- stem_width=64,
- stem_bias=True,
- head_hidden_size=512,
- **_tf_cfg(),
- ),
- maxvit_small_tf=MaxxVitCfg(
- embed_dim=(96, 192, 384, 768),
- depths=(2, 2, 5, 2),
- block_type=('M',) * 4,
- stem_width=64,
- stem_bias=True,
- head_hidden_size=768,
- **_tf_cfg(),
- ),
- maxvit_base_tf=MaxxVitCfg(
- embed_dim=(96, 192, 384, 768),
- depths=(2, 6, 14, 2),
- block_type=('M',) * 4,
- stem_width=64,
- stem_bias=True,
- head_hidden_size=768,
- **_tf_cfg(),
- ),
- maxvit_large_tf=MaxxVitCfg(
- embed_dim=(128, 256, 512, 1024),
- depths=(2, 6, 14, 2),
- block_type=('M',) * 4,
- stem_width=128,
- stem_bias=True,
- head_hidden_size=1024,
- **_tf_cfg(),
- ),
- maxvit_xlarge_tf=MaxxVitCfg(
- embed_dim=(192, 384, 768, 1536),
- depths=(2, 6, 14, 2),
- block_type=('M',) * 4,
- stem_width=192,
- stem_bias=True,
- head_hidden_size=1536,
- **_tf_cfg(),
- ),
- )
- def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
- """Filter checkpoint state dict for compatibility."""
- model_state_dict = model.state_dict()
- out_dict = {}
- for k, v in state_dict.items():
- if k.endswith('relative_position_bias_table'):
- m = model.get_submodule(k[:-29])
- if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
- v = resize_rel_pos_bias_table(
- v,
- new_window_size=m.window_size,
- new_bias_shape=m.relative_position_bias_table.shape,
- )
- if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel():
- # adapt between conv2d / linear layers
- assert v.ndim in (2, 4)
- v = v.reshape(model_state_dict[k].shape)
- out_dict[k] = v
- return out_dict
- def _create_maxxvit(variant: str, cfg_variant: Optional[str] = None, pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """Create a MaxxVit model variant."""
- if cfg_variant is None:
- if variant in model_cfgs:
- cfg_variant = variant
- else:
- cfg_variant = '_'.join(variant.split('_')[:-1])
- return build_model_with_cfg(
- MaxxVit, variant, pretrained,
- model_cfg=model_cfgs[cfg_variant],
- feature_cfg=dict(flatten_sequential=True),
- pretrained_filter_fn=checkpoint_filter_fn,
- **kwargs)
- def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
- """Create a default configuration dict."""
- return {
- 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
- 'crop_pct': 0.95, 'interpolation': 'bicubic',
- 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
- 'first_conv': 'stem.conv1', 'classifier': 'head.fc',
- 'fixed_input_size': True,
- 'license': 'apache-2.0', **kwargs
- }
- default_cfgs = generate_default_cfgs({
- # timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos
- 'coatnet_pico_rw_224.untrained': _cfg(url=''),
- 'coatnet_nano_rw_224.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
- crop_pct=0.9),
- 'coatnet_0_rw_224.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
- 'coatnet_1_rw_224.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
- ),
- # timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos
- 'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/'),
- #'coatnet_3_rw_224.untrained': _cfg(url=''),
- # Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos)
- 'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/'),
- 'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/'),
- 'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
- 'coatnet_bn_0_rw_224.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
- crop_pct=0.95),
- 'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
- crop_pct=0.9),
- 'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''),
- 'coatnet_rmlp_1_rw_224.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
- 'coatnet_rmlp_2_rw_224.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'),
- 'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''),
- 'coatnet_nano_cc_224.untrained': _cfg(url=''),
- 'coatnext_nano_rw_224.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth',
- crop_pct=0.9),
- # ImagenNet-12k pretrain CoAtNet
- 'coatnet_2_rw_224.sw_in12k': _cfg(
- hf_hub_id='timm/',
- num_classes=11821),
- 'coatnet_3_rw_224.sw_in12k': _cfg(
- hf_hub_id='timm/',
- num_classes=11821),
- 'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg(
- hf_hub_id='timm/',
- num_classes=11821),
- 'coatnet_rmlp_2_rw_224.sw_in12k': _cfg(
- hf_hub_id='timm/',
- num_classes=11821),
- # Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released)
- 'coatnet_0_224.untrained': _cfg(url=''),
- 'coatnet_1_224.untrained': _cfg(url=''),
- 'coatnet_2_224.untrained': _cfg(url=''),
- 'coatnet_3_224.untrained': _cfg(url=''),
- 'coatnet_4_224.untrained': _cfg(url=''),
- 'coatnet_5_224.untrained': _cfg(url=''),
- # timm specific MaxVit configs, ImageNet-1k pretrain or untrained
- 'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
- 'maxvit_nano_rw_256.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
- input_size=(3, 256, 256), pool_size=(8, 8)),
- 'maxvit_tiny_rw_224.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'),
- 'maxvit_tiny_rw_256.untrained': _cfg(
- url='',
- input_size=(3, 256, 256), pool_size=(8, 8)),
- 'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
- # timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain
- 'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth',
- input_size=(3, 256, 256), pool_size=(8, 8)),
- 'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
- input_size=(3, 256, 256), pool_size=(8, 8)),
- 'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth',
- input_size=(3, 256, 256), pool_size=(8, 8)),
- 'maxvit_rmlp_small_rw_224.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth',
- crop_pct=0.9,
- ),
- 'maxvit_rmlp_small_rw_256.untrained': _cfg(
- url='',
- input_size=(3, 256, 256), pool_size=(8, 8)),
- # timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune
- 'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- ),
- 'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- # timm specific MaxVit w/ ImageNet-12k pretrain
- 'maxvit_rmlp_base_rw_224.sw_in12k': _cfg(
- hf_hub_id='timm/',
- num_classes=11821,
- ),
- # timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks)
- 'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth',
- input_size=(3, 256, 256), pool_size=(8, 8)),
- 'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
- 'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth',
- input_size=(3, 256, 256), pool_size=(8, 8)),
- # timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn)
- 'maxxvitv2_nano_rw_256.sw_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), pool_size=(8, 8)),
- 'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/'),
- 'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''),
- 'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg(
- hf_hub_id='timm/',
- num_classes=11821),
- # MaxViT models ported from official Tensorflow impl
- 'maxvit_tiny_tf_224.in1k': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
- 'maxvit_tiny_tf_384.in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'maxvit_tiny_tf_512.in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
- 'maxvit_small_tf_224.in1k': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
- 'maxvit_small_tf_384.in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'maxvit_small_tf_512.in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
- 'maxvit_base_tf_224.in1k': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
- 'maxvit_base_tf_384.in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'maxvit_base_tf_512.in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
- 'maxvit_large_tf_224.in1k': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
- 'maxvit_large_tf_384.in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'maxvit_large_tf_512.in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
- 'maxvit_base_tf_224.in21k': _cfg(
- hf_hub_id='timm/',
- num_classes=21843),
- 'maxvit_base_tf_384.in21k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'maxvit_base_tf_512.in21k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
- 'maxvit_large_tf_224.in21k': _cfg(
- hf_hub_id='timm/',
- num_classes=21843),
- 'maxvit_large_tf_384.in21k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'maxvit_large_tf_512.in21k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
- 'maxvit_xlarge_tf_224.in21k': _cfg(
- hf_hub_id='timm/',
- num_classes=21843),
- 'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
- 'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
- })
- @register_model
- def coatnet_pico_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet Pico model with RW configuration."""
- return _create_maxxvit('coatnet_pico_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_nano_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet Nano model with RW configuration."""
- return _create_maxxvit('coatnet_nano_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_0_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-0 model with RW configuration."""
- return _create_maxxvit('coatnet_0_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_1_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-1 model with RW configuration."""
- return _create_maxxvit('coatnet_1_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_2_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-2 model with RW configuration."""
- return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_3_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-3 model with RW configuration."""
- return _create_maxxvit('coatnet_3_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_bn_0_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-0 model with BatchNorm and RW configuration."""
- return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_rmlp_nano_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet Nano model with Relative Position MLP."""
- return _create_maxxvit('coatnet_rmlp_nano_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_rmlp_0_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-0 model with Relative Position MLP."""
- return _create_maxxvit('coatnet_rmlp_0_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_rmlp_1_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-1 model with Relative Position MLP."""
- return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_rmlp_1_rw2_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-1 model with Relative Position MLP v2."""
- return _create_maxxvit('coatnet_rmlp_1_rw2_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_rmlp_2_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-2 model with Relative Position MLP."""
- return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_rmlp_2_rw_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-2 model with Relative Position MLP at 384x384."""
- return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_rmlp_3_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-3 model with Relative Position MLP."""
- return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_nano_cc_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet Nano model with ConvNeXt blocks."""
- return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnext_nano_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoAtNeXt Nano model with RW configuration."""
- return _create_maxxvit('coatnext_nano_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_0_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-0 model."""
- return _create_maxxvit('coatnet_0_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_1_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-1 model."""
- return _create_maxxvit('coatnet_1_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_2_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-2 model."""
- return _create_maxxvit('coatnet_2_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_3_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-3 model."""
- return _create_maxxvit('coatnet_3_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_4_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-4 model."""
- return _create_maxxvit('coatnet_4_224', pretrained=pretrained, **kwargs)
- @register_model
- def coatnet_5_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """CoatNet-5 model."""
- return _create_maxxvit('coatnet_5_224', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_pico_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Pico model with RW configuration."""
- return _create_maxxvit('maxvit_pico_rw_256', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Nano model with RW configuration."""
- return _create_maxxvit('maxvit_nano_rw_256', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_tiny_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Tiny model with RW configuration."""
- return _create_maxxvit('maxvit_tiny_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_tiny_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Tiny model with RW configuration at 256x256."""
- return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_rmlp_pico_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Relative Position MLP Pico RW 256x256 model."""
- return _create_maxxvit('maxvit_rmlp_pico_rw_256', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_rmlp_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Relative Position MLP Nano RW 256x256 model."""
- return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_rmlp_tiny_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Relative Position MLP Tiny RW 256x256 model."""
- return _create_maxxvit('maxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_rmlp_small_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Relative Position MLP Small RW 224x224 model."""
- return _create_maxxvit('maxvit_rmlp_small_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_rmlp_small_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Small model with Relative Position MLP at 256x256."""
- return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_rmlp_base_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Base model with Relative Position MLP."""
- return _create_maxxvit('maxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_rmlp_base_rw_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Base model with Relative Position MLP at 384x384."""
- return _create_maxxvit('maxvit_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_tiny_pm_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Tiny model with parallel blocks."""
- return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
- @register_model
- def maxxvit_rmlp_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxxViT Relative Position MLP Nano RW 256x256 model."""
- return _create_maxxvit('maxxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
- @register_model
- def maxxvit_rmlp_tiny_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxxViT Tiny model with Relative Position MLP."""
- return _create_maxxvit('maxxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs)
- @register_model
- def maxxvit_rmlp_small_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxxViT Small model with Relative Position MLP."""
- return _create_maxxvit('maxxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
- @register_model
- def maxxvitv2_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxxViT-V2 Nano model."""
- return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs)
- @register_model
- def maxxvitv2_rmlp_base_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxxViT-V2 Base model with Relative Position MLP."""
- return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def maxxvitv2_rmlp_base_rw_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxxViT-V2 Base model with Relative Position MLP at 384x384."""
- return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
- @register_model
- def maxxvitv2_rmlp_large_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxxViT-V2 Large model with Relative Position MLP."""
- return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_tiny_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Tiny model from TensorFlow."""
- return _create_maxxvit('maxvit_tiny_tf_224', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_tiny_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Tiny model from TensorFlow at 384x384."""
- return _create_maxxvit('maxvit_tiny_tf_384', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_tiny_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Tiny model from TensorFlow at 512x512."""
- return _create_maxxvit('maxvit_tiny_tf_512', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_small_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Small model from TensorFlow."""
- return _create_maxxvit('maxvit_small_tf_224', 'maxvit_small_tf', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_small_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Small model from TensorFlow at 384x384."""
- return _create_maxxvit('maxvit_small_tf_384', 'maxvit_small_tf', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_small_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Small model from TensorFlow at 512x512."""
- return _create_maxxvit('maxvit_small_tf_512', 'maxvit_small_tf', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_base_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Base model from TensorFlow."""
- return _create_maxxvit('maxvit_base_tf_224', 'maxvit_base_tf', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_base_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Base model from TensorFlow at 384x384."""
- return _create_maxxvit('maxvit_base_tf_384', 'maxvit_base_tf', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_base_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Base model from TensorFlow at 512x512."""
- return _create_maxxvit('maxvit_base_tf_512', 'maxvit_base_tf', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_large_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Large model from TensorFlow."""
- return _create_maxxvit('maxvit_large_tf_224', 'maxvit_large_tf', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_large_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Large model from TensorFlow at 384x384."""
- return _create_maxxvit('maxvit_large_tf_384', 'maxvit_large_tf', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_large_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT Large model from TensorFlow at 512x512."""
- return _create_maxxvit('maxvit_large_tf_512', 'maxvit_large_tf', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_xlarge_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT XLarge model from TensorFlow."""
- return _create_maxxvit('maxvit_xlarge_tf_224', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_xlarge_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT XLarge model from TensorFlow at 384x384."""
- return _create_maxxvit('maxvit_xlarge_tf_384', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)
- @register_model
- def maxvit_xlarge_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
- """MaxViT XLarge model from TensorFlow at 512x512."""
- return _create_maxxvit('maxvit_xlarge_tf_512', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)
|