nextvit.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821
  1. """ Next-ViT
  2. As described in https://arxiv.org/abs/2207.05501
  3. Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-ViT, original copyright below
  4. """
  5. # Copyright (c) ByteDance Inc. All rights reserved.
  6. from functools import partial
  7. from typing import List, Optional, Tuple, Union, Type
  8. import torch
  9. import torch.nn.functional as F
  10. from torch import nn
  11. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  12. from timm.layers import DropPath, calculate_drop_path_rates, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn
  13. from timm.layers import ClassifierHead
  14. from ._builder import build_model_with_cfg
  15. from ._features import feature_take_indices
  16. from ._manipulate import checkpoint, checkpoint_seq
  17. from ._registry import generate_default_cfgs, register_model
  18. __all__ = ['NextViT']
  19. def merge_pre_bn(module, pre_bn_1, pre_bn_2=None):
  20. """ Merge pre BN to reduce inference runtime.
  21. """
  22. weight = module.weight.data
  23. if module.bias is None:
  24. zeros = torch.zeros(module.out_chs, device=weight.device).type(weight.type())
  25. module.bias = nn.Parameter(zeros)
  26. bias = module.bias.data
  27. if pre_bn_2 is None:
  28. assert pre_bn_1.track_running_stats is True, "Unsupported bn_module.track_running_stats is False"
  29. assert pre_bn_1.affine is True, "Unsupported bn_module.affine is False"
  30. scale_invstd = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
  31. extra_weight = scale_invstd * pre_bn_1.weight
  32. extra_bias = pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd
  33. else:
  34. assert pre_bn_1.track_running_stats is True, "Unsupported bn_module.track_running_stats is False"
  35. assert pre_bn_1.affine is True, "Unsupported bn_module.affine is False"
  36. assert pre_bn_2.track_running_stats is True, "Unsupported bn_module.track_running_stats is False"
  37. assert pre_bn_2.affine is True, "Unsupported bn_module.affine is False"
  38. scale_invstd_1 = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
  39. scale_invstd_2 = pre_bn_2.running_var.add(pre_bn_2.eps).pow(-0.5)
  40. extra_weight = scale_invstd_1 * pre_bn_1.weight * scale_invstd_2 * pre_bn_2.weight
  41. extra_bias = (
  42. scale_invstd_2 * pre_bn_2.weight
  43. * (pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd_1 - pre_bn_2.running_mean)
  44. + pre_bn_2.bias
  45. )
  46. if isinstance(module, nn.Linear):
  47. extra_bias = weight @ extra_bias
  48. weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
  49. elif isinstance(module, nn.Conv2d):
  50. assert weight.shape[2] == 1 and weight.shape[3] == 1
  51. weight = weight.reshape(weight.shape[0], weight.shape[1])
  52. extra_bias = weight @ extra_bias
  53. weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
  54. weight = weight.reshape(weight.shape[0], weight.shape[1], 1, 1)
  55. bias.add_(extra_bias)
  56. module.weight.data = weight
  57. module.bias.data = bias
  58. class ConvNormAct(nn.Module):
  59. def __init__(
  60. self,
  61. in_chs: int,
  62. out_chs: int,
  63. kernel_size: int = 3,
  64. stride: int = 1,
  65. groups: int = 1,
  66. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  67. act_layer: Type[nn.Module] = nn.ReLU,
  68. device=None,
  69. dtype=None,
  70. ):
  71. dd = {'device': device, 'dtype': dtype}
  72. super().__init__()
  73. self.conv = nn.Conv2d(
  74. in_chs,
  75. out_chs,
  76. kernel_size=kernel_size,
  77. stride=stride,
  78. padding=1,
  79. groups=groups,
  80. bias=False,
  81. **dd,
  82. )
  83. self.norm = norm_layer(out_chs, **dd)
  84. self.act = act_layer()
  85. def forward(self, x):
  86. x = self.conv(x)
  87. x = self.norm(x)
  88. x = self.act(x)
  89. return x
  90. def _make_divisible(v, divisor, min_value=None):
  91. if min_value is None:
  92. min_value = divisor
  93. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  94. # Make sure that round down does not go down by more than 10%.
  95. if new_v < 0.9 * v:
  96. new_v += divisor
  97. return new_v
  98. class PatchEmbed(nn.Module):
  99. def __init__(
  100. self,
  101. in_chs: int,
  102. out_chs: int,
  103. stride: int = 1,
  104. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  105. device=None,
  106. dtype=None,
  107. ):
  108. dd = {'device': device, 'dtype': dtype}
  109. super().__init__()
  110. if stride == 2:
  111. self.pool = nn.AvgPool2d((2, 2), stride=2, ceil_mode=True, count_include_pad=False)
  112. self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False, **dd)
  113. self.norm = norm_layer(out_chs, **dd)
  114. elif in_chs != out_chs:
  115. self.pool = nn.Identity()
  116. self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False, **dd)
  117. self.norm = norm_layer(out_chs, **dd)
  118. else:
  119. self.pool = nn.Identity()
  120. self.conv = nn.Identity()
  121. self.norm = nn.Identity()
  122. def forward(self, x):
  123. return self.norm(self.conv(self.pool(x)))
  124. class ConvAttention(nn.Module):
  125. """
  126. Multi-Head Convolutional Attention
  127. """
  128. def __init__(
  129. self,
  130. out_chs: int,
  131. head_dim: int,
  132. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  133. act_layer: Type[nn.Module] = nn.ReLU,
  134. device=None,
  135. dtype=None,
  136. ):
  137. dd = {'device': device, 'dtype': dtype}
  138. super().__init__()
  139. self.group_conv3x3 = nn.Conv2d(
  140. out_chs,
  141. out_chs,
  142. kernel_size=3,
  143. stride=1,
  144. padding=1,
  145. groups=out_chs // head_dim,
  146. bias=False,
  147. **dd,
  148. )
  149. self.norm = norm_layer(out_chs, **dd)
  150. self.act = act_layer()
  151. self.projection = nn.Conv2d(out_chs, out_chs, kernel_size=1, bias=False, **dd)
  152. def forward(self, x):
  153. out = self.group_conv3x3(x)
  154. out = self.norm(out)
  155. out = self.act(out)
  156. out = self.projection(out)
  157. return out
  158. class NextConvBlock(nn.Module):
  159. """
  160. Next Convolution Block
  161. """
  162. def __init__(
  163. self,
  164. in_chs: int,
  165. out_chs: int,
  166. stride: int = 1,
  167. drop_path: float = 0.,
  168. drop: float = 0.,
  169. head_dim: int = 32,
  170. mlp_ratio: float = 3.,
  171. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  172. act_layer: Type[nn.Module] = nn.ReLU,
  173. device=None,
  174. dtype=None,
  175. ):
  176. dd = {'device': device, 'dtype': dtype}
  177. super().__init__()
  178. self.in_chs = in_chs
  179. self.out_chs = out_chs
  180. assert out_chs % head_dim == 0
  181. self.patch_embed = PatchEmbed(in_chs, out_chs, stride, norm_layer=norm_layer, **dd)
  182. self.mhca = ConvAttention(
  183. out_chs,
  184. head_dim,
  185. norm_layer=norm_layer,
  186. act_layer=act_layer,
  187. **dd,
  188. )
  189. self.attn_drop_path = DropPath(drop_path)
  190. self.norm = norm_layer(out_chs, **dd)
  191. self.mlp = ConvMlp(
  192. out_chs,
  193. hidden_features=int(out_chs * mlp_ratio),
  194. drop=drop,
  195. bias=True,
  196. act_layer=act_layer,
  197. **dd,
  198. )
  199. self.mlp_drop_path = DropPath(drop_path)
  200. self.is_fused = False
  201. @torch.no_grad()
  202. def reparameterize(self):
  203. if not self.is_fused:
  204. merge_pre_bn(self.mlp.fc1, self.norm)
  205. self.norm = nn.Identity()
  206. self.is_fused = True
  207. def forward(self, x):
  208. x = self.patch_embed(x)
  209. x = x + self.attn_drop_path(self.mhca(x))
  210. out = self.norm(x)
  211. x = x + self.mlp_drop_path(self.mlp(out))
  212. return x
  213. class EfficientAttention(nn.Module):
  214. """
  215. Efficient Multi-Head Self Attention
  216. """
  217. fused_attn: torch.jit.Final[bool]
  218. def __init__(
  219. self,
  220. dim: int,
  221. out_dim: Optional[int] = None,
  222. head_dim: int = 32,
  223. qkv_bias: bool = True,
  224. attn_drop: float = 0.,
  225. proj_drop: float = 0.,
  226. sr_ratio: int = 1,
  227. norm_layer: Type[nn.Module] = nn.BatchNorm1d,
  228. device=None,
  229. dtype=None,
  230. ):
  231. dd = {'device': device, 'dtype': dtype}
  232. super().__init__()
  233. self.dim = dim
  234. self.out_dim = out_dim if out_dim is not None else dim
  235. self.num_heads = self.dim // head_dim
  236. self.head_dim = head_dim
  237. self.scale = head_dim ** -0.5
  238. self.fused_attn = use_fused_attn()
  239. self.q = nn.Linear(dim, self.dim, bias=qkv_bias, **dd)
  240. self.k = nn.Linear(dim, self.dim, bias=qkv_bias, **dd)
  241. self.v = nn.Linear(dim, self.dim, bias=qkv_bias, **dd)
  242. self.proj = nn.Linear(self.dim, self.out_dim, **dd)
  243. self.attn_drop = nn.Dropout(attn_drop)
  244. self.proj_drop = nn.Dropout(proj_drop)
  245. self.sr_ratio = sr_ratio
  246. self.N_ratio = sr_ratio ** 2
  247. if sr_ratio > 1:
  248. self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
  249. self.norm = norm_layer(dim, **dd)
  250. else:
  251. self.sr = None
  252. self.norm = None
  253. def forward(self, x):
  254. B, N, C = x.shape
  255. q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
  256. if self.sr is not None:
  257. x = self.sr(x.transpose(1, 2))
  258. x = self.norm(x).transpose(1, 2)
  259. k = self.k(x).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
  260. v = self.v(x).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
  261. if self.fused_attn:
  262. x = F.scaled_dot_product_attention(
  263. q, k, v,
  264. dropout_p=self.attn_drop.p if self.training else 0.,
  265. )
  266. else:
  267. q = q * self.scale
  268. attn = q @ k.transpose(-1, -2)
  269. attn = attn.softmax(dim=-1)
  270. attn = self.attn_drop(attn)
  271. x = attn @ v
  272. x = x.transpose(1, 2).reshape(B, N, C)
  273. x = self.proj(x)
  274. x = self.proj_drop(x)
  275. return x
  276. class NextTransformerBlock(nn.Module):
  277. """
  278. Next Transformer Block
  279. """
  280. def __init__(
  281. self,
  282. in_chs: int,
  283. out_chs: int,
  284. drop_path: float,
  285. stride: int = 1,
  286. sr_ratio: int = 1,
  287. mlp_ratio: float = 2,
  288. head_dim: int = 32,
  289. mix_block_ratio: float = 0.75,
  290. attn_drop: float = 0.,
  291. drop: float = 0.,
  292. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  293. act_layer: Type[nn.Module] = nn.ReLU,
  294. device=None,
  295. dtype=None,
  296. ):
  297. dd = {'device': device, 'dtype': dtype}
  298. super().__init__()
  299. self.in_chs = in_chs
  300. self.out_chs = out_chs
  301. self.mix_block_ratio = mix_block_ratio
  302. self.mhsa_out_chs = _make_divisible(int(out_chs * mix_block_ratio), 32)
  303. self.mhca_out_chs = out_chs - self.mhsa_out_chs
  304. self.patch_embed = PatchEmbed(in_chs, self.mhsa_out_chs, stride, **dd)
  305. self.norm1 = norm_layer(self.mhsa_out_chs, **dd)
  306. self.e_mhsa = EfficientAttention(
  307. self.mhsa_out_chs,
  308. head_dim=head_dim,
  309. sr_ratio=sr_ratio,
  310. attn_drop=attn_drop,
  311. proj_drop=drop,
  312. **dd,
  313. )
  314. self.mhsa_drop_path = DropPath(drop_path * mix_block_ratio)
  315. self.projection = PatchEmbed(
  316. self.mhsa_out_chs,
  317. self.mhca_out_chs,
  318. stride=1,
  319. norm_layer=norm_layer,
  320. **dd,
  321. )
  322. self.mhca = ConvAttention(
  323. self.mhca_out_chs,
  324. head_dim=head_dim,
  325. norm_layer=norm_layer,
  326. act_layer=act_layer,
  327. **dd,
  328. )
  329. self.mhca_drop_path = DropPath(drop_path * (1 - mix_block_ratio))
  330. self.norm2 = norm_layer(out_chs, **dd)
  331. self.mlp = ConvMlp(
  332. out_chs,
  333. hidden_features=int(out_chs * mlp_ratio),
  334. act_layer=act_layer,
  335. drop=drop,
  336. **dd,
  337. )
  338. self.mlp_drop_path = DropPath(drop_path)
  339. self.is_fused = False
  340. @torch.no_grad()
  341. def reparameterize(self):
  342. if not self.is_fused:
  343. merge_pre_bn(self.e_mhsa.q, self.norm1)
  344. if self.e_mhsa.norm is not None:
  345. merge_pre_bn(self.e_mhsa.k, self.norm1, self.e_mhsa.norm)
  346. merge_pre_bn(self.e_mhsa.v, self.norm1, self.e_mhsa.norm)
  347. self.e_mhsa.norm = nn.Identity()
  348. else:
  349. merge_pre_bn(self.e_mhsa.k, self.norm1)
  350. merge_pre_bn(self.e_mhsa.v, self.norm1)
  351. self.norm1 = nn.Identity()
  352. merge_pre_bn(self.mlp.fc1, self.norm2)
  353. self.norm2 = nn.Identity()
  354. self.is_fused = True
  355. def forward(self, x):
  356. x = self.patch_embed(x)
  357. B, C, H, W = x.shape
  358. out = self.norm1(x)
  359. out = out.reshape(B, C, -1).transpose(-1, -2)
  360. out = self.mhsa_drop_path(self.e_mhsa(out))
  361. x = x + out.transpose(-1, -2).reshape(B, C, H, W)
  362. out = self.projection(x)
  363. out = out + self.mhca_drop_path(self.mhca(out))
  364. x = torch.cat([x, out], dim=1)
  365. out = self.norm2(x)
  366. x = x + self.mlp_drop_path(self.mlp(out))
  367. return x
  368. class NextStage(nn.Module):
  369. def __init__(
  370. self,
  371. in_chs: int,
  372. block_chs: List[int],
  373. block_types: List[Type[nn.Module]],
  374. stride: int = 2,
  375. sr_ratio: int = 1,
  376. mix_block_ratio: float = 1.0,
  377. drop: float = 0.,
  378. attn_drop: float = 0.,
  379. drop_path: Union[float, List[float], Tuple[float, ...]] = 0.,
  380. head_dim: int = 32,
  381. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  382. act_layer: Type[nn.Module] = nn.ReLU,
  383. device=None,
  384. dtype=None,
  385. ):
  386. dd = {'device': device, 'dtype': dtype}
  387. super().__init__()
  388. self.grad_checkpointing = False
  389. blocks = []
  390. for block_idx, block_type in enumerate(block_types):
  391. stride = stride if block_idx == 0 else 1
  392. out_chs = block_chs[block_idx]
  393. block_type = block_types[block_idx]
  394. dpr = drop_path[block_idx] if isinstance(drop_path, (list, tuple)) else drop_path
  395. if block_type is NextConvBlock:
  396. layer = NextConvBlock(
  397. in_chs,
  398. out_chs,
  399. stride=stride,
  400. drop_path=dpr,
  401. drop=drop,
  402. head_dim=head_dim,
  403. norm_layer=norm_layer,
  404. act_layer=act_layer,
  405. **dd,
  406. )
  407. blocks.append(layer)
  408. elif block_type is NextTransformerBlock:
  409. layer = NextTransformerBlock(
  410. in_chs,
  411. out_chs,
  412. drop_path=dpr,
  413. stride=stride,
  414. sr_ratio=sr_ratio,
  415. head_dim=head_dim,
  416. mix_block_ratio=mix_block_ratio,
  417. attn_drop=attn_drop,
  418. drop=drop,
  419. norm_layer=norm_layer,
  420. act_layer=act_layer,
  421. **dd,
  422. )
  423. blocks.append(layer)
  424. in_chs = out_chs
  425. self.blocks = nn.Sequential(*blocks)
  426. @torch.jit.ignore
  427. def set_grad_checkpointing(self, enable=True):
  428. self.grad_checkpointing = enable
  429. def forward(self, x):
  430. if self.grad_checkpointing and not torch.jit.is_scripting():
  431. x = checkpoint_seq(self.blocks, x)
  432. else:
  433. x = self.blocks(x)
  434. return x
  435. class NextViT(nn.Module):
  436. def __init__(
  437. self,
  438. in_chans: int,
  439. num_classes: int = 1000,
  440. global_pool: str = 'avg',
  441. stem_chs: Tuple[int, ...] = (64, 32, 64),
  442. depths: Tuple[int, ...] = (3, 4, 10, 3),
  443. strides: Tuple[int, ...] = (1, 2, 2, 2),
  444. sr_ratios: Tuple[int, ...] = (8, 4, 2, 1),
  445. drop_path_rate: float = 0.1,
  446. attn_drop_rate: float = 0.,
  447. drop_rate: float = 0.,
  448. head_dim: int = 32,
  449. mix_block_ratio: float = 0.75,
  450. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  451. act_layer: Optional[Type[nn.Module]] = None,
  452. device=None,
  453. dtype=None,
  454. ):
  455. super().__init__()
  456. dd = {'device': device, 'dtype': dtype}
  457. self.grad_checkpointing = False
  458. self.num_classes = num_classes
  459. self.in_chans = in_chans
  460. norm_layer = get_norm_layer(norm_layer)
  461. if act_layer is None:
  462. act_layer = partial(nn.ReLU, inplace=True)
  463. else:
  464. act_layer = get_act_layer(act_layer)
  465. self.stage_out_chs = [
  466. [96] * (depths[0]),
  467. [192] * (depths[1] - 1) + [256],
  468. [384, 384, 384, 384, 512] * (depths[2] // 5),
  469. [768] * (depths[3] - 1) + [1024]
  470. ]
  471. self.feature_info = [dict(
  472. num_chs=sc[-1],
  473. reduction=2**(i + 2),
  474. module=f'stages.{i}'
  475. ) for i, sc in enumerate(self.stage_out_chs)]
  476. # Next Hybrid Strategy
  477. self.stage_block_types = [
  478. [NextConvBlock] * depths[0],
  479. [NextConvBlock] * (depths[1] - 1) + [NextTransformerBlock],
  480. [NextConvBlock, NextConvBlock, NextConvBlock, NextConvBlock, NextTransformerBlock] * (depths[2] // 5),
  481. [NextConvBlock] * (depths[3] - 1) + [NextTransformerBlock]]
  482. self.stem = nn.Sequential(
  483. ConvNormAct(
  484. in_chans, stem_chs[0], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer, **dd),
  485. ConvNormAct(
  486. stem_chs[0], stem_chs[1], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer, **dd),
  487. ConvNormAct(
  488. stem_chs[1], stem_chs[2], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer, **dd),
  489. ConvNormAct(
  490. stem_chs[2], stem_chs[2], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer, **dd),
  491. )
  492. in_chs = out_chs = stem_chs[-1]
  493. stages = []
  494. idx = 0
  495. dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  496. for stage_idx in range(len(depths)):
  497. stage = NextStage(
  498. in_chs=in_chs,
  499. block_chs=self.stage_out_chs[stage_idx],
  500. block_types=self.stage_block_types[stage_idx],
  501. stride=strides[stage_idx],
  502. sr_ratio=sr_ratios[stage_idx],
  503. mix_block_ratio=mix_block_ratio,
  504. head_dim=head_dim,
  505. drop=drop_rate,
  506. attn_drop=attn_drop_rate,
  507. drop_path=dpr[stage_idx],
  508. norm_layer=norm_layer,
  509. act_layer=act_layer,
  510. **dd,
  511. )
  512. in_chs = out_chs = self.stage_out_chs[stage_idx][-1]
  513. stages += [stage]
  514. idx += depths[stage_idx]
  515. self.num_features = self.head_hidden_size = out_chs
  516. self.stages = nn.Sequential(*stages)
  517. self.norm = norm_layer(out_chs, **dd)
  518. self.head = ClassifierHead(pool_type=global_pool, in_features=out_chs, num_classes=num_classes, **dd)
  519. self.stage_out_idx = [sum(depths[:idx + 1]) - 1 for idx in range(len(depths))]
  520. self._initialize_weights()
  521. def _initialize_weights(self):
  522. for n, m in self.named_modules():
  523. if isinstance(m, nn.Linear):
  524. trunc_normal_(m.weight, std=.02)
  525. if hasattr(m, 'bias') and m.bias is not None:
  526. nn.init.constant_(m.bias, 0)
  527. elif isinstance(m, nn.Conv2d):
  528. trunc_normal_(m.weight, std=.02)
  529. if hasattr(m, 'bias') and m.bias is not None:
  530. nn.init.constant_(m.bias, 0)
  531. @torch.jit.ignore
  532. def group_matcher(self, coarse=False):
  533. return dict(
  534. stem=r'^stem', # stem and embed
  535. blocks=r'^stages\.(\d+)' if coarse else [
  536. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  537. (r'^norm', (99999,)),
  538. ]
  539. )
  540. @torch.jit.ignore
  541. def set_grad_checkpointing(self, enable=True):
  542. self.grad_checkpointing = enable
  543. for stage in self.stages:
  544. stage.set_grad_checkpointing(enable=enable)
  545. @torch.jit.ignore
  546. def get_classifier(self) -> nn.Module:
  547. return self.head.fc
  548. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  549. self.num_classes = num_classes
  550. self.head.reset(num_classes, pool_type=global_pool)
  551. def forward_intermediates(
  552. self,
  553. x: torch.Tensor,
  554. indices: Optional[Union[int, List[int]]] = None,
  555. norm: bool = False,
  556. stop_early: bool = False,
  557. output_fmt: str = 'NCHW',
  558. intermediates_only: bool = False,
  559. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  560. """ Forward features that returns intermediates.
  561. Args:
  562. x: Input image tensor
  563. indices: Take last n blocks if int, all if None, select matching indices if sequence
  564. norm: Apply norm layer to compatible intermediates
  565. stop_early: Stop iterating over blocks when last desired intermediate hit
  566. output_fmt: Shape of intermediate feature outputs
  567. intermediates_only: Only return intermediate features
  568. Returns:
  569. """
  570. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  571. intermediates = []
  572. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  573. # forward pass
  574. x = self.stem(x)
  575. last_idx = len(self.stages) - 1
  576. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  577. stages = self.stages
  578. else:
  579. stages = self.stages[:max_index + 1]
  580. for feat_idx, stage in enumerate(stages):
  581. if self.grad_checkpointing and not torch.jit.is_scripting():
  582. x = checkpoint(stage, x)
  583. else:
  584. x = stage(x)
  585. if feat_idx in take_indices:
  586. if feat_idx == last_idx:
  587. x_inter = self.norm(x) if norm else x
  588. intermediates.append(x_inter)
  589. else:
  590. intermediates.append(x)
  591. if intermediates_only:
  592. return intermediates
  593. if feat_idx == last_idx:
  594. x = self.norm(x)
  595. return x, intermediates
  596. def prune_intermediate_layers(
  597. self,
  598. indices: Union[int, List[int]] = 1,
  599. prune_norm: bool = False,
  600. prune_head: bool = True,
  601. ):
  602. """ Prune layers not required for specified intermediates.
  603. """
  604. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  605. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  606. if prune_norm:
  607. self.norm = nn.Identity()
  608. if prune_head:
  609. self.reset_classifier(0, '')
  610. return take_indices
  611. def forward_features(self, x):
  612. x = self.stem(x)
  613. if self.grad_checkpointing and not torch.jit.is_scripting():
  614. x = checkpoint_seq(self.stages, x)
  615. else:
  616. x = self.stages(x)
  617. x = self.norm(x)
  618. return x
  619. def forward_head(self, x, pre_logits: bool = False):
  620. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  621. def forward(self, x):
  622. x = self.forward_features(x)
  623. x = self.forward_head(x)
  624. return x
  625. def checkpoint_filter_fn(state_dict, model):
  626. """ Remap original checkpoints -> timm """
  627. if 'head.fc.weight' in state_dict:
  628. return state_dict # non-original
  629. D = model.state_dict()
  630. out_dict = {}
  631. # remap originals based on order
  632. for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
  633. out_dict[ka] = vb
  634. return out_dict
  635. def _create_nextvit(variant, pretrained=False, **kwargs):
  636. default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
  637. out_indices = kwargs.pop('out_indices', default_out_indices)
  638. model = build_model_with_cfg(
  639. NextViT,
  640. variant,
  641. pretrained,
  642. pretrained_filter_fn=checkpoint_filter_fn,
  643. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  644. **kwargs)
  645. return model
  646. def _cfg(url='', **kwargs):
  647. return {
  648. 'url': url,
  649. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  650. 'crop_pct': 0.95, 'interpolation': 'bicubic',
  651. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  652. 'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
  653. 'license': 'apache-2.0',
  654. **kwargs
  655. }
  656. default_cfgs = generate_default_cfgs({
  657. 'nextvit_small.bd_in1k': _cfg(
  658. hf_hub_id='timm/',
  659. ),
  660. 'nextvit_base.bd_in1k': _cfg(
  661. hf_hub_id='timm/',
  662. ),
  663. 'nextvit_large.bd_in1k': _cfg(
  664. hf_hub_id='timm/',
  665. ),
  666. 'nextvit_small.bd_in1k_384': _cfg(
  667. hf_hub_id='timm/',
  668. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  669. ),
  670. 'nextvit_base.bd_in1k_384': _cfg(
  671. hf_hub_id='timm/',
  672. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  673. ),
  674. 'nextvit_large.bd_in1k_384': _cfg(
  675. hf_hub_id='timm/',
  676. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  677. ),
  678. 'nextvit_small.bd_ssld_6m_in1k': _cfg(
  679. hf_hub_id='timm/',
  680. ),
  681. 'nextvit_base.bd_ssld_6m_in1k': _cfg(
  682. hf_hub_id='timm/',
  683. ),
  684. 'nextvit_large.bd_ssld_6m_in1k': _cfg(
  685. hf_hub_id='timm/',
  686. ),
  687. 'nextvit_small.bd_ssld_6m_in1k_384': _cfg(
  688. hf_hub_id='timm/',
  689. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  690. ),
  691. 'nextvit_base.bd_ssld_6m_in1k_384': _cfg(
  692. hf_hub_id='timm/',
  693. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  694. ),
  695. 'nextvit_large.bd_ssld_6m_in1k_384': _cfg(
  696. hf_hub_id='timm/',
  697. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  698. ),
  699. })
  700. @register_model
  701. def nextvit_small(pretrained=False, **kwargs):
  702. model_args = dict(depths=(3, 4, 10, 3), drop_path_rate=0.1)
  703. model = _create_nextvit(
  704. 'nextvit_small', pretrained=pretrained, **dict(model_args, **kwargs))
  705. return model
  706. @register_model
  707. def nextvit_base(pretrained=False, **kwargs):
  708. model_args = dict(depths=(3, 4, 20, 3), drop_path_rate=0.2)
  709. model = _create_nextvit(
  710. 'nextvit_base', pretrained=pretrained, **dict(model_args, **kwargs))
  711. return model
  712. @register_model
  713. def nextvit_large(pretrained=False, **kwargs):
  714. model_args = dict(depths=(3, 4, 30, 3), drop_path_rate=0.2)
  715. model = _create_nextvit(
  716. 'nextvit_large', pretrained=pretrained, **dict(model_args, **kwargs))
  717. return model