mambaout.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737
  1. """
  2. MambaOut models for image classification.
  3. Some implementations are modified from:
  4. timm (https://github.com/rwightman/pytorch-image-models),
  5. MetaFormer (https://github.com/sail-sg/metaformer),
  6. InceptionNeXt (https://github.com/sail-sg/inceptionnext)
  7. """
  8. from collections import OrderedDict
  9. from typing import List, Optional, Tuple, Type, Union
  10. import torch
  11. from torch import nn
  12. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  13. from timm.layers import trunc_normal_, DropPath, calculate_drop_path_rates, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
  14. from ._builder import build_model_with_cfg
  15. from ._features import feature_take_indices
  16. from ._manipulate import checkpoint_seq
  17. from ._registry import register_model, generate_default_cfgs
  18. class Stem(nn.Module):
  19. r""" Code modified from InternImage:
  20. https://github.com/OpenGVLab/InternImage
  21. """
  22. def __init__(
  23. self,
  24. in_chs: int = 3,
  25. out_chs: int = 96,
  26. mid_norm: bool = True,
  27. act_layer: Type[nn.Module] = nn.GELU,
  28. norm_layer: Type[nn.Module] = LayerNorm,
  29. device=None,
  30. dtype=None,
  31. ):
  32. dd = {'device': device, 'dtype': dtype}
  33. super().__init__()
  34. self.conv1 = nn.Conv2d(
  35. in_chs,
  36. out_chs // 2,
  37. kernel_size=3,
  38. stride=2,
  39. padding=1,
  40. **dd,
  41. )
  42. self.norm1 = norm_layer(out_chs // 2, **dd) if mid_norm else None
  43. self.act = act_layer()
  44. self.conv2 = nn.Conv2d(
  45. out_chs // 2,
  46. out_chs,
  47. kernel_size=3,
  48. stride=2,
  49. padding=1,
  50. **dd,
  51. )
  52. self.norm2 = norm_layer(out_chs, **dd)
  53. def forward(self, x):
  54. x = self.conv1(x)
  55. if self.norm1 is not None:
  56. x = x.permute(0, 2, 3, 1)
  57. x = self.norm1(x)
  58. x = x.permute(0, 3, 1, 2)
  59. x = self.act(x)
  60. x = self.conv2(x)
  61. x = x.permute(0, 2, 3, 1)
  62. x = self.norm2(x)
  63. return x
  64. class DownsampleNormFirst(nn.Module):
  65. def __init__(
  66. self,
  67. in_chs: int = 96,
  68. out_chs: int = 198,
  69. norm_layer: Type[nn.Module] = LayerNorm,
  70. device=None,
  71. dtype=None,
  72. ):
  73. dd = {'device': device, 'dtype': dtype}
  74. super().__init__()
  75. self.norm = norm_layer(in_chs, **dd)
  76. self.conv = nn.Conv2d(
  77. in_chs,
  78. out_chs,
  79. kernel_size=3,
  80. stride=2,
  81. padding=1,
  82. **dd,
  83. )
  84. def forward(self, x):
  85. x = self.norm(x)
  86. x = x.permute(0, 3, 1, 2)
  87. x = self.conv(x)
  88. x = x.permute(0, 2, 3, 1)
  89. return x
  90. class Downsample(nn.Module):
  91. def __init__(
  92. self,
  93. in_chs: int = 96,
  94. out_chs: int = 198,
  95. norm_layer: Type[nn.Module] = LayerNorm,
  96. device=None,
  97. dtype=None,
  98. ):
  99. dd = {'device': device, 'dtype': dtype}
  100. super().__init__()
  101. self.conv = nn.Conv2d(
  102. in_chs,
  103. out_chs,
  104. kernel_size=3,
  105. stride=2,
  106. padding=1,
  107. **dd,
  108. )
  109. self.norm = norm_layer(out_chs, **dd)
  110. def forward(self, x):
  111. x = x.permute(0, 3, 1, 2)
  112. x = self.conv(x)
  113. x = x.permute(0, 2, 3, 1)
  114. x = self.norm(x)
  115. return x
  116. class MlpHead(nn.Module):
  117. """ MLP classification head
  118. """
  119. def __init__(
  120. self,
  121. in_features: int,
  122. num_classes: int = 1000,
  123. pool_type: str = 'avg',
  124. act_layer: Type[nn.Module] = nn.GELU,
  125. mlp_ratio: Optional[int] = 4,
  126. norm_layer: Type[nn.Module] = LayerNorm,
  127. drop_rate: float = 0.,
  128. bias: bool = True,
  129. device=None,
  130. dtype=None,
  131. ):
  132. dd = {'device': device, 'dtype': dtype}
  133. super().__init__()
  134. if mlp_ratio is not None:
  135. hidden_size = int(mlp_ratio * in_features)
  136. else:
  137. hidden_size = None
  138. self.pool_type = pool_type
  139. self.in_features = in_features
  140. self.hidden_size = hidden_size or in_features
  141. self.norm = norm_layer(in_features, **dd)
  142. if hidden_size:
  143. self.pre_logits = nn.Sequential(OrderedDict([
  144. ('fc', nn.Linear(in_features, hidden_size, **dd)),
  145. ('act', act_layer()),
  146. ('norm', norm_layer(hidden_size, **dd))
  147. ]))
  148. self.num_features = hidden_size
  149. else:
  150. self.num_features = in_features
  151. self.pre_logits = nn.Identity()
  152. self.fc = nn.Linear(self.num_features, num_classes, bias=bias, **dd) if num_classes > 0 else nn.Identity()
  153. self.head_dropout = nn.Dropout(drop_rate)
  154. def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
  155. if pool_type is not None:
  156. self.pool_type = pool_type
  157. if reset_other:
  158. self.norm = nn.Identity()
  159. self.pre_logits = nn.Identity()
  160. self.num_features = self.in_features
  161. self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  162. def forward(self, x, pre_logits: bool = False):
  163. if self.pool_type == 'avg':
  164. x = x.mean((1, 2))
  165. x = self.norm(x)
  166. x = self.pre_logits(x)
  167. x = self.head_dropout(x)
  168. if pre_logits:
  169. return x
  170. x = self.fc(x)
  171. return x
  172. class GatedConvBlock(nn.Module):
  173. r""" Our implementation of Gated CNN Block: https://arxiv.org/pdf/1612.08083
  174. Args:
  175. conv_ratio: control the number of channels to conduct depthwise convolution.
  176. Conduct convolution on partial channels can improve paraitcal efficiency.
  177. The idea of partial channels is from ShuffleNet V2 (https://arxiv.org/abs/1807.11164) and
  178. also used by InceptionNeXt (https://arxiv.org/abs/2303.16900) and FasterNet (https://arxiv.org/abs/2303.03667)
  179. """
  180. def __init__(
  181. self,
  182. dim: int,
  183. expansion_ratio: float = 8 / 3,
  184. kernel_size: int = 7,
  185. conv_ratio: float = 1.0,
  186. ls_init_value: Optional[float] = None,
  187. norm_layer: Type[nn.Module] = LayerNorm,
  188. act_layer: Type[nn.Module] = nn.GELU,
  189. drop_path: float = 0.,
  190. device=None,
  191. dtype=None,
  192. **kwargs
  193. ):
  194. dd = {'device': device, 'dtype': dtype}
  195. super().__init__()
  196. self.norm = norm_layer(dim, **dd)
  197. hidden = int(expansion_ratio * dim)
  198. self.fc1 = nn.Linear(dim, hidden * 2, **dd)
  199. self.act = act_layer()
  200. conv_channels = int(conv_ratio * dim)
  201. self.split_indices = (hidden, hidden - conv_channels, conv_channels)
  202. self.conv = nn.Conv2d(
  203. conv_channels,
  204. conv_channels,
  205. kernel_size=kernel_size,
  206. padding=kernel_size // 2,
  207. groups=conv_channels,
  208. **dd,
  209. )
  210. self.fc2 = nn.Linear(hidden, dim, **dd)
  211. self.ls = LayerScale(dim, **dd) if ls_init_value is not None else nn.Identity()
  212. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  213. def forward(self, x):
  214. shortcut = x # [B, H, W, C]
  215. x = self.norm(x)
  216. x = self.fc1(x)
  217. g, i, c = torch.split(x, self.split_indices, dim=-1)
  218. c = c.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
  219. c = self.conv(c)
  220. c = c.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
  221. x = self.fc2(self.act(g) * torch.cat((i, c), dim=-1))
  222. x = self.ls(x)
  223. x = self.drop_path(x)
  224. return x + shortcut
  225. class MambaOutStage(nn.Module):
  226. def __init__(
  227. self,
  228. dim: int,
  229. dim_out: Optional[int] = None,
  230. depth: int = 4,
  231. expansion_ratio: float = 8 / 3,
  232. kernel_size: int = 7,
  233. conv_ratio: float = 1.0,
  234. downsample: str = '',
  235. ls_init_value: Optional[float] = None,
  236. norm_layer: Type[nn.Module] = LayerNorm,
  237. act_layer: Type[nn.Module] = nn.GELU,
  238. drop_path: float = 0.,
  239. device=None,
  240. dtype=None,
  241. ):
  242. dd = {'device': device, 'dtype': dtype}
  243. super().__init__()
  244. dim_out = dim_out or dim
  245. self.grad_checkpointing = False
  246. if downsample == 'conv':
  247. self.downsample = Downsample(dim, dim_out, norm_layer=norm_layer, **dd)
  248. elif downsample == 'conv_nf':
  249. self.downsample = DownsampleNormFirst(dim, dim_out, norm_layer=norm_layer, **dd)
  250. else:
  251. assert dim == dim_out
  252. self.downsample = nn.Identity()
  253. self.blocks = nn.Sequential(*[
  254. GatedConvBlock(
  255. dim=dim_out,
  256. expansion_ratio=expansion_ratio,
  257. kernel_size=kernel_size,
  258. conv_ratio=conv_ratio,
  259. ls_init_value=ls_init_value,
  260. norm_layer=norm_layer,
  261. act_layer=act_layer,
  262. drop_path=drop_path[j] if isinstance(drop_path, (list, tuple)) else drop_path,
  263. **dd,
  264. )
  265. for j in range(depth)
  266. ])
  267. def forward(self, x):
  268. x = self.downsample(x)
  269. if self.grad_checkpointing and not torch.jit.is_scripting():
  270. x = checkpoint_seq(self.blocks, x)
  271. else:
  272. x = self.blocks(x)
  273. return x
  274. class MambaOut(nn.Module):
  275. r""" MetaFormer
  276. A PyTorch impl of : `MetaFormer Baselines for Vision` -
  277. https://arxiv.org/abs/2210.13452
  278. Args:
  279. in_chans (int): Number of input image channels. Default: 3.
  280. num_classes (int): Number of classes for classification head. Default: 1000.
  281. depths (list or tuple): Number of blocks at each stage. Default: [3, 3, 9, 3].
  282. dims (int): Feature dimension at each stage. Default: [96, 192, 384, 576].
  283. downsample_layers: (list or tuple): Downsampling layers before each stage.
  284. drop_path_rate (float): Stochastic depth rate. Default: 0.
  285. output_norm: norm before classifier head. Default: partial(nn.LayerNorm, eps=1e-6).
  286. head_fn: classification head. Default: nn.Linear.
  287. head_dropout (float): dropout for MLP classifier. Default: 0.
  288. """
  289. def __init__(
  290. self,
  291. in_chans: int = 3,
  292. num_classes: int = 1000,
  293. global_pool: str = 'avg',
  294. depths: Tuple[int, ...] = (3, 3, 9, 3),
  295. dims: Tuple[int, ...] = (96, 192, 384, 576),
  296. norm_layer: Type[nn.Module] = LayerNorm,
  297. act_layer: Type[nn.Module] = nn.GELU,
  298. conv_ratio: float = 1.0,
  299. expansion_ratio: float = 8/3,
  300. kernel_size: int = 7,
  301. stem_mid_norm: bool = True,
  302. ls_init_value: Optional[float] = None,
  303. downsample: str = 'conv',
  304. drop_path_rate: float = 0.,
  305. drop_rate: float = 0.,
  306. head_fn: str = 'default',
  307. device=None,
  308. dtype=None,
  309. ):
  310. super().__init__()
  311. dd = {'device': device, 'dtype': dtype}
  312. self.num_classes = num_classes
  313. self.in_chans = in_chans
  314. self.drop_rate = drop_rate
  315. self.output_fmt = 'NHWC'
  316. if not isinstance(depths, (list, tuple)):
  317. depths = [depths] # it means the model has only one stage
  318. if not isinstance(dims, (list, tuple)):
  319. dims = [dims]
  320. act_layer = get_act_layer(act_layer)
  321. num_stage = len(depths)
  322. self.num_stage = num_stage
  323. self.feature_info = []
  324. self.stem = Stem(
  325. in_chans,
  326. dims[0],
  327. mid_norm=stem_mid_norm,
  328. act_layer=act_layer,
  329. norm_layer=norm_layer,
  330. **dd,
  331. )
  332. prev_dim = dims[0]
  333. dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  334. cur = 0
  335. curr_stride = 4
  336. self.stages = nn.Sequential()
  337. for i in range(num_stage):
  338. dim = dims[i]
  339. stride = 2 if curr_stride == 2 or i > 0 else 1
  340. curr_stride *= stride
  341. stage = MambaOutStage(
  342. dim=prev_dim,
  343. dim_out=dim,
  344. depth=depths[i],
  345. kernel_size=kernel_size,
  346. conv_ratio=conv_ratio,
  347. expansion_ratio=expansion_ratio,
  348. downsample=downsample if i > 0 else '',
  349. ls_init_value=ls_init_value,
  350. norm_layer=norm_layer,
  351. act_layer=act_layer,
  352. drop_path=dp_rates[i],
  353. **dd,
  354. )
  355. self.stages.append(stage)
  356. prev_dim = dim
  357. # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
  358. self.feature_info += [dict(num_chs=prev_dim, reduction=curr_stride, module=f'stages.{i}')]
  359. cur += depths[i]
  360. if head_fn == 'default':
  361. # specific to this model, unusual norm -> pool -> fc -> act -> norm -> fc combo
  362. self.head = MlpHead(
  363. prev_dim,
  364. num_classes,
  365. pool_type=global_pool,
  366. drop_rate=drop_rate,
  367. norm_layer=norm_layer,
  368. **dd,
  369. )
  370. else:
  371. # more typical norm -> pool -> fc -> act -> fc
  372. self.head = ClNormMlpClassifierHead(
  373. prev_dim,
  374. num_classes,
  375. hidden_size=int(prev_dim * 4),
  376. pool_type=global_pool,
  377. norm_layer=norm_layer,
  378. drop_rate=drop_rate,
  379. **dd,
  380. )
  381. self.num_features = prev_dim
  382. self.head_hidden_size = self.head.num_features
  383. self.apply(self._init_weights)
  384. def _init_weights(self, m):
  385. if isinstance(m, (nn.Conv2d, nn.Linear)):
  386. trunc_normal_(m.weight, std=.02)
  387. if m.bias is not None:
  388. nn.init.constant_(m.bias, 0)
  389. @torch.jit.ignore
  390. def group_matcher(self, coarse=False):
  391. return dict(
  392. stem=r'^stem',
  393. blocks=r'^stages\.(\d+)' if coarse else [
  394. (r'^stages\.(\d+)\.downsample', (0,)), # blocks
  395. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  396. ]
  397. )
  398. @torch.jit.ignore
  399. def set_grad_checkpointing(self, enable=True):
  400. for s in self.stages:
  401. s.grad_checkpointing = enable
  402. @torch.jit.ignore
  403. def get_classifier(self) -> nn.Module:
  404. return self.head.fc
  405. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  406. self.num_classes = num_classes
  407. self.head.reset(num_classes, global_pool)
  408. def forward_intermediates(
  409. self,
  410. x: torch.Tensor,
  411. indices: Optional[Union[int, List[int]]] = None,
  412. norm: bool = False,
  413. stop_early: bool = False,
  414. output_fmt: str = 'NCHW',
  415. intermediates_only: bool = False,
  416. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  417. """ Forward features that returns intermediates.
  418. Args:
  419. x: Input image tensor
  420. indices: Take last n blocks if int, all if None, select matching indices if sequence
  421. norm: Apply norm layer to compatible intermediates
  422. stop_early: Stop iterating over blocks when last desired intermediate hit
  423. output_fmt: Shape of intermediate feature outputs
  424. intermediates_only: Only return intermediate features
  425. Returns:
  426. """
  427. assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW or NHWC.'
  428. channel_first = output_fmt == 'NCHW'
  429. intermediates = []
  430. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  431. # forward pass
  432. x = self.stem(x)
  433. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  434. stages = self.stages
  435. else:
  436. stages = self.stages[:max_index + 1]
  437. for feat_idx, stage in enumerate(stages):
  438. x = stage(x)
  439. if feat_idx in take_indices:
  440. intermediates.append(x)
  441. if channel_first:
  442. # reshape to BCHW output format
  443. intermediates = [y.permute(0, 3, 1, 2).contiguous() for y in intermediates]
  444. if intermediates_only:
  445. return intermediates
  446. return x, intermediates
  447. def prune_intermediate_layers(
  448. self,
  449. indices: Union[int, List[int]] = 1,
  450. prune_norm: bool = False,
  451. prune_head: bool = True,
  452. ):
  453. """ Prune layers not required for specified intermediates.
  454. """
  455. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  456. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  457. if prune_head:
  458. self.reset_classifier(0, '')
  459. return take_indices
  460. def forward_features(self, x):
  461. x = self.stem(x)
  462. x = self.stages(x)
  463. return x
  464. def forward_head(self, x, pre_logits: bool = False):
  465. x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  466. return x
  467. def forward(self, x):
  468. x = self.forward_features(x)
  469. x = self.forward_head(x)
  470. return x
  471. def checkpoint_filter_fn(state_dict, model):
  472. if 'model' in state_dict:
  473. state_dict = state_dict['model']
  474. if 'stem.conv1.weight' in state_dict:
  475. return state_dict
  476. import re
  477. out_dict = {}
  478. for k, v in state_dict.items():
  479. k = k.replace('downsample_layers.0.', 'stem.')
  480. k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
  481. k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
  482. # remap head names
  483. if k.startswith('norm.'):
  484. # this is moving to head since it's after the pooling
  485. k = k.replace('norm.', 'head.norm.')
  486. elif k.startswith('head.'):
  487. k = k.replace('head.fc1.', 'head.pre_logits.fc.')
  488. k = k.replace('head.norm.', 'head.pre_logits.norm.')
  489. k = k.replace('head.fc2.', 'head.fc.')
  490. out_dict[k] = v
  491. return out_dict
  492. def _cfg(url='', **kwargs):
  493. return {
  494. 'url': url,
  495. 'num_classes': 1000, 'input_size': (3, 224, 224), 'test_input_size': (3, 288, 288),
  496. 'pool_size': (7, 7), 'crop_pct': 1.0, 'interpolation': 'bicubic',
  497. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  498. 'first_conv': 'stem.conv1', 'classifier': 'head.fc',
  499. 'license': 'apache-2.0',
  500. **kwargs
  501. }
  502. default_cfgs = generate_default_cfgs({
  503. # original weights
  504. 'mambaout_femto.in1k': _cfg(
  505. hf_hub_id='timm/'),
  506. 'mambaout_kobe.in1k': _cfg(
  507. hf_hub_id='timm/'),
  508. 'mambaout_tiny.in1k': _cfg(
  509. hf_hub_id='timm/'),
  510. 'mambaout_small.in1k': _cfg(
  511. hf_hub_id='timm/'),
  512. 'mambaout_base.in1k': _cfg(
  513. hf_hub_id='timm/'),
  514. # timm experiments below
  515. 'mambaout_small_rw.sw_e450_in1k': _cfg(
  516. hf_hub_id='timm/',
  517. ),
  518. 'mambaout_base_short_rw.sw_e500_in1k': _cfg(
  519. hf_hub_id='timm/',
  520. crop_pct=0.95, test_crop_pct=1.0,
  521. ),
  522. 'mambaout_base_tall_rw.sw_e500_in1k': _cfg(
  523. hf_hub_id='timm/',
  524. crop_pct=0.95, test_crop_pct=1.0,
  525. ),
  526. 'mambaout_base_wide_rw.sw_e500_in1k': _cfg(
  527. hf_hub_id='timm/',
  528. crop_pct=0.95, test_crop_pct=1.0,
  529. ),
  530. 'mambaout_base_plus_rw.sw_e150_in12k_ft_in1k': _cfg(
  531. hf_hub_id='timm/',
  532. ),
  533. 'mambaout_base_plus_rw.sw_e150_r384_in12k_ft_in1k': _cfg(
  534. hf_hub_id='timm/',
  535. input_size=(3, 384, 384), test_input_size=(3, 384, 384), crop_mode='squash', pool_size=(12, 12),
  536. ),
  537. 'mambaout_base_plus_rw.sw_e150_in12k': _cfg(
  538. hf_hub_id='timm/',
  539. num_classes=11821,
  540. ),
  541. 'test_mambaout': _cfg(input_size=(3, 160, 160), test_input_size=(3, 192, 192), pool_size=(5, 5)),
  542. })
  543. def _create_mambaout(variant, pretrained=False, **kwargs):
  544. model = build_model_with_cfg(
  545. MambaOut, variant, pretrained,
  546. pretrained_filter_fn=checkpoint_filter_fn,
  547. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  548. **kwargs,
  549. )
  550. return model
  551. # a series of MambaOut models
  552. @register_model
  553. def mambaout_femto(pretrained=False, **kwargs):
  554. model_args = dict(depths=(3, 3, 9, 3), dims=(48, 96, 192, 288))
  555. return _create_mambaout('mambaout_femto', pretrained=pretrained, **dict(model_args, **kwargs))
  556. # Kobe Memorial Version with 24 Gated CNN blocks
  557. @register_model
  558. def mambaout_kobe(pretrained=False, **kwargs):
  559. model_args = dict(depths=[3, 3, 15, 3], dims=[48, 96, 192, 288])
  560. return _create_mambaout('mambaout_kobe', pretrained=pretrained, **dict(model_args, **kwargs))
  561. @register_model
  562. def mambaout_tiny(pretrained=False, **kwargs):
  563. model_args = dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 576])
  564. return _create_mambaout('mambaout_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
  565. @register_model
  566. def mambaout_small(pretrained=False, **kwargs):
  567. model_args = dict(depths=[3, 4, 27, 3], dims=[96, 192, 384, 576])
  568. return _create_mambaout('mambaout_small', pretrained=pretrained, **dict(model_args, **kwargs))
  569. @register_model
  570. def mambaout_base(pretrained=False, **kwargs):
  571. model_args = dict(depths=[3, 4, 27, 3], dims=[128, 256, 512, 768])
  572. return _create_mambaout('mambaout_base', pretrained=pretrained, **dict(model_args, **kwargs))
  573. @register_model
  574. def mambaout_small_rw(pretrained=False, **kwargs):
  575. model_args = dict(
  576. depths=[3, 4, 27, 3],
  577. dims=[96, 192, 384, 576],
  578. stem_mid_norm=False,
  579. downsample='conv_nf',
  580. ls_init_value=1e-6,
  581. head_fn='norm_mlp',
  582. )
  583. return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))
  584. @register_model
  585. def mambaout_base_short_rw(pretrained=False, **kwargs):
  586. model_args = dict(
  587. depths=(3, 3, 25, 3),
  588. dims=(128, 256, 512, 768),
  589. expansion_ratio=3.0,
  590. conv_ratio=1.25,
  591. stem_mid_norm=False,
  592. downsample='conv_nf',
  593. ls_init_value=1e-6,
  594. head_fn='norm_mlp',
  595. )
  596. return _create_mambaout('mambaout_base_short_rw', pretrained=pretrained, **dict(model_args, **kwargs))
  597. @register_model
  598. def mambaout_base_tall_rw(pretrained=False, **kwargs):
  599. model_args = dict(
  600. depths=(3, 4, 30, 3),
  601. dims=(128, 256, 512, 768),
  602. expansion_ratio=2.5,
  603. conv_ratio=1.25,
  604. stem_mid_norm=False,
  605. downsample='conv_nf',
  606. ls_init_value=1e-6,
  607. head_fn='norm_mlp',
  608. )
  609. return _create_mambaout('mambaout_base_tall_rw', pretrained=pretrained, **dict(model_args, **kwargs))
  610. @register_model
  611. def mambaout_base_wide_rw(pretrained=False, **kwargs):
  612. model_args = dict(
  613. depths=(3, 4, 27, 3),
  614. dims=(128, 256, 512, 768),
  615. expansion_ratio=3.0,
  616. conv_ratio=1.5,
  617. stem_mid_norm=False,
  618. downsample='conv_nf',
  619. ls_init_value=1e-6,
  620. act_layer='silu',
  621. head_fn='norm_mlp',
  622. )
  623. return _create_mambaout('mambaout_base_wide_rw', pretrained=pretrained, **dict(model_args, **kwargs))
  624. @register_model
  625. def mambaout_base_plus_rw(pretrained=False, **kwargs):
  626. model_args = dict(
  627. depths=(3, 4, 30, 3),
  628. dims=(128, 256, 512, 768),
  629. expansion_ratio=3.0,
  630. conv_ratio=1.5,
  631. stem_mid_norm=False,
  632. downsample='conv_nf',
  633. ls_init_value=1e-6,
  634. act_layer='silu',
  635. head_fn='norm_mlp',
  636. )
  637. return _create_mambaout('mambaout_base_plus_rw', pretrained=pretrained, **dict(model_args, **kwargs))
  638. @register_model
  639. def test_mambaout(pretrained=False, **kwargs):
  640. model_args = dict(
  641. depths=(1, 1, 3, 1),
  642. dims=(16, 32, 48, 64),
  643. expansion_ratio=3,
  644. stem_mid_norm=False,
  645. downsample='conv_nf',
  646. ls_init_value=1e-4,
  647. act_layer='silu',
  648. head_fn='norm_mlp',
  649. )
  650. return _create_mambaout('test_mambaout', pretrained=pretrained, **dict(model_args, **kwargs))