maxxvit.py 99 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711
  1. """ MaxVit and CoAtNet Vision Transformer - CNN Hybrids in PyTorch
  2. This is a from-scratch implementation of both CoAtNet and MaxVit in PyTorch.
  3. 99% of the implementation was done from papers, however last minute some adjustments were made
  4. based on the (as yet unfinished?) public code release https://github.com/google-research/maxvit
  5. There are multiple sets of models defined for both architectures. Typically, names with a
  6. `_rw` suffix are my own original configs prior to referencing https://github.com/google-research/maxvit.
  7. These configs work well and appear to be a bit faster / lower resource than the paper.
  8. The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to
  9. match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match.
  10. Papers:
  11. MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697
  12. @article{tu2022maxvit,
  13. title={MaxViT: Multi-Axis Vision Transformer},
  14. author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao},
  15. journal={ECCV},
  16. year={2022},
  17. }
  18. CoAtNet: Marrying Convolution and Attention for All Data Sizes - https://arxiv.org/abs/2106.04803
  19. @article{DBLP:journals/corr/abs-2106-04803,
  20. author = {Zihang Dai and Hanxiao Liu and Quoc V. Le and Mingxing Tan},
  21. title = {CoAtNet: Marrying Convolution and Attention for All Data Sizes},
  22. journal = {CoRR},
  23. volume = {abs/2106.04803},
  24. year = {2021}
  25. }
  26. Hacked together by / Copyright 2022, Ross Wightman
  27. """
  28. import math
  29. from collections import OrderedDict
  30. from dataclasses import dataclass, replace, field
  31. from functools import partial
  32. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
  33. import torch
  34. from torch import nn
  35. from torch.jit import Final
  36. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  37. from timm.layers import (
  38. Mlp,
  39. ConvMlp,
  40. DropPath,
  41. calculate_drop_path_rates,
  42. LayerNorm,
  43. LayerScale,
  44. LayerScale2d,
  45. ClassifierHead,
  46. NormMlpClassifierHead,
  47. create_attn,
  48. get_act_layer,
  49. get_norm_layer,
  50. get_norm_act_layer,
  51. create_conv2d,
  52. create_pool2d,
  53. trunc_normal_tf_,
  54. to_2tuple,
  55. extend_tuple,
  56. make_divisible,
  57. _assert,
  58. RelPosMlp,
  59. RelPosBias,
  60. RelPosBiasTf,
  61. use_fused_attn,
  62. resize_rel_pos_bias_table,
  63. )
  64. from ._builder import build_model_with_cfg
  65. from ._features import feature_take_indices
  66. from ._features_fx import register_notrace_function
  67. from ._manipulate import named_apply, checkpoint_seq
  68. from ._registry import generate_default_cfgs, register_model
  69. __all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']
  70. @dataclass
  71. class MaxxVitTransformerCfg:
  72. """Configuration for MaxxVit transformer blocks."""
  73. dim_head: int = 32
  74. head_first: bool = True # head ordering in qkv channel dim
  75. expand_ratio: float = 4.0
  76. expand_first: bool = True
  77. shortcut_bias: bool = True
  78. attn_bias: bool = True
  79. attn_drop: float = 0.
  80. proj_drop: float = 0.
  81. pool_type: str = 'avg2'
  82. rel_pos_type: str = 'bias'
  83. rel_pos_dim: int = 512 # for relative position types w/ MLP
  84. partition_ratio: int = 32
  85. window_size: Optional[Tuple[int, int]] = None
  86. grid_size: Optional[Tuple[int, int]] = None
  87. no_block_attn: bool = False # disable window block attention for maxvit (ie only grid)
  88. use_nchw_attn: bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order
  89. init_values: Optional[float] = None
  90. act_layer: str = 'gelu'
  91. norm_layer: str = 'layernorm2d'
  92. norm_layer_cl: str = 'layernorm'
  93. norm_eps: float = 1e-6
  94. def __post_init__(self):
  95. if self.grid_size is not None:
  96. self.grid_size = to_2tuple(self.grid_size)
  97. if self.window_size is not None:
  98. self.window_size = to_2tuple(self.window_size)
  99. if self.grid_size is None:
  100. self.grid_size = self.window_size
  101. @dataclass
  102. class MaxxVitConvCfg:
  103. """Configuration for MaxxVit convolution blocks."""
  104. block_type: str = 'mbconv'
  105. expand_ratio: float = 4.0
  106. expand_output: bool = True # calculate expansion channels from output (vs input chs)
  107. kernel_size: int = 3
  108. group_size: int = 1 # 1 == depthwise
  109. pre_norm_act: bool = False # activation after pre-norm
  110. output_bias: bool = True # bias for shortcut + final 1x1 projection conv
  111. stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
  112. pool_type: str = 'avg2'
  113. downsample_pool_type: str = 'avg2'
  114. padding: str = ''
  115. attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2
  116. attn_layer: str = 'se'
  117. attn_act_layer: str = 'silu'
  118. attn_ratio: float = 0.25
  119. init_values: Optional[float] = 1e-6 # for ConvNeXt block, ignored by MBConv
  120. act_layer: str = 'gelu'
  121. norm_layer: str = ''
  122. norm_layer_cl: str = ''
  123. norm_eps: Optional[float] = None
  124. def __post_init__(self):
  125. # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args
  126. assert self.block_type in ('mbconv', 'convnext')
  127. use_mbconv = self.block_type == 'mbconv'
  128. if not self.norm_layer:
  129. self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d'
  130. if not self.norm_layer_cl and not use_mbconv:
  131. self.norm_layer_cl = 'layernorm'
  132. if self.norm_eps is None:
  133. self.norm_eps = 1e-5 if use_mbconv else 1e-6
  134. self.downsample_pool_type = self.downsample_pool_type or self.pool_type
  135. @dataclass
  136. class MaxxVitCfg:
  137. """Configuration for MaxxVit models."""
  138. embed_dim: Tuple[int, ...] = (96, 192, 384, 768)
  139. depths: Tuple[int, ...] = (2, 3, 5, 2)
  140. block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T')
  141. stem_width: Union[int, Tuple[int, int]] = 64
  142. stem_bias: bool = False
  143. conv_cfg: MaxxVitConvCfg = field(default_factory=MaxxVitConvCfg)
  144. transformer_cfg: MaxxVitTransformerCfg = field(default_factory=MaxxVitTransformerCfg)
  145. head_hidden_size: Optional[int] = None
  146. weight_init: str = 'vit_eff'
  147. class Attention2d(nn.Module):
  148. """Multi-head attention for 2D NCHW tensors."""
  149. fused_attn: Final[bool]
  150. def __init__(
  151. self,
  152. dim: int,
  153. dim_out: Optional[int] = None,
  154. dim_head: int = 32,
  155. bias: bool = True,
  156. expand_first: bool = True,
  157. head_first: bool = True,
  158. rel_pos_cls: Optional[Callable] = None,
  159. attn_drop: float = 0.,
  160. proj_drop: float = 0.,
  161. device=None,
  162. dtype=None,
  163. ):
  164. """
  165. Args:
  166. dim: Input dimension.
  167. dim_out: Output dimension (defaults to input dimension).
  168. dim_head: Dimension per attention head.
  169. bias: Whether to use bias in qkv and projection.
  170. expand_first: Whether to expand channels before or after qkv.
  171. head_first: Whether heads are first in tensor layout.
  172. rel_pos_cls: Relative position class to use.
  173. attn_drop: Attention dropout rate.
  174. proj_drop: Projection dropout rate.
  175. """
  176. dd = {'device': device, 'dtype': dtype}
  177. super().__init__()
  178. dim_out = dim_out or dim
  179. dim_attn = dim_out if expand_first else dim
  180. self.num_heads = dim_attn // dim_head
  181. self.dim_head = dim_head
  182. self.head_first = head_first
  183. self.scale = dim_head ** -0.5
  184. self.fused_attn = use_fused_attn()
  185. self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias, **dd)
  186. self.rel_pos = rel_pos_cls(num_heads=self.num_heads, **dd) if rel_pos_cls else None
  187. self.attn_drop = nn.Dropout(attn_drop)
  188. self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias, **dd)
  189. self.proj_drop = nn.Dropout(proj_drop)
  190. def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
  191. B, C, H, W = x.shape
  192. if self.head_first:
  193. q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2)
  194. else:
  195. q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
  196. if self.fused_attn:
  197. attn_bias = None
  198. if self.rel_pos is not None:
  199. attn_bias = self.rel_pos.get_bias()
  200. elif shared_rel_pos is not None:
  201. attn_bias = shared_rel_pos
  202. x = torch.nn.functional.scaled_dot_product_attention(
  203. q.transpose(-1, -2).contiguous(),
  204. k.transpose(-1, -2).contiguous(),
  205. v.transpose(-1, -2).contiguous(),
  206. attn_mask=attn_bias,
  207. dropout_p=self.attn_drop.p if self.training else 0.,
  208. ).transpose(-1, -2).reshape(B, -1, H, W)
  209. else:
  210. q = q * self.scale
  211. attn = q.transpose(-2, -1) @ k
  212. if self.rel_pos is not None:
  213. attn = self.rel_pos(attn)
  214. elif shared_rel_pos is not None:
  215. attn = attn + shared_rel_pos
  216. attn = attn.softmax(dim=-1)
  217. attn = self.attn_drop(attn)
  218. x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
  219. x = self.proj(x)
  220. x = self.proj_drop(x)
  221. return x
  222. class AttentionCl(nn.Module):
  223. """Channels-last multi-head attention (B, ..., C)."""
  224. fused_attn: Final[bool]
  225. def __init__(
  226. self,
  227. dim: int,
  228. dim_out: Optional[int] = None,
  229. dim_head: int = 32,
  230. bias: bool = True,
  231. expand_first: bool = True,
  232. head_first: bool = True,
  233. rel_pos_cls: Optional[Callable] = None,
  234. attn_drop: float = 0.,
  235. proj_drop: float = 0.,
  236. device=None,
  237. dtype=None,
  238. ):
  239. """
  240. Args:
  241. dim: Input dimension.
  242. dim_out: Output dimension (defaults to input dimension).
  243. dim_head: Dimension per attention head.
  244. bias: Whether to use bias in qkv and projection.
  245. expand_first: Whether to expand channels before or after qkv.
  246. head_first: Whether heads are first in tensor layout.
  247. rel_pos_cls: Relative position class to use.
  248. attn_drop: Attention dropout rate.
  249. proj_drop: Projection dropout rate.
  250. """
  251. dd = {'device': device, 'dtype': dtype}
  252. super().__init__()
  253. dim_out = dim_out or dim
  254. dim_attn = dim_out if expand_first and dim_out > dim else dim
  255. assert dim_attn % dim_head == 0, 'attn dim should be divisible by head_dim'
  256. self.num_heads = dim_attn // dim_head
  257. self.dim_head = dim_head
  258. self.head_first = head_first
  259. self.scale = dim_head ** -0.5
  260. self.fused_attn = use_fused_attn()
  261. self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias, **dd)
  262. self.rel_pos = rel_pos_cls(num_heads=self.num_heads, **dd) if rel_pos_cls else None
  263. self.attn_drop = nn.Dropout(attn_drop)
  264. self.proj = nn.Linear(dim_attn, dim_out, bias=bias, **dd)
  265. self.proj_drop = nn.Dropout(proj_drop)
  266. def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
  267. B = x.shape[0]
  268. restore_shape = x.shape[:-1]
  269. if self.head_first:
  270. q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3)
  271. else:
  272. q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2)
  273. if self.fused_attn:
  274. attn_bias = None
  275. if self.rel_pos is not None:
  276. attn_bias = self.rel_pos.get_bias()
  277. elif shared_rel_pos is not None:
  278. attn_bias = shared_rel_pos
  279. x = torch.nn.functional.scaled_dot_product_attention(
  280. q, k, v,
  281. attn_mask=attn_bias,
  282. dropout_p=self.attn_drop.p if self.training else 0.,
  283. )
  284. else:
  285. q = q * self.scale
  286. attn = q @ k.transpose(-2, -1)
  287. if self.rel_pos is not None:
  288. attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
  289. elif shared_rel_pos is not None:
  290. attn = attn + shared_rel_pos
  291. attn = attn.softmax(dim=-1)
  292. attn = self.attn_drop(attn)
  293. x = attn @ v
  294. x = x.transpose(1, 2).reshape(restore_shape + (-1,))
  295. x = self.proj(x)
  296. x = self.proj_drop(x)
  297. return x
  298. class Downsample2d(nn.Module):
  299. """A downsample pooling module supporting several maxpool and avgpool modes.
  300. * 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1
  301. * 'max2' - MaxPool2d w/ kernel_size = stride = 2
  302. * 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1
  303. * 'avg2' - AvgPool2d w/ kernel_size = stride = 2
  304. """
  305. def __init__(
  306. self,
  307. dim: int,
  308. dim_out: int,
  309. pool_type: str = 'avg2',
  310. padding: str = '',
  311. bias: bool = True,
  312. device=None,
  313. dtype=None,
  314. ):
  315. """
  316. Args:
  317. dim: Input dimension.
  318. dim_out: Output dimension.
  319. pool_type: Type of pooling operation.
  320. padding: Padding mode.
  321. bias: Whether to use bias in expansion conv.
  322. """
  323. super().__init__()
  324. assert pool_type in ('max', 'max2', 'avg', 'avg2')
  325. if pool_type == 'max':
  326. self.pool = create_pool2d('max', kernel_size=3, stride=2, padding=padding or 1)
  327. elif pool_type == 'max2':
  328. self.pool = create_pool2d('max', 2, padding=padding or 0) # kernel_size == stride == 2
  329. elif pool_type == 'avg':
  330. self.pool = create_pool2d(
  331. 'avg', kernel_size=3, stride=2, count_include_pad=False, padding=padding or 1)
  332. else:
  333. self.pool = create_pool2d('avg', 2, padding=padding or 0)
  334. if dim != dim_out:
  335. self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias, device=device, dtype=dtype)
  336. else:
  337. self.expand = nn.Identity()
  338. def forward(self, x: torch.Tensor) -> torch.Tensor:
  339. x = self.pool(x) # spatial downsample
  340. x = self.expand(x) # expand chs
  341. return x
  342. def _init_transformer(module: nn.Module, name: str, scheme: str = '') -> None:
  343. """Initialize transformer module weights."""
  344. if isinstance(module, (nn.Conv2d, nn.Linear)):
  345. if scheme == 'normal':
  346. nn.init.normal_(module.weight, std=.02)
  347. if module.bias is not None:
  348. nn.init.zeros_(module.bias)
  349. elif scheme == 'trunc_normal':
  350. trunc_normal_tf_(module.weight, std=.02)
  351. if module.bias is not None:
  352. nn.init.zeros_(module.bias)
  353. elif scheme == 'xavier_normal':
  354. nn.init.xavier_normal_(module.weight)
  355. if module.bias is not None:
  356. nn.init.zeros_(module.bias)
  357. else:
  358. # vit like
  359. nn.init.xavier_uniform_(module.weight)
  360. if module.bias is not None:
  361. if 'mlp' in name:
  362. nn.init.normal_(module.bias, std=1e-6)
  363. else:
  364. nn.init.zeros_(module.bias)
  365. class TransformerBlock2d(nn.Module):
  366. """Transformer block with 2D downsampling.
  367. '2D' NCHW tensor layout
  368. Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW
  369. for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs.
  370. This impl was faster on TPU w/ PT XLA than the 1D experiment.
  371. """
  372. def __init__(
  373. self,
  374. dim: int,
  375. dim_out: int,
  376. stride: int = 1,
  377. rel_pos_cls: Optional[Callable] = None,
  378. cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
  379. drop_path: float = 0.,
  380. device=None,
  381. dtype=None,
  382. ):
  383. """
  384. Args:
  385. dim: Input dimension.
  386. dim_out: Output dimension.
  387. stride: Stride for downsampling.
  388. rel_pos_cls: Relative position class.
  389. cfg: Transformer block configuration.
  390. drop_path: Drop path rate.
  391. """
  392. dd = {'device': device, 'dtype': dtype}
  393. super().__init__()
  394. norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
  395. act_layer = get_act_layer(cfg.act_layer)
  396. if stride == 2:
  397. self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias, **dd)
  398. self.norm1 = nn.Sequential(OrderedDict([
  399. ('norm', norm_layer(dim, **dd)),
  400. ('down', Downsample2d(dim, dim, pool_type=cfg.pool_type, **dd)),
  401. ]))
  402. else:
  403. assert dim == dim_out
  404. self.shortcut = nn.Identity()
  405. self.norm1 = norm_layer(dim, **dd)
  406. self.attn = Attention2d(
  407. dim,
  408. dim_out,
  409. dim_head=cfg.dim_head,
  410. expand_first=cfg.expand_first,
  411. bias=cfg.attn_bias,
  412. rel_pos_cls=rel_pos_cls,
  413. attn_drop=cfg.attn_drop,
  414. proj_drop=cfg.proj_drop,
  415. **dd,
  416. )
  417. self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
  418. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  419. self.norm2 = norm_layer(dim_out, **dd)
  420. self.mlp = ConvMlp(
  421. in_features=dim_out,
  422. hidden_features=int(dim_out * cfg.expand_ratio),
  423. act_layer=act_layer,
  424. drop=cfg.proj_drop,
  425. **dd,
  426. )
  427. self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
  428. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  429. def init_weights(self, scheme: str = '') -> None:
  430. named_apply(partial(_init_transformer, scheme=scheme), self)
  431. def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
  432. x = self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos)))
  433. x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
  434. return x
  435. def _init_conv(module: nn.Module, name: str, scheme: str = '') -> None:
  436. """Initialize convolution module weights."""
  437. if isinstance(module, nn.Conv2d):
  438. if scheme == 'normal':
  439. nn.init.normal_(module.weight, std=.02)
  440. if module.bias is not None:
  441. nn.init.zeros_(module.bias)
  442. elif scheme == 'trunc_normal':
  443. trunc_normal_tf_(module.weight, std=.02)
  444. if module.bias is not None:
  445. nn.init.zeros_(module.bias)
  446. elif scheme == 'xavier_normal':
  447. nn.init.xavier_normal_(module.weight)
  448. if module.bias is not None:
  449. nn.init.zeros_(module.bias)
  450. else:
  451. # efficientnet like
  452. fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
  453. fan_out //= module.groups
  454. nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
  455. if module.bias is not None:
  456. nn.init.zeros_(module.bias)
  457. def num_groups(group_size: Optional[int], channels: int) -> int:
  458. """Calculate number of groups for grouped convolution."""
  459. if not group_size: # 0 or None
  460. return 1 # normal conv with 1 group
  461. else:
  462. # NOTE group_size == 1 -> depthwise conv
  463. assert channels % group_size == 0
  464. return channels // group_size
  465. class MbConvBlock(nn.Module):
  466. """Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)."""
  467. def __init__(
  468. self,
  469. in_chs: int,
  470. out_chs: int,
  471. stride: int = 1,
  472. dilation: Tuple[int, int] = (1, 1),
  473. cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
  474. drop_path: float = 0.,
  475. device=None,
  476. dtype=None,
  477. ):
  478. """
  479. Args:
  480. in_chs: Input channels.
  481. out_chs: Output channels.
  482. stride: Stride for conv.
  483. dilation: Dilation for conv.
  484. cfg: Convolution block configuration.
  485. drop_path: Drop path rate.
  486. """
  487. dd = {'device': device, 'dtype': dtype}
  488. super().__init__()
  489. norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps)
  490. mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio)
  491. groups = num_groups(cfg.group_size, mid_chs)
  492. if stride == 2:
  493. self.shortcut = Downsample2d(
  494. in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias, padding=cfg.padding, **dd)
  495. else:
  496. self.shortcut = nn.Identity()
  497. assert cfg.stride_mode in ('pool', '1x1', 'dw')
  498. stride_pool, stride_1, stride_2 = 1, 1, 1
  499. if cfg.stride_mode == 'pool':
  500. # NOTE this is not described in paper, experiment to find faster option that doesn't stride in 1x1
  501. stride_pool, dilation_2 = stride, dilation[1]
  502. # FIXME handle dilation of avg pool
  503. elif cfg.stride_mode == '1x1':
  504. # NOTE I don't like this option described in paper, 1x1 w/ stride throws info away
  505. stride_1, dilation_2 = stride, dilation[1]
  506. else:
  507. stride_2, dilation_2 = stride, dilation[0]
  508. self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act, **dd)
  509. if stride_pool > 1:
  510. self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type, padding=cfg.padding, **dd)
  511. else:
  512. self.down = nn.Identity()
  513. self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1, **dd)
  514. self.norm1 = norm_act_layer(mid_chs, **dd)
  515. self.conv2_kxk = create_conv2d(
  516. mid_chs,
  517. mid_chs,
  518. cfg.kernel_size,
  519. stride=stride_2,
  520. dilation=dilation_2,
  521. groups=groups,
  522. padding=cfg.padding,
  523. **dd,
  524. )
  525. attn_kwargs = {}
  526. if isinstance(cfg.attn_layer, str):
  527. if cfg.attn_layer == 'se' or cfg.attn_layer == 'eca':
  528. attn_kwargs['act_layer'] = cfg.attn_act_layer
  529. attn_kwargs['rd_channels'] = int(cfg.attn_ratio * (out_chs if cfg.expand_output else mid_chs))
  530. # two different orderings for SE and norm2 (due to some weights and trials using SE before norm2)
  531. if cfg.attn_early:
  532. self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs, **dd)
  533. self.norm2 = norm_act_layer(mid_chs, **dd)
  534. self.se = None
  535. else:
  536. self.se_early = None
  537. self.norm2 = norm_act_layer(mid_chs, **dd)
  538. self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs, **dd)
  539. self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias, **dd)
  540. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  541. def init_weights(self, scheme: str = '') -> None:
  542. named_apply(partial(_init_conv, scheme=scheme), self)
  543. def forward(self, x: torch.Tensor) -> torch.Tensor:
  544. shortcut = self.shortcut(x)
  545. x = self.pre_norm(x)
  546. x = self.down(x)
  547. # 1x1 expansion conv & norm-act
  548. x = self.conv1_1x1(x)
  549. x = self.norm1(x)
  550. # depthwise / grouped 3x3 conv w/ SE (or other) channel attention & norm-act
  551. x = self.conv2_kxk(x)
  552. if self.se_early is not None:
  553. x = self.se_early(x)
  554. x = self.norm2(x)
  555. if self.se is not None:
  556. x = self.se(x)
  557. # 1x1 linear projection to output width
  558. x = self.conv3_1x1(x)
  559. x = self.drop_path(x) + shortcut
  560. return x
  561. class ConvNeXtBlock(nn.Module):
  562. """ConvNeXt Block."""
  563. def __init__(
  564. self,
  565. in_chs: int,
  566. out_chs: Optional[int] = None,
  567. kernel_size: int = 7,
  568. stride: int = 1,
  569. dilation: Tuple[int, int] = (1, 1),
  570. cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
  571. conv_mlp: bool = True,
  572. drop_path: float = 0.,
  573. device=None,
  574. dtype=None,
  575. ):
  576. """
  577. Args:
  578. in_chs: Input channels.
  579. out_chs: Output channels.
  580. kernel_size: Kernel size for depthwise conv.
  581. stride: Stride for conv.
  582. dilation: Dilation for conv.
  583. cfg: Convolution block configuration.
  584. conv_mlp: Whether to use convolutional MLP.
  585. drop_path: Drop path rate.
  586. """
  587. dd = {'device': device, 'dtype': dtype}
  588. super().__init__()
  589. out_chs = out_chs or in_chs
  590. act_layer = get_act_layer(cfg.act_layer)
  591. if conv_mlp:
  592. norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
  593. mlp_layer = ConvMlp
  594. else:
  595. assert 'layernorm' in cfg.norm_layer
  596. norm_layer = LayerNorm
  597. mlp_layer = Mlp
  598. self.use_conv_mlp = conv_mlp
  599. if stride == 2:
  600. self.shortcut = Downsample2d(in_chs, out_chs, **dd)
  601. elif in_chs != out_chs:
  602. self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias, **dd)
  603. else:
  604. self.shortcut = nn.Identity()
  605. assert cfg.stride_mode in ('pool', 'dw')
  606. stride_pool, stride_dw = 1, 1
  607. # FIXME handle dilation?
  608. if cfg.stride_mode == 'pool':
  609. stride_pool = stride
  610. else:
  611. stride_dw = stride
  612. if stride_pool == 2:
  613. self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type, **dd)
  614. else:
  615. self.down = nn.Identity()
  616. self.conv_dw = create_conv2d(
  617. in_chs,
  618. out_chs,
  619. kernel_size=kernel_size,
  620. stride=stride_dw,
  621. dilation=dilation[1],
  622. depthwise=True,
  623. bias=cfg.output_bias,
  624. **dd,
  625. )
  626. self.norm = norm_layer(out_chs, **dd)
  627. self.mlp = mlp_layer(
  628. out_chs,
  629. int(cfg.expand_ratio * out_chs),
  630. bias=cfg.output_bias,
  631. act_layer=act_layer,
  632. **dd,
  633. )
  634. if conv_mlp:
  635. self.ls = LayerScale2d(out_chs, cfg.init_values, **dd) if cfg.init_values else nn.Identity()
  636. else:
  637. self.ls = LayerScale(out_chs, cfg.init_values, **dd) if cfg.init_values else nn.Identity()
  638. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  639. def forward(self, x: torch.Tensor) -> torch.Tensor:
  640. shortcut = self.shortcut(x)
  641. x = self.down(x)
  642. x = self.conv_dw(x)
  643. if self.use_conv_mlp:
  644. x = self.norm(x)
  645. x = self.mlp(x)
  646. x = self.ls(x)
  647. else:
  648. x = x.permute(0, 2, 3, 1)
  649. x = self.norm(x)
  650. x = self.mlp(x)
  651. x = self.ls(x)
  652. x = x.permute(0, 3, 1, 2)
  653. x = self.drop_path(x) + shortcut
  654. return x
  655. def window_partition(x: torch.Tensor, window_size: List[int]) -> torch.Tensor:
  656. """Partition into non-overlapping windows."""
  657. B, H, W, C = x.shape
  658. _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
  659. _assert(W % window_size[1] == 0, f'width ({W}) must be divisible by window ({window_size[1]})')
  660. x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
  661. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
  662. return windows
  663. @register_notrace_function # reason: int argument is a Proxy
  664. def window_reverse(windows: torch.Tensor, window_size: List[int], img_size: List[int]) -> torch.Tensor:
  665. """Reverse window partition."""
  666. H, W = img_size
  667. C = windows.shape[-1]
  668. x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
  669. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
  670. return x
  671. def grid_partition(x: torch.Tensor, grid_size: List[int]) -> torch.Tensor:
  672. """Partition into overlapping windows with grid striding."""
  673. B, H, W, C = x.shape
  674. _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
  675. _assert(W % grid_size[1] == 0, f'width {W} must be divisible by grid {grid_size[1]}')
  676. x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C)
  677. windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C)
  678. return windows
  679. @register_notrace_function # reason: int argument is a Proxy
  680. def grid_reverse(windows: torch.Tensor, grid_size: List[int], img_size: List[int]) -> torch.Tensor:
  681. """Reverse grid partition."""
  682. H, W = img_size
  683. C = windows.shape[-1]
  684. x = windows.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C)
  685. x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C)
  686. return x
  687. def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size: Tuple[int, int]) -> Optional[Callable]:
  688. """Get relative position class based on config."""
  689. rel_pos_cls = None
  690. if cfg.rel_pos_type == 'mlp':
  691. rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim)
  692. elif cfg.rel_pos_type == 'bias':
  693. rel_pos_cls = partial(RelPosBias, window_size=window_size)
  694. elif cfg.rel_pos_type == 'bias_tf':
  695. rel_pos_cls = partial(RelPosBiasTf, window_size=window_size)
  696. return rel_pos_cls
  697. class PartitionAttentionCl(nn.Module):
  698. """Grid or Block partition + Attn + FFN.
  699. NxC 'channels last' tensor layout.
  700. """
  701. def __init__(
  702. self,
  703. dim: int,
  704. partition_type: str = 'block',
  705. cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
  706. drop_path: float = 0.,
  707. device=None,
  708. dtype=None,
  709. ):
  710. dd = {'device': device, 'dtype': dtype}
  711. super().__init__()
  712. norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last
  713. act_layer = get_act_layer(cfg.act_layer)
  714. self.partition_block = partition_type == 'block'
  715. self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size)
  716. rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
  717. self.norm1 = norm_layer(dim, **dd)
  718. self.attn = AttentionCl(
  719. dim,
  720. dim,
  721. dim_head=cfg.dim_head,
  722. bias=cfg.attn_bias,
  723. head_first=cfg.head_first,
  724. rel_pos_cls=rel_pos_cls,
  725. attn_drop=cfg.attn_drop,
  726. proj_drop=cfg.proj_drop,
  727. **dd,
  728. )
  729. self.ls1 = LayerScale(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
  730. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  731. self.norm2 = norm_layer(dim, **dd)
  732. self.mlp = Mlp(
  733. in_features=dim,
  734. hidden_features=int(dim * cfg.expand_ratio),
  735. act_layer=act_layer,
  736. drop=cfg.proj_drop,
  737. **dd,
  738. )
  739. self.ls2 = LayerScale(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
  740. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  741. def _partition_attn(self, x):
  742. img_size = x.shape[1:3]
  743. if self.partition_block:
  744. partitioned = window_partition(x, self.partition_size)
  745. else:
  746. partitioned = grid_partition(x, self.partition_size)
  747. partitioned = self.attn(partitioned)
  748. if self.partition_block:
  749. x = window_reverse(partitioned, self.partition_size, img_size)
  750. else:
  751. x = grid_reverse(partitioned, self.partition_size, img_size)
  752. return x
  753. def forward(self, x):
  754. x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
  755. x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
  756. return x
  757. class ParallelPartitionAttention(nn.Module):
  758. """Experimental. Grid and Block partition + single FFN.
  759. NxC tensor layout.
  760. """
  761. def __init__(
  762. self,
  763. dim: int,
  764. cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
  765. drop_path: float = 0.,
  766. device=None,
  767. dtype=None,
  768. ):
  769. """
  770. Args:
  771. dim: Input dimension.
  772. cfg: Transformer block configuration.
  773. drop_path: Drop path rate.
  774. """
  775. dd = {'device': device, 'dtype': dtype}
  776. super().__init__()
  777. assert dim % 2 == 0
  778. norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last
  779. act_layer = get_act_layer(cfg.act_layer)
  780. assert cfg.window_size == cfg.grid_size
  781. self.partition_size = to_2tuple(cfg.window_size)
  782. rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
  783. self.norm1 = norm_layer(dim, **dd)
  784. self.attn_block = AttentionCl(
  785. dim,
  786. dim // 2,
  787. dim_head=cfg.dim_head,
  788. bias=cfg.attn_bias,
  789. head_first=cfg.head_first,
  790. rel_pos_cls=rel_pos_cls,
  791. attn_drop=cfg.attn_drop,
  792. proj_drop=cfg.proj_drop,
  793. **dd,
  794. )
  795. self.attn_grid = AttentionCl(
  796. dim,
  797. dim // 2,
  798. dim_head=cfg.dim_head,
  799. bias=cfg.attn_bias,
  800. head_first=cfg.head_first,
  801. rel_pos_cls=rel_pos_cls,
  802. attn_drop=cfg.attn_drop,
  803. proj_drop=cfg.proj_drop,
  804. **dd,
  805. )
  806. self.ls1 = LayerScale(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
  807. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  808. self.norm2 = norm_layer(dim, **dd)
  809. self.mlp = Mlp(
  810. in_features=dim,
  811. hidden_features=int(dim * cfg.expand_ratio),
  812. out_features=dim,
  813. act_layer=act_layer,
  814. drop=cfg.proj_drop,
  815. **dd,
  816. )
  817. self.ls2 = LayerScale(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
  818. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  819. def _partition_attn(self, x: torch.Tensor) -> torch.Tensor:
  820. img_size = x.shape[1:3]
  821. partitioned_block = window_partition(x, self.partition_size)
  822. partitioned_block = self.attn_block(partitioned_block)
  823. x_window = window_reverse(partitioned_block, self.partition_size, img_size)
  824. partitioned_grid = grid_partition(x, self.partition_size)
  825. partitioned_grid = self.attn_grid(partitioned_grid)
  826. x_grid = grid_reverse(partitioned_grid, self.partition_size, img_size)
  827. return torch.cat([x_window, x_grid], dim=-1)
  828. def forward(self, x: torch.Tensor) -> torch.Tensor:
  829. x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
  830. x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
  831. return x
  832. def window_partition_nchw(x: torch.Tensor, window_size: List[int]) -> torch.Tensor:
  833. """Partition windows for NCHW tensors."""
  834. B, C, H, W = x.shape
  835. _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
  836. _assert(W % window_size[1] == 0, f'width ({W}) must be divisible by window ({window_size[1]})')
  837. x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1])
  838. windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1])
  839. return windows
  840. @register_notrace_function # reason: int argument is a Proxy
  841. def window_reverse_nchw(windows: torch.Tensor, window_size: List[int], img_size: List[int]) -> torch.Tensor:
  842. """Reverse window partition for NCHW tensors."""
  843. H, W = img_size
  844. C = windows.shape[1]
  845. x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1])
  846. x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W)
  847. return x
  848. def grid_partition_nchw(x: torch.Tensor, grid_size: List[int]) -> torch.Tensor:
  849. """Grid partition for NCHW tensors."""
  850. B, C, H, W = x.shape
  851. _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
  852. _assert(W % grid_size[1] == 0, f'width {W} must be divisible by grid {grid_size[1]}')
  853. x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1])
  854. windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1])
  855. return windows
  856. @register_notrace_function # reason: int argument is a Proxy
  857. def grid_reverse_nchw(windows: torch.Tensor, grid_size: List[int], img_size: List[int]) -> torch.Tensor:
  858. """Reverse grid partition for NCHW tensors."""
  859. H, W = img_size
  860. C = windows.shape[1]
  861. x = windows.view(-1, H // grid_size[0], W // grid_size[1], C, grid_size[0], grid_size[1])
  862. x = x.permute(0, 3, 4, 1, 5, 2).contiguous().view(-1, C, H, W)
  863. return x
  864. class PartitionAttention2d(nn.Module):
  865. """Grid or Block partition + Attn + FFN.
  866. '2D' NCHW tensor layout.
  867. """
  868. def __init__(
  869. self,
  870. dim: int,
  871. partition_type: str = 'block',
  872. cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
  873. drop_path: float = 0.,
  874. device=None,
  875. dtype=None,
  876. ):
  877. """
  878. Args:
  879. dim: Input dimension.
  880. partition_type: Partition type ('block' or 'grid').
  881. cfg: Transformer block configuration.
  882. drop_path: Drop path rate.
  883. """
  884. dd = {'device': device, 'dtype': dtype}
  885. super().__init__()
  886. norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) # NOTE this block is channels-last
  887. act_layer = get_act_layer(cfg.act_layer)
  888. self.partition_block = partition_type == 'block'
  889. self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size)
  890. rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
  891. self.norm1 = norm_layer(dim, **dd)
  892. self.attn = Attention2d(
  893. dim,
  894. dim,
  895. dim_head=cfg.dim_head,
  896. bias=cfg.attn_bias,
  897. head_first=cfg.head_first,
  898. rel_pos_cls=rel_pos_cls,
  899. attn_drop=cfg.attn_drop,
  900. proj_drop=cfg.proj_drop,
  901. **dd,
  902. )
  903. self.ls1 = LayerScale2d(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
  904. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  905. self.norm2 = norm_layer(dim, **dd)
  906. self.mlp = ConvMlp(
  907. in_features=dim,
  908. hidden_features=int(dim * cfg.expand_ratio),
  909. act_layer=act_layer,
  910. drop=cfg.proj_drop,
  911. **dd,
  912. )
  913. self.ls2 = LayerScale2d(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity()
  914. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  915. def _partition_attn(self, x: torch.Tensor) -> torch.Tensor:
  916. img_size = x.shape[-2:]
  917. if self.partition_block:
  918. partitioned = window_partition_nchw(x, self.partition_size)
  919. else:
  920. partitioned = grid_partition_nchw(x, self.partition_size)
  921. partitioned = self.attn(partitioned)
  922. if self.partition_block:
  923. x = window_reverse_nchw(partitioned, self.partition_size, img_size)
  924. else:
  925. x = grid_reverse_nchw(partitioned, self.partition_size, img_size)
  926. return x
  927. def forward(self, x: torch.Tensor) -> torch.Tensor:
  928. x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
  929. x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
  930. return x
  931. class MaxxVitBlock(nn.Module):
  932. """MaxVit conv, window partition + FFN , grid partition + FFN."""
  933. def __init__(
  934. self,
  935. dim: int,
  936. dim_out: int,
  937. stride: int = 1,
  938. conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
  939. transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
  940. drop_path: float = 0.,
  941. device=None,
  942. dtype=None,
  943. ):
  944. """Initialize MaxxVitBlock.
  945. Args:
  946. dim: Input channel dimension.
  947. dim_out: Output channel dimension.
  948. stride: Stride for downsampling.
  949. conv_cfg: Configuration for convolutional blocks.
  950. transformer_cfg: Configuration for transformer blocks.
  951. drop_path: Drop path rate.
  952. """
  953. dd = {'device': device, 'dtype': dtype}
  954. super().__init__()
  955. self.nchw_attn = transformer_cfg.use_nchw_attn
  956. conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
  957. self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path, **dd)
  958. attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path, **dd)
  959. partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl
  960. self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs)
  961. self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
  962. def init_weights(self, scheme=''):
  963. if self.attn_block is not None:
  964. named_apply(partial(_init_transformer, scheme=scheme), self.attn_block)
  965. named_apply(partial(_init_transformer, scheme=scheme), self.attn_grid)
  966. named_apply(partial(_init_conv, scheme=scheme), self.conv)
  967. def forward(self, x):
  968. # NCHW format
  969. x = self.conv(x)
  970. if not self.nchw_attn:
  971. x = x.permute(0, 2, 3, 1) # to NHWC (channels-last)
  972. if self.attn_block is not None:
  973. x = self.attn_block(x)
  974. x = self.attn_grid(x)
  975. if not self.nchw_attn:
  976. x = x.permute(0, 3, 1, 2) # back to NCHW
  977. return x
  978. class ParallelMaxxVitBlock(nn.Module):
  979. """MaxVit block with parallel cat(window + grid), one FF.
  980. Experimental timm block.
  981. """
  982. def __init__(
  983. self,
  984. dim: int,
  985. dim_out: int,
  986. stride: int = 1,
  987. num_conv: int = 2,
  988. conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
  989. transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
  990. drop_path: float = 0.,
  991. device=None,
  992. dtype=None,
  993. ):
  994. """
  995. Args:
  996. dim: Input dimension.
  997. dim_out: Output dimension.
  998. stride: Stride for first conv block.
  999. num_conv: Number of convolution blocks.
  1000. conv_cfg: Convolution block configuration.
  1001. transformer_cfg: Transformer block configuration.
  1002. drop_path: Drop path rate.
  1003. """
  1004. dd = {'device': device, 'dtype': dtype}
  1005. super().__init__()
  1006. conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
  1007. if num_conv > 1:
  1008. convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path, **dd)]
  1009. convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path, **dd)] * (num_conv - 1)
  1010. self.conv = nn.Sequential(*convs)
  1011. else:
  1012. self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path, **dd)
  1013. self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path, **dd)
  1014. def init_weights(self, scheme: str = '') -> None:
  1015. named_apply(partial(_init_transformer, scheme=scheme), self.attn)
  1016. named_apply(partial(_init_conv, scheme=scheme), self.conv)
  1017. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1018. x = self.conv(x)
  1019. x = x.permute(0, 2, 3, 1)
  1020. x = self.attn(x)
  1021. x = x.permute(0, 3, 1, 2)
  1022. return x
  1023. class MaxxVitStage(nn.Module):
  1024. """MaxxVit stage consisting of mixed convolution and transformer blocks."""
  1025. def __init__(
  1026. self,
  1027. in_chs: int,
  1028. out_chs: int,
  1029. stride: int = 2,
  1030. depth: int = 4,
  1031. feat_size: Tuple[int, int] = (14, 14),
  1032. block_types: Union[str, Tuple[str]] = 'C',
  1033. transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
  1034. conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
  1035. drop_path: Union[float, List[float]] = 0.,
  1036. device=None,
  1037. dtype=None,
  1038. ):
  1039. """
  1040. Args:
  1041. in_chs: Input channels.
  1042. out_chs: Output channels.
  1043. stride: Stride for first block.
  1044. depth: Number of blocks in stage.
  1045. feat_size: Feature map size.
  1046. block_types: Block types ('C' for conv, 'T' for transformer, etc).
  1047. transformer_cfg: Transformer block configuration.
  1048. conv_cfg: Convolution block configuration.
  1049. drop_path: Drop path rate(s).
  1050. """
  1051. dd = {'device': device, 'dtype': dtype}
  1052. super().__init__()
  1053. self.grad_checkpointing = False
  1054. block_types = extend_tuple(block_types, depth)
  1055. blocks = []
  1056. for i, t in enumerate(block_types):
  1057. block_stride = stride if i == 0 else 1
  1058. assert t in ('C', 'T', 'M', 'PM')
  1059. if t == 'C':
  1060. conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
  1061. blocks += [conv_cls(
  1062. in_chs,
  1063. out_chs,
  1064. stride=block_stride,
  1065. cfg=conv_cfg,
  1066. drop_path=drop_path[i],
  1067. **dd,
  1068. )]
  1069. elif t == 'T':
  1070. rel_pos_cls = get_rel_pos_cls(transformer_cfg, feat_size)
  1071. blocks += [TransformerBlock2d(
  1072. in_chs,
  1073. out_chs,
  1074. stride=block_stride,
  1075. rel_pos_cls=rel_pos_cls,
  1076. cfg=transformer_cfg,
  1077. drop_path=drop_path[i],
  1078. **dd,
  1079. )]
  1080. elif t == 'M':
  1081. blocks += [MaxxVitBlock(
  1082. in_chs,
  1083. out_chs,
  1084. stride=block_stride,
  1085. conv_cfg=conv_cfg,
  1086. transformer_cfg=transformer_cfg,
  1087. drop_path=drop_path[i],
  1088. **dd,
  1089. )]
  1090. elif t == 'PM':
  1091. blocks += [ParallelMaxxVitBlock(
  1092. in_chs,
  1093. out_chs,
  1094. stride=block_stride,
  1095. conv_cfg=conv_cfg,
  1096. transformer_cfg=transformer_cfg,
  1097. drop_path=drop_path[i],
  1098. **dd,
  1099. )]
  1100. in_chs = out_chs
  1101. self.blocks = nn.Sequential(*blocks)
  1102. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1103. if self.grad_checkpointing and not torch.jit.is_scripting():
  1104. x = checkpoint_seq(self.blocks, x)
  1105. else:
  1106. x = self.blocks(x)
  1107. return x
  1108. class Stem(nn.Module):
  1109. """Stem layer for feature extraction."""
  1110. def __init__(
  1111. self,
  1112. in_chs: int,
  1113. out_chs: int,
  1114. kernel_size: int = 3,
  1115. padding: str = '',
  1116. bias: bool = False,
  1117. act_layer: str = 'gelu',
  1118. norm_layer: str = 'batchnorm2d',
  1119. norm_eps: float = 1e-5,
  1120. device=None,
  1121. dtype=None,
  1122. ):
  1123. """
  1124. Args:
  1125. in_chs: Input channels.
  1126. out_chs: Output channels.
  1127. kernel_size: Kernel size for convolutions.
  1128. padding: Padding mode.
  1129. bias: Whether to use bias.
  1130. act_layer: Activation layer.
  1131. norm_layer: Normalization layer.
  1132. norm_eps: Normalization epsilon.
  1133. """
  1134. dd = {'device': device, 'dtype': dtype}
  1135. super().__init__()
  1136. if not isinstance(out_chs, (list, tuple)):
  1137. out_chs = to_2tuple(out_chs)
  1138. norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
  1139. self.out_chs = out_chs[-1]
  1140. self.stride = 2
  1141. self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2, padding=padding, bias=bias, **dd)
  1142. self.norm1 = norm_act_layer(out_chs[0], **dd)
  1143. self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1, padding=padding, bias=bias, **dd)
  1144. def init_weights(self, scheme: str = '') -> None:
  1145. named_apply(partial(_init_conv, scheme=scheme), self)
  1146. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1147. x = self.conv1(x)
  1148. x = self.norm1(x)
  1149. x = self.conv2(x)
  1150. return x
  1151. def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]) -> MaxxVitTransformerCfg:
  1152. """Configure window size based on image size and partition ratio."""
  1153. if cfg.window_size is not None:
  1154. assert cfg.grid_size
  1155. return cfg
  1156. partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio
  1157. cfg = replace(cfg, window_size=partition_size, grid_size=partition_size)
  1158. return cfg
  1159. def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs: Any) -> MaxxVitCfg:
  1160. """Overlay keyword arguments onto configuration."""
  1161. transformer_kwargs = {}
  1162. conv_kwargs = {}
  1163. base_kwargs = {}
  1164. for k, v in kwargs.items():
  1165. if k.startswith('transformer_'):
  1166. transformer_kwargs[k.replace('transformer_', '')] = v
  1167. elif k.startswith('conv_'):
  1168. conv_kwargs[k.replace('conv_', '')] = v
  1169. else:
  1170. base_kwargs[k] = v
  1171. cfg = replace(
  1172. cfg,
  1173. transformer_cfg=replace(cfg.transformer_cfg, **transformer_kwargs),
  1174. conv_cfg=replace(cfg.conv_cfg, **conv_kwargs),
  1175. **base_kwargs
  1176. )
  1177. return cfg
  1178. class MaxxVit(nn.Module):
  1179. """CoaTNet + MaxVit base model.
  1180. Highly configurable for different block compositions, tensor layouts, pooling types.
  1181. """
  1182. def __init__(
  1183. self,
  1184. cfg: MaxxVitCfg,
  1185. img_size: Union[int, Tuple[int, int]] = 224,
  1186. in_chans: int = 3,
  1187. num_classes: int = 1000,
  1188. global_pool: str = 'avg',
  1189. drop_rate: float = 0.,
  1190. drop_path_rate: float = 0.,
  1191. device=None,
  1192. dtype=None,
  1193. **kwargs: Any,
  1194. ):
  1195. """
  1196. Args:
  1197. cfg: Model configuration.
  1198. img_size: Input image size.
  1199. in_chans: Number of input channels.
  1200. num_classes: Number of classification classes.
  1201. global_pool: Global pooling type.
  1202. drop_rate: Dropout rate.
  1203. drop_path_rate: Drop path rate.
  1204. **kwargs: Additional keyword arguments to overlay on config.
  1205. """
  1206. super().__init__()
  1207. dd = {'device': device, 'dtype': dtype}
  1208. img_size = to_2tuple(img_size)
  1209. if kwargs:
  1210. cfg = _overlay_kwargs(cfg, **kwargs)
  1211. transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
  1212. self.num_classes = num_classes
  1213. self.in_chans = in_chans
  1214. self.global_pool = global_pool
  1215. self.num_features = self.embed_dim = cfg.embed_dim[-1]
  1216. self.drop_rate = drop_rate
  1217. self.grad_checkpointing = False
  1218. self.feature_info = []
  1219. self.stem = Stem(
  1220. in_chs=in_chans,
  1221. out_chs=cfg.stem_width,
  1222. padding=cfg.conv_cfg.padding,
  1223. bias=cfg.stem_bias,
  1224. act_layer=cfg.conv_cfg.act_layer,
  1225. norm_layer=cfg.conv_cfg.norm_layer,
  1226. norm_eps=cfg.conv_cfg.norm_eps,
  1227. **dd,
  1228. )
  1229. stride = self.stem.stride
  1230. self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')]
  1231. feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))])
  1232. num_stages = len(cfg.embed_dim)
  1233. assert len(cfg.depths) == num_stages
  1234. dpr = calculate_drop_path_rates(drop_path_rate, cfg.depths, stagewise=True)
  1235. in_chs = self.stem.out_chs
  1236. stages = []
  1237. for i in range(num_stages):
  1238. stage_stride = 2
  1239. out_chs = cfg.embed_dim[i]
  1240. feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size])
  1241. stages += [MaxxVitStage(
  1242. in_chs,
  1243. out_chs,
  1244. depth=cfg.depths[i],
  1245. block_types=cfg.block_type[i],
  1246. conv_cfg=cfg.conv_cfg,
  1247. transformer_cfg=transformer_cfg,
  1248. feat_size=feat_size,
  1249. drop_path=dpr[i],
  1250. **dd,
  1251. )]
  1252. stride *= stage_stride
  1253. in_chs = out_chs
  1254. self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')]
  1255. self.stages = nn.Sequential(*stages)
  1256. final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps)
  1257. if cfg.head_hidden_size:
  1258. self.norm = nn.Identity()
  1259. self.head_hidden_size = cfg.head_hidden_size
  1260. self.head = NormMlpClassifierHead(
  1261. self.num_features,
  1262. num_classes,
  1263. hidden_size=self.head_hidden_size,
  1264. pool_type=global_pool,
  1265. drop_rate=drop_rate,
  1266. norm_layer=final_norm_layer,
  1267. **dd,
  1268. )
  1269. else:
  1270. # standard classifier head w/ norm, pooling, fc classifier
  1271. self.head_hidden_size = self.num_features
  1272. self.norm = final_norm_layer(self.num_features, **dd)
  1273. self.head = ClassifierHead(
  1274. self.num_features,
  1275. num_classes,
  1276. pool_type=global_pool,
  1277. drop_rate=drop_rate,
  1278. **dd,
  1279. )
  1280. # Weight init (default PyTorch init works well for AdamW if scheme not set)
  1281. assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff')
  1282. if cfg.weight_init:
  1283. named_apply(partial(self._init_weights, scheme=cfg.weight_init), self)
  1284. def _init_weights(self, module: nn.Module, name: str, scheme: str = '') -> None:
  1285. if hasattr(module, 'init_weights'):
  1286. try:
  1287. module.init_weights(scheme=scheme)
  1288. except TypeError:
  1289. module.init_weights()
  1290. @torch.jit.ignore
  1291. def no_weight_decay(self) -> Set[str]:
  1292. return {
  1293. k for k, _ in self.named_parameters()
  1294. if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
  1295. @torch.jit.ignore
  1296. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  1297. matcher = dict(
  1298. stem=r'^stem', # stem and embed
  1299. blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
  1300. )
  1301. return matcher
  1302. @torch.jit.ignore
  1303. def set_grad_checkpointing(self, enable: bool = True) -> None:
  1304. for s in self.stages:
  1305. s.grad_checkpointing = enable
  1306. @torch.jit.ignore
  1307. def get_classifier(self) -> nn.Module:
  1308. return self.head.fc
  1309. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  1310. self.num_classes = num_classes
  1311. self.head.reset(num_classes, global_pool)
  1312. def forward_intermediates(
  1313. self,
  1314. x: torch.Tensor,
  1315. indices: Optional[Union[int, List[int]]] = None,
  1316. norm: bool = False,
  1317. stop_early: bool = False,
  1318. output_fmt: str = 'NCHW',
  1319. intermediates_only: bool = False,
  1320. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  1321. """ Forward features that returns intermediates.
  1322. Args:
  1323. x: Input image tensor
  1324. indices: Take last n blocks if int, all if None, select matching indices if sequence
  1325. norm: Apply norm layer to compatible intermediates
  1326. stop_early: Stop iterating over blocks when last desired intermediate hit
  1327. output_fmt: Shape of intermediate feature outputs
  1328. intermediates_only: Only return intermediate features
  1329. Returns:
  1330. """
  1331. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  1332. intermediates = []
  1333. take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
  1334. # forward pass
  1335. feat_idx = 0 # stem is index 0
  1336. x = self.stem(x)
  1337. if feat_idx in take_indices:
  1338. intermediates.append(x)
  1339. last_idx = len(self.stages)
  1340. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  1341. stages = self.stages
  1342. else:
  1343. stages = self.stages[:max_index]
  1344. for stage in stages:
  1345. feat_idx += 1
  1346. x = stage(x)
  1347. if feat_idx in take_indices:
  1348. if norm and feat_idx == last_idx:
  1349. x_inter = self.norm(x) # applying final norm to last intermediate
  1350. else:
  1351. x_inter = x
  1352. intermediates.append(x_inter)
  1353. if intermediates_only:
  1354. return intermediates
  1355. if feat_idx == last_idx:
  1356. x = self.norm(x)
  1357. return x, intermediates
  1358. def prune_intermediate_layers(
  1359. self,
  1360. indices: Union[int, List[int]] = 1,
  1361. prune_norm: bool = False,
  1362. prune_head: bool = True,
  1363. ) -> Tuple[int, ...]:
  1364. """Prune layers not required for specified intermediates."""
  1365. take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
  1366. self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
  1367. if prune_norm:
  1368. self.norm = nn.Identity()
  1369. if prune_head:
  1370. self.head = self.reset_classifier(0, '')
  1371. return take_indices
  1372. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  1373. x = self.stem(x)
  1374. x = self.stages(x)
  1375. x = self.norm(x)
  1376. return x
  1377. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  1378. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  1379. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1380. x = self.forward_features(x)
  1381. x = self.forward_head(x)
  1382. return x
  1383. def _rw_coat_cfg(
  1384. stride_mode: str = 'pool',
  1385. pool_type: str = 'avg2',
  1386. conv_output_bias: bool = False,
  1387. conv_attn_early: bool = False,
  1388. conv_attn_act_layer: str = 'relu',
  1389. conv_norm_layer: str = '',
  1390. transformer_shortcut_bias: bool = True,
  1391. transformer_norm_layer: str = 'layernorm2d',
  1392. transformer_norm_layer_cl: str = 'layernorm',
  1393. init_values: Optional[float] = None,
  1394. rel_pos_type: str = 'bias',
  1395. rel_pos_dim: int = 512,
  1396. ) -> Dict[str, Any]:
  1397. """RW variant configuration for CoAtNet models.
  1398. These models were created and trained before seeing https://github.com/google-research/maxvit
  1399. Common differences for initial timm models:
  1400. - pre-norm layer in MZBConv included an activation after norm
  1401. - mbconv expansion calculated from input instead of output chs
  1402. - mbconv shortcut and final 1x1 conv did not have a bias
  1403. - SE act layer was relu, not silu
  1404. - mbconv uses silu in timm, not gelu
  1405. - expansion in attention block done via output proj, not input proj
  1406. Variable differences (evolved over training initial models):
  1407. - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat)
  1408. - SE attention was between conv2 and norm/act
  1409. - default to avg pool for mbconv downsample instead of 1x1 or dw conv
  1410. - transformer block shortcut has no bias
  1411. """
  1412. return dict(
  1413. conv_cfg=MaxxVitConvCfg(
  1414. stride_mode=stride_mode,
  1415. pool_type=pool_type,
  1416. pre_norm_act=True,
  1417. expand_output=False,
  1418. output_bias=conv_output_bias,
  1419. attn_early=conv_attn_early,
  1420. attn_act_layer=conv_attn_act_layer,
  1421. act_layer='silu',
  1422. norm_layer=conv_norm_layer,
  1423. ),
  1424. transformer_cfg=MaxxVitTransformerCfg(
  1425. expand_first=False,
  1426. shortcut_bias=transformer_shortcut_bias,
  1427. pool_type=pool_type,
  1428. init_values=init_values,
  1429. norm_layer=transformer_norm_layer,
  1430. norm_layer_cl=transformer_norm_layer_cl,
  1431. rel_pos_type=rel_pos_type,
  1432. rel_pos_dim=rel_pos_dim,
  1433. ),
  1434. )
  1435. def _rw_max_cfg(
  1436. stride_mode: str = 'dw',
  1437. pool_type: str = 'avg2',
  1438. conv_output_bias: bool = False,
  1439. conv_attn_ratio: float = 1 / 16,
  1440. conv_norm_layer: str = '',
  1441. transformer_norm_layer: str = 'layernorm2d',
  1442. transformer_norm_layer_cl: str = 'layernorm',
  1443. window_size: Optional[Tuple[int, int]] = None,
  1444. dim_head: int = 32,
  1445. init_values: Optional[float] = None,
  1446. rel_pos_type: str = 'bias',
  1447. rel_pos_dim: int = 512,
  1448. ) -> Dict[str, Any]:
  1449. """RW variant configuration for MaxViT models.
  1450. These models were created and trained before seeing https://github.com/google-research/maxvit
  1451. Differences of initial timm models:
  1452. - mbconv expansion calculated from input instead of output chs
  1453. - mbconv shortcut and final 1x1 conv did not have a bias
  1454. - mbconv uses silu in timm, not gelu
  1455. - expansion in attention block done via output proj, not input proj
  1456. """
  1457. return dict(
  1458. conv_cfg=MaxxVitConvCfg(
  1459. stride_mode=stride_mode,
  1460. pool_type=pool_type,
  1461. expand_output=False,
  1462. output_bias=conv_output_bias,
  1463. attn_ratio=conv_attn_ratio,
  1464. act_layer='silu',
  1465. norm_layer=conv_norm_layer,
  1466. ),
  1467. transformer_cfg=MaxxVitTransformerCfg(
  1468. expand_first=False,
  1469. pool_type=pool_type,
  1470. dim_head=dim_head,
  1471. window_size=window_size,
  1472. init_values=init_values,
  1473. norm_layer=transformer_norm_layer,
  1474. norm_layer_cl=transformer_norm_layer_cl,
  1475. rel_pos_type=rel_pos_type,
  1476. rel_pos_dim=rel_pos_dim,
  1477. ),
  1478. )
  1479. def _next_cfg(
  1480. stride_mode: str = 'dw',
  1481. pool_type: str = 'avg2',
  1482. conv_norm_layer: str = 'layernorm2d',
  1483. conv_norm_layer_cl: str = 'layernorm',
  1484. transformer_norm_layer: str = 'layernorm2d',
  1485. transformer_norm_layer_cl: str = 'layernorm',
  1486. window_size: Optional[Tuple[int, int]] = None,
  1487. no_block_attn: bool = False,
  1488. init_values: Union[float, Tuple[float, float]] = 1e-6,
  1489. rel_pos_type: str = 'mlp', # MLP by default for maxxvit
  1490. rel_pos_dim: int = 512,
  1491. ) -> Dict[str, Any]:
  1492. """Configuration for experimental ConvNeXt-based MaxxViT models."""
  1493. init_values = to_2tuple(init_values)
  1494. return dict(
  1495. conv_cfg=MaxxVitConvCfg(
  1496. block_type='convnext',
  1497. stride_mode=stride_mode,
  1498. pool_type=pool_type,
  1499. expand_output=False,
  1500. init_values=init_values[0],
  1501. norm_layer=conv_norm_layer,
  1502. norm_layer_cl=conv_norm_layer_cl,
  1503. ),
  1504. transformer_cfg=MaxxVitTransformerCfg(
  1505. expand_first=False,
  1506. pool_type=pool_type,
  1507. window_size=window_size,
  1508. no_block_attn=no_block_attn, # enabled for MaxxViT-V2
  1509. init_values=init_values[1],
  1510. norm_layer=transformer_norm_layer,
  1511. norm_layer_cl=transformer_norm_layer_cl,
  1512. rel_pos_type=rel_pos_type,
  1513. rel_pos_dim=rel_pos_dim,
  1514. ),
  1515. )
  1516. def _tf_cfg() -> Dict[str, Any]:
  1517. """Configuration matching TensorFlow MaxViT models."""
  1518. return dict(
  1519. conv_cfg=MaxxVitConvCfg(
  1520. norm_eps=1e-3,
  1521. act_layer='gelu_tanh',
  1522. padding='same',
  1523. ),
  1524. transformer_cfg=MaxxVitTransformerCfg(
  1525. norm_eps=1e-5,
  1526. act_layer='gelu_tanh',
  1527. head_first=False, # heads are interleaved (q_nh, q_hdim, k_nh, q_hdim, ....)
  1528. rel_pos_type='bias_tf',
  1529. ),
  1530. )
  1531. model_cfgs = dict(
  1532. # timm specific CoAtNet configs
  1533. coatnet_pico_rw=MaxxVitCfg(
  1534. embed_dim=(64, 128, 256, 512),
  1535. depths=(2, 3, 5, 2),
  1536. stem_width=(32, 64),
  1537. **_rw_max_cfg( # using newer max defaults here
  1538. conv_output_bias=True,
  1539. conv_attn_ratio=0.25,
  1540. ),
  1541. ),
  1542. coatnet_nano_rw=MaxxVitCfg(
  1543. embed_dim=(64, 128, 256, 512),
  1544. depths=(3, 4, 6, 3),
  1545. stem_width=(32, 64),
  1546. **_rw_max_cfg( # using newer max defaults here
  1547. stride_mode='pool',
  1548. conv_output_bias=True,
  1549. conv_attn_ratio=0.25,
  1550. ),
  1551. ),
  1552. coatnet_0_rw=MaxxVitCfg(
  1553. embed_dim=(96, 192, 384, 768),
  1554. depths=(2, 3, 7, 2), # deeper than paper '0' model
  1555. stem_width=(32, 64),
  1556. **_rw_coat_cfg(
  1557. conv_attn_early=True,
  1558. transformer_shortcut_bias=False,
  1559. ),
  1560. ),
  1561. coatnet_1_rw=MaxxVitCfg(
  1562. embed_dim=(96, 192, 384, 768),
  1563. depths=(2, 6, 14, 2),
  1564. stem_width=(32, 64),
  1565. **_rw_coat_cfg(
  1566. stride_mode='dw',
  1567. conv_attn_early=True,
  1568. transformer_shortcut_bias=False,
  1569. )
  1570. ),
  1571. coatnet_2_rw=MaxxVitCfg(
  1572. embed_dim=(128, 256, 512, 1024),
  1573. depths=(2, 6, 14, 2),
  1574. stem_width=(64, 128),
  1575. **_rw_coat_cfg(
  1576. stride_mode='dw',
  1577. conv_attn_act_layer='silu',
  1578. #init_values=1e-6,
  1579. ),
  1580. ),
  1581. coatnet_3_rw=MaxxVitCfg(
  1582. embed_dim=(192, 384, 768, 1536),
  1583. depths=(2, 6, 14, 2),
  1584. stem_width=(96, 192),
  1585. **_rw_coat_cfg(
  1586. stride_mode='dw',
  1587. conv_attn_act_layer='silu',
  1588. init_values=1e-6,
  1589. ),
  1590. ),
  1591. # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
  1592. coatnet_bn_0_rw=MaxxVitCfg(
  1593. embed_dim=(96, 192, 384, 768),
  1594. depths=(2, 3, 7, 2), # deeper than paper '0' model
  1595. stem_width=(32, 64),
  1596. **_rw_coat_cfg(
  1597. stride_mode='dw',
  1598. conv_attn_early=True,
  1599. transformer_shortcut_bias=False,
  1600. transformer_norm_layer='batchnorm2d',
  1601. )
  1602. ),
  1603. coatnet_rmlp_nano_rw=MaxxVitCfg(
  1604. embed_dim=(64, 128, 256, 512),
  1605. depths=(3, 4, 6, 3),
  1606. stem_width=(32, 64),
  1607. **_rw_max_cfg(
  1608. conv_output_bias=True,
  1609. conv_attn_ratio=0.25,
  1610. rel_pos_type='mlp',
  1611. rel_pos_dim=384,
  1612. ),
  1613. ),
  1614. coatnet_rmlp_0_rw=MaxxVitCfg(
  1615. embed_dim=(96, 192, 384, 768),
  1616. depths=(2, 3, 7, 2), # deeper than paper '0' model
  1617. stem_width=(32, 64),
  1618. **_rw_coat_cfg(
  1619. stride_mode='dw',
  1620. rel_pos_type='mlp',
  1621. ),
  1622. ),
  1623. coatnet_rmlp_1_rw=MaxxVitCfg(
  1624. embed_dim=(96, 192, 384, 768),
  1625. depths=(2, 6, 14, 2),
  1626. stem_width=(32, 64),
  1627. **_rw_coat_cfg(
  1628. pool_type='max',
  1629. conv_attn_early=True,
  1630. transformer_shortcut_bias=False,
  1631. rel_pos_type='mlp',
  1632. rel_pos_dim=384, # was supposed to be 512, woops
  1633. ),
  1634. ),
  1635. coatnet_rmlp_1_rw2=MaxxVitCfg(
  1636. embed_dim=(96, 192, 384, 768),
  1637. depths=(2, 6, 14, 2),
  1638. stem_width=(32, 64),
  1639. **_rw_coat_cfg(
  1640. stride_mode='dw',
  1641. rel_pos_type='mlp',
  1642. rel_pos_dim=512, # was supposed to be 512, woops
  1643. ),
  1644. ),
  1645. coatnet_rmlp_2_rw=MaxxVitCfg(
  1646. embed_dim=(128, 256, 512, 1024),
  1647. depths=(2, 6, 14, 2),
  1648. stem_width=(64, 128),
  1649. **_rw_coat_cfg(
  1650. stride_mode='dw',
  1651. conv_attn_act_layer='silu',
  1652. init_values=1e-6,
  1653. rel_pos_type='mlp'
  1654. ),
  1655. ),
  1656. coatnet_rmlp_3_rw=MaxxVitCfg(
  1657. embed_dim=(192, 384, 768, 1536),
  1658. depths=(2, 6, 14, 2),
  1659. stem_width=(96, 192),
  1660. **_rw_coat_cfg(
  1661. stride_mode='dw',
  1662. conv_attn_act_layer='silu',
  1663. init_values=1e-6,
  1664. rel_pos_type='mlp'
  1665. ),
  1666. ),
  1667. coatnet_nano_cc=MaxxVitCfg(
  1668. embed_dim=(64, 128, 256, 512),
  1669. depths=(3, 4, 6, 3),
  1670. stem_width=(32, 64),
  1671. block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
  1672. **_rw_coat_cfg(),
  1673. ),
  1674. coatnext_nano_rw=MaxxVitCfg(
  1675. embed_dim=(64, 128, 256, 512),
  1676. depths=(3, 4, 6, 3),
  1677. stem_width=(32, 64),
  1678. weight_init='normal',
  1679. **_next_cfg(
  1680. rel_pos_type='bias',
  1681. init_values=(1e-5, None)
  1682. ),
  1683. ),
  1684. # Trying to be like the CoAtNet paper configs
  1685. coatnet_0=MaxxVitCfg(
  1686. embed_dim=(96, 192, 384, 768),
  1687. depths=(2, 3, 5, 2),
  1688. stem_width=64,
  1689. head_hidden_size=768,
  1690. ),
  1691. coatnet_1=MaxxVitCfg(
  1692. embed_dim=(96, 192, 384, 768),
  1693. depths=(2, 6, 14, 2),
  1694. stem_width=64,
  1695. head_hidden_size=768,
  1696. ),
  1697. coatnet_2=MaxxVitCfg(
  1698. embed_dim=(128, 256, 512, 1024),
  1699. depths=(2, 6, 14, 2),
  1700. stem_width=128,
  1701. head_hidden_size=1024,
  1702. ),
  1703. coatnet_3=MaxxVitCfg(
  1704. embed_dim=(192, 384, 768, 1536),
  1705. depths=(2, 6, 14, 2),
  1706. stem_width=192,
  1707. head_hidden_size=1536,
  1708. ),
  1709. coatnet_4=MaxxVitCfg(
  1710. embed_dim=(192, 384, 768, 1536),
  1711. depths=(2, 12, 28, 2),
  1712. stem_width=192,
  1713. head_hidden_size=1536,
  1714. ),
  1715. coatnet_5=MaxxVitCfg(
  1716. embed_dim=(256, 512, 1280, 2048),
  1717. depths=(2, 12, 28, 2),
  1718. stem_width=192,
  1719. head_hidden_size=2048,
  1720. ),
  1721. # Experimental MaxVit configs
  1722. maxvit_pico_rw=MaxxVitCfg(
  1723. embed_dim=(32, 64, 128, 256),
  1724. depths=(2, 2, 5, 2),
  1725. block_type=('M',) * 4,
  1726. stem_width=(24, 32),
  1727. **_rw_max_cfg(),
  1728. ),
  1729. maxvit_nano_rw=MaxxVitCfg(
  1730. embed_dim=(64, 128, 256, 512),
  1731. depths=(1, 2, 3, 1),
  1732. block_type=('M',) * 4,
  1733. stem_width=(32, 64),
  1734. **_rw_max_cfg(),
  1735. ),
  1736. maxvit_tiny_rw=MaxxVitCfg(
  1737. embed_dim=(64, 128, 256, 512),
  1738. depths=(2, 2, 5, 2),
  1739. block_type=('M',) * 4,
  1740. stem_width=(32, 64),
  1741. **_rw_max_cfg(),
  1742. ),
  1743. maxvit_tiny_pm=MaxxVitCfg(
  1744. embed_dim=(64, 128, 256, 512),
  1745. depths=(2, 2, 5, 2),
  1746. block_type=('PM',) * 4,
  1747. stem_width=(32, 64),
  1748. **_rw_max_cfg(),
  1749. ),
  1750. maxvit_rmlp_pico_rw=MaxxVitCfg(
  1751. embed_dim=(32, 64, 128, 256),
  1752. depths=(2, 2, 5, 2),
  1753. block_type=('M',) * 4,
  1754. stem_width=(24, 32),
  1755. **_rw_max_cfg(rel_pos_type='mlp'),
  1756. ),
  1757. maxvit_rmlp_nano_rw=MaxxVitCfg(
  1758. embed_dim=(64, 128, 256, 512),
  1759. depths=(1, 2, 3, 1),
  1760. block_type=('M',) * 4,
  1761. stem_width=(32, 64),
  1762. **_rw_max_cfg(rel_pos_type='mlp'),
  1763. ),
  1764. maxvit_rmlp_tiny_rw=MaxxVitCfg(
  1765. embed_dim=(64, 128, 256, 512),
  1766. depths=(2, 2, 5, 2),
  1767. block_type=('M',) * 4,
  1768. stem_width=(32, 64),
  1769. **_rw_max_cfg(rel_pos_type='mlp'),
  1770. ),
  1771. maxvit_rmlp_small_rw=MaxxVitCfg(
  1772. embed_dim=(96, 192, 384, 768),
  1773. depths=(2, 2, 5, 2),
  1774. block_type=('M',) * 4,
  1775. stem_width=(32, 64),
  1776. **_rw_max_cfg(
  1777. rel_pos_type='mlp',
  1778. init_values=1e-6,
  1779. ),
  1780. ),
  1781. maxvit_rmlp_base_rw=MaxxVitCfg(
  1782. embed_dim=(96, 192, 384, 768),
  1783. depths=(2, 6, 14, 2),
  1784. block_type=('M',) * 4,
  1785. stem_width=(32, 64),
  1786. head_hidden_size=768,
  1787. **_rw_max_cfg(
  1788. rel_pos_type='mlp',
  1789. ),
  1790. ),
  1791. maxxvit_rmlp_nano_rw=MaxxVitCfg(
  1792. embed_dim=(64, 128, 256, 512),
  1793. depths=(1, 2, 3, 1),
  1794. block_type=('M',) * 4,
  1795. stem_width=(32, 64),
  1796. weight_init='normal',
  1797. **_next_cfg(),
  1798. ),
  1799. maxxvit_rmlp_tiny_rw=MaxxVitCfg(
  1800. embed_dim=(64, 128, 256, 512),
  1801. depths=(2, 2, 5, 2),
  1802. block_type=('M',) * 4,
  1803. stem_width=(32, 64),
  1804. **_next_cfg(),
  1805. ),
  1806. maxxvit_rmlp_small_rw=MaxxVitCfg(
  1807. embed_dim=(96, 192, 384, 768),
  1808. depths=(2, 2, 5, 2),
  1809. block_type=('M',) * 4,
  1810. stem_width=(48, 96),
  1811. **_next_cfg(),
  1812. ),
  1813. maxxvitv2_nano_rw=MaxxVitCfg(
  1814. embed_dim=(96, 192, 384, 768),
  1815. depths=(1, 2, 3, 1),
  1816. block_type=('M',) * 4,
  1817. stem_width=(48, 96),
  1818. weight_init='normal',
  1819. **_next_cfg(
  1820. no_block_attn=True,
  1821. rel_pos_type='bias',
  1822. ),
  1823. ),
  1824. maxxvitv2_rmlp_base_rw=MaxxVitCfg(
  1825. embed_dim=(128, 256, 512, 1024),
  1826. depths=(2, 6, 12, 2),
  1827. block_type=('M',) * 4,
  1828. stem_width=(64, 128),
  1829. **_next_cfg(
  1830. no_block_attn=True,
  1831. ),
  1832. ),
  1833. maxxvitv2_rmlp_large_rw=MaxxVitCfg(
  1834. embed_dim=(160, 320, 640, 1280),
  1835. depths=(2, 6, 16, 2),
  1836. block_type=('M',) * 4,
  1837. stem_width=(80, 160),
  1838. head_hidden_size=1280,
  1839. **_next_cfg(
  1840. no_block_attn=True,
  1841. ),
  1842. ),
  1843. # Trying to be like the MaxViT paper configs
  1844. maxvit_tiny_tf=MaxxVitCfg(
  1845. embed_dim=(64, 128, 256, 512),
  1846. depths=(2, 2, 5, 2),
  1847. block_type=('M',) * 4,
  1848. stem_width=64,
  1849. stem_bias=True,
  1850. head_hidden_size=512,
  1851. **_tf_cfg(),
  1852. ),
  1853. maxvit_small_tf=MaxxVitCfg(
  1854. embed_dim=(96, 192, 384, 768),
  1855. depths=(2, 2, 5, 2),
  1856. block_type=('M',) * 4,
  1857. stem_width=64,
  1858. stem_bias=True,
  1859. head_hidden_size=768,
  1860. **_tf_cfg(),
  1861. ),
  1862. maxvit_base_tf=MaxxVitCfg(
  1863. embed_dim=(96, 192, 384, 768),
  1864. depths=(2, 6, 14, 2),
  1865. block_type=('M',) * 4,
  1866. stem_width=64,
  1867. stem_bias=True,
  1868. head_hidden_size=768,
  1869. **_tf_cfg(),
  1870. ),
  1871. maxvit_large_tf=MaxxVitCfg(
  1872. embed_dim=(128, 256, 512, 1024),
  1873. depths=(2, 6, 14, 2),
  1874. block_type=('M',) * 4,
  1875. stem_width=128,
  1876. stem_bias=True,
  1877. head_hidden_size=1024,
  1878. **_tf_cfg(),
  1879. ),
  1880. maxvit_xlarge_tf=MaxxVitCfg(
  1881. embed_dim=(192, 384, 768, 1536),
  1882. depths=(2, 6, 14, 2),
  1883. block_type=('M',) * 4,
  1884. stem_width=192,
  1885. stem_bias=True,
  1886. head_hidden_size=1536,
  1887. **_tf_cfg(),
  1888. ),
  1889. )
  1890. def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
  1891. """Filter checkpoint state dict for compatibility."""
  1892. model_state_dict = model.state_dict()
  1893. out_dict = {}
  1894. for k, v in state_dict.items():
  1895. if k.endswith('relative_position_bias_table'):
  1896. m = model.get_submodule(k[:-29])
  1897. if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
  1898. v = resize_rel_pos_bias_table(
  1899. v,
  1900. new_window_size=m.window_size,
  1901. new_bias_shape=m.relative_position_bias_table.shape,
  1902. )
  1903. if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel():
  1904. # adapt between conv2d / linear layers
  1905. assert v.ndim in (2, 4)
  1906. v = v.reshape(model_state_dict[k].shape)
  1907. out_dict[k] = v
  1908. return out_dict
  1909. def _create_maxxvit(variant: str, cfg_variant: Optional[str] = None, pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  1910. """Create a MaxxVit model variant."""
  1911. if cfg_variant is None:
  1912. if variant in model_cfgs:
  1913. cfg_variant = variant
  1914. else:
  1915. cfg_variant = '_'.join(variant.split('_')[:-1])
  1916. return build_model_with_cfg(
  1917. MaxxVit, variant, pretrained,
  1918. model_cfg=model_cfgs[cfg_variant],
  1919. feature_cfg=dict(flatten_sequential=True),
  1920. pretrained_filter_fn=checkpoint_filter_fn,
  1921. **kwargs)
  1922. def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
  1923. """Create a default configuration dict."""
  1924. return {
  1925. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  1926. 'crop_pct': 0.95, 'interpolation': 'bicubic',
  1927. 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
  1928. 'first_conv': 'stem.conv1', 'classifier': 'head.fc',
  1929. 'fixed_input_size': True,
  1930. 'license': 'apache-2.0', **kwargs
  1931. }
  1932. default_cfgs = generate_default_cfgs({
  1933. # timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos
  1934. 'coatnet_pico_rw_224.untrained': _cfg(url=''),
  1935. 'coatnet_nano_rw_224.sw_in1k': _cfg(
  1936. hf_hub_id='timm/',
  1937. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
  1938. crop_pct=0.9),
  1939. 'coatnet_0_rw_224.sw_in1k': _cfg(
  1940. hf_hub_id='timm/',
  1941. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
  1942. 'coatnet_1_rw_224.sw_in1k': _cfg(
  1943. hf_hub_id='timm/',
  1944. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
  1945. ),
  1946. # timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos
  1947. 'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg(
  1948. hf_hub_id='timm/'),
  1949. #'coatnet_3_rw_224.untrained': _cfg(url=''),
  1950. # Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos)
  1951. 'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg(
  1952. hf_hub_id='timm/'),
  1953. 'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg(
  1954. hf_hub_id='timm/'),
  1955. 'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg(
  1956. hf_hub_id='timm/',
  1957. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  1958. # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
  1959. 'coatnet_bn_0_rw_224.sw_in1k': _cfg(
  1960. hf_hub_id='timm/',
  1961. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
  1962. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1963. crop_pct=0.95),
  1964. 'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg(
  1965. hf_hub_id='timm/',
  1966. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
  1967. crop_pct=0.9),
  1968. 'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''),
  1969. 'coatnet_rmlp_1_rw_224.sw_in1k': _cfg(
  1970. hf_hub_id='timm/',
  1971. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
  1972. 'coatnet_rmlp_2_rw_224.sw_in1k': _cfg(
  1973. hf_hub_id='timm/',
  1974. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'),
  1975. 'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''),
  1976. 'coatnet_nano_cc_224.untrained': _cfg(url=''),
  1977. 'coatnext_nano_rw_224.sw_in1k': _cfg(
  1978. hf_hub_id='timm/',
  1979. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth',
  1980. crop_pct=0.9),
  1981. # ImagenNet-12k pretrain CoAtNet
  1982. 'coatnet_2_rw_224.sw_in12k': _cfg(
  1983. hf_hub_id='timm/',
  1984. num_classes=11821),
  1985. 'coatnet_3_rw_224.sw_in12k': _cfg(
  1986. hf_hub_id='timm/',
  1987. num_classes=11821),
  1988. 'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg(
  1989. hf_hub_id='timm/',
  1990. num_classes=11821),
  1991. 'coatnet_rmlp_2_rw_224.sw_in12k': _cfg(
  1992. hf_hub_id='timm/',
  1993. num_classes=11821),
  1994. # Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released)
  1995. 'coatnet_0_224.untrained': _cfg(url=''),
  1996. 'coatnet_1_224.untrained': _cfg(url=''),
  1997. 'coatnet_2_224.untrained': _cfg(url=''),
  1998. 'coatnet_3_224.untrained': _cfg(url=''),
  1999. 'coatnet_4_224.untrained': _cfg(url=''),
  2000. 'coatnet_5_224.untrained': _cfg(url=''),
  2001. # timm specific MaxVit configs, ImageNet-1k pretrain or untrained
  2002. 'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
  2003. 'maxvit_nano_rw_256.sw_in1k': _cfg(
  2004. hf_hub_id='timm/',
  2005. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
  2006. input_size=(3, 256, 256), pool_size=(8, 8)),
  2007. 'maxvit_tiny_rw_224.sw_in1k': _cfg(
  2008. hf_hub_id='timm/',
  2009. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'),
  2010. 'maxvit_tiny_rw_256.untrained': _cfg(
  2011. url='',
  2012. input_size=(3, 256, 256), pool_size=(8, 8)),
  2013. 'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
  2014. # timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain
  2015. 'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg(
  2016. hf_hub_id='timm/',
  2017. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth',
  2018. input_size=(3, 256, 256), pool_size=(8, 8)),
  2019. 'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
  2020. hf_hub_id='timm/',
  2021. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
  2022. input_size=(3, 256, 256), pool_size=(8, 8)),
  2023. 'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg(
  2024. hf_hub_id='timm/',
  2025. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth',
  2026. input_size=(3, 256, 256), pool_size=(8, 8)),
  2027. 'maxvit_rmlp_small_rw_224.sw_in1k': _cfg(
  2028. hf_hub_id='timm/',
  2029. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth',
  2030. crop_pct=0.9,
  2031. ),
  2032. 'maxvit_rmlp_small_rw_256.untrained': _cfg(
  2033. url='',
  2034. input_size=(3, 256, 256), pool_size=(8, 8)),
  2035. # timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune
  2036. 'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
  2037. hf_hub_id='timm/',
  2038. ),
  2039. 'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
  2040. hf_hub_id='timm/',
  2041. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  2042. # timm specific MaxVit w/ ImageNet-12k pretrain
  2043. 'maxvit_rmlp_base_rw_224.sw_in12k': _cfg(
  2044. hf_hub_id='timm/',
  2045. num_classes=11821,
  2046. ),
  2047. # timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks)
  2048. 'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
  2049. hf_hub_id='timm/',
  2050. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth',
  2051. input_size=(3, 256, 256), pool_size=(8, 8)),
  2052. 'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
  2053. 'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg(
  2054. hf_hub_id='timm/',
  2055. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth',
  2056. input_size=(3, 256, 256), pool_size=(8, 8)),
  2057. # timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn)
  2058. 'maxxvitv2_nano_rw_256.sw_in1k': _cfg(
  2059. hf_hub_id='timm/',
  2060. input_size=(3, 256, 256), pool_size=(8, 8)),
  2061. 'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
  2062. hf_hub_id='timm/'),
  2063. 'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
  2064. hf_hub_id='timm/',
  2065. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  2066. 'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''),
  2067. 'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg(
  2068. hf_hub_id='timm/',
  2069. num_classes=11821),
  2070. # MaxViT models ported from official Tensorflow impl
  2071. 'maxvit_tiny_tf_224.in1k': _cfg(
  2072. hf_hub_id='timm/',
  2073. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  2074. 'maxvit_tiny_tf_384.in1k': _cfg(
  2075. hf_hub_id='timm/',
  2076. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  2077. 'maxvit_tiny_tf_512.in1k': _cfg(
  2078. hf_hub_id='timm/',
  2079. input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
  2080. 'maxvit_small_tf_224.in1k': _cfg(
  2081. hf_hub_id='timm/',
  2082. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  2083. 'maxvit_small_tf_384.in1k': _cfg(
  2084. hf_hub_id='timm/',
  2085. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  2086. 'maxvit_small_tf_512.in1k': _cfg(
  2087. hf_hub_id='timm/',
  2088. input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
  2089. 'maxvit_base_tf_224.in1k': _cfg(
  2090. hf_hub_id='timm/',
  2091. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  2092. 'maxvit_base_tf_384.in1k': _cfg(
  2093. hf_hub_id='timm/',
  2094. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  2095. 'maxvit_base_tf_512.in1k': _cfg(
  2096. hf_hub_id='timm/',
  2097. input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
  2098. 'maxvit_large_tf_224.in1k': _cfg(
  2099. hf_hub_id='timm/',
  2100. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  2101. 'maxvit_large_tf_384.in1k': _cfg(
  2102. hf_hub_id='timm/',
  2103. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  2104. 'maxvit_large_tf_512.in1k': _cfg(
  2105. hf_hub_id='timm/',
  2106. input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
  2107. 'maxvit_base_tf_224.in21k': _cfg(
  2108. hf_hub_id='timm/',
  2109. num_classes=21843),
  2110. 'maxvit_base_tf_384.in21k_ft_in1k': _cfg(
  2111. hf_hub_id='timm/',
  2112. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  2113. 'maxvit_base_tf_512.in21k_ft_in1k': _cfg(
  2114. hf_hub_id='timm/',
  2115. input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
  2116. 'maxvit_large_tf_224.in21k': _cfg(
  2117. hf_hub_id='timm/',
  2118. num_classes=21843),
  2119. 'maxvit_large_tf_384.in21k_ft_in1k': _cfg(
  2120. hf_hub_id='timm/',
  2121. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  2122. 'maxvit_large_tf_512.in21k_ft_in1k': _cfg(
  2123. hf_hub_id='timm/',
  2124. input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
  2125. 'maxvit_xlarge_tf_224.in21k': _cfg(
  2126. hf_hub_id='timm/',
  2127. num_classes=21843),
  2128. 'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg(
  2129. hf_hub_id='timm/',
  2130. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  2131. 'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg(
  2132. hf_hub_id='timm/',
  2133. input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
  2134. })
  2135. @register_model
  2136. def coatnet_pico_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2137. """CoatNet Pico model with RW configuration."""
  2138. return _create_maxxvit('coatnet_pico_rw_224', pretrained=pretrained, **kwargs)
  2139. @register_model
  2140. def coatnet_nano_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2141. """CoatNet Nano model with RW configuration."""
  2142. return _create_maxxvit('coatnet_nano_rw_224', pretrained=pretrained, **kwargs)
  2143. @register_model
  2144. def coatnet_0_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2145. """CoatNet-0 model with RW configuration."""
  2146. return _create_maxxvit('coatnet_0_rw_224', pretrained=pretrained, **kwargs)
  2147. @register_model
  2148. def coatnet_1_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2149. """CoatNet-1 model with RW configuration."""
  2150. return _create_maxxvit('coatnet_1_rw_224', pretrained=pretrained, **kwargs)
  2151. @register_model
  2152. def coatnet_2_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2153. """CoatNet-2 model with RW configuration."""
  2154. return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs)
  2155. @register_model
  2156. def coatnet_3_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2157. """CoatNet-3 model with RW configuration."""
  2158. return _create_maxxvit('coatnet_3_rw_224', pretrained=pretrained, **kwargs)
  2159. @register_model
  2160. def coatnet_bn_0_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2161. """CoatNet-0 model with BatchNorm and RW configuration."""
  2162. return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs)
  2163. @register_model
  2164. def coatnet_rmlp_nano_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2165. """CoatNet Nano model with Relative Position MLP."""
  2166. return _create_maxxvit('coatnet_rmlp_nano_rw_224', pretrained=pretrained, **kwargs)
  2167. @register_model
  2168. def coatnet_rmlp_0_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2169. """CoatNet-0 model with Relative Position MLP."""
  2170. return _create_maxxvit('coatnet_rmlp_0_rw_224', pretrained=pretrained, **kwargs)
  2171. @register_model
  2172. def coatnet_rmlp_1_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2173. """CoatNet-1 model with Relative Position MLP."""
  2174. return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs)
  2175. @register_model
  2176. def coatnet_rmlp_1_rw2_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2177. """CoatNet-1 model with Relative Position MLP v2."""
  2178. return _create_maxxvit('coatnet_rmlp_1_rw2_224', pretrained=pretrained, **kwargs)
  2179. @register_model
  2180. def coatnet_rmlp_2_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2181. """CoatNet-2 model with Relative Position MLP."""
  2182. return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs)
  2183. @register_model
  2184. def coatnet_rmlp_2_rw_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2185. """CoatNet-2 model with Relative Position MLP at 384x384."""
  2186. return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs)
  2187. @register_model
  2188. def coatnet_rmlp_3_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2189. """CoatNet-3 model with Relative Position MLP."""
  2190. return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs)
  2191. @register_model
  2192. def coatnet_nano_cc_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2193. """CoatNet Nano model with ConvNeXt blocks."""
  2194. return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs)
  2195. @register_model
  2196. def coatnext_nano_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2197. """CoAtNeXt Nano model with RW configuration."""
  2198. return _create_maxxvit('coatnext_nano_rw_224', pretrained=pretrained, **kwargs)
  2199. @register_model
  2200. def coatnet_0_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2201. """CoatNet-0 model."""
  2202. return _create_maxxvit('coatnet_0_224', pretrained=pretrained, **kwargs)
  2203. @register_model
  2204. def coatnet_1_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2205. """CoatNet-1 model."""
  2206. return _create_maxxvit('coatnet_1_224', pretrained=pretrained, **kwargs)
  2207. @register_model
  2208. def coatnet_2_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2209. """CoatNet-2 model."""
  2210. return _create_maxxvit('coatnet_2_224', pretrained=pretrained, **kwargs)
  2211. @register_model
  2212. def coatnet_3_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2213. """CoatNet-3 model."""
  2214. return _create_maxxvit('coatnet_3_224', pretrained=pretrained, **kwargs)
  2215. @register_model
  2216. def coatnet_4_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2217. """CoatNet-4 model."""
  2218. return _create_maxxvit('coatnet_4_224', pretrained=pretrained, **kwargs)
  2219. @register_model
  2220. def coatnet_5_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2221. """CoatNet-5 model."""
  2222. return _create_maxxvit('coatnet_5_224', pretrained=pretrained, **kwargs)
  2223. @register_model
  2224. def maxvit_pico_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2225. """MaxViT Pico model with RW configuration."""
  2226. return _create_maxxvit('maxvit_pico_rw_256', pretrained=pretrained, **kwargs)
  2227. @register_model
  2228. def maxvit_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2229. """MaxViT Nano model with RW configuration."""
  2230. return _create_maxxvit('maxvit_nano_rw_256', pretrained=pretrained, **kwargs)
  2231. @register_model
  2232. def maxvit_tiny_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2233. """MaxViT Tiny model with RW configuration."""
  2234. return _create_maxxvit('maxvit_tiny_rw_224', pretrained=pretrained, **kwargs)
  2235. @register_model
  2236. def maxvit_tiny_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2237. """MaxViT Tiny model with RW configuration at 256x256."""
  2238. return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
  2239. @register_model
  2240. def maxvit_rmlp_pico_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2241. """MaxViT Relative Position MLP Pico RW 256x256 model."""
  2242. return _create_maxxvit('maxvit_rmlp_pico_rw_256', pretrained=pretrained, **kwargs)
  2243. @register_model
  2244. def maxvit_rmlp_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2245. """MaxViT Relative Position MLP Nano RW 256x256 model."""
  2246. return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
  2247. @register_model
  2248. def maxvit_rmlp_tiny_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2249. """MaxViT Relative Position MLP Tiny RW 256x256 model."""
  2250. return _create_maxxvit('maxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs)
  2251. @register_model
  2252. def maxvit_rmlp_small_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2253. """MaxViT Relative Position MLP Small RW 224x224 model."""
  2254. return _create_maxxvit('maxvit_rmlp_small_rw_224', pretrained=pretrained, **kwargs)
  2255. @register_model
  2256. def maxvit_rmlp_small_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2257. """MaxViT Small model with Relative Position MLP at 256x256."""
  2258. return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
  2259. @register_model
  2260. def maxvit_rmlp_base_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2261. """MaxViT Base model with Relative Position MLP."""
  2262. return _create_maxxvit('maxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
  2263. @register_model
  2264. def maxvit_rmlp_base_rw_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2265. """MaxViT Base model with Relative Position MLP at 384x384."""
  2266. return _create_maxxvit('maxvit_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
  2267. @register_model
  2268. def maxvit_tiny_pm_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2269. """MaxViT Tiny model with parallel blocks."""
  2270. return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
  2271. @register_model
  2272. def maxxvit_rmlp_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2273. """MaxxViT Relative Position MLP Nano RW 256x256 model."""
  2274. return _create_maxxvit('maxxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
  2275. @register_model
  2276. def maxxvit_rmlp_tiny_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2277. """MaxxViT Tiny model with Relative Position MLP."""
  2278. return _create_maxxvit('maxxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs)
  2279. @register_model
  2280. def maxxvit_rmlp_small_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2281. """MaxxViT Small model with Relative Position MLP."""
  2282. return _create_maxxvit('maxxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
  2283. @register_model
  2284. def maxxvitv2_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2285. """MaxxViT-V2 Nano model."""
  2286. return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs)
  2287. @register_model
  2288. def maxxvitv2_rmlp_base_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2289. """MaxxViT-V2 Base model with Relative Position MLP."""
  2290. return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
  2291. @register_model
  2292. def maxxvitv2_rmlp_base_rw_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2293. """MaxxViT-V2 Base model with Relative Position MLP at 384x384."""
  2294. return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
  2295. @register_model
  2296. def maxxvitv2_rmlp_large_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2297. """MaxxViT-V2 Large model with Relative Position MLP."""
  2298. return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs)
  2299. @register_model
  2300. def maxvit_tiny_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2301. """MaxViT Tiny model from TensorFlow."""
  2302. return _create_maxxvit('maxvit_tiny_tf_224', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)
  2303. @register_model
  2304. def maxvit_tiny_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2305. """MaxViT Tiny model from TensorFlow at 384x384."""
  2306. return _create_maxxvit('maxvit_tiny_tf_384', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)
  2307. @register_model
  2308. def maxvit_tiny_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2309. """MaxViT Tiny model from TensorFlow at 512x512."""
  2310. return _create_maxxvit('maxvit_tiny_tf_512', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)
  2311. @register_model
  2312. def maxvit_small_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2313. """MaxViT Small model from TensorFlow."""
  2314. return _create_maxxvit('maxvit_small_tf_224', 'maxvit_small_tf', pretrained=pretrained, **kwargs)
  2315. @register_model
  2316. def maxvit_small_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2317. """MaxViT Small model from TensorFlow at 384x384."""
  2318. return _create_maxxvit('maxvit_small_tf_384', 'maxvit_small_tf', pretrained=pretrained, **kwargs)
  2319. @register_model
  2320. def maxvit_small_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2321. """MaxViT Small model from TensorFlow at 512x512."""
  2322. return _create_maxxvit('maxvit_small_tf_512', 'maxvit_small_tf', pretrained=pretrained, **kwargs)
  2323. @register_model
  2324. def maxvit_base_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2325. """MaxViT Base model from TensorFlow."""
  2326. return _create_maxxvit('maxvit_base_tf_224', 'maxvit_base_tf', pretrained=pretrained, **kwargs)
  2327. @register_model
  2328. def maxvit_base_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2329. """MaxViT Base model from TensorFlow at 384x384."""
  2330. return _create_maxxvit('maxvit_base_tf_384', 'maxvit_base_tf', pretrained=pretrained, **kwargs)
  2331. @register_model
  2332. def maxvit_base_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2333. """MaxViT Base model from TensorFlow at 512x512."""
  2334. return _create_maxxvit('maxvit_base_tf_512', 'maxvit_base_tf', pretrained=pretrained, **kwargs)
  2335. @register_model
  2336. def maxvit_large_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2337. """MaxViT Large model from TensorFlow."""
  2338. return _create_maxxvit('maxvit_large_tf_224', 'maxvit_large_tf', pretrained=pretrained, **kwargs)
  2339. @register_model
  2340. def maxvit_large_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2341. """MaxViT Large model from TensorFlow at 384x384."""
  2342. return _create_maxxvit('maxvit_large_tf_384', 'maxvit_large_tf', pretrained=pretrained, **kwargs)
  2343. @register_model
  2344. def maxvit_large_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2345. """MaxViT Large model from TensorFlow at 512x512."""
  2346. return _create_maxxvit('maxvit_large_tf_512', 'maxvit_large_tf', pretrained=pretrained, **kwargs)
  2347. @register_model
  2348. def maxvit_xlarge_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2349. """MaxViT XLarge model from TensorFlow."""
  2350. return _create_maxxvit('maxvit_xlarge_tf_224', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)
  2351. @register_model
  2352. def maxvit_xlarge_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2353. """MaxViT XLarge model from TensorFlow at 384x384."""
  2354. return _create_maxxvit('maxvit_xlarge_tf_384', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)
  2355. @register_model
  2356. def maxvit_xlarge_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit:
  2357. """MaxViT XLarge model from TensorFlow at 512x512."""
  2358. return _create_maxxvit('maxvit_xlarge_tf_512', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)