swiftformer.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. """SwiftFormer
  2. SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications
  3. Code: https://github.com/Amshaker/SwiftFormer
  4. Paper: https://arxiv.org/pdf/2303.15446
  5. @InProceedings{Shaker_2023_ICCV,
  6. author = {Shaker, Abdelrahman and Maaz, Muhammad and Rasheed, Hanoona and Khan, Salman and Yang, Ming-Hsuan and Khan, Fahad Shahbaz},
  7. title = {SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications},
  8. booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
  9. year = {2023},
  10. }
  11. """
  12. import re
  13. from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  18. from timm.layers import DropPath, Linear, LayerType, to_2tuple, trunc_normal_
  19. from ._builder import build_model_with_cfg
  20. from ._features import feature_take_indices
  21. from ._manipulate import checkpoint_seq
  22. from ._registry import generate_default_cfgs, register_model
  23. __all__ = ['SwiftFormer']
  24. class LayerScale2d(nn.Module):
  25. def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False, device=None, dtype=None):
  26. dd = {'device': device, 'dtype': dtype}
  27. super().__init__()
  28. self.inplace = inplace
  29. self.gamma = nn.Parameter(
  30. init_values * torch.ones(dim, 1, 1, **dd), requires_grad=True)
  31. def forward(self, x: torch.Tensor) -> torch.Tensor:
  32. return x.mul_(self.gamma) if self.inplace else x * self.gamma
  33. class Embedding(nn.Module):
  34. """
  35. Patch Embedding that is implemented by a layer of conv.
  36. Input: tensor in shape [B, C, H, W]
  37. Output: tensor in shape [B, C, H/stride, W/stride]
  38. """
  39. def __init__(
  40. self,
  41. in_chans: int = 3,
  42. embed_dim: int = 768,
  43. patch_size: int = 16,
  44. stride: int = 16,
  45. padding: int = 0,
  46. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  47. device=None,
  48. dtype=None,
  49. ):
  50. dd = {'device': device, 'dtype': dtype}
  51. super().__init__()
  52. patch_size = to_2tuple(patch_size)
  53. stride = to_2tuple(stride)
  54. padding = to_2tuple(padding)
  55. self.proj = nn.Conv2d(in_chans, embed_dim, patch_size, stride, padding, **dd)
  56. self.norm = norm_layer(embed_dim, **dd) if norm_layer else nn.Identity()
  57. def forward(self, x: torch.Tensor) -> torch.Tensor:
  58. x = self.proj(x)
  59. x = self.norm(x)
  60. return x
  61. class ConvEncoder(nn.Module):
  62. """
  63. Implementation of ConvEncoder with 3*3 and 1*1 convolutions.
  64. Input: tensor with shape [B, C, H, W]
  65. Output: tensor with shape [B, C, H, W]
  66. """
  67. def __init__(
  68. self,
  69. dim: int,
  70. hidden_dim: int = 64,
  71. kernel_size: int = 3,
  72. drop_path: float = 0.,
  73. act_layer: Type[nn.Module] = nn.GELU,
  74. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  75. use_layer_scale: bool = True,
  76. device=None,
  77. dtype=None,
  78. ):
  79. dd = {'device': device, 'dtype': dtype}
  80. super().__init__()
  81. self.dwconv = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size // 2, groups=dim, **dd)
  82. self.norm = norm_layer(dim, **dd)
  83. self.pwconv1 = nn.Conv2d(dim, hidden_dim, 1, **dd)
  84. self.act = act_layer()
  85. self.pwconv2 = nn.Conv2d(hidden_dim, dim, 1, **dd)
  86. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  87. self.layer_scale = LayerScale2d(dim, 1, **dd) if use_layer_scale else nn.Identity()
  88. def forward(self, x: torch.Tensor) -> torch.Tensor:
  89. input = x
  90. x = self.dwconv(x)
  91. x = self.norm(x)
  92. x = self.pwconv1(x)
  93. x = self.act(x)
  94. x = self.pwconv2(x)
  95. x = self.layer_scale(x)
  96. x = input + self.drop_path(x)
  97. return x
  98. class Mlp(nn.Module):
  99. """
  100. Implementation of MLP layer with 1*1 convolutions.
  101. Input: tensor with shape [B, C, H, W]
  102. Output: tensor with shape [B, C, H, W]
  103. """
  104. def __init__(
  105. self,
  106. in_features: int,
  107. hidden_features: Optional[int] = None,
  108. out_features: Optional[int] = None,
  109. act_layer: Type[nn.Module] = nn.GELU,
  110. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  111. drop: float = 0.,
  112. device=None,
  113. dtype=None,
  114. ):
  115. dd = {'device': device, 'dtype': dtype}
  116. super().__init__()
  117. out_features = out_features or in_features
  118. hidden_features = hidden_features or in_features
  119. self.norm1 = norm_layer(in_features, **dd)
  120. self.fc1 = nn.Conv2d(in_features, hidden_features, 1, **dd)
  121. self.act = act_layer()
  122. self.fc2 = nn.Conv2d(hidden_features, out_features, 1, **dd)
  123. self.drop = nn.Dropout(drop)
  124. def forward(self, x: torch.Tensor) -> torch.Tensor:
  125. x = self.norm1(x)
  126. x = self.fc1(x)
  127. x = self.act(x)
  128. x = self.drop(x)
  129. x = self.fc2(x)
  130. x = self.drop(x)
  131. return x
  132. class EfficientAdditiveAttention(nn.Module):
  133. """
  134. Efficient Additive Attention module for SwiftFormer.
  135. Input: tensor in shape [B, C, H, W]
  136. Output: tensor in shape [B, C, H, W]
  137. """
  138. def __init__(
  139. self,
  140. in_dims: int = 512,
  141. token_dim: int = 256,
  142. num_heads: int = 1,
  143. device=None,
  144. dtype=None,
  145. ):
  146. dd = {'device': device, 'dtype': dtype}
  147. super().__init__()
  148. self.scale_factor = token_dim ** -0.5
  149. self.to_query = nn.Linear(in_dims, token_dim * num_heads, **dd)
  150. self.to_key = nn.Linear(in_dims, token_dim * num_heads, **dd)
  151. self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1, **dd))
  152. self.proj = nn.Linear(token_dim * num_heads, token_dim * num_heads, **dd)
  153. self.final = nn.Linear(token_dim * num_heads, token_dim, **dd)
  154. def forward(self, x: torch.Tensor) -> torch.Tensor:
  155. B, _, H, W = x.shape
  156. x = x.flatten(2).permute(0, 2, 1)
  157. query = F.normalize(self.to_query(x), dim=-1)
  158. key = F.normalize(self.to_key(x), dim=-1)
  159. attn = F.normalize(query @ self.w_g * self.scale_factor, dim=1)
  160. attn = torch.sum(attn * query, dim=1, keepdim=True)
  161. out = self.proj(attn * key) + query
  162. out = self.final(out).permute(0, 2, 1).reshape(B, -1, H, W)
  163. return out
  164. class LocalRepresentation(nn.Module):
  165. """
  166. Local Representation module for SwiftFormer that is implemented by 3*3 depth-wise and point-wise convolutions.
  167. Input: tensor in shape [B, C, H, W]
  168. Output: tensor in shape [B, C, H, W]
  169. """
  170. def __init__(
  171. self,
  172. dim: int,
  173. kernel_size: int = 3,
  174. drop_path: float = 0.,
  175. use_layer_scale: bool = True,
  176. act_layer: Type[nn.Module] = nn.GELU,
  177. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  178. device=None,
  179. dtype=None,
  180. ):
  181. dd = {'device': device, 'dtype': dtype}
  182. super().__init__()
  183. self.dwconv = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size // 2, groups=dim, **dd)
  184. self.norm = norm_layer(dim, **dd)
  185. self.pwconv1 = nn.Conv2d(dim, dim, kernel_size=1, **dd)
  186. self.act = act_layer()
  187. self.pwconv2 = nn.Conv2d(dim, dim, kernel_size=1, **dd)
  188. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  189. self.layer_scale = LayerScale2d(dim, 1, **dd) if use_layer_scale else nn.Identity()
  190. def forward(self, x: torch.Tensor) -> torch.Tensor:
  191. skip = x
  192. x = self.dwconv(x)
  193. x = self.norm(x)
  194. x = self.pwconv1(x)
  195. x = self.act(x)
  196. x = self.pwconv2(x)
  197. x = self.layer_scale(x)
  198. x = skip + self.drop_path(x)
  199. return x
  200. class Block(nn.Module):
  201. """
  202. SwiftFormer Encoder Block for SwiftFormer. It consists of :
  203. (1) Local representation module, (2) EfficientAdditiveAttention, and (3) MLP block.
  204. Input: tensor in shape [B, C, H, W]
  205. Output: tensor in shape [B, C, H, W]
  206. """
  207. def __init__(
  208. self,
  209. dim: int,
  210. mlp_ratio: float = 4.,
  211. drop_rate: float = 0.,
  212. drop_path: float = 0.,
  213. act_layer: Type[nn.Module] = nn.GELU,
  214. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  215. use_layer_scale: bool = True,
  216. layer_scale_init_value: float = 1e-5,
  217. device=None,
  218. dtype=None,
  219. ):
  220. dd = {'device': device, 'dtype': dtype}
  221. super().__init__()
  222. self.local_representation = LocalRepresentation(
  223. dim=dim,
  224. use_layer_scale=use_layer_scale,
  225. act_layer=act_layer,
  226. norm_layer=norm_layer,
  227. **dd,
  228. )
  229. self.attn = EfficientAdditiveAttention(in_dims=dim, token_dim=dim, **dd)
  230. self.linear = Mlp(
  231. in_features=dim,
  232. hidden_features=int(dim * mlp_ratio),
  233. act_layer=act_layer,
  234. norm_layer=norm_layer,
  235. drop=drop_rate,
  236. **dd,
  237. )
  238. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  239. self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value, **dd) \
  240. if use_layer_scale else nn.Identity()
  241. self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value, **dd) \
  242. if use_layer_scale else nn.Identity()
  243. def forward(self, x: torch.Tensor) -> torch.Tensor:
  244. x = self.local_representation(x)
  245. x = x + self.drop_path(self.layer_scale_1(self.attn(x)))
  246. x = x + self.drop_path(self.layer_scale_2(self.linear(x)))
  247. return x
  248. class Stage(nn.Module):
  249. """
  250. Implementation of each SwiftFormer stages. Here, SwiftFormerEncoder used as the last block in all stages, while ConvEncoder used in the rest of the blocks.
  251. Input: tensor in shape [B, C, H, W]
  252. Output: tensor in shape [B, C, H, W]
  253. """
  254. def __init__(
  255. self,
  256. dim: int,
  257. index: int,
  258. layers: List[int],
  259. mlp_ratio: float = 4.,
  260. act_layer: Type[nn.Module] = nn.GELU,
  261. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  262. drop_rate: float = 0.,
  263. drop_path_rate: float = 0.,
  264. use_layer_scale: bool = True,
  265. layer_scale_init_value: float = 1e-5,
  266. downsample: Optional[Type[nn.Module]] = None,
  267. device=None,
  268. dtype=None,
  269. ):
  270. dd = {'device': device, 'dtype': dtype}
  271. super().__init__()
  272. self.grad_checkpointing = False
  273. self.downsample = downsample if downsample is not None else nn.Identity()
  274. blocks = []
  275. for block_idx in range(layers[index]):
  276. block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
  277. if layers[index] - block_idx <= 1:
  278. blocks.append(Block(
  279. dim,
  280. mlp_ratio=mlp_ratio,
  281. drop_rate=drop_rate,
  282. drop_path=block_dpr,
  283. act_layer=act_layer,
  284. norm_layer=norm_layer,
  285. use_layer_scale=use_layer_scale,
  286. layer_scale_init_value=layer_scale_init_value,
  287. **dd,
  288. ))
  289. else:
  290. blocks.append(ConvEncoder(
  291. dim=dim,
  292. hidden_dim=int(mlp_ratio * dim),
  293. kernel_size=3,
  294. drop_path=block_dpr,
  295. act_layer=act_layer,
  296. norm_layer=norm_layer,
  297. use_layer_scale=use_layer_scale,
  298. **dd,
  299. ))
  300. self.blocks = nn.Sequential(*blocks)
  301. def forward(self, x: torch.Tensor) -> torch.Tensor:
  302. x = self.downsample(x)
  303. if self.grad_checkpointing and not torch.jit.is_scripting():
  304. x = checkpoint_seq(self.blocks, x)
  305. else:
  306. x = self.blocks(x)
  307. return x
  308. class SwiftFormer(nn.Module):
  309. def __init__(
  310. self,
  311. layers: List[int] = [3, 3, 6, 4],
  312. embed_dims: List[int] = [48, 56, 112, 220],
  313. mlp_ratios: int = 4,
  314. downsamples: List[bool] = [False, True, True, True],
  315. act_layer: Type[nn.Module] = nn.GELU,
  316. down_patch_size: int = 3,
  317. down_stride: int = 2,
  318. down_pad: int = 1,
  319. num_classes: int = 1000,
  320. drop_rate: float = 0.,
  321. drop_path_rate: float = 0.,
  322. use_layer_scale: bool = True,
  323. layer_scale_init_value: float = 1e-5,
  324. global_pool: str = 'avg',
  325. output_stride: int = 32,
  326. in_chans: int = 3,
  327. device=None,
  328. dtype=None,
  329. **kwargs,
  330. ):
  331. super().__init__()
  332. dd = {'device': device, 'dtype': dtype}
  333. assert output_stride == 32
  334. self.num_classes = num_classes
  335. self.in_chans = in_chans
  336. self.global_pool = global_pool
  337. self.feature_info = []
  338. self.stem = nn.Sequential(
  339. nn.Conv2d(in_chans, embed_dims[0] // 2, 3, 2, 1, **dd),
  340. nn.BatchNorm2d(embed_dims[0] // 2, **dd),
  341. nn.ReLU(),
  342. nn.Conv2d(embed_dims[0] // 2, embed_dims[0], 3, 2, 1, **dd),
  343. nn.BatchNorm2d(embed_dims[0], **dd),
  344. nn.ReLU(),
  345. )
  346. prev_dim = embed_dims[0]
  347. stages = []
  348. for i in range(len(layers)):
  349. downsample = Embedding(
  350. in_chans=prev_dim,
  351. embed_dim=embed_dims[i],
  352. patch_size=down_patch_size,
  353. stride=down_stride,
  354. padding=down_pad,
  355. **dd,
  356. ) if downsamples[i] else nn.Identity()
  357. stage = Stage(
  358. dim=embed_dims[i],
  359. index=i,
  360. layers=layers,
  361. mlp_ratio=mlp_ratios,
  362. act_layer=act_layer,
  363. drop_rate=drop_rate,
  364. drop_path_rate=drop_path_rate,
  365. use_layer_scale=use_layer_scale,
  366. layer_scale_init_value=layer_scale_init_value,
  367. downsample=downsample,
  368. **dd,
  369. )
  370. prev_dim = embed_dims[i]
  371. stages.append(stage)
  372. self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(i+2), module=f'stages.{i}')]
  373. self.stages = nn.Sequential(*stages)
  374. # Classifier head
  375. self.num_features = self.head_hidden_size = out_chs = embed_dims[-1]
  376. self.norm = nn.BatchNorm2d(out_chs, **dd)
  377. self.head_drop = nn.Dropout(drop_rate)
  378. self.head = Linear(out_chs, num_classes, **dd) if num_classes > 0 else nn.Identity()
  379. # assuming model is always distilled (valid for current checkpoints, will split def if that changes)
  380. self.head_dist = Linear(out_chs, num_classes, **dd) if num_classes > 0 else nn.Identity()
  381. self.distilled_training = False # must set this True to train w/ distillation token
  382. self._initialize_weights()
  383. def _initialize_weights(self):
  384. for name, m in self.named_modules():
  385. if isinstance(m, nn.Linear):
  386. trunc_normal_(m.weight, std=.02)
  387. if m.bias is not None:
  388. nn.init.constant_(m.bias, 0)
  389. elif isinstance(m, nn.Conv2d):
  390. trunc_normal_(m.weight, std=.02)
  391. if m.bias is not None:
  392. nn.init.constant_(m.bias, 0)
  393. @torch.jit.ignore
  394. def no_weight_decay(self) -> Set:
  395. return set()
  396. @torch.jit.ignore
  397. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  398. matcher = dict(
  399. stem=r'^stem', # stem and embed
  400. blocks=r'^stages\.(\d+)' if coarse else [
  401. (r'^stages\.(\d+).downsample', (0,)),
  402. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  403. (r'^norm', (99999,)),
  404. ]
  405. )
  406. return matcher
  407. @torch.jit.ignore
  408. def set_grad_checkpointing(self, enable: bool = True):
  409. for s in self.stages:
  410. s.grad_checkpointing = enable
  411. @torch.jit.ignore
  412. def get_classifier(self) -> Tuple[nn.Module, nn.Module]:
  413. return self.head, self.head_dist
  414. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  415. self.num_classes = num_classes
  416. if global_pool is not None:
  417. self.global_pool = global_pool
  418. device, dtype = self.head.weight.device, self.head.weight.dtype if hasattr(self.head, 'weight') else (None, None)
  419. self.head = Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  420. self.head_dist = Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  421. @torch.jit.ignore
  422. def set_distilled_training(self, enable: bool = True):
  423. self.distilled_training = enable
  424. def forward_intermediates(
  425. self,
  426. x: torch.Tensor,
  427. indices: Optional[Union[int, List[int]]] = None,
  428. norm: bool = False,
  429. stop_early: bool = False,
  430. output_fmt: str = 'NCHW',
  431. intermediates_only: bool = False,
  432. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  433. """ Forward features that returns intermediates.
  434. Args:
  435. x: Input image tensor
  436. indices: Take last n blocks if int, all if None, select matching indices if sequence
  437. norm: Apply norm layer to compatible intermediates
  438. stop_early: Stop iterating over blocks when last desired intermediate hit
  439. output_fmt: Shape of intermediate feature outputs
  440. intermediates_only: Only return intermediate features
  441. Returns:
  442. """
  443. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  444. intermediates = []
  445. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  446. last_idx = len(self.stages) - 1
  447. # forward pass
  448. x = self.stem(x)
  449. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  450. stages = self.stages
  451. else:
  452. stages = self.stages[:max_index + 1]
  453. for feat_idx, stage in enumerate(stages):
  454. x = stage(x)
  455. if feat_idx in take_indices:
  456. if norm and feat_idx == last_idx:
  457. x_inter = self.norm(x) # applying final norm last intermediate
  458. else:
  459. x_inter = x
  460. intermediates.append(x_inter)
  461. if intermediates_only:
  462. return intermediates
  463. if feat_idx == last_idx:
  464. x = self.norm(x)
  465. return x, intermediates
  466. def prune_intermediate_layers(
  467. self,
  468. indices: Union[int, List[int]] = 1,
  469. prune_norm: bool = False,
  470. prune_head: bool = True,
  471. ):
  472. """ Prune layers not required for specified intermediates.
  473. """
  474. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  475. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  476. if prune_norm:
  477. self.norm = nn.Identity()
  478. if prune_head:
  479. self.reset_classifier(0, '')
  480. return take_indices
  481. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  482. x = self.stem(x)
  483. x = self.stages(x)
  484. x = self.norm(x)
  485. return x
  486. def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
  487. if self.global_pool == 'avg':
  488. x = x.mean(dim=(2, 3))
  489. x = self.head_drop(x)
  490. if pre_logits:
  491. return x
  492. x, x_dist = self.head(x), self.head_dist(x)
  493. if self.distilled_training and self.training and not torch.jit.is_scripting():
  494. # only return separate classification predictions when training in distilled mode
  495. return x, x_dist
  496. else:
  497. # during standard train/finetune, inference average the classifier predictions
  498. return (x + x_dist) / 2
  499. def forward(self, x: torch.Tensor):
  500. x = self.forward_features(x)
  501. x = self.forward_head(x)
  502. return x
  503. def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
  504. state_dict = state_dict.get('model', state_dict)
  505. if 'stem.0.weight' in state_dict:
  506. return state_dict
  507. out_dict = {}
  508. for k, v in state_dict.items():
  509. k = k.replace('patch_embed.', 'stem.')
  510. k = k.replace('dist_head.', 'head_dist.')
  511. k = k.replace('attn.Proj.', 'attn.proj.')
  512. k = k.replace('.layer_scale_1', '.layer_scale_1.gamma')
  513. k = k.replace('.layer_scale_2', '.layer_scale_2.gamma')
  514. k = re.sub(r'\.layer_scale(?=$|\.)', '.layer_scale.gamma', k)
  515. m = re.match(r'^network\.(\d+)\.(.*)', k)
  516. if m:
  517. n_idx, rest = int(m.group(1)), m.group(2)
  518. stage_idx = n_idx // 2
  519. if n_idx % 2 == 0:
  520. k = f'stages.{stage_idx}.blocks.{rest}'
  521. else:
  522. k = f'stages.{stage_idx+1}.downsample.{rest}'
  523. out_dict[k] = v
  524. return out_dict
  525. def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
  526. return {
  527. 'url': url,
  528. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True,
  529. 'crop_pct': .95, 'interpolation': 'bicubic',
  530. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  531. 'first_conv': 'stem.0', 'classifier': ('head', 'head_dist'),
  532. 'license': 'apache-2.0',
  533. 'paper_ids': 'arXiv:2303.15446',
  534. 'paper_name': 'SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications',
  535. 'origin_url': 'https://github.com/Amshaker/SwiftFormer',
  536. **kwargs
  537. }
  538. default_cfgs = generate_default_cfgs({
  539. 'swiftformer_xs.dist_in1k': _cfg(
  540. hf_hub_id='timm/',
  541. ),
  542. 'swiftformer_s.dist_in1k': _cfg(
  543. hf_hub_id='timm/'
  544. ),
  545. 'swiftformer_l1.dist_in1k': _cfg(
  546. hf_hub_id='timm/'
  547. ),
  548. 'swiftformer_l3.dist_in1k': _cfg(
  549. hf_hub_id='timm/'
  550. ),
  551. })
  552. def _create_swiftformer(variant: str, pretrained: bool = False, **kwargs: Any) -> SwiftFormer:
  553. model = build_model_with_cfg(
  554. SwiftFormer, variant, pretrained,
  555. pretrained_filter_fn=checkpoint_filter_fn,
  556. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  557. **kwargs,
  558. )
  559. return model
  560. @register_model
  561. def swiftformer_xs(pretrained: bool = False, **kwargs: Any) -> SwiftFormer:
  562. model_args = dict(layers=[3, 3, 6, 4], embed_dims=[48, 56, 112, 220])
  563. return _create_swiftformer('swiftformer_xs', pretrained=pretrained, **dict(model_args, **kwargs))
  564. @register_model
  565. def swiftformer_s(pretrained: bool = False, **kwargs: Any) -> SwiftFormer:
  566. model_args = dict(layers=[3, 3, 9, 6], embed_dims=[48, 64, 168, 224])
  567. return _create_swiftformer('swiftformer_s', pretrained=pretrained, **dict(model_args, **kwargs))
  568. @register_model
  569. def swiftformer_l1(pretrained: bool = False, **kwargs: Any) -> SwiftFormer:
  570. model_args = dict(layers=[4, 3, 10, 5], embed_dims=[48, 96, 192, 384])
  571. return _create_swiftformer('swiftformer_l1', pretrained=pretrained, **dict(model_args, **kwargs))
  572. @register_model
  573. def swiftformer_l3(pretrained: bool = False, **kwargs: Any) -> SwiftFormer:
  574. model_args = dict(layers=[4, 4, 12, 6], embed_dims=[64, 128, 320, 512])
  575. return _create_swiftformer('swiftformer_l3', pretrained=pretrained, **dict(model_args, **kwargs))