fastvit.py 61 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807
  1. # FastViT for PyTorch
  2. #
  3. # Original implementation and weights from https://github.com/apple/ml-fastvit
  4. #
  5. # For licensing see accompanying LICENSE file at https://github.com/apple/ml-fastvit/tree/main
  6. # Original work is copyright (C) 2023 Apple Inc. All Rights Reserved.
  7. #
  8. import os
  9. from functools import partial
  10. from typing import List, Optional, Tuple, Type, Union
  11. import torch
  12. import torch.nn as nn
  13. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
  14. from timm.layers import (
  15. DropPath,
  16. calculate_drop_path_rates,
  17. trunc_normal_,
  18. create_conv2d,
  19. ConvNormAct,
  20. SqueezeExcite,
  21. use_fused_attn,
  22. ClassifierHead,
  23. LayerNorm2d,
  24. )
  25. from ._builder import build_model_with_cfg
  26. from ._features import feature_take_indices
  27. from ._manipulate import checkpoint_seq
  28. from ._registry import register_model, generate_default_cfgs
  29. __all__ = ['FastVit']
  30. def num_groups(group_size, channels):
  31. if not group_size: # 0 or None
  32. return 1 # normal conv with 1 group
  33. else:
  34. # NOTE group_size == 1 -> depthwise conv
  35. assert channels % group_size == 0
  36. return channels // group_size
  37. class MobileOneBlock(nn.Module):
  38. """MobileOne building block.
  39. This block has a multi-branched architecture at train-time
  40. and plain-CNN style architecture at inference time
  41. For more details, please refer to our paper:
  42. `An Improved One millisecond Mobile Backbone` -
  43. https://arxiv.org/pdf/2206.04040.pdf
  44. """
  45. def __init__(
  46. self,
  47. in_chs: int,
  48. out_chs: int,
  49. kernel_size: int,
  50. stride: int = 1,
  51. dilation: int = 1,
  52. group_size: int = 0,
  53. inference_mode: bool = False,
  54. use_se: bool = False,
  55. use_act: bool = True,
  56. use_scale_branch: bool = True,
  57. num_conv_branches: int = 1,
  58. act_layer: Type[nn.Module] = nn.GELU,
  59. device=None,
  60. dtype=None,
  61. ) -> None:
  62. """Construct a MobileOneBlock module.
  63. Args:
  64. in_chs: Number of channels in the input.
  65. out_chs: Number of channels produced by the block.
  66. kernel_size: Size of the convolution kernel.
  67. stride: Stride size.
  68. dilation: Kernel dilation factor.
  69. group_size: Convolution group size.
  70. inference_mode: If True, instantiates model in inference mode.
  71. use_se: Whether to use SE-ReLU activations.
  72. use_act: Whether to use activation. Default: ``True``
  73. use_scale_branch: Whether to use scale branch. Default: ``True``
  74. num_conv_branches: Number of linear conv branches.
  75. """
  76. dd = {'device': device, 'dtype': dtype}
  77. super().__init__()
  78. self.inference_mode = inference_mode
  79. self.groups = num_groups(group_size, in_chs)
  80. self.stride = stride
  81. self.dilation = dilation
  82. self.kernel_size = kernel_size
  83. self.in_chs = in_chs
  84. self.out_chs = out_chs
  85. self.num_conv_branches = num_conv_branches
  86. # Check if SE-ReLU is requested
  87. self.se = SqueezeExcite(out_chs, rd_divisor=1, **dd) if use_se else nn.Identity()
  88. if inference_mode:
  89. self.reparam_conv = create_conv2d(
  90. in_chs,
  91. out_chs,
  92. kernel_size=kernel_size,
  93. stride=stride,
  94. dilation=dilation,
  95. groups=self.groups,
  96. bias=True,
  97. **dd,
  98. )
  99. else:
  100. # Re-parameterizable skip connection
  101. self.reparam_conv = None
  102. self.identity = (
  103. nn.BatchNorm2d(num_features=in_chs, **dd)
  104. if out_chs == in_chs and stride == 1
  105. else None
  106. )
  107. # Re-parameterizable conv branches
  108. if num_conv_branches > 0:
  109. self.conv_kxk = nn.ModuleList([
  110. ConvNormAct(
  111. self.in_chs,
  112. self.out_chs,
  113. kernel_size=kernel_size,
  114. stride=self.stride,
  115. groups=self.groups,
  116. apply_act=False,
  117. **dd,
  118. ) for _ in range(self.num_conv_branches)
  119. ])
  120. else:
  121. self.conv_kxk = None
  122. # Re-parameterizable scale branch
  123. self.conv_scale = None
  124. if kernel_size > 1 and use_scale_branch:
  125. self.conv_scale = ConvNormAct(
  126. self.in_chs,
  127. self.out_chs,
  128. kernel_size=1,
  129. stride=self.stride,
  130. groups=self.groups,
  131. apply_act=False,
  132. **dd,
  133. )
  134. self.act = act_layer() if use_act else nn.Identity()
  135. def forward(self, x: torch.Tensor) -> torch.Tensor:
  136. """Apply forward pass."""
  137. # Inference mode forward pass.
  138. if self.reparam_conv is not None:
  139. return self.act(self.se(self.reparam_conv(x)))
  140. # Multi-branched train-time forward pass.
  141. # Identity branch output
  142. identity_out = 0
  143. if self.identity is not None:
  144. identity_out = self.identity(x)
  145. # Scale branch output
  146. scale_out = 0
  147. if self.conv_scale is not None:
  148. scale_out = self.conv_scale(x)
  149. # Other kxk conv branches
  150. out = scale_out + identity_out
  151. if self.conv_kxk is not None:
  152. for rc in self.conv_kxk:
  153. out += rc(x)
  154. return self.act(self.se(out))
  155. def reparameterize(self):
  156. """Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
  157. https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
  158. architecture used at training time to obtain a plain CNN-like structure
  159. for inference.
  160. """
  161. if self.reparam_conv is not None:
  162. return
  163. kernel, bias = self._get_kernel_bias()
  164. self.reparam_conv = create_conv2d(
  165. in_channels=self.in_chs,
  166. out_channels=self.out_chs,
  167. kernel_size=self.kernel_size,
  168. stride=self.stride,
  169. dilation=self.dilation,
  170. groups=self.groups,
  171. bias=True,
  172. )
  173. self.reparam_conv.weight.data = kernel
  174. self.reparam_conv.bias.data = bias
  175. # Delete un-used branches
  176. for name, para in self.named_parameters():
  177. if 'reparam_conv' in name:
  178. continue
  179. para.detach_()
  180. self.__delattr__("conv_kxk")
  181. self.__delattr__("conv_scale")
  182. if hasattr(self, "identity"):
  183. self.__delattr__("identity")
  184. self.inference_mode = True
  185. def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
  186. """Method to obtain re-parameterized kernel and bias.
  187. Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
  188. Returns:
  189. Tuple of (kernel, bias) after fusing branches.
  190. """
  191. # get weights and bias of scale branch
  192. kernel_scale = 0
  193. bias_scale = 0
  194. if self.conv_scale is not None:
  195. kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale)
  196. # Pad scale branch kernel to match conv branch kernel size.
  197. pad = self.kernel_size // 2
  198. kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
  199. # get weights and bias of skip branch
  200. kernel_identity = 0
  201. bias_identity = 0
  202. if self.identity is not None:
  203. kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
  204. # get weights and bias of conv branches
  205. kernel_conv = 0
  206. bias_conv = 0
  207. if self.conv_kxk is not None:
  208. for ix in range(self.num_conv_branches):
  209. _kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix])
  210. kernel_conv += _kernel
  211. bias_conv += _bias
  212. kernel_final = kernel_conv + kernel_scale + kernel_identity
  213. bias_final = bias_conv + bias_scale + bias_identity
  214. return kernel_final, bias_final
  215. def _fuse_bn_tensor(
  216. self,
  217. branch: Union[nn.Sequential, nn.BatchNorm2d]
  218. ) -> Tuple[torch.Tensor, torch.Tensor]:
  219. """Method to fuse batchnorm layer with preceding conv layer.
  220. Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
  221. Args:
  222. branch: Sequence of ops to be fused.
  223. Returns:
  224. Tuple of (kernel, bias) after fusing batchnorm.
  225. """
  226. if isinstance(branch, ConvNormAct):
  227. kernel = branch.conv.weight
  228. running_mean = branch.bn.running_mean
  229. running_var = branch.bn.running_var
  230. gamma = branch.bn.weight
  231. beta = branch.bn.bias
  232. eps = branch.bn.eps
  233. else:
  234. assert isinstance(branch, nn.BatchNorm2d)
  235. if not hasattr(self, "id_tensor"):
  236. input_dim = self.in_chs // self.groups
  237. kernel_value = torch.zeros(
  238. (self.in_chs, input_dim, self.kernel_size, self.kernel_size),
  239. dtype=branch.weight.dtype,
  240. device=branch.weight.device,
  241. )
  242. for i in range(self.in_chs):
  243. kernel_value[
  244. i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2
  245. ] = 1
  246. self.id_tensor = kernel_value
  247. kernel = self.id_tensor
  248. running_mean = branch.running_mean
  249. running_var = branch.running_var
  250. gamma = branch.weight
  251. beta = branch.bias
  252. eps = branch.eps
  253. std = (running_var + eps).sqrt()
  254. t = (gamma / std).reshape(-1, 1, 1, 1)
  255. return kernel * t, beta - running_mean * gamma / std
  256. class ReparamLargeKernelConv(nn.Module):
  257. """Building Block of RepLKNet
  258. This class defines overparameterized large kernel conv block
  259. introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_
  260. Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
  261. """
  262. def __init__(
  263. self,
  264. in_chs: int,
  265. out_chs: int,
  266. kernel_size: int,
  267. stride: int,
  268. group_size: int,
  269. small_kernel: Optional[int] = None,
  270. use_se: bool = False,
  271. act_layer: Optional[nn.Module] = None,
  272. inference_mode: bool = False,
  273. device=None,
  274. dtype=None,
  275. ) -> None:
  276. """Construct a ReparamLargeKernelConv module.
  277. Args:
  278. in_chs: Number of input channels.
  279. out_chs: Number of output channels.
  280. kernel_size: Kernel size of the large kernel conv branch.
  281. stride: Stride size. Default: 1
  282. group_size: Group size. Default: 1
  283. small_kernel: Kernel size of small kernel conv branch.
  284. act_layer: Activation module. Default: ``nn.GELU``
  285. inference_mode: If True, instantiates model in inference mode. Default: ``False``
  286. """
  287. dd = {'device': device, 'dtype': dtype}
  288. super().__init__()
  289. self.stride = stride
  290. self.groups = num_groups(group_size, in_chs)
  291. self.in_chs = in_chs
  292. self.out_chs = out_chs
  293. self.kernel_size = kernel_size
  294. self.small_kernel = small_kernel
  295. if inference_mode:
  296. self.reparam_conv = create_conv2d(
  297. in_chs,
  298. out_chs,
  299. kernel_size=kernel_size,
  300. stride=stride,
  301. dilation=1,
  302. groups=self.groups,
  303. bias=True,
  304. **dd,
  305. )
  306. else:
  307. self.reparam_conv = None
  308. self.large_conv = ConvNormAct(
  309. in_chs,
  310. out_chs,
  311. kernel_size=kernel_size,
  312. stride=self.stride,
  313. groups=self.groups,
  314. apply_act=False,
  315. **dd,
  316. )
  317. if small_kernel is not None:
  318. assert (
  319. small_kernel <= kernel_size
  320. ), "The kernel size for re-param cannot be larger than the large kernel!"
  321. self.small_conv = ConvNormAct(
  322. in_chs,
  323. out_chs,
  324. kernel_size=small_kernel,
  325. stride=self.stride,
  326. groups=self.groups,
  327. apply_act=False,
  328. **dd,
  329. )
  330. self.se = SqueezeExcite(out_chs, rd_ratio=0.25, **dd) if use_se else nn.Identity()
  331. # FIXME output of this act was not used in original impl, likely due to bug
  332. self.act = act_layer() if act_layer is not None else nn.Identity()
  333. def forward(self, x: torch.Tensor) -> torch.Tensor:
  334. if self.reparam_conv is not None:
  335. out = self.reparam_conv(x)
  336. else:
  337. out = self.large_conv(x)
  338. if self.small_conv is not None:
  339. out = out + self.small_conv(x)
  340. out = self.se(out)
  341. out = self.act(out)
  342. return out
  343. def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
  344. """Method to obtain re-parameterized kernel and bias.
  345. Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
  346. Returns:
  347. Tuple of (kernel, bias) after fusing branches.
  348. """
  349. eq_k, eq_b = self._fuse_bn(self.large_conv.conv, self.large_conv.bn)
  350. if hasattr(self, "small_conv"):
  351. small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
  352. eq_b += small_b
  353. eq_k += nn.functional.pad(
  354. small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
  355. )
  356. return eq_k, eq_b
  357. def reparameterize(self) -> None:
  358. """
  359. Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
  360. https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
  361. architecture used at training time to obtain a plain CNN-like structure
  362. for inference.
  363. """
  364. eq_k, eq_b = self.get_kernel_bias()
  365. self.reparam_conv = create_conv2d(
  366. self.in_chs,
  367. self.out_chs,
  368. kernel_size=self.kernel_size,
  369. stride=self.stride,
  370. groups=self.groups,
  371. bias=True,
  372. )
  373. self.reparam_conv.weight.data = eq_k
  374. self.reparam_conv.bias.data = eq_b
  375. self.__delattr__("large_conv")
  376. if hasattr(self, "small_conv"):
  377. self.__delattr__("small_conv")
  378. @staticmethod
  379. def _fuse_bn(
  380. conv: nn.Conv2d,
  381. bn: nn.BatchNorm2d
  382. ) -> Tuple[torch.Tensor, torch.Tensor]:
  383. """Method to fuse batchnorm layer with conv layer.
  384. Args:
  385. conv: Convolutional kernel weights.
  386. bn: Batchnorm 2d layer.
  387. Returns:
  388. Tuple of (kernel, bias) after fusing batchnorm.
  389. """
  390. kernel = conv.weight
  391. running_mean = bn.running_mean
  392. running_var = bn.running_var
  393. gamma = bn.weight
  394. beta = bn.bias
  395. eps = bn.eps
  396. std = (running_var + eps).sqrt()
  397. t = (gamma / std).reshape(-1, 1, 1, 1)
  398. return kernel * t, beta - running_mean * gamma / std
  399. def convolutional_stem(
  400. in_chs: int,
  401. out_chs: int,
  402. act_layer: Type[nn.Module] = nn.GELU,
  403. inference_mode: bool = False,
  404. use_scale_branch: bool = True,
  405. device=None,
  406. dtype=None,
  407. ) -> nn.Sequential:
  408. """Build convolutional stem with MobileOne blocks.
  409. Args:
  410. in_chs: Number of input channels.
  411. out_chs: Number of output channels.
  412. inference_mode: Flag to instantiate model in inference mode. Default: ``False``
  413. Returns:
  414. nn.Sequential object with stem elements.
  415. """
  416. dd = {'device': device, 'dtype': dtype}
  417. return nn.Sequential(
  418. MobileOneBlock(
  419. in_chs=in_chs,
  420. out_chs=out_chs,
  421. kernel_size=3,
  422. stride=2,
  423. act_layer=act_layer,
  424. inference_mode=inference_mode,
  425. use_scale_branch=use_scale_branch,
  426. **dd,
  427. ),
  428. MobileOneBlock(
  429. in_chs=out_chs,
  430. out_chs=out_chs,
  431. kernel_size=3,
  432. stride=2,
  433. group_size=1,
  434. act_layer=act_layer,
  435. inference_mode=inference_mode,
  436. use_scale_branch=use_scale_branch,
  437. **dd,
  438. ),
  439. MobileOneBlock(
  440. in_chs=out_chs,
  441. out_chs=out_chs,
  442. kernel_size=1,
  443. stride=1,
  444. act_layer=act_layer,
  445. inference_mode=inference_mode,
  446. use_scale_branch=use_scale_branch,
  447. **dd,
  448. ),
  449. )
  450. class Attention(nn.Module):
  451. """Multi-headed Self Attention module.
  452. Source modified from:
  453. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  454. """
  455. fused_attn: torch.jit.Final[bool]
  456. def __init__(
  457. self,
  458. dim: int,
  459. head_dim: int = 32,
  460. qkv_bias: bool = False,
  461. attn_drop: float = 0.0,
  462. proj_drop: float = 0.0,
  463. device=None,
  464. dtype=None,
  465. ) -> None:
  466. """Build MHSA module that can handle 3D or 4D input tensors.
  467. Args:
  468. dim: Number of embedding dimensions.
  469. head_dim: Number of hidden dimensions per head. Default: ``32``
  470. qkv_bias: Use bias or not. Default: ``False``
  471. attn_drop: Dropout rate for attention tensor.
  472. proj_drop: Dropout rate for projection tensor.
  473. """
  474. dd = {'device': device, 'dtype': dtype}
  475. super().__init__()
  476. assert dim % head_dim == 0, "dim should be divisible by head_dim"
  477. self.head_dim = head_dim
  478. self.num_heads = dim // head_dim
  479. self.scale = head_dim ** -0.5
  480. self.fused_attn = use_fused_attn()
  481. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  482. self.attn_drop = nn.Dropout(attn_drop)
  483. self.proj = nn.Linear(dim, dim, **dd)
  484. self.proj_drop = nn.Dropout(proj_drop)
  485. def forward(self, x: torch.Tensor) -> torch.Tensor:
  486. B, C, H, W = x.shape
  487. N = H * W
  488. x = x.flatten(2).transpose(-2, -1) # (B, N, C)
  489. qkv = (
  490. self.qkv(x)
  491. .reshape(B, N, 3, self.num_heads, self.head_dim)
  492. .permute(2, 0, 3, 1, 4)
  493. )
  494. q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
  495. if self.fused_attn:
  496. x = torch.nn.functional.scaled_dot_product_attention(
  497. q, k, v,
  498. dropout_p=self.attn_drop.p if self.training else 0.,
  499. )
  500. else:
  501. q = q * self.scale
  502. attn = q @ k.transpose(-2, -1)
  503. attn = attn.softmax(dim=-1)
  504. attn = self.attn_drop(attn)
  505. x = attn @ v
  506. x = x.transpose(1, 2).reshape(B, N, C)
  507. x = self.proj(x)
  508. x = self.proj_drop(x)
  509. x = x.transpose(-2, -1).reshape(B, C, H, W)
  510. return x
  511. class PatchEmbed(nn.Module):
  512. """Convolutional patch embedding layer."""
  513. def __init__(
  514. self,
  515. patch_size: int,
  516. stride: int,
  517. in_chs: int,
  518. embed_dim: int,
  519. act_layer: Type[nn.Module] = nn.GELU,
  520. lkc_use_act: bool = False,
  521. use_se: bool = False,
  522. inference_mode: bool = False,
  523. device=None,
  524. dtype=None,
  525. ) -> None:
  526. """Build patch embedding layer.
  527. Args:
  528. patch_size: Patch size for embedding computation.
  529. stride: Stride for convolutional embedding layer.
  530. in_chs: Number of channels of input tensor.
  531. embed_dim: Number of embedding dimensions.
  532. inference_mode: Flag to instantiate model in inference mode. Default: ``False``
  533. """
  534. dd = {'device': device, 'dtype': dtype}
  535. super().__init__()
  536. self.proj = nn.Sequential(
  537. ReparamLargeKernelConv(
  538. in_chs=in_chs,
  539. out_chs=embed_dim,
  540. kernel_size=patch_size,
  541. stride=stride,
  542. group_size=1,
  543. small_kernel=3,
  544. use_se=use_se,
  545. act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act
  546. inference_mode=inference_mode,
  547. **dd,
  548. ),
  549. MobileOneBlock(
  550. in_chs=embed_dim,
  551. out_chs=embed_dim,
  552. kernel_size=1,
  553. stride=1,
  554. use_se=False,
  555. act_layer=act_layer,
  556. inference_mode=inference_mode,
  557. **dd,
  558. )
  559. )
  560. def forward(self, x: torch.Tensor) -> torch.Tensor:
  561. x = self.proj(x)
  562. return x
  563. class LayerScale2d(nn.Module):
  564. def __init__(
  565. self,
  566. dim: int,
  567. init_values: float = 1e-5,
  568. inplace: bool = False,
  569. device=None,
  570. dtype=None,
  571. ):
  572. super().__init__()
  573. self.inplace = inplace
  574. self.gamma = nn.Parameter(init_values * torch.ones(dim, 1, 1, device=device, dtype=dtype))
  575. def forward(self, x):
  576. return x.mul_(self.gamma) if self.inplace else x * self.gamma
  577. class RepMixer(nn.Module):
  578. """Reparameterizable token mixer.
  579. For more details, please refer to our paper:
  580. `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
  581. """
  582. def __init__(
  583. self,
  584. dim: int,
  585. kernel_size: int = 3,
  586. layer_scale_init_value: Optional[float] = 1e-5,
  587. inference_mode: bool = False,
  588. device=None,
  589. dtype=None,
  590. ):
  591. """Build RepMixer Module.
  592. Args:
  593. dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
  594. kernel_size: Kernel size for spatial mixing. Default: 3
  595. layer_scale_init_value: Initial value for layer scale. Default: 1e-5
  596. inference_mode: If True, instantiates model in inference mode. Default: ``False``
  597. """
  598. dd = {'device': device, 'dtype': dtype}
  599. super().__init__()
  600. self.dim = dim
  601. self.kernel_size = kernel_size
  602. self.inference_mode = inference_mode
  603. if inference_mode:
  604. self.reparam_conv = nn.Conv2d(
  605. self.dim,
  606. self.dim,
  607. kernel_size=self.kernel_size,
  608. stride=1,
  609. padding=self.kernel_size // 2,
  610. groups=self.dim,
  611. bias=True,
  612. **dd,
  613. )
  614. else:
  615. self.reparam_conv = None
  616. self.norm = MobileOneBlock(
  617. dim,
  618. dim,
  619. kernel_size,
  620. group_size=1,
  621. use_act=False,
  622. use_scale_branch=False,
  623. num_conv_branches=0,
  624. **dd,
  625. )
  626. self.mixer = MobileOneBlock(
  627. dim,
  628. dim,
  629. kernel_size,
  630. group_size=1,
  631. use_act=False,
  632. **dd,
  633. )
  634. if layer_scale_init_value is not None:
  635. self.layer_scale = LayerScale2d(dim, layer_scale_init_value, **dd)
  636. else:
  637. self.layer_scale = nn.Identity()
  638. def forward(self, x: torch.Tensor) -> torch.Tensor:
  639. if self.reparam_conv is not None:
  640. x = self.reparam_conv(x)
  641. else:
  642. x = x + self.layer_scale(self.mixer(x) - self.norm(x))
  643. return x
  644. def reparameterize(self) -> None:
  645. """Reparameterize mixer and norm into a single
  646. convolutional layer for efficient inference.
  647. """
  648. if self.inference_mode:
  649. return
  650. self.mixer.reparameterize()
  651. self.norm.reparameterize()
  652. if isinstance(self.layer_scale, LayerScale2d):
  653. w = self.mixer.id_tensor + self.layer_scale.gamma.unsqueeze(-1) * (
  654. self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
  655. )
  656. b = torch.squeeze(self.layer_scale.gamma) * (
  657. self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
  658. )
  659. else:
  660. w = (
  661. self.mixer.id_tensor
  662. + self.mixer.reparam_conv.weight
  663. - self.norm.reparam_conv.weight
  664. )
  665. b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
  666. self.reparam_conv = create_conv2d(
  667. self.dim,
  668. self.dim,
  669. kernel_size=self.kernel_size,
  670. stride=1,
  671. groups=self.dim,
  672. bias=True,
  673. )
  674. self.reparam_conv.weight.data = w
  675. self.reparam_conv.bias.data = b
  676. for name, para in self.named_parameters():
  677. if 'reparam_conv' in name:
  678. continue
  679. para.detach_()
  680. self.__delattr__("mixer")
  681. self.__delattr__("norm")
  682. self.__delattr__("layer_scale")
  683. class ConvMlp(nn.Module):
  684. """Convolutional FFN Module."""
  685. def __init__(
  686. self,
  687. in_chs: int,
  688. hidden_channels: Optional[int] = None,
  689. out_chs: Optional[int] = None,
  690. act_layer: Type[nn.Module] = nn.GELU,
  691. drop: float = 0.0,
  692. device=None,
  693. dtype=None,
  694. ) -> None:
  695. """Build convolutional FFN module.
  696. Args:
  697. in_chs: Number of input channels.
  698. hidden_channels: Number of channels after expansion. Default: None
  699. out_chs: Number of output channels. Default: None
  700. act_layer: Activation layer. Default: ``GELU``
  701. drop: Dropout rate. Default: ``0.0``.
  702. """
  703. dd = {'device': device, 'dtype': dtype}
  704. super().__init__()
  705. out_chs = out_chs or in_chs
  706. hidden_channels = hidden_channels or in_chs
  707. self.conv = ConvNormAct(
  708. in_chs,
  709. out_chs,
  710. kernel_size=7,
  711. groups=in_chs,
  712. apply_act=False,
  713. **dd,
  714. )
  715. self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1, **dd)
  716. self.act = act_layer()
  717. self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1, **dd)
  718. self.drop = nn.Dropout(drop)
  719. self.apply(self._init_weights)
  720. def _init_weights(self, m: nn.Module) -> None:
  721. if isinstance(m, nn.Conv2d):
  722. trunc_normal_(m.weight, std=0.02)
  723. if m.bias is not None:
  724. nn.init.constant_(m.bias, 0)
  725. def forward(self, x: torch.Tensor) -> torch.Tensor:
  726. x = self.conv(x)
  727. x = self.fc1(x)
  728. x = self.act(x)
  729. x = self.drop(x)
  730. x = self.fc2(x)
  731. x = self.drop(x)
  732. return x
  733. class RepConditionalPosEnc(nn.Module):
  734. """Implementation of conditional positional encoding.
  735. For more details refer to paper:
  736. `Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
  737. In our implementation, we can reparameterize this module to eliminate a skip connection.
  738. """
  739. def __init__(
  740. self,
  741. dim: int,
  742. dim_out: Optional[int] = None,
  743. spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
  744. inference_mode: bool = False,
  745. device=None,
  746. dtype=None,
  747. ) -> None:
  748. """Build reparameterizable conditional positional encoding
  749. Args:
  750. dim: Number of input channels.
  751. dim_out: Number of embedding dimensions. Default: 768
  752. spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
  753. inference_mode: Flag to instantiate block in inference mode. Default: ``False``
  754. """
  755. dd = {'device': device, 'dtype': dtype}
  756. super().__init__()
  757. if isinstance(spatial_shape, int):
  758. spatial_shape = tuple([spatial_shape] * 2)
  759. assert isinstance(spatial_shape, Tuple), (
  760. f'"spatial_shape" must by a sequence or int, '
  761. f"get {type(spatial_shape)} instead."
  762. )
  763. assert len(spatial_shape) == 2, (
  764. f'Length of "spatial_shape" should be 2, '
  765. f"got {len(spatial_shape)} instead."
  766. )
  767. self.spatial_shape = spatial_shape
  768. self.dim = dim
  769. self.dim_out = dim_out or dim
  770. self.groups = dim
  771. if inference_mode:
  772. self.reparam_conv = nn.Conv2d(
  773. self.dim,
  774. self.dim_out,
  775. kernel_size=self.spatial_shape,
  776. stride=1,
  777. padding=spatial_shape[0] // 2,
  778. groups=self.groups,
  779. bias=True,
  780. **dd,
  781. )
  782. else:
  783. self.reparam_conv = None
  784. self.pos_enc = nn.Conv2d(
  785. self.dim,
  786. self.dim_out,
  787. spatial_shape,
  788. 1,
  789. int(spatial_shape[0] // 2),
  790. groups=self.groups,
  791. bias=True,
  792. **dd,
  793. )
  794. def forward(self, x: torch.Tensor) -> torch.Tensor:
  795. if self.reparam_conv is not None:
  796. x = self.reparam_conv(x)
  797. else:
  798. x = self.pos_enc(x) + x
  799. return x
  800. def reparameterize(self) -> None:
  801. # Build equivalent Id tensor
  802. input_dim = self.dim // self.groups
  803. kernel_value = torch.zeros(
  804. (
  805. self.dim,
  806. input_dim,
  807. self.spatial_shape[0],
  808. self.spatial_shape[1],
  809. ),
  810. dtype=self.pos_enc.weight.dtype,
  811. device=self.pos_enc.weight.device,
  812. )
  813. for i in range(self.dim):
  814. kernel_value[
  815. i,
  816. i % input_dim,
  817. self.spatial_shape[0] // 2,
  818. self.spatial_shape[1] // 2,
  819. ] = 1
  820. id_tensor = kernel_value
  821. # Reparameterize Id tensor and conv
  822. w_final = id_tensor + self.pos_enc.weight
  823. b_final = self.pos_enc.bias
  824. # Introduce reparam conv
  825. self.reparam_conv = nn.Conv2d(
  826. self.dim,
  827. self.dim_out,
  828. kernel_size=self.spatial_shape,
  829. stride=1,
  830. padding=int(self.spatial_shape[0] // 2),
  831. groups=self.groups,
  832. bias=True,
  833. )
  834. self.reparam_conv.weight.data = w_final
  835. self.reparam_conv.bias.data = b_final
  836. for name, para in self.named_parameters():
  837. if 'reparam_conv' in name:
  838. continue
  839. para.detach_()
  840. self.__delattr__("pos_enc")
  841. class RepMixerBlock(nn.Module):
  842. """Implementation of Metaformer block with RepMixer as token mixer.
  843. For more details on Metaformer structure, please refer to:
  844. `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
  845. """
  846. def __init__(
  847. self,
  848. dim: int,
  849. kernel_size: int = 3,
  850. mlp_ratio: float = 4.0,
  851. act_layer: Type[nn.Module] = nn.GELU,
  852. proj_drop: float = 0.0,
  853. drop_path: float = 0.0,
  854. layer_scale_init_value: float = 1e-5,
  855. inference_mode: bool = False,
  856. device=None,
  857. dtype=None,
  858. ):
  859. """Build RepMixer Block.
  860. Args:
  861. dim: Number of embedding dimensions.
  862. kernel_size: Kernel size for repmixer. Default: 3
  863. mlp_ratio: MLP expansion ratio. Default: 4.0
  864. act_layer: Activation layer. Default: ``nn.GELU``
  865. proj_drop: Dropout rate. Default: 0.0
  866. drop_path: Drop path rate. Default: 0.0
  867. layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
  868. inference_mode: Flag to instantiate block in inference mode. Default: ``False``
  869. """
  870. dd = {'device': device, 'dtype': dtype}
  871. super().__init__()
  872. self.token_mixer = RepMixer(
  873. dim,
  874. kernel_size=kernel_size,
  875. layer_scale_init_value=layer_scale_init_value,
  876. inference_mode=inference_mode,
  877. **dd,
  878. )
  879. self.mlp = ConvMlp(
  880. in_chs=dim,
  881. hidden_channels=int(dim * mlp_ratio),
  882. act_layer=act_layer,
  883. drop=proj_drop,
  884. **dd,
  885. )
  886. if layer_scale_init_value is not None:
  887. self.layer_scale = LayerScale2d(dim, layer_scale_init_value, **dd)
  888. else:
  889. self.layer_scale = nn.Identity()
  890. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  891. def forward(self, x):
  892. x = self.token_mixer(x)
  893. x = x + self.drop_path(self.layer_scale(self.mlp(x)))
  894. return x
  895. class AttentionBlock(nn.Module):
  896. """Implementation of metaformer block with MHSA as token mixer.
  897. For more details on Metaformer structure, please refer to:
  898. `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
  899. """
  900. def __init__(
  901. self,
  902. dim: int,
  903. mlp_ratio: float = 4.0,
  904. act_layer: Type[nn.Module] = nn.GELU,
  905. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  906. proj_drop: float = 0.0,
  907. drop_path: float = 0.0,
  908. layer_scale_init_value: float = 1e-5,
  909. device=None,
  910. dtype=None,
  911. ):
  912. """Build Attention Block.
  913. Args:
  914. dim: Number of embedding dimensions.
  915. mlp_ratio: MLP expansion ratio. Default: 4.0
  916. act_layer: Activation layer. Default: ``nn.GELU``
  917. norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
  918. proj_drop: Dropout rate. Default: 0.0
  919. drop_path: Drop path rate. Default: 0.0
  920. layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
  921. """
  922. dd = {'device': device, 'dtype': dtype}
  923. super().__init__()
  924. self.norm = norm_layer(dim, **dd)
  925. self.token_mixer = Attention(dim=dim, **dd)
  926. if layer_scale_init_value is not None:
  927. self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value, **dd)
  928. else:
  929. self.layer_scale_1 = nn.Identity()
  930. self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  931. self.mlp = ConvMlp(
  932. in_chs=dim,
  933. hidden_channels=int(dim * mlp_ratio),
  934. act_layer=act_layer,
  935. drop=proj_drop,
  936. **dd,
  937. )
  938. if layer_scale_init_value is not None:
  939. self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value, **dd)
  940. else:
  941. self.layer_scale_2 = nn.Identity()
  942. self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  943. def forward(self, x):
  944. x = x + self.drop_path1(self.layer_scale_1(self.token_mixer(self.norm(x))))
  945. x = x + self.drop_path2(self.layer_scale_2(self.mlp(x)))
  946. return x
  947. class FastVitStage(nn.Module):
  948. def __init__(
  949. self,
  950. dim: int,
  951. dim_out: int,
  952. depth: int,
  953. token_mixer_type: str,
  954. downsample: bool = True,
  955. se_downsample: bool = False,
  956. down_patch_size: int = 7,
  957. down_stride: int = 2,
  958. pos_emb_layer: Optional[nn.Module] = None,
  959. kernel_size: int = 3,
  960. mlp_ratio: float = 4.0,
  961. act_layer: Type[nn.Module] = nn.GELU,
  962. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  963. proj_drop_rate: float = 0.0,
  964. drop_path_rate: Union[List[float], float] = 0.0,
  965. layer_scale_init_value: Optional[float] = 1e-5,
  966. lkc_use_act: bool = False,
  967. inference_mode: bool = False,
  968. device=None,
  969. dtype=None,
  970. ):
  971. """FastViT stage.
  972. Args:
  973. dim: Number of embedding dimensions.
  974. depth: Number of blocks in stage
  975. token_mixer_type: Token mixer type.
  976. kernel_size: Kernel size for repmixer.
  977. mlp_ratio: MLP expansion ratio.
  978. act_layer: Activation layer.
  979. norm_layer: Normalization layer.
  980. proj_drop_rate: Dropout rate.
  981. drop_path_rate: Drop path rate.
  982. layer_scale_init_value: Layer scale value at initialization.
  983. inference_mode: Flag to instantiate block in inference mode.
  984. """
  985. super().__init__()
  986. dd = {'device': device, 'dtype': dtype}
  987. self.grad_checkpointing = False
  988. if downsample:
  989. self.downsample = PatchEmbed(
  990. patch_size=down_patch_size,
  991. stride=down_stride,
  992. in_chs=dim,
  993. embed_dim=dim_out,
  994. use_se=se_downsample,
  995. act_layer=act_layer,
  996. lkc_use_act=lkc_use_act,
  997. inference_mode=inference_mode,
  998. **dd,
  999. )
  1000. else:
  1001. assert dim == dim_out
  1002. self.downsample = nn.Identity()
  1003. if pos_emb_layer is not None:
  1004. self.pos_emb = pos_emb_layer(dim_out, inference_mode=inference_mode, **dd)
  1005. else:
  1006. self.pos_emb = nn.Identity()
  1007. blocks = []
  1008. for block_idx in range(depth):
  1009. if token_mixer_type == "repmixer":
  1010. blocks.append(RepMixerBlock(
  1011. dim_out,
  1012. kernel_size=kernel_size,
  1013. mlp_ratio=mlp_ratio,
  1014. act_layer=act_layer,
  1015. proj_drop=proj_drop_rate,
  1016. drop_path=drop_path_rate[block_idx],
  1017. layer_scale_init_value=layer_scale_init_value,
  1018. inference_mode=inference_mode,
  1019. **dd,
  1020. ))
  1021. elif token_mixer_type == "attention":
  1022. blocks.append(AttentionBlock(
  1023. dim_out,
  1024. mlp_ratio=mlp_ratio,
  1025. act_layer=act_layer,
  1026. norm_layer=norm_layer,
  1027. proj_drop=proj_drop_rate,
  1028. drop_path=drop_path_rate[block_idx],
  1029. layer_scale_init_value=layer_scale_init_value,
  1030. **dd,
  1031. ))
  1032. else:
  1033. raise ValueError(
  1034. "Token mixer type: {} not supported".format(token_mixer_type)
  1035. )
  1036. self.blocks = nn.Sequential(*blocks)
  1037. def forward(self, x):
  1038. x = self.downsample(x)
  1039. x = self.pos_emb(x)
  1040. if self.grad_checkpointing and not torch.jit.is_scripting():
  1041. x = checkpoint_seq(self.blocks, x)
  1042. else:
  1043. x = self.blocks(x)
  1044. return x
  1045. class FastVit(nn.Module):
  1046. fork_feat: torch.jit.Final[bool]
  1047. """
  1048. This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
  1049. """
  1050. def __init__(
  1051. self,
  1052. in_chans: int = 3,
  1053. layers: Tuple[int, ...] = (2, 2, 6, 2),
  1054. token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
  1055. embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
  1056. mlp_ratios: Tuple[float, ...] = (4,) * 4,
  1057. downsamples: Tuple[bool, ...] = (False, True, True, True),
  1058. se_downsamples: Tuple[bool, ...] = (False, False, False, False),
  1059. repmixer_kernel_size: int = 3,
  1060. num_classes: int = 1000,
  1061. pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
  1062. down_patch_size: int = 7,
  1063. down_stride: int = 2,
  1064. drop_rate: float = 0.0,
  1065. proj_drop_rate: float = 0.0,
  1066. drop_path_rate: float = 0.0,
  1067. layer_scale_init_value: float = 1e-5,
  1068. lkc_use_act: bool = False,
  1069. stem_use_scale_branch: bool = True,
  1070. fork_feat: bool = False,
  1071. cls_ratio: float = 2.0,
  1072. global_pool: str = 'avg',
  1073. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  1074. act_layer: Type[nn.Module] = nn.GELU,
  1075. inference_mode: bool = False,
  1076. device=None,
  1077. dtype=None,
  1078. ) -> None:
  1079. super().__init__()
  1080. dd = {'device': device, 'dtype': dtype}
  1081. self.num_classes = 0 if fork_feat else num_classes
  1082. self.fork_feat = fork_feat
  1083. self.global_pool = global_pool
  1084. self.feature_info = []
  1085. # Convolutional stem
  1086. self.stem = convolutional_stem(
  1087. in_chans,
  1088. embed_dims[0],
  1089. act_layer,
  1090. inference_mode,
  1091. use_scale_branch=stem_use_scale_branch,
  1092. **dd,
  1093. )
  1094. # Build the main stages of the network architecture
  1095. prev_dim = embed_dims[0]
  1096. scale = 1
  1097. dpr = calculate_drop_path_rates(drop_path_rate, layers, stagewise=True)
  1098. stages = []
  1099. for i in range(len(layers)):
  1100. downsample = downsamples[i] or prev_dim != embed_dims[i]
  1101. stage = FastVitStage(
  1102. dim=prev_dim,
  1103. dim_out=embed_dims[i],
  1104. depth=layers[i],
  1105. downsample=downsample,
  1106. se_downsample=se_downsamples[i],
  1107. down_patch_size=down_patch_size,
  1108. down_stride=down_stride,
  1109. pos_emb_layer=pos_embs[i],
  1110. token_mixer_type=token_mixers[i],
  1111. kernel_size=repmixer_kernel_size,
  1112. mlp_ratio=mlp_ratios[i],
  1113. act_layer=act_layer,
  1114. norm_layer=norm_layer,
  1115. proj_drop_rate=proj_drop_rate,
  1116. drop_path_rate=dpr[i],
  1117. layer_scale_init_value=layer_scale_init_value,
  1118. lkc_use_act=lkc_use_act,
  1119. inference_mode=inference_mode,
  1120. **dd,
  1121. )
  1122. stages.append(stage)
  1123. prev_dim = embed_dims[i]
  1124. if downsample:
  1125. scale *= 2
  1126. self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')]
  1127. self.stages = nn.Sequential(*stages)
  1128. self.num_stages = len(self.stages)
  1129. self.num_features = self.head_hidden_size = prev_dim
  1130. # For segmentation and detection, extract intermediate output
  1131. if self.fork_feat:
  1132. # Add a norm layer for each output. self.stages is slightly different than self.network
  1133. # in the original code, the PatchEmbed layer is part of self.stages in this code where
  1134. # it was part of self.network in the original code. So we do not need to skip out indices.
  1135. self.out_indices = [0, 1, 2, 3]
  1136. for i_emb, i_layer in enumerate(self.out_indices):
  1137. if i_emb == 0 and os.environ.get("FORK_LAST3", None):
  1138. """For RetinaNet, `start_level=1`. The first norm layer will not used.
  1139. cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...`
  1140. """
  1141. layer = nn.Identity()
  1142. else:
  1143. layer = norm_layer(embed_dims[i_emb], **dd)
  1144. layer_name = f"norm{i_layer}"
  1145. self.add_module(layer_name, layer)
  1146. else:
  1147. # Classifier head
  1148. self.num_features = self.head_hidden_size = final_features = int(embed_dims[-1] * cls_ratio)
  1149. self.final_conv = MobileOneBlock(
  1150. in_chs=embed_dims[-1],
  1151. out_chs=final_features,
  1152. kernel_size=3,
  1153. stride=1,
  1154. group_size=1,
  1155. inference_mode=inference_mode,
  1156. use_se=True,
  1157. act_layer=act_layer,
  1158. num_conv_branches=1,
  1159. **dd,
  1160. )
  1161. self.head = ClassifierHead(
  1162. final_features,
  1163. num_classes,
  1164. pool_type=global_pool,
  1165. drop_rate=drop_rate,
  1166. **dd,
  1167. )
  1168. self.apply(self._init_weights)
  1169. def _init_weights(self, m: nn.Module) -> None:
  1170. """Init. for classification"""
  1171. if isinstance(m, nn.Linear):
  1172. trunc_normal_(m.weight, std=0.02)
  1173. if isinstance(m, nn.Linear) and m.bias is not None:
  1174. nn.init.constant_(m.bias, 0)
  1175. @torch.jit.ignore
  1176. def no_weight_decay(self):
  1177. return set()
  1178. @torch.jit.ignore
  1179. def group_matcher(self, coarse=False):
  1180. return dict(
  1181. stem=r'^stem', # stem and embed
  1182. blocks=r'^stages\.(\d+)' if coarse else [
  1183. (r'^stages\.(\d+).downsample', (0,)),
  1184. (r'^stages\.(\d+).pos_emb', (0,)),
  1185. (r'^stages\.(\d+)\.\w+\.(\d+)', None),
  1186. ]
  1187. )
  1188. @torch.jit.ignore
  1189. def set_grad_checkpointing(self, enable=True):
  1190. for s in self.stages:
  1191. s.grad_checkpointing = enable
  1192. @torch.jit.ignore
  1193. def get_classifier(self) -> nn.Module:
  1194. return self.head.fc
  1195. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  1196. self.num_classes = num_classes
  1197. self.head.reset(num_classes, global_pool)
  1198. def forward_intermediates(
  1199. self,
  1200. x: torch.Tensor,
  1201. indices: Optional[Union[int, List[int]]] = None,
  1202. norm: bool = False,
  1203. stop_early: bool = False,
  1204. output_fmt: str = 'NCHW',
  1205. intermediates_only: bool = False,
  1206. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  1207. """ Forward features that returns intermediates.
  1208. Args:
  1209. x: Input image tensor
  1210. indices: Take last n blocks if int, all if None, select matching indices if sequence
  1211. norm: Apply norm layer to compatible intermediates
  1212. stop_early: Stop iterating over blocks when last desired intermediate hit
  1213. output_fmt: Shape of intermediate feature outputs
  1214. intermediates_only: Only return intermediate features
  1215. Returns:
  1216. """
  1217. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  1218. intermediates = []
  1219. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  1220. # forward pass
  1221. x = self.stem(x)
  1222. last_idx = self.num_stages - 1
  1223. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  1224. stages = self.stages
  1225. else:
  1226. stages = self.stages[:max_index + 1]
  1227. feat_idx = 0
  1228. for feat_idx, stage in enumerate(stages):
  1229. x = stage(x)
  1230. if feat_idx in take_indices:
  1231. intermediates.append(x)
  1232. if intermediates_only:
  1233. return intermediates
  1234. if feat_idx == last_idx:
  1235. x = self.final_conv(x)
  1236. return x, intermediates
  1237. def prune_intermediate_layers(
  1238. self,
  1239. indices: Union[int, List[int]] = 1,
  1240. prune_norm: bool = False,
  1241. prune_head: bool = True,
  1242. ):
  1243. """ Prune layers not required for specified intermediates.
  1244. """
  1245. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  1246. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  1247. if prune_head:
  1248. self.reset_classifier(0, '')
  1249. return take_indices
  1250. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  1251. # input embedding
  1252. x = self.stem(x)
  1253. outs = []
  1254. for idx, block in enumerate(self.stages):
  1255. x = block(x)
  1256. if self.fork_feat:
  1257. if idx in self.out_indices:
  1258. norm_layer = getattr(self, f"norm{idx}")
  1259. x_out = norm_layer(x)
  1260. outs.append(x_out)
  1261. if self.fork_feat:
  1262. # output the features of four stages for dense prediction
  1263. return outs
  1264. x = self.final_conv(x)
  1265. return x
  1266. def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
  1267. return self.head(x, pre_logits=True) if pre_logits else self.head(x)
  1268. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1269. x = self.forward_features(x)
  1270. if self.fork_feat:
  1271. return x
  1272. x = self.forward_head(x)
  1273. return x
  1274. def _cfg(url="", **kwargs):
  1275. return {
  1276. "url": url,
  1277. "num_classes": 1000,
  1278. "input_size": (3, 256, 256),
  1279. "pool_size": (8, 8),
  1280. "crop_pct": 0.9,
  1281. "interpolation": "bicubic",
  1282. "mean": IMAGENET_DEFAULT_MEAN,
  1283. "license": "fastvit-license",
  1284. "std": IMAGENET_DEFAULT_STD,
  1285. 'first_conv': ('stem.0.conv_kxk.0.conv', 'stem.0.conv_scale.conv'),
  1286. "classifier": "head.fc",
  1287. **kwargs,
  1288. }
  1289. default_cfgs = generate_default_cfgs({
  1290. "fastvit_t8.apple_in1k": _cfg(
  1291. hf_hub_id='timm/'),
  1292. "fastvit_t12.apple_in1k": _cfg(
  1293. hf_hub_id='timm/'),
  1294. "fastvit_s12.apple_in1k": _cfg(
  1295. hf_hub_id='timm/'),
  1296. "fastvit_sa12.apple_in1k": _cfg(
  1297. hf_hub_id='timm/'),
  1298. "fastvit_sa24.apple_in1k": _cfg(
  1299. hf_hub_id='timm/'),
  1300. "fastvit_sa36.apple_in1k": _cfg(
  1301. hf_hub_id='timm/'),
  1302. "fastvit_ma36.apple_in1k": _cfg(
  1303. hf_hub_id='timm/',
  1304. crop_pct=0.95),
  1305. "fastvit_t8.apple_dist_in1k": _cfg(
  1306. hf_hub_id='timm/'),
  1307. "fastvit_t12.apple_dist_in1k": _cfg(
  1308. hf_hub_id='timm/'),
  1309. "fastvit_s12.apple_dist_in1k": _cfg(
  1310. hf_hub_id='timm/',),
  1311. "fastvit_sa12.apple_dist_in1k": _cfg(
  1312. hf_hub_id='timm/',),
  1313. "fastvit_sa24.apple_dist_in1k": _cfg(
  1314. hf_hub_id='timm/',),
  1315. "fastvit_sa36.apple_dist_in1k": _cfg(
  1316. hf_hub_id='timm/',),
  1317. "fastvit_ma36.apple_dist_in1k": _cfg(
  1318. hf_hub_id='timm/',
  1319. crop_pct=0.95
  1320. ),
  1321. "fastvit_mci0.apple_mclip": _cfg(
  1322. hf_hub_id='apple/mobileclip_s0_timm',
  1323. url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt',
  1324. crop_pct=0.95,
  1325. num_classes=512, # CLIP proj dim
  1326. mean=(0., 0., 0.), std=(1., 1., 1.), license='apple-amlr'
  1327. ),
  1328. "fastvit_mci1.apple_mclip": _cfg(
  1329. hf_hub_id='apple/mobileclip_s1_timm',
  1330. url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt',
  1331. crop_pct=0.95,
  1332. num_classes=512, # CLIP proj dim
  1333. mean=(0., 0., 0.), std=(1., 1., 1.), license='apple-amlr'
  1334. ),
  1335. "fastvit_mci2.apple_mclip": _cfg(
  1336. hf_hub_id='apple/mobileclip_s2_timm',
  1337. url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt',
  1338. crop_pct=0.95,
  1339. num_classes=512, # CLIP proj dim
  1340. mean=(0., 0., 0.), std=(1., 1., 1.), license='apple-amlr'
  1341. ),
  1342. "fastvit_mci0.apple_mclip2_dfndr2b": _cfg(
  1343. hf_hub_id='timm/',
  1344. crop_pct=1.0,
  1345. num_classes=512, # CLIP proj dim
  1346. mean=(0., 0., 0.), std=(1., 1., 1.),
  1347. license='apple-amlr'
  1348. ),
  1349. "fastvit_mci2.apple_mclip2_dfndr2b": _cfg(
  1350. hf_hub_id='timm/',
  1351. crop_pct=0.95,
  1352. num_classes=512, # CLIP proj dim
  1353. mean=(0., 0., 0.), std=(1., 1., 1.),
  1354. license='apple-amlr'
  1355. ),
  1356. "fastvit_mci3.apple_mclip2_dfndr2b": _cfg(
  1357. hf_hub_id='timm/',
  1358. crop_pct=0.95,
  1359. num_classes=768, # CLIP proj dim
  1360. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1361. pool_size=(4, 4),
  1362. first_conv='stem.0.conv_kxk.0.conv',
  1363. license='apple-amlr'
  1364. ),
  1365. "fastvit_mci4.apple_mclip2_dfndr2b": _cfg(
  1366. hf_hub_id='timm/',
  1367. crop_pct=0.95,
  1368. num_classes=768, # CLIP proj dim
  1369. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1370. pool_size=(4, 4),
  1371. first_conv='stem.0.conv_kxk.0.conv',
  1372. license='apple-amlr'
  1373. ),
  1374. })
  1375. def checkpoint_filter_fn(state_dict, model):
  1376. """ Remap original checkpoints -> timm """
  1377. if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
  1378. return state_dict # non-original checkpoint, no remapping needed
  1379. if 'module.visual.trunk.stem.0.conv_kxk.0.conv.weight' in state_dict:
  1380. return {k.replace('module.visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('module.visual.trunk')}
  1381. state_dict = state_dict.get('state_dict', state_dict)
  1382. if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
  1383. # remap MobileCLIP checkpoints
  1384. prefix = 'image_encoder.model.'
  1385. else:
  1386. prefix = ''
  1387. import re
  1388. import bisect
  1389. # find stage ends by locating downsample layers
  1390. stage_ends = []
  1391. for k, v in state_dict.items():
  1392. match = re.match(r'^(.*?)network\.(\d+)\.proj.*', k)
  1393. if match:
  1394. stage_ends.append(int(match.group(2)))
  1395. stage_ends = list(sorted(set(stage_ends)))
  1396. out_dict = {}
  1397. for k, v in state_dict.items():
  1398. if prefix:
  1399. if prefix not in k:
  1400. continue
  1401. k = k.replace(prefix, '')
  1402. # remap renamed layers
  1403. k = k.replace('patch_embed', 'stem')
  1404. k = k.replace('rbr_conv', 'conv_kxk')
  1405. k = k.replace('rbr_scale', 'conv_scale')
  1406. k = k.replace('rbr_skip', 'identity')
  1407. k = k.replace('conv_exp', 'final_conv') # to match byobnet, regnet, nfnet
  1408. k = k.replace('lkb_origin', 'large_conv')
  1409. k = k.replace('convffn', 'mlp')
  1410. k = k.replace('se.reduce', 'se.fc1')
  1411. k = k.replace('se.expand', 'se.fc2')
  1412. k = re.sub(r'layer_scale_([0-9])', r'layer_scale_\1.gamma', k)
  1413. if k.endswith('layer_scale'):
  1414. k = k.replace('layer_scale', 'layer_scale.gamma')
  1415. k = k.replace('dist_head', 'head_dist')
  1416. if k.startswith('head.'):
  1417. if k == 'head.proj' and hasattr(model.head, 'fc') and isinstance(model.head.fc, nn.Linear):
  1418. # if CLIP projection, map to head.fc w/ bias = zeros
  1419. k = k.replace('head.proj', 'head.fc.weight')
  1420. v = v.T
  1421. out_dict['head.fc.bias'] = torch.zeros(v.shape[0])
  1422. else:
  1423. k = k.replace('head.', 'head.fc.')
  1424. # remap flat sequential network to stages
  1425. match = re.match(r'^network\.(\d+)', k)
  1426. stage_idx, net_idx = None, None
  1427. if match:
  1428. net_idx = int(match.group(1))
  1429. stage_idx = bisect.bisect_right(stage_ends, net_idx)
  1430. if stage_idx is not None:
  1431. net_prefix = f'network.{net_idx}'
  1432. stage_prefix = f'stages.{stage_idx}'
  1433. if net_prefix + '.proj' in k:
  1434. k = k.replace(net_prefix + '.proj', stage_prefix + '.downsample.proj')
  1435. elif net_prefix + '.pe' in k:
  1436. k = k.replace(net_prefix + '.pe', stage_prefix + '.pos_emb.pos_enc')
  1437. else:
  1438. k = k.replace(net_prefix, stage_prefix + '.blocks')
  1439. out_dict[k] = v
  1440. return out_dict
  1441. def _create_fastvit(variant, pretrained=False, **kwargs):
  1442. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
  1443. model = build_model_with_cfg(
  1444. FastVit,
  1445. variant,
  1446. pretrained,
  1447. pretrained_filter_fn=checkpoint_filter_fn,
  1448. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  1449. **kwargs
  1450. )
  1451. return model
  1452. @register_model
  1453. def fastvit_t8(pretrained=False, **kwargs):
  1454. """Instantiate FastViT-T8 model variant."""
  1455. model_args = dict(
  1456. layers=(2, 2, 4, 2),
  1457. embed_dims=(48, 96, 192, 384),
  1458. mlp_ratios=(3, 3, 3, 3),
  1459. token_mixers=("repmixer", "repmixer", "repmixer", "repmixer")
  1460. )
  1461. return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs))
  1462. @register_model
  1463. def fastvit_t12(pretrained=False, **kwargs):
  1464. """Instantiate FastViT-T12 model variant."""
  1465. model_args = dict(
  1466. layers=(2, 2, 6, 2),
  1467. embed_dims=(64, 128, 256, 512),
  1468. mlp_ratios=(3, 3, 3, 3),
  1469. token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
  1470. )
  1471. return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs))
  1472. @register_model
  1473. def fastvit_s12(pretrained=False, **kwargs):
  1474. """Instantiate FastViT-S12 model variant."""
  1475. model_args = dict(
  1476. layers=(2, 2, 6, 2),
  1477. embed_dims=(64, 128, 256, 512),
  1478. mlp_ratios=(4, 4, 4, 4),
  1479. token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
  1480. )
  1481. return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs))
  1482. @register_model
  1483. def fastvit_sa12(pretrained=False, **kwargs):
  1484. """Instantiate FastViT-SA12 model variant."""
  1485. model_args = dict(
  1486. layers=(2, 2, 6, 2),
  1487. embed_dims=(64, 128, 256, 512),
  1488. mlp_ratios=(4, 4, 4, 4),
  1489. pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
  1490. token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
  1491. )
  1492. return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs))
  1493. @register_model
  1494. def fastvit_sa24(pretrained=False, **kwargs):
  1495. """Instantiate FastViT-SA24 model variant."""
  1496. model_args = dict(
  1497. layers=(4, 4, 12, 4),
  1498. embed_dims=(64, 128, 256, 512),
  1499. mlp_ratios=(4, 4, 4, 4),
  1500. pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
  1501. token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
  1502. )
  1503. return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs))
  1504. @register_model
  1505. def fastvit_sa36(pretrained=False, **kwargs):
  1506. """Instantiate FastViT-SA36 model variant."""
  1507. model_args = dict(
  1508. layers=(6, 6, 18, 6),
  1509. embed_dims=(64, 128, 256, 512),
  1510. mlp_ratios=(4, 4, 4, 4),
  1511. pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
  1512. token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
  1513. )
  1514. return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs))
  1515. @register_model
  1516. def fastvit_ma36(pretrained=False, **kwargs):
  1517. """Instantiate FastViT-MA36 model variant."""
  1518. model_args = dict(
  1519. layers=(6, 6, 18, 6),
  1520. embed_dims=(76, 152, 304, 608),
  1521. mlp_ratios=(4, 4, 4, 4),
  1522. pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
  1523. token_mixers=("repmixer", "repmixer", "repmixer", "attention")
  1524. )
  1525. return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs))
  1526. @register_model
  1527. def fastvit_mci0(pretrained=False, **kwargs):
  1528. """Instantiate MCi0 model variant."""
  1529. model_args = dict(
  1530. layers=(2, 6, 10, 2),
  1531. embed_dims=(64, 128, 256, 512),
  1532. mlp_ratios=(3, 3, 3, 3),
  1533. se_downsamples=(False, False, True, True),
  1534. pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
  1535. token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
  1536. lkc_use_act=True,
  1537. )
  1538. return _create_fastvit('fastvit_mci0', pretrained=pretrained, **dict(model_args, **kwargs))
  1539. @register_model
  1540. def fastvit_mci1(pretrained=False, **kwargs):
  1541. """Instantiate MCi1 model variant."""
  1542. model_args = dict(
  1543. layers=(4, 12, 20, 4),
  1544. embed_dims=(64, 128, 256, 512),
  1545. mlp_ratios=(3, 3, 3, 3),
  1546. se_downsamples=(False, False, True, True),
  1547. pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
  1548. token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
  1549. lkc_use_act=True,
  1550. )
  1551. return _create_fastvit('fastvit_mci1', pretrained=pretrained, **dict(model_args, **kwargs))
  1552. @register_model
  1553. def fastvit_mci2(pretrained=False, **kwargs):
  1554. """Instantiate MCi2 model variant."""
  1555. model_args = dict(
  1556. layers=(4, 12, 24, 4),
  1557. embed_dims=(80, 160, 320, 640),
  1558. mlp_ratios=(3, 3, 3, 3),
  1559. se_downsamples=(False, False, True, True),
  1560. pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
  1561. token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
  1562. lkc_use_act=True,
  1563. )
  1564. return _create_fastvit('fastvit_mci2', pretrained=pretrained, **dict(model_args, **kwargs))
  1565. @register_model
  1566. def fastvit_mci3(pretrained=False, **kwargs):
  1567. """Instantiate L model variant."""
  1568. model_args = dict(
  1569. layers=(2, 12, 24, 4, 2),
  1570. embed_dims=(96, 192, 384, 768, 1536),
  1571. mlp_ratios=(4, 4, 4, 4, 4),
  1572. se_downsamples=(False, False, False, False, False),
  1573. downsamples=(False, True, True, True, True),
  1574. pos_embs=(
  1575. None,
  1576. None,
  1577. None,
  1578. partial(RepConditionalPosEnc, spatial_shape=(7, 7)),
  1579. partial(RepConditionalPosEnc, spatial_shape=(7, 7))
  1580. ),
  1581. token_mixers=("repmixer", "repmixer", "repmixer", "attention", "attention"),
  1582. lkc_use_act=True,
  1583. norm_layer=partial(LayerNorm2d, eps=1e-5),
  1584. stem_use_scale_branch=False,
  1585. )
  1586. model = _create_fastvit('fastvit_mci3', pretrained=pretrained, **dict(model_args, **kwargs))
  1587. return model
  1588. @register_model
  1589. def fastvit_mci4(pretrained=False, **kwargs):
  1590. """Instantiate XL model variant."""
  1591. model_args = dict(
  1592. layers=(2, 12, 24, 4, 4),
  1593. embed_dims=(128, 256, 512, 1024, 2048),
  1594. mlp_ratios=(4, 4, 4, 4, 4),
  1595. se_downsamples=(False, False, False, False, False),
  1596. downsamples=(False, True, True, True, True),
  1597. pos_embs=(
  1598. None,
  1599. None,
  1600. None,
  1601. partial(RepConditionalPosEnc, spatial_shape=(7, 7)),
  1602. partial(RepConditionalPosEnc, spatial_shape=(7, 7))
  1603. ),
  1604. token_mixers=("repmixer", "repmixer", "repmixer", "attention", "attention"),
  1605. lkc_use_act=True,
  1606. norm_layer=partial(LayerNorm2d, eps=1e-5),
  1607. stem_use_scale_branch=False,
  1608. )
  1609. model = _create_fastvit('fastvit_mci4', pretrained=pretrained, **dict(model_args, **kwargs))
  1610. return model