metaformer.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183
  1. """
  2. Poolformer from MetaFormer is Actually What You Need for Vision https://arxiv.org/abs/2111.11418
  3. IdentityFormer, RandFormer, PoolFormerV2, ConvFormer, and CAFormer
  4. from MetaFormer Baselines for Vision https://arxiv.org/abs/2210.13452
  5. All implemented models support feature extraction and variable input resolution.
  6. Original implementation by Weihao Yu et al.,
  7. adapted for timm by Fredo Guan and Ross Wightman.
  8. Adapted from https://github.com/sail-sg/metaformer, original copyright below
  9. """
  10. # Copyright 2022 Garena Online Private Limited
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. from collections import OrderedDict
  24. from functools import partial
  25. from typing import List, Optional, Tuple, Union, Type
  26. import torch
  27. import torch.nn as nn
  28. import torch.nn.functional as F
  29. from torch import Tensor
  30. from torch.jit import Final
  31. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  32. from timm.layers import (
  33. trunc_normal_,
  34. DropPath,
  35. calculate_drop_path_rates,
  36. SelectAdaptivePool2d,
  37. GroupNorm1,
  38. LayerNorm,
  39. LayerNorm2d,
  40. Mlp,
  41. use_fused_attn,
  42. )
  43. from ._builder import build_model_with_cfg
  44. from ._features import feature_take_indices
  45. from ._manipulate import checkpoint, checkpoint_seq
  46. from ._registry import generate_default_cfgs, register_model
  47. __all__ = ['MetaFormer']
  48. class Stem(nn.Module):
  49. """
  50. Stem implemented by a layer of convolution.
  51. Conv2d params constant across all models.
  52. """
  53. def __init__(
  54. self,
  55. in_channels: int,
  56. out_channels: int,
  57. norm_layer: Optional[Type[nn.Module]] = None,
  58. device=None,
  59. dtype=None,
  60. ):
  61. dd = {'device': device, 'dtype': dtype}
  62. super().__init__()
  63. self.conv = nn.Conv2d(
  64. in_channels,
  65. out_channels,
  66. kernel_size=7,
  67. stride=4,
  68. padding=2,
  69. **dd,
  70. )
  71. self.norm = norm_layer(out_channels, **dd) if norm_layer else nn.Identity()
  72. def forward(self, x):
  73. x = self.conv(x)
  74. x = self.norm(x)
  75. return x
  76. class Downsampling(nn.Module):
  77. """
  78. Downsampling implemented by a layer of convolution.
  79. """
  80. def __init__(
  81. self,
  82. in_channels: int,
  83. out_channels: int,
  84. kernel_size: int,
  85. stride: int = 1,
  86. padding: int = 0,
  87. norm_layer: Optional[Type[nn.Module]] = None,
  88. device=None,
  89. dtype=None,
  90. ):
  91. dd = {'device': device, 'dtype': dtype}
  92. super().__init__()
  93. self.norm = norm_layer(in_channels, **dd) if norm_layer else nn.Identity()
  94. self.conv = nn.Conv2d(
  95. in_channels,
  96. out_channels,
  97. kernel_size=kernel_size,
  98. stride=stride,
  99. padding=padding,
  100. **dd
  101. )
  102. def forward(self, x):
  103. x = self.norm(x)
  104. x = self.conv(x)
  105. return x
  106. class Scale(nn.Module):
  107. """
  108. Scale vector by element multiplications.
  109. """
  110. def __init__(
  111. self,
  112. dim: int,
  113. init_value: float = 1.0,
  114. trainable: bool = True,
  115. use_nchw: bool = True,
  116. device=None,
  117. dtype=None,
  118. ):
  119. dd = {'device': device, 'dtype': dtype}
  120. super().__init__()
  121. self.shape = (dim, 1, 1) if use_nchw else (dim,)
  122. self.scale = nn.Parameter(init_value * torch.ones(dim, **dd), requires_grad=trainable)
  123. def forward(self, x):
  124. return x * self.scale.view(self.shape)
  125. class SquaredReLU(nn.Module):
  126. """
  127. Squared ReLU: https://arxiv.org/abs/2109.08668
  128. """
  129. def __init__(self, inplace: bool = False):
  130. super().__init__()
  131. self.relu = nn.ReLU(inplace=inplace)
  132. def forward(self, x):
  133. return torch.square(self.relu(x))
  134. class StarReLU(nn.Module):
  135. """
  136. StarReLU: s * relu(x) ** 2 + b
  137. """
  138. def __init__(
  139. self,
  140. scale_value: float = 1.0,
  141. bias_value: float = 0.0,
  142. scale_learnable: bool = True,
  143. bias_learnable: bool = True,
  144. mode: Optional[str] = None,
  145. inplace: bool = False,
  146. device=None,
  147. dtype=None,
  148. ):
  149. dd = {'device': device, 'dtype': dtype}
  150. super().__init__()
  151. self.inplace = inplace
  152. self.relu = nn.ReLU(inplace=inplace)
  153. self.scale = nn.Parameter(scale_value * torch.ones(1, **dd), requires_grad=scale_learnable)
  154. self.bias = nn.Parameter(bias_value * torch.ones(1, **dd), requires_grad=bias_learnable)
  155. def forward(self, x):
  156. return self.scale * self.relu(x) ** 2 + self.bias
  157. class Attention(nn.Module):
  158. """
  159. Vanilla self-attention from Transformer: https://arxiv.org/abs/1706.03762.
  160. Modified from timm.
  161. """
  162. fused_attn: Final[bool]
  163. def __init__(
  164. self,
  165. dim: int,
  166. head_dim: int = 32,
  167. num_heads: Optional[int] = None,
  168. qkv_bias: bool = False,
  169. attn_drop: float = 0.,
  170. proj_drop: float = 0.,
  171. proj_bias: bool = False,
  172. device=None,
  173. dtype=None,
  174. **kwargs
  175. ):
  176. dd = {'device': device, 'dtype': dtype}
  177. super().__init__()
  178. self.head_dim = head_dim
  179. self.scale = head_dim ** -0.5
  180. self.fused_attn = use_fused_attn()
  181. self.num_heads = num_heads if num_heads else dim // head_dim
  182. if self.num_heads == 0:
  183. self.num_heads = 1
  184. self.attention_dim = self.num_heads * self.head_dim
  185. self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias, **dd)
  186. self.attn_drop = nn.Dropout(attn_drop)
  187. self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias, **dd)
  188. self.proj_drop = nn.Dropout(proj_drop)
  189. def forward(self, x):
  190. B, N, C = x.shape
  191. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  192. q, k, v = qkv.unbind(0)
  193. if self.fused_attn:
  194. x = F.scaled_dot_product_attention(
  195. q, k, v,
  196. dropout_p=self.attn_drop.p if self.training else 0.,
  197. )
  198. else:
  199. attn = (q @ k.transpose(-2, -1)) * self.scale
  200. attn = attn.softmax(dim=-1)
  201. attn = self.attn_drop(attn)
  202. x = attn @ v
  203. x = x.transpose(1, 2).reshape(B, N, C)
  204. x = self.proj(x)
  205. x = self.proj_drop(x)
  206. return x
  207. # custom norm modules that disable the bias term, since the original models defs
  208. # used a custom norm with a weight term but no bias term.
  209. class GroupNorm1NoBias(GroupNorm1):
  210. def __init__(self, num_channels: int, **kwargs):
  211. super().__init__(num_channels, **kwargs)
  212. self.eps = kwargs.get('eps', 1e-6)
  213. self.bias = None
  214. class LayerNorm2dNoBias(LayerNorm2d):
  215. def __init__(self, num_channels: int, **kwargs):
  216. super().__init__(num_channels, **kwargs)
  217. self.eps = kwargs.get('eps', 1e-6)
  218. self.bias = None
  219. class LayerNormNoBias(nn.LayerNorm):
  220. def __init__(self, num_channels: int, **kwargs):
  221. super().__init__(num_channels, **kwargs)
  222. self.eps = kwargs.get('eps', 1e-6)
  223. self.bias = None
  224. class SepConv(nn.Module):
  225. r"""
  226. Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381.
  227. """
  228. def __init__(
  229. self,
  230. dim: int,
  231. expansion_ratio: float = 2,
  232. act1_layer: Type[nn.Module] = StarReLU,
  233. act2_layer: Type[nn.Module] = nn.Identity,
  234. bias: bool = False,
  235. kernel_size: int = 7,
  236. padding: int = 3,
  237. device=None,
  238. dtype=None,
  239. **kwargs
  240. ):
  241. dd = {'device': device, 'dtype': dtype}
  242. super().__init__()
  243. mid_channels = int(expansion_ratio * dim)
  244. self.pwconv1 = nn.Conv2d(dim, mid_channels, kernel_size=1, bias=bias, **dd)
  245. self.act1 = act1_layer(**dd) if issubclass(act1_layer, StarReLU) else act1_layer()
  246. self.dwconv = nn.Conv2d(
  247. mid_channels,
  248. mid_channels,
  249. kernel_size=kernel_size,
  250. padding=padding,
  251. groups=mid_channels,
  252. bias=bias,
  253. **dd,
  254. ) # depthwise conv
  255. self.act2 = act2_layer(**dd) if issubclass(act2_layer, StarReLU) else act2_layer()
  256. self.pwconv2 = nn.Conv2d(mid_channels, dim, kernel_size=1, bias=bias, **dd)
  257. def forward(self, x):
  258. x = self.pwconv1(x)
  259. x = self.act1(x)
  260. x = self.dwconv(x)
  261. x = self.act2(x)
  262. x = self.pwconv2(x)
  263. return x
  264. class Pooling(nn.Module):
  265. """
  266. Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418
  267. """
  268. def __init__(self, pool_size: int = 3, **kwargs):
  269. super().__init__()
  270. self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)
  271. def forward(self, x):
  272. y = self.pool(x)
  273. return y - x
  274. class MlpHead(nn.Module):
  275. """ MLP classification head
  276. """
  277. def __init__(
  278. self,
  279. dim: int,
  280. num_classes: int = 1000,
  281. mlp_ratio: float = 4,
  282. act_layer: Type[nn.Module] = SquaredReLU,
  283. norm_layer: Type[nn.Module] = LayerNorm,
  284. drop_rate: float = 0.,
  285. bias: bool = True,
  286. device=None,
  287. dtype=None,
  288. ):
  289. dd = {'device': device, 'dtype': dtype}
  290. super().__init__()
  291. hidden_features = int(mlp_ratio * dim)
  292. self.fc1 = nn.Linear(dim, hidden_features, bias=bias, **dd)
  293. self.act = act_layer()
  294. self.norm = norm_layer(hidden_features, **dd)
  295. self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias, **dd)
  296. self.head_drop = nn.Dropout(drop_rate)
  297. def forward(self, x):
  298. x = self.fc1(x)
  299. x = self.act(x)
  300. x = self.norm(x)
  301. x = self.head_drop(x)
  302. x = self.fc2(x)
  303. return x
  304. class MetaFormerBlock(nn.Module):
  305. """
  306. Implementation of one MetaFormer block.
  307. """
  308. def __init__(
  309. self,
  310. dim: int,
  311. token_mixer: Type[nn.Module] = Pooling,
  312. mlp_act: Type[nn.Module] = StarReLU,
  313. mlp_bias: bool = False,
  314. norm_layer: Type[nn.Module] = LayerNorm2d,
  315. proj_drop: float = 0.,
  316. drop_path: float = 0.,
  317. use_nchw: bool = True,
  318. layer_scale_init_value: Optional[float] = None,
  319. res_scale_init_value: Optional[float] = None,
  320. device=None,
  321. dtype=None,
  322. **kwargs
  323. ):
  324. dd = {'device': device, 'dtype': dtype}
  325. super().__init__()
  326. ls_layer = partial(Scale, dim=dim, init_value=layer_scale_init_value, use_nchw=use_nchw, **dd)
  327. rs_layer = partial(Scale, dim=dim, init_value=res_scale_init_value, use_nchw=use_nchw, **dd)
  328. self.norm1 = norm_layer(dim, **dd)
  329. self.token_mixer = token_mixer(dim=dim, proj_drop=proj_drop, **dd, **kwargs)
  330. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  331. self.layer_scale1 = ls_layer() if layer_scale_init_value is not None else nn.Identity()
  332. self.res_scale1 = rs_layer() if res_scale_init_value is not None else nn.Identity()
  333. self.norm2 = norm_layer(dim, **dd)
  334. self.mlp = Mlp(
  335. dim,
  336. int(4 * dim),
  337. act_layer=mlp_act,
  338. bias=mlp_bias,
  339. drop=proj_drop,
  340. use_conv=use_nchw,
  341. **dd
  342. )
  343. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  344. self.layer_scale2 = ls_layer() if layer_scale_init_value is not None else nn.Identity()
  345. self.res_scale2 = rs_layer() if res_scale_init_value is not None else nn.Identity()
  346. def forward(self, x):
  347. x = self.res_scale1(x) + \
  348. self.layer_scale1(
  349. self.drop_path1(
  350. self.token_mixer(self.norm1(x))
  351. )
  352. )
  353. x = self.res_scale2(x) + \
  354. self.layer_scale2(
  355. self.drop_path2(
  356. self.mlp(self.norm2(x))
  357. )
  358. )
  359. return x
  360. class MetaFormerStage(nn.Module):
  361. def __init__(
  362. self,
  363. in_chs: int,
  364. out_chs: int,
  365. depth: int = 2,
  366. token_mixer: Type[nn.Module] = nn.Identity,
  367. mlp_act: Type[nn.Module] = StarReLU,
  368. mlp_bias: bool = False,
  369. downsample_norm: Optional[Type[nn.Module]] = LayerNorm2d,
  370. norm_layer: Type[nn.Module] = LayerNorm2d,
  371. proj_drop: float = 0.,
  372. dp_rates: List[float] = [0.] * 2,
  373. layer_scale_init_value: Optional[float] = None,
  374. res_scale_init_value: Optional[float] = None,
  375. device=None,
  376. dtype=None,
  377. **kwargs,
  378. ):
  379. dd = {'device': device, 'dtype': dtype}
  380. super().__init__()
  381. self.grad_checkpointing = False
  382. self.use_nchw = not issubclass(token_mixer, Attention)
  383. # don't downsample if in_chs and out_chs are the same
  384. self.downsample = nn.Identity() if in_chs == out_chs else Downsampling(
  385. in_chs,
  386. out_chs,
  387. kernel_size=3,
  388. stride=2,
  389. padding=1,
  390. norm_layer=downsample_norm,
  391. **dd,
  392. )
  393. self.blocks = nn.Sequential(*[MetaFormerBlock(
  394. dim=out_chs,
  395. token_mixer=token_mixer,
  396. mlp_act=mlp_act,
  397. mlp_bias=mlp_bias,
  398. norm_layer=norm_layer,
  399. proj_drop=proj_drop,
  400. drop_path=dp_rates[i],
  401. layer_scale_init_value=layer_scale_init_value,
  402. res_scale_init_value=res_scale_init_value,
  403. use_nchw=self.use_nchw,
  404. **dd,
  405. **kwargs,
  406. ) for i in range(depth)])
  407. @torch.jit.ignore
  408. def set_grad_checkpointing(self, enable=True):
  409. self.grad_checkpointing = enable
  410. def forward(self, x: Tensor):
  411. x = self.downsample(x)
  412. B, C, H, W = x.shape
  413. if not self.use_nchw:
  414. x = x.reshape(B, C, -1).transpose(1, 2)
  415. if self.grad_checkpointing and not torch.jit.is_scripting():
  416. x = checkpoint_seq(self.blocks, x)
  417. else:
  418. x = self.blocks(x)
  419. if not self.use_nchw:
  420. x = x.transpose(1, 2).reshape(B, C, H, W)
  421. return x
  422. class MetaFormer(nn.Module):
  423. r""" MetaFormer
  424. A PyTorch impl of : `MetaFormer Baselines for Vision` -
  425. https://arxiv.org/abs/2210.13452
  426. Args:
  427. in_chans (int): Number of input image channels.
  428. num_classes (int): Number of classes for classification head.
  429. global_pool: Pooling for classifier head.
  430. depths (list or tuple): Number of blocks at each stage.
  431. dims (list or tuple): Feature dimension at each stage.
  432. token_mixers (list, tuple or token_fcn): Token mixer for each stage.
  433. mlp_act: Activation layer for MLP.
  434. mlp_bias (boolean): Enable or disable mlp bias term.
  435. drop_path_rate (float): Stochastic depth rate.
  436. drop_rate (float): Dropout rate.
  437. layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale.
  438. None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239.
  439. res_scale_init_values (list, tuple, float or None): Init value for res Scale on residual connections.
  440. None means not use the res scale. From: https://arxiv.org/abs/2110.09456.
  441. downsample_norm (nn.Module): Norm layer used in stem and downsampling layers.
  442. norm_layers (list, tuple or norm_fcn): Norm layers for each stage.
  443. output_norm: Norm layer before classifier head.
  444. use_mlp_head: Use MLP classification head.
  445. """
  446. def __init__(
  447. self,
  448. in_chans: int = 3,
  449. num_classes: int = 1000,
  450. global_pool: str = 'avg',
  451. depths: Tuple[int, ...] = (2, 2, 6, 2),
  452. dims: Tuple[int, ...] = (64, 128, 320, 512),
  453. token_mixers: Union[Type[nn.Module], List[Type[nn.Module]]] = Pooling,
  454. mlp_act: Type[nn.Module] = StarReLU,
  455. mlp_bias: bool = False,
  456. drop_path_rate: float = 0.,
  457. proj_drop_rate: float = 0.,
  458. drop_rate: float = 0.0,
  459. layer_scale_init_values: Optional[Union[float, List[float]]] = None,
  460. res_scale_init_values: Union[Tuple[Optional[float], ...], List[Optional[float]]] = (None, None, 1.0, 1.0),
  461. downsample_norm: Optional[Type[nn.Module]] = LayerNorm2dNoBias,
  462. norm_layers: Union[Type[nn.Module], List[Type[nn.Module]]] = LayerNorm2dNoBias,
  463. output_norm: Type[nn.Module] = LayerNorm2d,
  464. use_mlp_head: bool = True,
  465. device=None,
  466. dtype=None,
  467. **kwargs,
  468. ):
  469. super().__init__()
  470. dd = {'device': device, 'dtype': dtype}
  471. # Bind dd kwargs to activation layers that need them
  472. if mlp_act in (StarReLU,):
  473. mlp_act = partial(mlp_act, **dd)
  474. self.num_classes = num_classes
  475. self.in_chans = in_chans
  476. self.num_features = dims[-1]
  477. self.drop_rate = drop_rate
  478. self.use_mlp_head = use_mlp_head
  479. self.num_stages = len(depths)
  480. # convert everything to lists if they aren't indexable
  481. if not isinstance(depths, (list, tuple)):
  482. depths = [depths] # it means the model has only one stage
  483. if not isinstance(dims, (list, tuple)):
  484. dims = [dims]
  485. if not isinstance(token_mixers, (list, tuple)):
  486. token_mixers = [token_mixers] * self.num_stages
  487. if not isinstance(norm_layers, (list, tuple)):
  488. norm_layers = [norm_layers] * self.num_stages
  489. if not isinstance(layer_scale_init_values, (list, tuple)):
  490. layer_scale_init_values = [layer_scale_init_values] * self.num_stages
  491. if not isinstance(res_scale_init_values, (list, tuple)):
  492. res_scale_init_values = [res_scale_init_values] * self.num_stages
  493. self.grad_checkpointing = False
  494. self.feature_info = []
  495. self.stem = Stem(
  496. in_chans,
  497. dims[0],
  498. norm_layer=downsample_norm,
  499. **dd,
  500. )
  501. stages = []
  502. prev_dim = dims[0]
  503. dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  504. for i in range(self.num_stages):
  505. stages += [MetaFormerStage(
  506. prev_dim,
  507. dims[i],
  508. depth=depths[i],
  509. token_mixer=token_mixers[i],
  510. mlp_act=mlp_act,
  511. mlp_bias=mlp_bias,
  512. proj_drop=proj_drop_rate,
  513. dp_rates=dp_rates[i],
  514. layer_scale_init_value=layer_scale_init_values[i],
  515. res_scale_init_value=res_scale_init_values[i],
  516. downsample_norm=downsample_norm,
  517. norm_layer=norm_layers[i],
  518. **dd,
  519. **kwargs,
  520. )]
  521. prev_dim = dims[i]
  522. self.feature_info += [dict(num_chs=dims[i], reduction=2**(i+2), module=f'stages.{i}')]
  523. self.stages = nn.Sequential(*stages)
  524. # if using MlpHead, dropout is handled by MlpHead
  525. if num_classes > 0:
  526. if self.use_mlp_head:
  527. # FIXME not actually returning mlp hidden state right now as pre-logits.
  528. final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate, **dd)
  529. self.head_hidden_size = self.num_features
  530. else:
  531. final = nn.Linear(self.num_features, num_classes, **dd)
  532. self.head_hidden_size = self.num_features
  533. else:
  534. final = nn.Identity()
  535. self.head = nn.Sequential(OrderedDict([
  536. ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
  537. ('norm', output_norm(self.num_features, **dd)),
  538. ('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
  539. ('drop', nn.Dropout(drop_rate) if self.use_mlp_head else nn.Identity()),
  540. ('fc', final)
  541. ]))
  542. self.apply(self._init_weights)
  543. def _init_weights(self, m):
  544. if isinstance(m, (nn.Conv2d, nn.Linear)):
  545. trunc_normal_(m.weight, std=.02)
  546. if m.bias is not None:
  547. nn.init.constant_(m.bias, 0)
  548. @torch.jit.ignore
  549. def set_grad_checkpointing(self, enable=True):
  550. self.grad_checkpointing = enable
  551. for stage in self.stages:
  552. stage.set_grad_checkpointing(enable=enable)
  553. @torch.jit.ignore
  554. def get_classifier(self) -> nn.Module:
  555. return self.head.fc
  556. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, device=None, dtype=None):
  557. dd = {'device': device, 'dtype': dtype}
  558. self.num_classes = num_classes
  559. if global_pool is not None:
  560. self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  561. self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
  562. if num_classes > 0:
  563. if self.use_mlp_head:
  564. final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate, **dd)
  565. else:
  566. final = nn.Linear(self.num_features, num_classes, **dd)
  567. else:
  568. final = nn.Identity()
  569. self.head.fc = final
  570. def forward_intermediates(
  571. self,
  572. x: torch.Tensor,
  573. indices: Optional[Union[int, List[int]]] = None,
  574. norm: bool = False,
  575. stop_early: bool = False,
  576. output_fmt: str = 'NCHW',
  577. intermediates_only: bool = False,
  578. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  579. """ Forward features that returns intermediates.
  580. Args:
  581. x: Input image tensor
  582. indices: Take last n blocks if int, all if None, select matching indices if sequence
  583. norm: Apply norm layer to compatible intermediates
  584. stop_early: Stop iterating over blocks when last desired intermediate hit
  585. output_fmt: Shape of intermediate feature outputs
  586. intermediates_only: Only return intermediate features
  587. Returns:
  588. """
  589. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  590. intermediates = []
  591. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  592. # forward pass
  593. x = self.stem(x)
  594. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  595. stages = self.stages
  596. else:
  597. stages = self.stages[:max_index + 1]
  598. for feat_idx, stage in enumerate(stages):
  599. if self.grad_checkpointing and not torch.jit.is_scripting():
  600. x = checkpoint(stage, x)
  601. else:
  602. x = stage(x)
  603. if feat_idx in take_indices:
  604. intermediates.append(x)
  605. if intermediates_only:
  606. return intermediates
  607. return x, intermediates
  608. def prune_intermediate_layers(
  609. self,
  610. indices: Union[int, List[int]] = 1,
  611. prune_norm: bool = False,
  612. prune_head: bool = True,
  613. ):
  614. """ Prune layers not required for specified intermediates.
  615. """
  616. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  617. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  618. if prune_head:
  619. self.reset_classifier(0, '')
  620. return take_indices
  621. def forward_head(self, x: Tensor, pre_logits: bool = False):
  622. # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
  623. x = self.head.global_pool(x)
  624. x = self.head.norm(x)
  625. x = self.head.flatten(x)
  626. x = self.head.drop(x)
  627. return x if pre_logits else self.head.fc(x)
  628. def forward_features(self, x: Tensor):
  629. x = self.stem(x)
  630. if self.grad_checkpointing and not torch.jit.is_scripting():
  631. x = checkpoint_seq(self.stages, x)
  632. else:
  633. x = self.stages(x)
  634. return x
  635. def forward(self, x: Tensor):
  636. x = self.forward_features(x)
  637. x = self.forward_head(x)
  638. return x
  639. # this works but it's long and breaks backwards compatibility with weights from the poolformer-only impl
  640. def checkpoint_filter_fn(state_dict, model):
  641. if 'stem.conv.weight' in state_dict:
  642. return state_dict
  643. import re
  644. out_dict = {}
  645. is_poolformerv1 = 'network.0.0.mlp.fc1.weight' in state_dict
  646. model_state_dict = model.state_dict()
  647. for k, v in state_dict.items():
  648. if is_poolformerv1:
  649. k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k)
  650. k = k.replace('network.1', 'downsample_layers.1')
  651. k = k.replace('network.3', 'downsample_layers.2')
  652. k = k.replace('network.5', 'downsample_layers.3')
  653. k = k.replace('network.2', 'network.1')
  654. k = k.replace('network.4', 'network.2')
  655. k = k.replace('network.6', 'network.3')
  656. k = k.replace('network', 'stages')
  657. k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
  658. k = k.replace('downsample.proj', 'downsample.conv')
  659. k = k.replace('patch_embed.proj', 'patch_embed.conv')
  660. k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
  661. k = k.replace('stages.0.downsample', 'patch_embed')
  662. k = k.replace('patch_embed', 'stem')
  663. k = k.replace('post_norm', 'norm')
  664. k = k.replace('pre_norm', 'norm')
  665. k = re.sub(r'^head', 'head.fc', k)
  666. k = re.sub(r'^norm', 'head.norm', k)
  667. if v.shape != model_state_dict[k] and v.numel() == model_state_dict[k].numel():
  668. v = v.reshape(model_state_dict[k].shape)
  669. out_dict[k] = v
  670. return out_dict
  671. def _create_metaformer(variant, pretrained=False, **kwargs):
  672. default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2))))
  673. out_indices = kwargs.pop('out_indices', default_out_indices)
  674. model = build_model_with_cfg(
  675. MetaFormer,
  676. variant,
  677. pretrained,
  678. pretrained_filter_fn=checkpoint_filter_fn,
  679. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  680. **kwargs,
  681. )
  682. return model
  683. def _cfg(url='', **kwargs):
  684. return {
  685. 'url': url,
  686. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  687. 'crop_pct': 1.0, 'interpolation': 'bicubic',
  688. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  689. 'classifier': 'head.fc', 'first_conv': 'stem.conv',
  690. 'license': 'apache-2.0',
  691. **kwargs
  692. }
  693. default_cfgs = generate_default_cfgs({
  694. 'poolformer_s12.sail_in1k': _cfg(
  695. hf_hub_id='timm/',
  696. crop_pct=0.9),
  697. 'poolformer_s24.sail_in1k': _cfg(
  698. hf_hub_id='timm/',
  699. crop_pct=0.9),
  700. 'poolformer_s36.sail_in1k': _cfg(
  701. hf_hub_id='timm/',
  702. crop_pct=0.9),
  703. 'poolformer_m36.sail_in1k': _cfg(
  704. hf_hub_id='timm/',
  705. crop_pct=0.95),
  706. 'poolformer_m48.sail_in1k': _cfg(
  707. hf_hub_id='timm/',
  708. crop_pct=0.95),
  709. 'poolformerv2_s12.sail_in1k': _cfg(hf_hub_id='timm/'),
  710. 'poolformerv2_s24.sail_in1k': _cfg(hf_hub_id='timm/'),
  711. 'poolformerv2_s36.sail_in1k': _cfg(hf_hub_id='timm/'),
  712. 'poolformerv2_m36.sail_in1k': _cfg(hf_hub_id='timm/'),
  713. 'poolformerv2_m48.sail_in1k': _cfg(hf_hub_id='timm/'),
  714. 'convformer_s18.sail_in1k': _cfg(
  715. hf_hub_id='timm/',
  716. classifier='head.fc.fc2'),
  717. 'convformer_s18.sail_in1k_384': _cfg(
  718. hf_hub_id='timm/',
  719. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  720. 'convformer_s18.sail_in22k_ft_in1k': _cfg(
  721. hf_hub_id='timm/',
  722. classifier='head.fc.fc2'),
  723. 'convformer_s18.sail_in22k_ft_in1k_384': _cfg(
  724. hf_hub_id='timm/',
  725. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  726. 'convformer_s18.sail_in22k': _cfg(
  727. hf_hub_id='timm/',
  728. classifier='head.fc.fc2', num_classes=21841),
  729. 'convformer_s36.sail_in1k': _cfg(
  730. hf_hub_id='timm/',
  731. classifier='head.fc.fc2'),
  732. 'convformer_s36.sail_in1k_384': _cfg(
  733. hf_hub_id='timm/',
  734. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  735. 'convformer_s36.sail_in22k_ft_in1k': _cfg(
  736. hf_hub_id='timm/',
  737. classifier='head.fc.fc2'),
  738. 'convformer_s36.sail_in22k_ft_in1k_384': _cfg(
  739. hf_hub_id='timm/',
  740. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  741. 'convformer_s36.sail_in22k': _cfg(
  742. hf_hub_id='timm/',
  743. classifier='head.fc.fc2', num_classes=21841),
  744. 'convformer_m36.sail_in1k': _cfg(
  745. hf_hub_id='timm/',
  746. classifier='head.fc.fc2'),
  747. 'convformer_m36.sail_in1k_384': _cfg(
  748. hf_hub_id='timm/',
  749. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  750. 'convformer_m36.sail_in22k_ft_in1k': _cfg(
  751. hf_hub_id='timm/',
  752. classifier='head.fc.fc2'),
  753. 'convformer_m36.sail_in22k_ft_in1k_384': _cfg(
  754. hf_hub_id='timm/',
  755. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  756. 'convformer_m36.sail_in22k': _cfg(
  757. hf_hub_id='timm/',
  758. classifier='head.fc.fc2', num_classes=21841),
  759. 'convformer_b36.sail_in1k': _cfg(
  760. hf_hub_id='timm/',
  761. classifier='head.fc.fc2'),
  762. 'convformer_b36.sail_in1k_384': _cfg(
  763. hf_hub_id='timm/',
  764. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  765. 'convformer_b36.sail_in22k_ft_in1k': _cfg(
  766. hf_hub_id='timm/',
  767. classifier='head.fc.fc2'),
  768. 'convformer_b36.sail_in22k_ft_in1k_384': _cfg(
  769. hf_hub_id='timm/',
  770. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  771. 'convformer_b36.sail_in22k': _cfg(
  772. hf_hub_id='timm/',
  773. classifier='head.fc.fc2', num_classes=21841),
  774. 'caformer_s18.sail_in1k': _cfg(
  775. hf_hub_id='timm/',
  776. classifier='head.fc.fc2'),
  777. 'caformer_s18.sail_in1k_384': _cfg(
  778. hf_hub_id='timm/',
  779. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  780. 'caformer_s18.sail_in22k_ft_in1k': _cfg(
  781. hf_hub_id='timm/',
  782. classifier='head.fc.fc2'),
  783. 'caformer_s18.sail_in22k_ft_in1k_384': _cfg(
  784. hf_hub_id='timm/',
  785. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  786. 'caformer_s18.sail_in22k': _cfg(
  787. hf_hub_id='timm/',
  788. classifier='head.fc.fc2', num_classes=21841),
  789. 'caformer_s36.sail_in1k': _cfg(
  790. hf_hub_id='timm/',
  791. classifier='head.fc.fc2'),
  792. 'caformer_s36.sail_in1k_384': _cfg(
  793. hf_hub_id='timm/',
  794. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  795. 'caformer_s36.sail_in22k_ft_in1k': _cfg(
  796. hf_hub_id='timm/',
  797. classifier='head.fc.fc2'),
  798. 'caformer_s36.sail_in22k_ft_in1k_384': _cfg(
  799. hf_hub_id='timm/',
  800. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  801. 'caformer_s36.sail_in22k': _cfg(
  802. hf_hub_id='timm/',
  803. classifier='head.fc.fc2', num_classes=21841),
  804. 'caformer_m36.sail_in1k': _cfg(
  805. hf_hub_id='timm/',
  806. classifier='head.fc.fc2'),
  807. 'caformer_m36.sail_in1k_384': _cfg(
  808. hf_hub_id='timm/',
  809. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  810. 'caformer_m36.sail_in22k_ft_in1k': _cfg(
  811. hf_hub_id='timm/',
  812. classifier='head.fc.fc2'),
  813. 'caformer_m36.sail_in22k_ft_in1k_384': _cfg(
  814. hf_hub_id='timm/',
  815. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  816. 'caformer_m36.sail_in22k': _cfg(
  817. hf_hub_id='timm/',
  818. classifier='head.fc.fc2', num_classes=21841),
  819. 'caformer_b36.sail_in1k': _cfg(
  820. hf_hub_id='timm/',
  821. classifier='head.fc.fc2'),
  822. 'caformer_b36.sail_in1k_384': _cfg(
  823. hf_hub_id='timm/',
  824. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  825. 'caformer_b36.sail_in22k_ft_in1k': _cfg(
  826. hf_hub_id='timm/',
  827. classifier='head.fc.fc2'),
  828. 'caformer_b36.sail_in22k_ft_in1k_384': _cfg(
  829. hf_hub_id='timm/',
  830. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  831. 'caformer_b36.sail_in22k': _cfg(
  832. hf_hub_id='timm/',
  833. classifier='head.fc.fc2', num_classes=21841),
  834. })
  835. @register_model
  836. def poolformer_s12(pretrained=False, **kwargs) -> MetaFormer:
  837. model_kwargs = dict(
  838. depths=[2, 2, 6, 2],
  839. dims=[64, 128, 320, 512],
  840. downsample_norm=None,
  841. mlp_act=nn.GELU,
  842. mlp_bias=True,
  843. norm_layers=GroupNorm1,
  844. layer_scale_init_values=1e-5,
  845. res_scale_init_values=None,
  846. use_mlp_head=False,
  847. **kwargs)
  848. return _create_metaformer('poolformer_s12', pretrained=pretrained, **model_kwargs)
  849. @register_model
  850. def poolformer_s24(pretrained=False, **kwargs) -> MetaFormer:
  851. model_kwargs = dict(
  852. depths=[4, 4, 12, 4],
  853. dims=[64, 128, 320, 512],
  854. downsample_norm=None,
  855. mlp_act=nn.GELU,
  856. mlp_bias=True,
  857. norm_layers=GroupNorm1,
  858. layer_scale_init_values=1e-5,
  859. res_scale_init_values=None,
  860. use_mlp_head=False,
  861. **kwargs)
  862. return _create_metaformer('poolformer_s24', pretrained=pretrained, **model_kwargs)
  863. @register_model
  864. def poolformer_s36(pretrained=False, **kwargs) -> MetaFormer:
  865. model_kwargs = dict(
  866. depths=[6, 6, 18, 6],
  867. dims=[64, 128, 320, 512],
  868. downsample_norm=None,
  869. mlp_act=nn.GELU,
  870. mlp_bias=True,
  871. norm_layers=GroupNorm1,
  872. layer_scale_init_values=1e-6,
  873. res_scale_init_values=None,
  874. use_mlp_head=False,
  875. **kwargs)
  876. return _create_metaformer('poolformer_s36', pretrained=pretrained, **model_kwargs)
  877. @register_model
  878. def poolformer_m36(pretrained=False, **kwargs) -> MetaFormer:
  879. model_kwargs = dict(
  880. depths=[6, 6, 18, 6],
  881. dims=[96, 192, 384, 768],
  882. downsample_norm=None,
  883. mlp_act=nn.GELU,
  884. mlp_bias=True,
  885. norm_layers=GroupNorm1,
  886. layer_scale_init_values=1e-6,
  887. res_scale_init_values=None,
  888. use_mlp_head=False,
  889. **kwargs)
  890. return _create_metaformer('poolformer_m36', pretrained=pretrained, **model_kwargs)
  891. @register_model
  892. def poolformer_m48(pretrained=False, **kwargs) -> MetaFormer:
  893. model_kwargs = dict(
  894. depths=[8, 8, 24, 8],
  895. dims=[96, 192, 384, 768],
  896. downsample_norm=None,
  897. mlp_act=nn.GELU,
  898. mlp_bias=True,
  899. norm_layers=GroupNorm1,
  900. layer_scale_init_values=1e-6,
  901. res_scale_init_values=None,
  902. use_mlp_head=False,
  903. **kwargs)
  904. return _create_metaformer('poolformer_m48', pretrained=pretrained, **model_kwargs)
  905. @register_model
  906. def poolformerv2_s12(pretrained=False, **kwargs) -> MetaFormer:
  907. model_kwargs = dict(
  908. depths=[2, 2, 6, 2],
  909. dims=[64, 128, 320, 512],
  910. norm_layers=GroupNorm1NoBias,
  911. use_mlp_head=False,
  912. **kwargs)
  913. return _create_metaformer('poolformerv2_s12', pretrained=pretrained, **model_kwargs)
  914. @register_model
  915. def poolformerv2_s24(pretrained=False, **kwargs) -> MetaFormer:
  916. model_kwargs = dict(
  917. depths=[4, 4, 12, 4],
  918. dims=[64, 128, 320, 512],
  919. norm_layers=GroupNorm1NoBias,
  920. use_mlp_head=False,
  921. **kwargs)
  922. return _create_metaformer('poolformerv2_s24', pretrained=pretrained, **model_kwargs)
  923. @register_model
  924. def poolformerv2_s36(pretrained=False, **kwargs) -> MetaFormer:
  925. model_kwargs = dict(
  926. depths=[6, 6, 18, 6],
  927. dims=[64, 128, 320, 512],
  928. norm_layers=GroupNorm1NoBias,
  929. use_mlp_head=False,
  930. **kwargs)
  931. return _create_metaformer('poolformerv2_s36', pretrained=pretrained, **model_kwargs)
  932. @register_model
  933. def poolformerv2_m36(pretrained=False, **kwargs) -> MetaFormer:
  934. model_kwargs = dict(
  935. depths=[6, 6, 18, 6],
  936. dims=[96, 192, 384, 768],
  937. norm_layers=GroupNorm1NoBias,
  938. use_mlp_head=False,
  939. **kwargs)
  940. return _create_metaformer('poolformerv2_m36', pretrained=pretrained, **model_kwargs)
  941. @register_model
  942. def poolformerv2_m48(pretrained=False, **kwargs) -> MetaFormer:
  943. model_kwargs = dict(
  944. depths=[8, 8, 24, 8],
  945. dims=[96, 192, 384, 768],
  946. norm_layers=GroupNorm1NoBias,
  947. use_mlp_head=False,
  948. **kwargs)
  949. return _create_metaformer('poolformerv2_m48', pretrained=pretrained, **model_kwargs)
  950. @register_model
  951. def convformer_s18(pretrained=False, **kwargs) -> MetaFormer:
  952. model_kwargs = dict(
  953. depths=[3, 3, 9, 3],
  954. dims=[64, 128, 320, 512],
  955. token_mixers=SepConv,
  956. norm_layers=LayerNorm2dNoBias,
  957. **kwargs)
  958. return _create_metaformer('convformer_s18', pretrained=pretrained, **model_kwargs)
  959. @register_model
  960. def convformer_s36(pretrained=False, **kwargs) -> MetaFormer:
  961. model_kwargs = dict(
  962. depths=[3, 12, 18, 3],
  963. dims=[64, 128, 320, 512],
  964. token_mixers=SepConv,
  965. norm_layers=LayerNorm2dNoBias,
  966. **kwargs)
  967. return _create_metaformer('convformer_s36', pretrained=pretrained, **model_kwargs)
  968. @register_model
  969. def convformer_m36(pretrained=False, **kwargs) -> MetaFormer:
  970. model_kwargs = dict(
  971. depths=[3, 12, 18, 3],
  972. dims=[96, 192, 384, 576],
  973. token_mixers=SepConv,
  974. norm_layers=LayerNorm2dNoBias,
  975. **kwargs)
  976. return _create_metaformer('convformer_m36', pretrained=pretrained, **model_kwargs)
  977. @register_model
  978. def convformer_b36(pretrained=False, **kwargs) -> MetaFormer:
  979. model_kwargs = dict(
  980. depths=[3, 12, 18, 3],
  981. dims=[128, 256, 512, 768],
  982. token_mixers=SepConv,
  983. norm_layers=LayerNorm2dNoBias,
  984. **kwargs)
  985. return _create_metaformer('convformer_b36', pretrained=pretrained, **model_kwargs)
  986. @register_model
  987. def caformer_s18(pretrained=False, **kwargs) -> MetaFormer:
  988. model_kwargs = dict(
  989. depths=[3, 3, 9, 3],
  990. dims=[64, 128, 320, 512],
  991. token_mixers=[SepConv, SepConv, Attention, Attention],
  992. norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
  993. **kwargs)
  994. return _create_metaformer('caformer_s18', pretrained=pretrained, **model_kwargs)
  995. @register_model
  996. def caformer_s36(pretrained=False, **kwargs) -> MetaFormer:
  997. model_kwargs = dict(
  998. depths=[3, 12, 18, 3],
  999. dims=[64, 128, 320, 512],
  1000. token_mixers=[SepConv, SepConv, Attention, Attention],
  1001. norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
  1002. **kwargs)
  1003. return _create_metaformer('caformer_s36', pretrained=pretrained, **model_kwargs)
  1004. @register_model
  1005. def caformer_m36(pretrained=False, **kwargs) -> MetaFormer:
  1006. model_kwargs = dict(
  1007. depths=[3, 12, 18, 3],
  1008. dims=[96, 192, 384, 576],
  1009. token_mixers=[SepConv, SepConv, Attention, Attention],
  1010. norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
  1011. **kwargs)
  1012. return _create_metaformer('caformer_m36', pretrained=pretrained, **model_kwargs)
  1013. @register_model
  1014. def caformer_b36(pretrained=False, **kwargs) -> MetaFormer:
  1015. model_kwargs = dict(
  1016. depths=[3, 12, 18, 3],
  1017. dims=[128, 256, 512, 768],
  1018. token_mixers=[SepConv, SepConv, Attention, Attention],
  1019. norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
  1020. **kwargs)
  1021. return _create_metaformer('caformer_b36', pretrained=pretrained, **model_kwargs)