repvit.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. """ RepViT
  2. Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective`
  3. - https://arxiv.org/abs/2307.09283
  4. @misc{wang2023repvit,
  5. title={RepViT: Revisiting Mobile CNN From ViT Perspective},
  6. author={Ao Wang and Hui Chen and Zijia Lin and Hengjun Pu and Guiguang Ding},
  7. year={2023},
  8. eprint={2307.09283},
  9. archivePrefix={arXiv},
  10. primaryClass={cs.CV}
  11. }
  12. Adapted from official impl at https://github.com/jameslahm/RepViT
  13. """
  14. from typing import List, Optional, Tuple, Union, Type
  15. import torch
  16. import torch.nn as nn
  17. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  18. from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
  19. from ._builder import build_model_with_cfg
  20. from ._features import feature_take_indices
  21. from ._manipulate import checkpoint, checkpoint_seq
  22. from ._registry import register_model, generate_default_cfgs
  23. __all__ = ['RepVit']
  24. class ConvNorm(nn.Sequential):
  25. def __init__(
  26. self,
  27. in_dim: int,
  28. out_dim: int,
  29. ks: int = 1,
  30. stride: int = 1,
  31. pad: int = 0,
  32. dilation: int = 1,
  33. groups: int = 1,
  34. bn_weight_init: float = 1,
  35. device=None,
  36. dtype=None,
  37. ):
  38. dd = {'device': device, 'dtype': dtype}
  39. super().__init__()
  40. self.add_module('c', nn.Conv2d(in_dim, out_dim, ks, stride, pad, dilation, groups, bias=False, **dd))
  41. self.add_module('bn', nn.BatchNorm2d(out_dim, **dd))
  42. nn.init.constant_(self.bn.weight, bn_weight_init)
  43. nn.init.constant_(self.bn.bias, 0)
  44. @torch.no_grad()
  45. def fuse(self):
  46. c, bn = self._modules.values()
  47. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  48. w = c.weight * w[:, None, None, None]
  49. b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
  50. m = nn.Conv2d(
  51. w.size(1) * self.c.groups,
  52. w.size(0),
  53. w.shape[2:],
  54. stride=self.c.stride,
  55. padding=self.c.padding,
  56. dilation=self.c.dilation,
  57. groups=self.c.groups,
  58. device=c.weight.device,
  59. )
  60. m.weight.data.copy_(w)
  61. m.bias.data.copy_(b)
  62. return m
  63. class NormLinear(nn.Sequential):
  64. def __init__(
  65. self,
  66. in_dim: int,
  67. out_dim: int,
  68. bias: bool = True,
  69. std: float = 0.02,
  70. device=None,
  71. dtype=None,
  72. ):
  73. dd = {'device': device, 'dtype': dtype}
  74. super().__init__()
  75. self.add_module('bn', nn.BatchNorm1d(in_dim, **dd))
  76. self.add_module('l', nn.Linear(in_dim, out_dim, bias=bias, **dd))
  77. trunc_normal_(self.l.weight, std=std)
  78. if bias:
  79. nn.init.constant_(self.l.bias, 0)
  80. @torch.no_grad()
  81. def fuse(self):
  82. bn, l = self._modules.values()
  83. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  84. b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
  85. w = l.weight * w[None, :]
  86. if l.bias is None:
  87. b = b @ self.l.weight.T
  88. else:
  89. b = (l.weight @ b[:, None]).view(-1) + self.l.bias
  90. m = nn.Linear(w.size(1), w.size(0), device=l.weight.device)
  91. m.weight.data.copy_(w)
  92. m.bias.data.copy_(b)
  93. return m
  94. class RepVggDw(nn.Module):
  95. def __init__(
  96. self,
  97. ed: int,
  98. kernel_size: int,
  99. legacy: bool = False,
  100. device=None,
  101. dtype=None,
  102. ):
  103. dd = {'device': device, 'dtype': dtype}
  104. super().__init__()
  105. self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed, **dd)
  106. if legacy:
  107. self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed, **dd)
  108. # Make torchscript happy.
  109. self.bn = nn.Identity()
  110. else:
  111. self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed, **dd)
  112. self.bn = nn.BatchNorm2d(ed, **dd)
  113. self.dim = ed
  114. self.legacy = legacy
  115. def forward(self, x):
  116. return self.bn(self.conv(x) + self.conv1(x) + x)
  117. @torch.no_grad()
  118. def fuse(self):
  119. conv = self.conv.fuse()
  120. if self.legacy:
  121. conv1 = self.conv1.fuse()
  122. else:
  123. conv1 = self.conv1
  124. conv_w = conv.weight
  125. conv_b = conv.bias
  126. conv1_w = conv1.weight
  127. conv1_b = conv1.bias
  128. conv1_w = nn.functional.pad(conv1_w, [1, 1, 1, 1])
  129. identity = nn.functional.pad(
  130. torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1, 1, 1, 1]
  131. )
  132. final_conv_w = conv_w + conv1_w + identity
  133. final_conv_b = conv_b + conv1_b
  134. conv.weight.data.copy_(final_conv_w)
  135. conv.bias.data.copy_(final_conv_b)
  136. if not self.legacy:
  137. bn = self.bn
  138. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  139. w = conv.weight * w[:, None, None, None]
  140. b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / (bn.running_var + bn.eps) ** 0.5
  141. conv.weight.data.copy_(w)
  142. conv.bias.data.copy_(b)
  143. return conv
  144. class RepVitMlp(nn.Module):
  145. def __init__(
  146. self,
  147. in_dim: int,
  148. hidden_dim: int,
  149. act_layer: Type[nn.Module],
  150. device=None,
  151. dtype=None,
  152. ):
  153. dd = {'device': device, 'dtype': dtype}
  154. super().__init__()
  155. self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0, **dd)
  156. self.act = act_layer()
  157. self.conv2 = ConvNorm(hidden_dim, in_dim, 1, 1, 0, bn_weight_init=0, **dd)
  158. def forward(self, x):
  159. return self.conv2(self.act(self.conv1(x)))
  160. class RepViTBlock(nn.Module):
  161. def __init__(
  162. self,
  163. in_dim: int,
  164. mlp_ratio: float,
  165. kernel_size: int,
  166. use_se: bool,
  167. act_layer: Type[nn.Module],
  168. legacy: bool = False,
  169. device=None,
  170. dtype=None,
  171. ):
  172. dd = {'device': device, 'dtype': dtype}
  173. super().__init__()
  174. self.token_mixer = RepVggDw(in_dim, kernel_size, legacy, **dd)
  175. self.se = SqueezeExcite(in_dim, 0.25, **dd) if use_se else nn.Identity()
  176. self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer, **dd)
  177. def forward(self, x):
  178. x = self.token_mixer(x)
  179. x = self.se(x)
  180. identity = x
  181. x = self.channel_mixer(x)
  182. return identity + x
  183. class RepVitStem(nn.Module):
  184. def __init__(
  185. self,
  186. in_chs: int,
  187. out_chs: int,
  188. act_layer: Type[nn.Module],
  189. device=None,
  190. dtype=None,
  191. ):
  192. dd = {'device': device, 'dtype': dtype}
  193. super().__init__()
  194. self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1, **dd)
  195. self.act1 = act_layer()
  196. self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1, **dd)
  197. self.stride = 4
  198. def forward(self, x):
  199. return self.conv2(self.act1(self.conv1(x)))
  200. class RepVitDownsample(nn.Module):
  201. def __init__(
  202. self,
  203. in_dim: int,
  204. mlp_ratio: float,
  205. out_dim: int,
  206. kernel_size: int,
  207. act_layer: Type[nn.Module],
  208. legacy: bool = False,
  209. device=None,
  210. dtype=None,
  211. ):
  212. dd = {'device': device, 'dtype': dtype}
  213. super().__init__()
  214. self.pre_block = RepViTBlock(
  215. in_dim,
  216. mlp_ratio,
  217. kernel_size,
  218. use_se=False,
  219. act_layer=act_layer,
  220. legacy=legacy,
  221. **dd,
  222. )
  223. self.spatial_downsample = ConvNorm(
  224. in_dim,
  225. in_dim,
  226. kernel_size,
  227. stride=2,
  228. pad=(kernel_size - 1) // 2,
  229. groups=in_dim,
  230. **dd,
  231. )
  232. self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1, **dd)
  233. self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer, **dd)
  234. def forward(self, x):
  235. x = self.pre_block(x)
  236. x = self.spatial_downsample(x)
  237. x = self.channel_downsample(x)
  238. identity = x
  239. x = self.ffn(x)
  240. return x + identity
  241. class RepVitClassifier(nn.Module):
  242. def __init__(
  243. self,
  244. dim: int,
  245. num_classes: int,
  246. distillation: bool = False,
  247. drop: float = 0.0,
  248. device=None,
  249. dtype=None,
  250. ):
  251. dd = {'device': device, 'dtype': dtype}
  252. super().__init__()
  253. self.head_drop = nn.Dropout(drop)
  254. self.head = NormLinear(dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  255. self.distillation = distillation
  256. self.distilled_training = False
  257. self.num_classes = num_classes
  258. if distillation:
  259. self.head_dist = NormLinear(dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  260. def forward(self, x):
  261. x = self.head_drop(x)
  262. if self.distillation:
  263. x1, x2 = self.head(x), self.head_dist(x)
  264. if self.training and self.distilled_training and not torch.jit.is_scripting():
  265. return x1, x2
  266. else:
  267. return (x1 + x2) / 2
  268. else:
  269. x = self.head(x)
  270. return x
  271. @torch.no_grad()
  272. def fuse(self):
  273. if not self.num_classes > 0:
  274. return nn.Identity()
  275. head = self.head.fuse()
  276. if self.distillation:
  277. head_dist = self.head_dist.fuse()
  278. head.weight += head_dist.weight
  279. head.bias += head_dist.bias
  280. head.weight /= 2
  281. head.bias /= 2
  282. return head
  283. else:
  284. return head
  285. class RepVitStage(nn.Module):
  286. def __init__(
  287. self,
  288. in_dim: int,
  289. out_dim: int,
  290. depth: int,
  291. mlp_ratio: float,
  292. act_layer: Type[nn.Module],
  293. kernel_size: int = 3,
  294. downsample: bool = True,
  295. legacy: bool = False,
  296. device=None,
  297. dtype=None,
  298. ):
  299. dd = {'device': device, 'dtype': dtype}
  300. super().__init__()
  301. if downsample:
  302. self.downsample = RepVitDownsample(
  303. in_dim,
  304. mlp_ratio,
  305. out_dim,
  306. kernel_size,
  307. act_layer=act_layer,
  308. legacy=legacy,
  309. **dd,
  310. )
  311. else:
  312. assert in_dim == out_dim
  313. self.downsample = nn.Identity()
  314. blocks = []
  315. use_se = True
  316. for _ in range(depth):
  317. blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy, **dd))
  318. use_se = not use_se
  319. self.blocks = nn.Sequential(*blocks)
  320. def forward(self, x):
  321. x = self.downsample(x)
  322. x = self.blocks(x)
  323. return x
  324. class RepVit(nn.Module):
  325. def __init__(
  326. self,
  327. in_chans: int = 3,
  328. img_size: int = 224,
  329. embed_dim: Tuple[int, ...] = (48,),
  330. depth: Tuple[int, ...] = (2,),
  331. mlp_ratio: float = 2,
  332. global_pool: str = 'avg',
  333. kernel_size: int = 3,
  334. num_classes: int = 1000,
  335. act_layer: Type[nn.Module] = nn.GELU,
  336. distillation: bool = True,
  337. drop_rate: float = 0.0,
  338. legacy: bool = False,
  339. device=None,
  340. dtype=None,
  341. ):
  342. super().__init__()
  343. dd = {'device': device, 'dtype': dtype}
  344. self.grad_checkpointing = False
  345. self.global_pool = global_pool
  346. self.embed_dim = embed_dim
  347. self.num_classes = num_classes
  348. self.in_chans = in_chans
  349. in_dim = embed_dim[0]
  350. self.stem = RepVitStem(in_chans, in_dim, act_layer, **dd)
  351. stride = self.stem.stride
  352. resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
  353. num_stages = len(embed_dim)
  354. mlp_ratios = to_ntuple(num_stages)(mlp_ratio)
  355. self.feature_info = []
  356. stages = []
  357. for i in range(num_stages):
  358. downsample = True if i != 0 else False
  359. stages.append(
  360. RepVitStage(
  361. in_dim,
  362. embed_dim[i],
  363. depth[i],
  364. mlp_ratio=mlp_ratios[i],
  365. act_layer=act_layer,
  366. kernel_size=kernel_size,
  367. downsample=downsample,
  368. legacy=legacy,
  369. **dd,
  370. )
  371. )
  372. stage_stride = 2 if downsample else 1
  373. stride *= stage_stride
  374. resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution])
  375. self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
  376. in_dim = embed_dim[i]
  377. self.stages = nn.Sequential(*stages)
  378. self.num_features = self.head_hidden_size = embed_dim[-1]
  379. self.head_drop = nn.Dropout(drop_rate)
  380. self.head = RepVitClassifier(embed_dim[-1], num_classes, distillation, **dd)
  381. @torch.jit.ignore
  382. def group_matcher(self, coarse=False):
  383. matcher = dict(stem=r'^stem', blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]) # stem and embed
  384. return matcher
  385. @torch.jit.ignore
  386. def set_grad_checkpointing(self, enable=True):
  387. self.grad_checkpointing = enable
  388. @torch.jit.ignore
  389. def get_classifier(self) -> nn.Module:
  390. return self.head
  391. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation: bool = False, device=None, dtype=None):
  392. self.num_classes = num_classes
  393. if global_pool is not None:
  394. self.global_pool = global_pool
  395. dd = {'device': device, 'dtype': dtype}
  396. self.head = RepVitClassifier(self.embed_dim[-1], num_classes, distillation, **dd)
  397. @torch.jit.ignore
  398. def set_distilled_training(self, enable=True):
  399. self.head.distilled_training = enable
  400. def forward_intermediates(
  401. self,
  402. x: torch.Tensor,
  403. indices: Optional[Union[int, List[int]]] = None,
  404. norm: bool = False,
  405. stop_early: bool = False,
  406. output_fmt: str = 'NCHW',
  407. intermediates_only: bool = False,
  408. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  409. """ Forward features that returns intermediates.
  410. Args:
  411. x: Input image tensor
  412. indices: Take last n blocks if int, all if None, select matching indices if sequence
  413. norm: Apply norm layer to compatible intermediates
  414. stop_early: Stop iterating over blocks when last desired intermediate hit
  415. output_fmt: Shape of intermediate feature outputs
  416. intermediates_only: Only return intermediate features
  417. Returns:
  418. """
  419. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  420. intermediates = []
  421. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  422. # forward pass
  423. x = self.stem(x)
  424. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  425. stages = self.stages
  426. else:
  427. stages = self.stages[:max_index + 1]
  428. for feat_idx, stage in enumerate(stages):
  429. if self.grad_checkpointing and not torch.jit.is_scripting():
  430. x = checkpoint(stage, x)
  431. else:
  432. x = stage(x)
  433. if feat_idx in take_indices:
  434. intermediates.append(x)
  435. if intermediates_only:
  436. return intermediates
  437. return x, intermediates
  438. def prune_intermediate_layers(
  439. self,
  440. indices: Union[int, List[int]] = 1,
  441. prune_norm: bool = False,
  442. prune_head: bool = True,
  443. ):
  444. """ Prune layers not required for specified intermediates.
  445. """
  446. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  447. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  448. if prune_head:
  449. self.reset_classifier(0, '')
  450. return take_indices
  451. def forward_features(self, x):
  452. x = self.stem(x)
  453. if self.grad_checkpointing and not torch.jit.is_scripting():
  454. x = checkpoint_seq(self.stages, x)
  455. else:
  456. x = self.stages(x)
  457. return x
  458. def forward_head(self, x, pre_logits: bool = False):
  459. if self.global_pool == 'avg':
  460. x = x.mean((2, 3), keepdim=False)
  461. x = self.head_drop(x)
  462. if pre_logits:
  463. return x
  464. return self.head(x)
  465. def forward(self, x):
  466. x = self.forward_features(x)
  467. x = self.forward_head(x)
  468. return x
  469. @torch.no_grad()
  470. def fuse(self):
  471. def fuse_children(net):
  472. for child_name, child in net.named_children():
  473. if hasattr(child, 'fuse'):
  474. fused = child.fuse()
  475. setattr(net, child_name, fused)
  476. fuse_children(fused)
  477. else:
  478. fuse_children(child)
  479. fuse_children(self)
  480. def _cfg(url='', **kwargs):
  481. return {
  482. 'url': url,
  483. 'num_classes': 1000,
  484. 'input_size': (3, 224, 224),
  485. 'pool_size': (7, 7),
  486. 'crop_pct': 0.95,
  487. 'interpolation': 'bicubic',
  488. 'mean': IMAGENET_DEFAULT_MEAN,
  489. 'std': IMAGENET_DEFAULT_STD,
  490. 'first_conv': 'stem.conv1.c',
  491. 'classifier': ('head.head.l', 'head.head_dist.l'),
  492. 'license': 'apache-2.0',
  493. **kwargs,
  494. }
  495. default_cfgs = generate_default_cfgs(
  496. {
  497. 'repvit_m1.dist_in1k': _cfg(
  498. hf_hub_id='timm/',
  499. ),
  500. 'repvit_m2.dist_in1k': _cfg(
  501. hf_hub_id='timm/',
  502. ),
  503. 'repvit_m3.dist_in1k': _cfg(
  504. hf_hub_id='timm/',
  505. ),
  506. 'repvit_m0_9.dist_300e_in1k': _cfg(
  507. hf_hub_id='timm/',
  508. ),
  509. 'repvit_m0_9.dist_450e_in1k': _cfg(
  510. hf_hub_id='timm/',
  511. ),
  512. 'repvit_m1_0.dist_300e_in1k': _cfg(
  513. hf_hub_id='timm/',
  514. ),
  515. 'repvit_m1_0.dist_450e_in1k': _cfg(
  516. hf_hub_id='timm/',
  517. ),
  518. 'repvit_m1_1.dist_300e_in1k': _cfg(
  519. hf_hub_id='timm/',
  520. ),
  521. 'repvit_m1_1.dist_450e_in1k': _cfg(
  522. hf_hub_id='timm/',
  523. ),
  524. 'repvit_m1_5.dist_300e_in1k': _cfg(
  525. hf_hub_id='timm/',
  526. ),
  527. 'repvit_m1_5.dist_450e_in1k': _cfg(
  528. hf_hub_id='timm/',
  529. ),
  530. 'repvit_m2_3.dist_300e_in1k': _cfg(
  531. hf_hub_id='timm/',
  532. ),
  533. 'repvit_m2_3.dist_450e_in1k': _cfg(
  534. hf_hub_id='timm/',
  535. ),
  536. }
  537. )
  538. def _create_repvit(variant, pretrained=False, **kwargs):
  539. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
  540. model = build_model_with_cfg(
  541. RepVit,
  542. variant,
  543. pretrained,
  544. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  545. **kwargs,
  546. )
  547. return model
  548. @register_model
  549. def repvit_m1(pretrained=False, **kwargs):
  550. """
  551. Constructs a RepViT-M1 model
  552. """
  553. model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2), legacy=True)
  554. return _create_repvit('repvit_m1', pretrained=pretrained, **dict(model_args, **kwargs))
  555. @register_model
  556. def repvit_m2(pretrained=False, **kwargs):
  557. """
  558. Constructs a RepViT-M2 model
  559. """
  560. model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2), legacy=True)
  561. return _create_repvit('repvit_m2', pretrained=pretrained, **dict(model_args, **kwargs))
  562. @register_model
  563. def repvit_m3(pretrained=False, **kwargs):
  564. """
  565. Constructs a RepViT-M3 model
  566. """
  567. model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2), legacy=True)
  568. return _create_repvit('repvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))
  569. @register_model
  570. def repvit_m0_9(pretrained=False, **kwargs):
  571. """
  572. Constructs a RepViT-M0.9 model
  573. """
  574. model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2))
  575. return _create_repvit('repvit_m0_9', pretrained=pretrained, **dict(model_args, **kwargs))
  576. @register_model
  577. def repvit_m1_0(pretrained=False, **kwargs):
  578. """
  579. Constructs a RepViT-M1.0 model
  580. """
  581. model_args = dict(embed_dim=(56, 112, 224, 448), depth=(2, 2, 14, 2))
  582. return _create_repvit('repvit_m1_0', pretrained=pretrained, **dict(model_args, **kwargs))
  583. @register_model
  584. def repvit_m1_1(pretrained=False, **kwargs):
  585. """
  586. Constructs a RepViT-M1.1 model
  587. """
  588. model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2))
  589. return _create_repvit('repvit_m1_1', pretrained=pretrained, **dict(model_args, **kwargs))
  590. @register_model
  591. def repvit_m1_5(pretrained=False, **kwargs):
  592. """
  593. Constructs a RepViT-M1.5 model
  594. """
  595. model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 24, 4))
  596. return _create_repvit('repvit_m1_5', pretrained=pretrained, **dict(model_args, **kwargs))
  597. @register_model
  598. def repvit_m2_3(pretrained=False, **kwargs):
  599. """
  600. Constructs a RepViT-M2.3 model
  601. """
  602. model_args = dict(embed_dim=(80, 160, 320, 640), depth=(6, 6, 34, 2))
  603. return _create_repvit('repvit_m2_3', pretrained=pretrained, **dict(model_args, **kwargs))