fasternet.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. """FasterNet
  2. Run, Don't Walk: Chasing Higher FLOPS for Faster Neural Networks
  3. - paper: https://arxiv.org/abs/2303.03667
  4. - code: https://github.com/JierunChen/FasterNet
  5. @article{chen2023run,
  6. title={Run, Don't Walk: Chasing Higher FLOPS for Faster Neural Networks},
  7. author={Chen, Jierun and Kao, Shiu-hong and He, Hao and Zhuo, Weipeng and Wen, Song and Lee, Chul-Ho and Chan, S-H Gary},
  8. journal={arXiv preprint arXiv:2303.03667},
  9. year={2023}
  10. }
  11. Modifications by / Copyright 2025 Ryan Hou & Ross Wightman, original copyrights below
  12. """
  13. # Copyright (c) Microsoft Corporation.
  14. # Licensed under the MIT License.
  15. from functools import partial
  16. from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
  17. import torch
  18. import torch.nn as nn
  19. import torch.nn.functional as F
  20. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  21. from timm.layers import SelectAdaptivePool2d, Linear, DropPath, trunc_normal_, LayerType, calculate_drop_path_rates
  22. from ._builder import build_model_with_cfg
  23. from ._features import feature_take_indices
  24. from ._manipulate import checkpoint_seq
  25. from ._registry import register_model, generate_default_cfgs
  26. __all__ = ['FasterNet']
  27. class Partial_conv3(nn.Module):
  28. def __init__(self, dim: int, n_div: int, forward: str, device=None, dtype=None):
  29. dd = {'device': device, 'dtype': dtype}
  30. super().__init__()
  31. self.dim_conv3 = dim // n_div
  32. self.dim_untouched = dim - self.dim_conv3
  33. self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False, **dd)
  34. if forward == 'slicing':
  35. self.forward = self.forward_slicing
  36. elif forward == 'split_cat':
  37. self.forward = self.forward_split_cat
  38. else:
  39. raise NotImplementedError
  40. def forward_slicing(self, x: torch.Tensor) -> torch.Tensor:
  41. # only for inference
  42. x = x.clone() # !!! Keep the original input intact for the residual connection later
  43. x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
  44. return x
  45. def forward_split_cat(self, x: torch.Tensor) -> torch.Tensor:
  46. # for training/inference
  47. x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
  48. x1 = self.partial_conv3(x1)
  49. x = torch.cat((x1, x2), 1)
  50. return x
  51. class MLPBlock(nn.Module):
  52. def __init__(
  53. self,
  54. dim: int,
  55. n_div: int,
  56. mlp_ratio: float,
  57. drop_path: float,
  58. layer_scale_init_value: float,
  59. act_layer: Type[nn.Module] = partial(nn.ReLU, inplace=True),
  60. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  61. pconv_fw_type: str = 'split_cat',
  62. device=None,
  63. dtype=None,
  64. ):
  65. dd = {'device': device, 'dtype': dtype}
  66. super().__init__()
  67. mlp_hidden_dim = int(dim * mlp_ratio)
  68. self.mlp = nn.Sequential(*[
  69. nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False, **dd),
  70. norm_layer(mlp_hidden_dim, **dd),
  71. act_layer(),
  72. nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False, **dd),
  73. ])
  74. self.spatial_mixing = Partial_conv3(dim, n_div, pconv_fw_type, **dd)
  75. if layer_scale_init_value > 0:
  76. self.layer_scale = nn.Parameter(
  77. layer_scale_init_value * torch.ones((dim), **dd), requires_grad=True)
  78. else:
  79. self.layer_scale = None
  80. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  81. def forward(self, x: torch.Tensor) -> torch.Tensor:
  82. shortcut = x
  83. x = self.spatial_mixing(x)
  84. if self.layer_scale is not None:
  85. x = shortcut + self.drop_path(
  86. self.layer_scale.unsqueeze(-1).unsqueeze(-1) * self.mlp(x))
  87. else:
  88. x = shortcut + self.drop_path(self.mlp(x))
  89. return x
  90. class Block(nn.Module):
  91. def __init__(
  92. self,
  93. dim: int,
  94. depth: int,
  95. n_div: int,
  96. mlp_ratio: float,
  97. drop_path: float,
  98. layer_scale_init_value: float,
  99. act_layer: Type[nn.Module] = partial(nn.ReLU, inplace=True),
  100. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  101. pconv_fw_type: str = 'split_cat',
  102. use_merge: bool = True,
  103. merge_size: Union[int, Tuple[int, int]] = 2,
  104. device=None,
  105. dtype=None,
  106. ):
  107. dd = {'device': device, 'dtype': dtype}
  108. super().__init__()
  109. self.grad_checkpointing = False
  110. self.blocks = nn.Sequential(*[
  111. MLPBlock(
  112. dim=dim,
  113. n_div=n_div,
  114. mlp_ratio=mlp_ratio,
  115. drop_path=drop_path[i],
  116. layer_scale_init_value=layer_scale_init_value,
  117. norm_layer=norm_layer,
  118. act_layer=act_layer,
  119. pconv_fw_type=pconv_fw_type,
  120. **dd,
  121. )
  122. for i in range(depth)
  123. ])
  124. self.downsample = PatchMerging(
  125. dim=dim // 2,
  126. patch_size=merge_size,
  127. norm_layer=norm_layer,
  128. **dd,
  129. ) if use_merge else nn.Identity()
  130. def forward(self, x: torch.Tensor) -> torch.Tensor:
  131. x = self.downsample(x)
  132. if self.grad_checkpointing and not torch.jit.is_scripting():
  133. x = checkpoint_seq(self.blocks, x)
  134. else:
  135. x = self.blocks(x)
  136. return x
  137. class PatchEmbed(nn.Module):
  138. def __init__(
  139. self,
  140. in_chans: int,
  141. embed_dim: int,
  142. patch_size: Union[int, Tuple[int, int]] = 4,
  143. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  144. device=None,
  145. dtype=None,
  146. ):
  147. dd = {'device': device, 'dtype': dtype}
  148. super().__init__()
  149. self.proj = nn.Conv2d(in_chans, embed_dim, patch_size, patch_size, bias=False, **dd)
  150. self.norm = norm_layer(embed_dim, **dd)
  151. def forward(self, x: torch.Tensor) -> torch.Tensor:
  152. return self.norm(self.proj(x))
  153. class PatchMerging(nn.Module):
  154. def __init__(
  155. self,
  156. dim: int,
  157. patch_size: Union[int, Tuple[int, int]] = 2,
  158. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  159. device=None,
  160. dtype=None,
  161. ):
  162. dd = {'device': device, 'dtype': dtype}
  163. super().__init__()
  164. self.reduction = nn.Conv2d(dim, 2 * dim, patch_size, patch_size, bias=False, **dd)
  165. self.norm = norm_layer(2 * dim, **dd)
  166. def forward(self, x: torch.Tensor) -> torch.Tensor:
  167. return self.norm(self.reduction(x))
  168. class FasterNet(nn.Module):
  169. def __init__(
  170. self,
  171. in_chans: int = 3,
  172. num_classes: int = 1000,
  173. global_pool: str = 'avg',
  174. embed_dim: int = 96,
  175. depths: Union[int, Tuple[int, ...]] = (1, 2, 8, 2),
  176. mlp_ratio: float = 2.,
  177. n_div: int = 4,
  178. patch_size: Union[int, Tuple[int, int]] = 4,
  179. merge_size: Union[int, Tuple[int, int]] = 2,
  180. patch_norm: bool = True,
  181. feature_dim: int = 1280,
  182. drop_rate: float = 0.,
  183. drop_path_rate: float = 0.1,
  184. layer_scale_init_value: float = 0.,
  185. act_layer: Type[nn.Module] = partial(nn.ReLU, inplace=True),
  186. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  187. pconv_fw_type: str = 'split_cat',
  188. device=None,
  189. dtype=None,
  190. ):
  191. super().__init__()
  192. dd = {'device': device, 'dtype': dtype}
  193. assert pconv_fw_type in ('split_cat', 'slicing',)
  194. self.num_classes = num_classes
  195. self.in_chans = in_chans
  196. self.drop_rate = drop_rate
  197. if not isinstance(depths, (list, tuple)):
  198. depths = (depths) # it means the model has only one stage
  199. self.num_stages = len(depths)
  200. self.feature_info = []
  201. self.patch_embed = PatchEmbed(
  202. in_chans=in_chans,
  203. embed_dim=embed_dim,
  204. patch_size=patch_size,
  205. norm_layer=norm_layer if patch_norm else nn.Identity,
  206. **dd,
  207. )
  208. # stochastic depth decay rule
  209. dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  210. # build layers
  211. stages_list = []
  212. for i in range(self.num_stages):
  213. dim = int(embed_dim * 2 ** i)
  214. stage = Block(
  215. dim=dim,
  216. depth=depths[i],
  217. n_div=n_div,
  218. mlp_ratio=mlp_ratio,
  219. drop_path=dpr[i],
  220. layer_scale_init_value=layer_scale_init_value,
  221. norm_layer=norm_layer,
  222. act_layer=act_layer,
  223. pconv_fw_type=pconv_fw_type,
  224. use_merge=False if i == 0 else True,
  225. merge_size=merge_size,
  226. **dd,
  227. )
  228. stages_list.append(stage)
  229. self.feature_info += [dict(num_chs=dim, reduction=2**(i+2), module=f'stages.{i}')]
  230. self.stages = nn.Sequential(*stages_list)
  231. # building last several layers
  232. self.num_features = prev_chs = int(embed_dim * 2 ** (self.num_stages - 1))
  233. self.head_hidden_size = out_chs = feature_dim # 1280
  234. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  235. self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=False, **dd)
  236. self.act = act_layer()
  237. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  238. self.classifier = Linear(out_chs, num_classes, bias=True, **dd) if num_classes > 0 else nn.Identity()
  239. self._initialize_weights()
  240. def _initialize_weights(self):
  241. for name, m in self.named_modules():
  242. if isinstance(m, nn.Linear):
  243. trunc_normal_(m.weight, std=.02)
  244. if isinstance(m, nn.Linear) and m.bias is not None:
  245. nn.init.constant_(m.bias, 0)
  246. elif isinstance(m, nn.Conv2d):
  247. trunc_normal_(m.weight, std=.02)
  248. if m.bias is not None:
  249. nn.init.constant_(m.bias, 0)
  250. @torch.jit.ignore
  251. def no_weight_decay(self) -> Set:
  252. return set()
  253. @torch.jit.ignore
  254. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  255. matcher = dict(
  256. stem=r'^patch_embed', # stem and embed
  257. blocks=r'^stages\.(\d+)' if coarse else [
  258. (r'^stages\.(\d+).downsample', (0,)),
  259. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  260. (r'^conv_head', (99999,)),
  261. ]
  262. )
  263. return matcher
  264. @torch.jit.ignore
  265. def set_grad_checkpointing(self, enable=True):
  266. for s in self.stages:
  267. s.grad_checkpointing = enable
  268. @torch.jit.ignore
  269. def get_classifier(self) -> nn.Module:
  270. return self.classifier
  271. def reset_classifier(self, num_classes: int, global_pool: str = 'avg', device=None, dtype=None):
  272. dd = {'device': device, 'dtype': dtype}
  273. self.num_classes = num_classes
  274. # cannot meaningfully change pooling of efficient head after creation
  275. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  276. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  277. self.classifier = Linear(self.head_hidden_size, num_classes, **dd) if num_classes > 0 else nn.Identity()
  278. def forward_intermediates(
  279. self,
  280. x: torch.Tensor,
  281. indices: Optional[Union[int, List[int]]] = None,
  282. norm: bool = False,
  283. stop_early: bool = False,
  284. output_fmt: str = 'NCHW',
  285. intermediates_only: bool = False,
  286. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  287. """ Forward features that returns intermediates.
  288. Args:
  289. x: Input image tensor
  290. indices: Take last n blocks if int, all if None, select matching indices if sequence
  291. norm: Apply norm layer to compatible intermediates
  292. stop_early: Stop iterating over blocks when last desired intermediate hit
  293. output_fmt: Shape of intermediate feature outputs
  294. intermediates_only: Only return intermediate features
  295. Returns:
  296. """
  297. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  298. intermediates = []
  299. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  300. # forward pass
  301. x = self.patch_embed(x)
  302. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  303. stages = self.stages
  304. else:
  305. stages = self.stages[:max_index + 1]
  306. for feat_idx, stage in enumerate(stages):
  307. x = stage(x)
  308. if feat_idx in take_indices:
  309. intermediates.append(x)
  310. if intermediates_only:
  311. return intermediates
  312. return x, intermediates
  313. def prune_intermediate_layers(
  314. self,
  315. indices: Union[int, List[int]] = 1,
  316. prune_norm: bool = False,
  317. prune_head: bool = True,
  318. ):
  319. """ Prune layers not required for specified intermediates.
  320. """
  321. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  322. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  323. if prune_head:
  324. self.reset_classifier(0, '')
  325. return take_indices
  326. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  327. x = self.patch_embed(x)
  328. x = self.stages(x)
  329. return x
  330. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  331. x = self.global_pool(x)
  332. x = self.conv_head(x)
  333. x = self.act(x)
  334. x = self.flatten(x)
  335. if self.drop_rate > 0.:
  336. x = F.dropout(x, p=self.drop_rate, training=self.training)
  337. return x if pre_logits else self.classifier(x)
  338. def forward(self, x: torch.Tensor) -> torch.Tensor:
  339. x = self.forward_features(x)
  340. x = self.forward_head(x)
  341. return x
  342. def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
  343. # if 'avgpool_pre_head' in state_dict:
  344. # return state_dict
  345. #
  346. # out_dict = {
  347. # 'conv_head.weight': state_dict.pop('avgpool_pre_head.1.weight'),
  348. # 'classifier.weight': state_dict.pop('head.weight'),
  349. # 'classifier.bias': state_dict.pop('head.bias')
  350. # }
  351. #
  352. # stage_mapping = {
  353. # 'stages.1.': 'stages.1.downsample.',
  354. # 'stages.2.': 'stages.1.',
  355. # 'stages.3.': 'stages.2.downsample.',
  356. # 'stages.4.': 'stages.2.',
  357. # 'stages.5.': 'stages.3.downsample.',
  358. # 'stages.6.': 'stages.3.'
  359. # }
  360. #
  361. # for k, v in state_dict.items():
  362. # for old_prefix, new_prefix in stage_mapping.items():
  363. # if k.startswith(old_prefix):
  364. # k = k.replace(old_prefix, new_prefix)
  365. # break
  366. # out_dict[k] = v
  367. return state_dict
  368. def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
  369. return {
  370. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  371. 'crop_pct': 1.0, 'interpolation': 'bicubic', 'test_crop_pct': 0.9,
  372. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  373. 'first_conv': 'patch_embed.proj', 'classifier': 'classifier',
  374. 'paper_ids': 'arXiv:2303.03667',
  375. 'paper_name': "Run, Don't Walk: Chasing Higher FLOPS for Faster Neural Networks",
  376. 'origin_url': 'https://github.com/JierunChen/FasterNet',
  377. 'license': 'apache-2.0',
  378. **kwargs
  379. }
  380. default_cfgs = generate_default_cfgs({
  381. 'fasternet_t0.in1k': _cfg(
  382. hf_hub_id='timm/',
  383. #url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_t0-epoch.281-val_acc1.71.9180.pth',
  384. ),
  385. 'fasternet_t1.in1k': _cfg(
  386. hf_hub_id='timm/',
  387. #url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_t1-epoch.291-val_acc1.76.2180.pth',
  388. ),
  389. 'fasternet_t2.in1k': _cfg(
  390. hf_hub_id='timm/',
  391. #url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_t2-epoch.289-val_acc1.78.8860.pth',
  392. ),
  393. 'fasternet_s.in1k': _cfg(
  394. hf_hub_id='timm/',
  395. #url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_s-epoch.299-val_acc1.81.2840.pth',
  396. ),
  397. 'fasternet_m.in1k': _cfg(
  398. hf_hub_id='timm/',
  399. #url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_m-epoch.291-val_acc1.82.9620.pth',
  400. ),
  401. 'fasternet_l.in1k': _cfg(
  402. hf_hub_id='timm/',
  403. #url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_l-epoch.299-val_acc1.83.5060.pth',
  404. ),
  405. })
  406. def _create_fasternet(variant: str, pretrained: bool = False, **kwargs: Any) -> FasterNet:
  407. model = build_model_with_cfg(
  408. FasterNet, variant, pretrained,
  409. pretrained_filter_fn=checkpoint_filter_fn,
  410. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  411. **kwargs,
  412. )
  413. return model
  414. @register_model
  415. def fasternet_t0(pretrained: bool = False, **kwargs: Any) -> FasterNet:
  416. model_args = dict(embed_dim=40, depths=(1, 2, 8, 2), drop_path_rate=0.0, act_layer=nn.GELU)
  417. return _create_fasternet('fasternet_t0', pretrained=pretrained, **dict(model_args, **kwargs))
  418. @register_model
  419. def fasternet_t1(pretrained: bool = False, **kwargs: Any) -> FasterNet:
  420. model_args = dict(embed_dim=64, depths=(1, 2, 8, 2), drop_path_rate=0.02, act_layer=nn.GELU)
  421. return _create_fasternet('fasternet_t1', pretrained=pretrained, **dict(model_args, **kwargs))
  422. @register_model
  423. def fasternet_t2(pretrained: bool = False, **kwargs: Any) -> FasterNet:
  424. model_args = dict(embed_dim=96, depths=(1, 2, 8, 2), drop_path_rate=0.05)
  425. return _create_fasternet('fasternet_t2', pretrained=pretrained, **dict(model_args, **kwargs))
  426. @register_model
  427. def fasternet_s(pretrained: bool = False, **kwargs: Any) -> FasterNet:
  428. model_args = dict(embed_dim=128, depths=(1, 2, 13, 2), drop_path_rate=0.1)
  429. return _create_fasternet('fasternet_s', pretrained=pretrained, **dict(model_args, **kwargs))
  430. @register_model
  431. def fasternet_m(pretrained: bool = False, **kwargs: Any) -> FasterNet:
  432. model_args = dict(embed_dim=144, depths=(3, 4, 18, 3), drop_path_rate=0.2)
  433. return _create_fasternet('fasternet_m', pretrained=pretrained, **dict(model_args, **kwargs))
  434. @register_model
  435. def fasternet_l(pretrained: bool = False, **kwargs: Any) -> FasterNet:
  436. model_args = dict(embed_dim=192, depths=(3, 4, 18, 3), drop_path_rate=0.3)
  437. return _create_fasternet('fasternet_l', pretrained=pretrained, **dict(model_args, **kwargs))