shvit.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. """SHViT
  2. SHViT: Single-Head Vision Transformer with Memory Efficient Macro Design
  3. Code: https://github.com/ysj9909/SHViT
  4. Paper: https://arxiv.org/abs/2401.16456
  5. @inproceedings{yun2024shvit,
  6. author={Yun, Seokju and Ro, Youngmin},
  7. title={SHViT: Single-Head Vision Transformer with Memory Efficient Macro Design},
  8. booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  9. pages={5756--5767},
  10. year={2024}
  11. }
  12. """
  13. from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  18. from timm.layers import GroupNorm1, SqueezeExcite, SelectAdaptivePool2d, LayerType, trunc_normal_
  19. from ._builder import build_model_with_cfg
  20. from ._features import feature_take_indices
  21. from ._manipulate import checkpoint_seq
  22. from ._registry import register_model, generate_default_cfgs
  23. __all__ = ['SHViT']
  24. class Residual(nn.Module):
  25. def __init__(self, m: nn.Module):
  26. super().__init__()
  27. self.m = m
  28. def forward(self, x: torch.Tensor) -> torch.Tensor:
  29. return x + self.m(x)
  30. @torch.no_grad()
  31. def fuse(self) -> nn.Module:
  32. if isinstance(self.m, Conv2dNorm):
  33. m = self.m.fuse()
  34. assert(m.groups == m.in_channels)
  35. identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
  36. identity = F.pad(identity, [1,1,1,1])
  37. m.weight += identity.to(m.weight.device)
  38. return m
  39. else:
  40. return self
  41. class Conv2dNorm(nn.Sequential):
  42. def __init__(
  43. self,
  44. in_channels: int,
  45. out_channels: int,
  46. kernel_size: int = 1,
  47. stride: int = 1,
  48. padding: int = 0,
  49. bn_weight_init: int = 1,
  50. device=None,
  51. dtype=None,
  52. **kwargs,
  53. ):
  54. dd = {'device': device, 'dtype': dtype}
  55. super().__init__()
  56. self.add_module('c', nn.Conv2d(
  57. in_channels, out_channels, kernel_size, stride, padding, bias=False, **dd, **kwargs))
  58. self.add_module('bn', nn.BatchNorm2d(out_channels, **dd))
  59. nn.init.constant_(self.bn.weight, bn_weight_init)
  60. nn.init.constant_(self.bn.bias, 0)
  61. @torch.no_grad()
  62. def fuse(self) -> nn.Conv2d:
  63. c, bn = self._modules.values()
  64. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  65. w = c.weight * w[:, None, None, None]
  66. b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
  67. m = nn.Conv2d(
  68. in_channels=w.size(1) * self.c.groups,
  69. out_channels=w.size(0),
  70. kernel_size=w.shape[2:],
  71. stride=self.c.stride,
  72. padding=self.c.padding,
  73. dilation=self.c.dilation,
  74. groups=self.c.groups,
  75. device=c.weight.device,
  76. dtype=c.weight.dtype,
  77. )
  78. m.weight.data.copy_(w)
  79. m.bias.data.copy_(b)
  80. return m
  81. class NormLinear(nn.Sequential):
  82. def __init__(
  83. self,
  84. in_features: int,
  85. out_features: int,
  86. bias: bool = True,
  87. std: float = 0.02,
  88. device=None,
  89. dtype=None,
  90. ):
  91. dd = {'device': device, 'dtype': dtype}
  92. super().__init__()
  93. self.add_module('bn', nn.BatchNorm1d(in_features, **dd))
  94. self.add_module('l', nn.Linear(in_features, out_features, bias=bias, **dd))
  95. trunc_normal_(self.l.weight, std=std)
  96. if bias:
  97. nn.init.constant_(self.l.bias, 0)
  98. @torch.no_grad()
  99. def fuse(self) -> nn.Linear:
  100. bn, l = self._modules.values()
  101. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  102. b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
  103. w = l.weight * w[None, :]
  104. if l.bias is None:
  105. b = b @ self.l.weight.T
  106. else:
  107. b = (l.weight @ b[:, None]).view(-1) + self.l.bias
  108. m = nn.Linear(w.size(1), w.size(0), device=l.weight.device, dtype=l.weight.dtype)
  109. m.weight.data.copy_(w)
  110. m.bias.data.copy_(b)
  111. return m
  112. class PatchMerging(nn.Module):
  113. def __init__(
  114. self,
  115. dim: int,
  116. out_dim: int,
  117. act_layer: Type[nn.Module] = nn.ReLU,
  118. device=None,
  119. dtype=None,
  120. ):
  121. dd = {'device': device, 'dtype': dtype}
  122. super().__init__()
  123. hid_dim = int(dim * 4)
  124. self.conv1 = Conv2dNorm(dim, hid_dim, **dd)
  125. self.act1 = act_layer()
  126. self.conv2 = Conv2dNorm(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim, **dd)
  127. self.act2 = act_layer()
  128. self.se = SqueezeExcite(hid_dim, 0.25, **dd)
  129. self.conv3 = Conv2dNorm(hid_dim, out_dim, **dd)
  130. def forward(self, x: torch.Tensor) -> torch.Tensor:
  131. x = self.conv1(x)
  132. x = self.act1(x)
  133. x = self.conv2(x)
  134. x = self.act2(x)
  135. x = self.se(x)
  136. x = self.conv3(x)
  137. return x
  138. class FFN(nn.Module):
  139. def __init__(
  140. self,
  141. dim: int,
  142. embed_dim: int,
  143. act_layer: Type[nn.Module] = nn.ReLU,
  144. device=None,
  145. dtype=None,
  146. ):
  147. dd = {'device': device, 'dtype': dtype}
  148. super().__init__()
  149. self.pw1 = Conv2dNorm(dim, embed_dim, **dd)
  150. self.act = act_layer()
  151. self.pw2 = Conv2dNorm(embed_dim, dim, bn_weight_init=0, **dd)
  152. def forward(self, x: torch.Tensor) -> torch.Tensor:
  153. x = self.pw1(x)
  154. x = self.act(x)
  155. x = self.pw2(x)
  156. return x
  157. class SHSA(nn.Module):
  158. """Single-Head Self-Attention"""
  159. def __init__(
  160. self,
  161. dim: int,
  162. qk_dim: int,
  163. pdim: int,
  164. norm_layer: Type[nn.Module] = GroupNorm1,
  165. act_layer: Type[nn.Module] = nn.ReLU,
  166. device=None,
  167. dtype=None,
  168. ):
  169. dd = {'device': device, 'dtype': dtype}
  170. super().__init__()
  171. self.scale = qk_dim ** -0.5
  172. self.qk_dim = qk_dim
  173. self.dim = dim
  174. self.pdim = pdim
  175. self.pre_norm = norm_layer(pdim, **dd)
  176. self.qkv = Conv2dNorm(pdim, qk_dim * 2 + pdim, **dd)
  177. self.proj = nn.Sequential(act_layer(), Conv2dNorm(dim, dim, bn_weight_init=0, **dd))
  178. def forward(self, x: torch.Tensor) -> torch.Tensor:
  179. B, _, H, W = x.shape
  180. x1, x2 = torch.split(x, [self.pdim, self.dim - self.pdim], dim = 1)
  181. x1 = self.pre_norm(x1)
  182. qkv = self.qkv(x1)
  183. q, k, v = torch.split(qkv, [self.qk_dim, self.qk_dim, self.pdim], dim=1)
  184. q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
  185. attn = (q.transpose(-2, -1) @ k) * self.scale
  186. attn = attn.softmax(dim=-1)
  187. x1 = (v @ attn.transpose(-2, -1)).reshape(B, self.pdim, H, W)
  188. x = self.proj(torch.cat([x1, x2], dim = 1))
  189. return x
  190. class BasicBlock(nn.Module):
  191. def __init__(
  192. self,
  193. dim: int,
  194. qk_dim: int,
  195. pdim: int,
  196. type: str,
  197. norm_layer: Type[nn.Module] = GroupNorm1,
  198. act_layer: Type[nn.Module] = nn.ReLU,
  199. device=None,
  200. dtype=None,
  201. ):
  202. dd = {'device': device, 'dtype': dtype}
  203. super().__init__()
  204. self.conv = Residual(Conv2dNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0, **dd))
  205. if type == "s":
  206. self.mixer = Residual(SHSA(dim, qk_dim, pdim, norm_layer, act_layer, **dd))
  207. else:
  208. self.mixer = nn.Identity()
  209. self.ffn = Residual(FFN(dim, int(dim * 2), **dd))
  210. def forward(self, x: torch.Tensor) -> torch.Tensor:
  211. x = self.conv(x)
  212. x = self.mixer(x)
  213. x = self.ffn(x)
  214. return x
  215. class StageBlock(nn.Module):
  216. def __init__(
  217. self,
  218. prev_dim: int,
  219. dim: int,
  220. qk_dim: int,
  221. pdim: int,
  222. type: str,
  223. depth: int,
  224. norm_layer: Type[nn.Module] = GroupNorm1,
  225. act_layer: Type[nn.Module] = nn.ReLU,
  226. device=None,
  227. dtype=None,
  228. ):
  229. dd = {'device': device, 'dtype': dtype}
  230. super().__init__()
  231. self.grad_checkpointing = False
  232. self.downsample = nn.Sequential(
  233. Residual(Conv2dNorm(prev_dim, prev_dim, 3, 1, 1, groups=prev_dim, **dd)),
  234. Residual(FFN(prev_dim, int(prev_dim * 2), act_layer, **dd)),
  235. PatchMerging(prev_dim, dim, act_layer, **dd),
  236. Residual(Conv2dNorm(dim, dim, 3, 1, 1, groups=dim, **dd)),
  237. Residual(FFN(dim, int(dim * 2), act_layer, **dd)),
  238. ) if prev_dim != dim else nn.Identity()
  239. self.blocks = nn.Sequential(*[
  240. BasicBlock(dim, qk_dim, pdim, type, norm_layer, act_layer, **dd) for _ in range(depth)
  241. ])
  242. def forward(self, x: torch.Tensor) -> torch.Tensor:
  243. x = self.downsample(x)
  244. if self.grad_checkpointing and not torch.jit.is_scripting():
  245. x = checkpoint_seq(self.blocks, x)
  246. else:
  247. x = self.blocks(x)
  248. return x
  249. class SHViT(nn.Module):
  250. def __init__(
  251. self,
  252. in_chans: int = 3,
  253. num_classes: int = 1000,
  254. global_pool: str = 'avg',
  255. embed_dim: Tuple[int, int, int] = (128, 256, 384),
  256. partial_dim: Tuple[int, int, int] = (32, 64, 96),
  257. qk_dim: Tuple[int, int, int] = (16, 16, 16),
  258. depth: Tuple[int, int, int] = (1, 2, 3),
  259. types: Tuple[str, str, str] = ("s", "s", "s"),
  260. drop_rate: float = 0.,
  261. norm_layer: Type[nn.Module] = GroupNorm1,
  262. act_layer: Type[nn.Module] = nn.ReLU,
  263. device=None,
  264. dtype=None,
  265. ):
  266. super().__init__()
  267. dd = {'device': device, 'dtype': dtype}
  268. self.num_classes = num_classes
  269. self.in_chans = in_chans
  270. self.drop_rate = drop_rate
  271. self.feature_info = []
  272. # Patch embedding
  273. stem_chs = embed_dim[0]
  274. self.patch_embed = nn.Sequential(
  275. Conv2dNorm(in_chans, stem_chs // 8, 3, 2, 1, **dd),
  276. act_layer(),
  277. Conv2dNorm(stem_chs // 8, stem_chs // 4, 3, 2, 1, **dd),
  278. act_layer(),
  279. Conv2dNorm(stem_chs // 4, stem_chs // 2, 3, 2, 1, **dd),
  280. act_layer(),
  281. Conv2dNorm(stem_chs // 2, stem_chs, 3, 2, 1, **dd)
  282. )
  283. # Build SHViT blocks
  284. stages = []
  285. prev_chs = stem_chs
  286. for i in range(len(embed_dim)):
  287. stages.append(StageBlock(
  288. prev_dim=prev_chs,
  289. dim=embed_dim[i],
  290. qk_dim=qk_dim[i],
  291. pdim=partial_dim[i],
  292. type=types[i],
  293. depth=depth[i],
  294. norm_layer=norm_layer,
  295. act_layer=act_layer,
  296. **dd,
  297. ))
  298. prev_chs = embed_dim[i]
  299. self.feature_info.append(dict(num_chs=prev_chs, reduction=2**(i+4), module=f'stages.{i}'))
  300. self.stages = nn.Sequential(*stages)
  301. # Classifier head
  302. self.num_features = self.head_hidden_size = embed_dim[-1]
  303. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  304. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  305. self.head = NormLinear(self.head_hidden_size, num_classes, **dd) if num_classes > 0 else nn.Identity()
  306. @torch.jit.ignore
  307. def no_weight_decay(self) -> Set:
  308. return set()
  309. @torch.jit.ignore
  310. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  311. matcher = dict(
  312. stem=r'^patch_embed', # stem and embed
  313. blocks=r'^stages\.(\d+)' if coarse else [
  314. (r'^stages\.(\d+).downsample', (0,)),
  315. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  316. ]
  317. )
  318. return matcher
  319. @torch.jit.ignore
  320. def set_grad_checkpointing(self, enable=True):
  321. for s in self.stages:
  322. s.grad_checkpointing = enable
  323. @torch.jit.ignore
  324. def get_classifier(self) -> nn.Module:
  325. return self.head.l
  326. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  327. self.num_classes = num_classes
  328. # cannot meaningfully change pooling of efficient head after creation
  329. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  330. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  331. self.head = NormLinear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
  332. def forward_intermediates(
  333. self,
  334. x: torch.Tensor,
  335. indices: Optional[Union[int, List[int]]] = None,
  336. norm: bool = False,
  337. stop_early: bool = False,
  338. output_fmt: str = 'NCHW',
  339. intermediates_only: bool = False,
  340. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  341. """ Forward features that returns intermediates.
  342. Args:
  343. x: Input image tensor
  344. indices: Take last n blocks if int, all if None, select matching indices if sequence
  345. norm: Apply norm layer to compatible intermediates
  346. stop_early: Stop iterating over blocks when last desired intermediate hit
  347. output_fmt: Shape of intermediate feature outputs
  348. intermediates_only: Only return intermediate features
  349. Returns:
  350. """
  351. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  352. intermediates = []
  353. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  354. # forward pass
  355. x = self.patch_embed(x)
  356. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  357. stages = self.stages
  358. else:
  359. stages = self.stages[:max_index + 1]
  360. for feat_idx, stage in enumerate(stages):
  361. x = stage(x)
  362. if feat_idx in take_indices:
  363. intermediates.append(x)
  364. if intermediates_only:
  365. return intermediates
  366. return x, intermediates
  367. def prune_intermediate_layers(
  368. self,
  369. indices: Union[int, List[int]] = 1,
  370. prune_norm: bool = False,
  371. prune_head: bool = True,
  372. ):
  373. """ Prune layers not required for specified intermediates.
  374. """
  375. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  376. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  377. if prune_head:
  378. self.reset_classifier(0, '')
  379. return take_indices
  380. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  381. x = self.patch_embed(x)
  382. x = self.stages(x)
  383. return x
  384. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  385. x = self.global_pool(x)
  386. x = self.flatten(x)
  387. if self.drop_rate > 0.:
  388. x = F.dropout(x, p=self.drop_rate, training=self.training)
  389. return x if pre_logits else self.head(x)
  390. def forward(self, x: torch.Tensor) -> torch.Tensor:
  391. x = self.forward_features(x)
  392. x = self.forward_head(x)
  393. return x
  394. @torch.no_grad()
  395. def fuse(self):
  396. def fuse_children(net):
  397. for child_name, child in net.named_children():
  398. if hasattr(child, 'fuse'):
  399. fused = child.fuse()
  400. setattr(net, child_name, fused)
  401. fuse_children(fused)
  402. else:
  403. fuse_children(child)
  404. fuse_children(self)
  405. def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
  406. state_dict = state_dict.get('model', state_dict)
  407. # out_dict = {}
  408. # import re
  409. # replace_rules = [
  410. # (re.compile(r'^blocks1\.'), 'stages.0.blocks.'),
  411. # (re.compile(r'^blocks2\.'), 'stages.1.blocks.'),
  412. # (re.compile(r'^blocks3\.'), 'stages.2.blocks.'),
  413. # ]
  414. # downsample_mapping = {}
  415. # for i in range(1, 3):
  416. # downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.0\\.'] = f'stages.{i}.downsample.0.'
  417. # downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.1\\.'] = f'stages.{i}.downsample.1.'
  418. # downsample_mapping[f'^stages\\.{i}\\.blocks\\.1\\.'] = f'stages.{i}.downsample.2.'
  419. # downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.0\\.'] = f'stages.{i}.downsample.3.'
  420. # downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.1\\.'] = f'stages.{i}.downsample.4.'
  421. # for j in range(3, 10):
  422. # downsample_mapping[f'^stages\\.{i}\\.blocks\\.{j}\\.'] = f'stages.{i}.blocks.{j - 3}.'
  423. #
  424. # downsample_patterns = [
  425. # (re.compile(pattern), replacement) for pattern, replacement in downsample_mapping.items()]
  426. #
  427. # for k, v in state_dict.items():
  428. # for pattern, replacement in replace_rules:
  429. # k = pattern.sub(replacement, k)
  430. # for pattern, replacement in downsample_patterns:
  431. # k = pattern.sub(replacement, k)
  432. # out_dict[k] = v
  433. return state_dict
  434. def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
  435. return {
  436. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4),
  437. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  438. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  439. 'first_conv': 'patch_embed.0.c', 'classifier': 'head.l',
  440. 'license': 'mit',
  441. 'paper_ids': 'arXiv:2401.16456',
  442. 'paper_name': 'SHViT: Single-Head Vision Transformer with Memory Efficient Macro Design',
  443. 'origin_url': 'https://github.com/ysj9909/SHViT',
  444. **kwargs
  445. }
  446. default_cfgs = generate_default_cfgs({
  447. 'shvit_s1.in1k': _cfg(
  448. hf_hub_id='timm/',
  449. #url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s1.pth',
  450. ),
  451. 'shvit_s2.in1k': _cfg(
  452. hf_hub_id='timm/',
  453. #url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s2.pth',
  454. ),
  455. 'shvit_s3.in1k': _cfg(
  456. hf_hub_id='timm/',
  457. #url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s3.pth',
  458. ),
  459. 'shvit_s4.in1k': _cfg(
  460. hf_hub_id='timm/',
  461. #url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s4.pth',
  462. input_size=(3, 256, 256),
  463. ),
  464. })
  465. def _create_shvit(variant: str, pretrained: bool = False, **kwargs: Any) -> SHViT:
  466. model = build_model_with_cfg(
  467. SHViT, variant, pretrained,
  468. pretrained_filter_fn=checkpoint_filter_fn,
  469. feature_cfg=dict(out_indices=(0, 1, 2), flatten_sequential=True),
  470. **kwargs,
  471. )
  472. return model
  473. @register_model
  474. def shvit_s1(pretrained: bool = False, **kwargs: Any) -> SHViT:
  475. model_args = dict(
  476. embed_dim=(128, 224, 320), depth=(2, 4, 5), partial_dim=(32, 48, 68), types=("i", "s", "s"))
  477. return _create_shvit('shvit_s1', pretrained=pretrained, **dict(model_args, **kwargs))
  478. @register_model
  479. def shvit_s2(pretrained: bool = False, **kwargs: Any) -> SHViT:
  480. model_args = dict(
  481. embed_dim=(128, 308, 448), depth=(2, 4, 5), partial_dim=(32, 66, 96), types=("i", "s", "s"))
  482. return _create_shvit('shvit_s2', pretrained=pretrained, **dict(model_args, **kwargs))
  483. @register_model
  484. def shvit_s3(pretrained: bool = False, **kwargs: Any) -> SHViT:
  485. model_args = dict(
  486. embed_dim=(192, 352, 448), depth=(3, 5, 5), partial_dim=(48, 75, 96), types=("i", "s", "s"))
  487. return _create_shvit('shvit_s3', pretrained=pretrained, **dict(model_args, **kwargs))
  488. @register_model
  489. def shvit_s4(pretrained: bool = False, **kwargs: Any) -> SHViT:
  490. model_args = dict(
  491. embed_dim=(224, 336, 448), depth=(4, 7, 6), partial_dim=(48, 72, 96), types=("i", "s", "s"))
  492. return _create_shvit('shvit_s4', pretrained=pretrained, **dict(model_args, **kwargs))