inception_next.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. """
  2. InceptionNeXt paper: https://arxiv.org/abs/2303.16900
  3. Original implementation & weights from: https://github.com/sail-sg/inceptionnext
  4. """
  5. from functools import partial
  6. from typing import List, Optional, Tuple, Union, Type
  7. import torch
  8. import torch.nn as nn
  9. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  10. from timm.layers import trunc_normal_, DropPath, calculate_drop_path_rates, to_2tuple, get_padding, SelectAdaptivePool2d
  11. from ._builder import build_model_with_cfg
  12. from ._features import feature_take_indices
  13. from ._manipulate import checkpoint_seq
  14. from ._registry import register_model, generate_default_cfgs
  15. __all__ = ['MetaNeXt']
  16. class InceptionDWConv2d(nn.Module):
  17. """ Inception depthwise convolution
  18. """
  19. def __init__(
  20. self,
  21. in_chs: int,
  22. square_kernel_size: int = 3,
  23. band_kernel_size: int = 11,
  24. branch_ratio: float = 0.125,
  25. dilation: int = 1,
  26. device=None,
  27. dtype=None,
  28. ):
  29. dd = {'device': device, 'dtype': dtype}
  30. super().__init__()
  31. gc = int(in_chs * branch_ratio) # channel numbers of a convolution branch
  32. square_padding = get_padding(square_kernel_size, dilation=dilation)
  33. band_padding = get_padding(band_kernel_size, dilation=dilation)
  34. self.dwconv_hw = nn.Conv2d(
  35. gc, gc, square_kernel_size,
  36. padding=square_padding, dilation=dilation, groups=gc, **dd)
  37. self.dwconv_w = nn.Conv2d(
  38. gc, gc, (1, band_kernel_size),
  39. padding=(0, band_padding), dilation=(1, dilation), groups=gc, **dd)
  40. self.dwconv_h = nn.Conv2d(
  41. gc, gc, (band_kernel_size, 1),
  42. padding=(band_padding, 0), dilation=(dilation, 1), groups=gc, **dd)
  43. self.split_indexes = (in_chs - 3 * gc, gc, gc, gc)
  44. def forward(self, x):
  45. x_id, x_hw, x_w, x_h = torch.split(x, self.split_indexes, dim=1)
  46. return torch.cat((
  47. x_id,
  48. self.dwconv_hw(x_hw),
  49. self.dwconv_w(x_w),
  50. self.dwconv_h(x_h)
  51. ), dim=1,
  52. )
  53. class ConvMlp(nn.Module):
  54. """ MLP using 1x1 convs that keeps spatial dims
  55. copied from timm: https://github.com/huggingface/pytorch-image-models/blob/v0.6.11/timm/models/layers/mlp.py
  56. """
  57. def __init__(
  58. self,
  59. in_features: int,
  60. hidden_features: Optional[int] = None,
  61. out_features: Optional[int] = None,
  62. act_layer: Type[nn.Module] = nn.ReLU,
  63. norm_layer: Optional[Type[nn.Module]] = None,
  64. bias: bool = True,
  65. drop: float = 0.,
  66. device=None,
  67. dtype=None,
  68. ):
  69. dd = {'device': device, 'dtype': dtype}
  70. super().__init__()
  71. out_features = out_features or in_features
  72. hidden_features = hidden_features or in_features
  73. bias = to_2tuple(bias)
  74. self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0], **dd)
  75. self.norm = norm_layer(hidden_features, **dd) if norm_layer else nn.Identity()
  76. self.act = act_layer()
  77. self.drop = nn.Dropout(drop)
  78. self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1], **dd)
  79. def forward(self, x):
  80. x = self.fc1(x)
  81. x = self.norm(x)
  82. x = self.act(x)
  83. x = self.drop(x)
  84. x = self.fc2(x)
  85. return x
  86. class MlpClassifierHead(nn.Module):
  87. """ MLP classification head
  88. """
  89. def __init__(
  90. self,
  91. in_features: int,
  92. num_classes: int = 1000,
  93. pool_type: str = 'avg',
  94. mlp_ratio: float = 3,
  95. act_layer: Type[nn.Module] = nn.GELU,
  96. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  97. drop: float = 0.,
  98. bias: bool = True,
  99. device=None,
  100. dtype=None,
  101. ):
  102. dd = {'device': device, 'dtype': dtype}
  103. super().__init__()
  104. self.use_conv = False
  105. self.in_features = in_features
  106. self.num_features = hidden_features = int(mlp_ratio * in_features)
  107. assert pool_type, 'Cannot disable pooling'
  108. self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True)
  109. self.fc1 = nn.Linear(in_features * self.global_pool.feat_mult(), hidden_features, bias=bias, **dd)
  110. self.act = act_layer()
  111. self.norm = norm_layer(hidden_features, **dd)
  112. self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias, **dd)
  113. self.drop = nn.Dropout(drop)
  114. def reset(self, num_classes: int, pool_type: Optional[str] = None):
  115. if pool_type is not None:
  116. assert pool_type, 'Cannot disable pooling'
  117. self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True)
  118. self.fc2 = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  119. def forward(self, x, pre_logits: bool = False):
  120. x = self.global_pool(x)
  121. x = self.fc1(x)
  122. x = self.act(x)
  123. x = self.norm(x)
  124. x = self.drop(x)
  125. return x if pre_logits else self.fc2(x)
  126. class MetaNeXtBlock(nn.Module):
  127. """ MetaNeXtBlock Block
  128. Args:
  129. dim (int): Number of input channels.
  130. drop_path (float): Stochastic depth rate. Default: 0.0
  131. ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
  132. """
  133. def __init__(
  134. self,
  135. dim: int,
  136. dilation: int = 1,
  137. token_mixer: Type[nn.Module] = InceptionDWConv2d,
  138. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  139. mlp_layer: Type[nn.Module] = ConvMlp,
  140. mlp_ratio: float = 4,
  141. act_layer: Type[nn.Module] = nn.GELU,
  142. ls_init_value: float = 1e-6,
  143. drop_path: float = 0.,
  144. device=None,
  145. dtype=None,
  146. ):
  147. dd = {'device': device, 'dtype': dtype}
  148. super().__init__()
  149. self.token_mixer = token_mixer(dim, dilation=dilation, **dd)
  150. self.norm = norm_layer(dim, **dd)
  151. self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=act_layer, **dd)
  152. self.gamma = nn.Parameter(ls_init_value * torch.ones(dim, **dd)) if ls_init_value else None
  153. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  154. def forward(self, x):
  155. shortcut = x
  156. x = self.token_mixer(x)
  157. x = self.norm(x)
  158. x = self.mlp(x)
  159. if self.gamma is not None:
  160. x = x.mul(self.gamma.reshape(1, -1, 1, 1))
  161. x = self.drop_path(x) + shortcut
  162. return x
  163. class MetaNeXtStage(nn.Module):
  164. def __init__(
  165. self,
  166. in_chs: int,
  167. out_chs: int,
  168. stride: int = 2,
  169. depth: int = 2,
  170. dilation: Tuple[int, int] = (1, 1),
  171. drop_path_rates: Optional[List[float]] = None,
  172. ls_init_value: float = 1.0,
  173. token_mixer: Type[nn.Module] = InceptionDWConv2d,
  174. act_layer: Type[nn.Module] = nn.GELU,
  175. norm_layer: Optional[Type[nn.Module]] = None,
  176. mlp_ratio: float = 4,
  177. device=None,
  178. dtype=None,
  179. ):
  180. dd = {'device': device, 'dtype': dtype}
  181. super().__init__()
  182. self.grad_checkpointing = False
  183. if stride > 1 or dilation[0] != dilation[1]:
  184. self.downsample = nn.Sequential(
  185. norm_layer(in_chs, **dd),
  186. nn.Conv2d(
  187. in_chs,
  188. out_chs,
  189. kernel_size=2,
  190. stride=stride,
  191. dilation=dilation[0],
  192. **dd,
  193. ),
  194. )
  195. else:
  196. self.downsample = nn.Identity()
  197. drop_path_rates = drop_path_rates or [0.] * depth
  198. stage_blocks = []
  199. for i in range(depth):
  200. stage_blocks.append(MetaNeXtBlock(
  201. dim=out_chs,
  202. dilation=dilation[1],
  203. drop_path=drop_path_rates[i],
  204. ls_init_value=ls_init_value,
  205. token_mixer=token_mixer,
  206. act_layer=act_layer,
  207. norm_layer=norm_layer,
  208. mlp_ratio=mlp_ratio,
  209. **dd,
  210. ))
  211. self.blocks = nn.Sequential(*stage_blocks)
  212. def forward(self, x):
  213. x = self.downsample(x)
  214. if self.grad_checkpointing and not torch.jit.is_scripting():
  215. x = checkpoint_seq(self.blocks, x)
  216. else:
  217. x = self.blocks(x)
  218. return x
  219. class MetaNeXt(nn.Module):
  220. r""" MetaNeXt
  221. A PyTorch impl of : `InceptionNeXt: When Inception Meets ConvNeXt` - https://arxiv.org/abs/2303.16900
  222. Args:
  223. in_chans (int): Number of input image channels. Default: 3
  224. num_classes (int): Number of classes for classification head. Default: 1000
  225. depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 9, 3)
  226. dims (tuple(int)): Feature dimension at each stage. Default: (96, 192, 384, 768)
  227. token_mixers: Token mixer function. Default: nn.Identity
  228. norm_layer: Normalization layer. Default: nn.BatchNorm2d
  229. act_layer: Activation function for MLP. Default: nn.GELU
  230. mlp_ratios (int or tuple(int)): MLP ratios. Default: (4, 4, 4, 3)
  231. drop_rate (float): Head dropout rate
  232. drop_path_rate (float): Stochastic depth rate. Default: 0.
  233. ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
  234. """
  235. def __init__(
  236. self,
  237. in_chans: int = 3,
  238. num_classes: int = 1000,
  239. global_pool: str = 'avg',
  240. output_stride: int = 32,
  241. depths: Tuple[int, ...] = (3, 3, 9, 3),
  242. dims: Tuple[int, ...] = (96, 192, 384, 768),
  243. token_mixers: Union[Type[nn.Module], List[Type[nn.Module]]] = InceptionDWConv2d,
  244. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  245. act_layer: Type[nn.Module] = nn.GELU,
  246. mlp_ratios: Union[int, Tuple[int, ...]] = (4, 4, 4, 3),
  247. drop_rate: float = 0.,
  248. drop_path_rate: float = 0.,
  249. ls_init_value: float = 1e-6,
  250. device=None,
  251. dtype=None,
  252. ):
  253. super().__init__()
  254. dd = {'device': device, 'dtype': dtype}
  255. num_stage = len(depths)
  256. if not isinstance(token_mixers, (list, tuple)):
  257. token_mixers = [token_mixers] * num_stage
  258. if not isinstance(mlp_ratios, (list, tuple)):
  259. mlp_ratios = [mlp_ratios] * num_stage
  260. self.num_classes = num_classes
  261. self.in_chans = in_chans
  262. self.global_pool = global_pool
  263. self.drop_rate = drop_rate
  264. self.feature_info = []
  265. self.stem = nn.Sequential(
  266. nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, **dd),
  267. norm_layer(dims[0], **dd)
  268. )
  269. dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  270. prev_chs = dims[0]
  271. curr_stride = 4
  272. dilation = 1
  273. # feature resolution stages, each consisting of multiple residual blocks
  274. self.stages = nn.Sequential()
  275. for i in range(num_stage):
  276. stride = 2 if curr_stride == 2 or i > 0 else 1
  277. if curr_stride >= output_stride and stride > 1:
  278. dilation *= stride
  279. stride = 1
  280. curr_stride *= stride
  281. first_dilation = 1 if dilation in (1, 2) else 2
  282. out_chs = dims[i]
  283. self.stages.append(MetaNeXtStage(
  284. prev_chs,
  285. out_chs,
  286. stride=stride if i > 0 else 1,
  287. dilation=(first_dilation, dilation),
  288. depth=depths[i],
  289. drop_path_rates=dp_rates[i],
  290. ls_init_value=ls_init_value,
  291. act_layer=act_layer,
  292. token_mixer=token_mixers[i],
  293. norm_layer=norm_layer,
  294. mlp_ratio=mlp_ratios[i],
  295. **dd,
  296. ))
  297. prev_chs = out_chs
  298. self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
  299. self.num_features = prev_chs
  300. self.head = MlpClassifierHead(self.num_features, num_classes, pool_type=self.global_pool, drop=drop_rate, **dd)
  301. self.head_hidden_size = self.head.num_features
  302. self.apply(self._init_weights)
  303. def _init_weights(self, m):
  304. if isinstance(m, (nn.Conv2d, nn.Linear)):
  305. trunc_normal_(m.weight, std=.02)
  306. if m.bias is not None:
  307. nn.init.constant_(m.bias, 0)
  308. @torch.jit.ignore
  309. def group_matcher(self, coarse=False):
  310. return dict(
  311. stem=r'^stem',
  312. blocks=r'^stages\.(\d+)' if coarse else [
  313. (r'^stages\.(\d+)\.downsample', (0,)), # blocks
  314. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  315. ]
  316. )
  317. @torch.jit.ignore
  318. def get_classifier(self) -> nn.Module:
  319. return self.head.fc2
  320. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  321. self.num_classes = num_classes
  322. self.head.reset(num_classes, global_pool)
  323. @torch.jit.ignore
  324. def set_grad_checkpointing(self, enable=True):
  325. for s in self.stages:
  326. s.grad_checkpointing = enable
  327. @torch.jit.ignore
  328. def no_weight_decay(self):
  329. return set()
  330. def forward_intermediates(
  331. self,
  332. x: torch.Tensor,
  333. indices: Optional[Union[int, List[int]]] = None,
  334. norm: bool = False,
  335. stop_early: bool = False,
  336. output_fmt: str = 'NCHW',
  337. intermediates_only: bool = False,
  338. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  339. """ Forward features that returns intermediates.
  340. Args:
  341. x: Input image tensor
  342. indices: Take last n blocks if int, all if None, select matching indices if sequence
  343. norm: Apply norm layer to compatible intermediates
  344. stop_early: Stop iterating over blocks when last desired intermediate hit
  345. output_fmt: Shape of intermediate feature outputs
  346. intermediates_only: Only return intermediate features
  347. Returns:
  348. """
  349. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  350. intermediates = []
  351. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  352. # forward pass
  353. x = self.stem(x)
  354. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  355. stages = self.stages
  356. else:
  357. stages = self.stages[:max_index + 1]
  358. for feat_idx, stage in enumerate(stages):
  359. x = stage(x)
  360. if feat_idx in take_indices:
  361. intermediates.append(x)
  362. if intermediates_only:
  363. return intermediates
  364. return x, intermediates
  365. def prune_intermediate_layers(
  366. self,
  367. indices: Union[int, List[int]] = 1,
  368. prune_norm: bool = False,
  369. prune_head: bool = True,
  370. ):
  371. """ Prune layers not required for specified intermediates.
  372. """
  373. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  374. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  375. if prune_head:
  376. self.reset_classifier(0, 'avg')
  377. return take_indices
  378. def forward_features(self, x):
  379. x = self.stem(x)
  380. x = self.stages(x)
  381. return x
  382. def forward_head(self, x, pre_logits: bool = False):
  383. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  384. def forward(self, x):
  385. x = self.forward_features(x)
  386. x = self.forward_head(x)
  387. return x
  388. def _cfg(url='', **kwargs):
  389. return {
  390. 'url': url,
  391. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  392. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  393. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  394. 'first_conv': 'stem.0', 'classifier': 'head.fc2',
  395. 'license': 'apache-2.0',
  396. **kwargs
  397. }
  398. default_cfgs = generate_default_cfgs({
  399. 'inception_next_atto.sail_in1k': _cfg(
  400. hf_hub_id='timm/',
  401. # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_atto.pth',
  402. ),
  403. 'inception_next_tiny.sail_in1k': _cfg(
  404. hf_hub_id='timm/',
  405. # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth',
  406. ),
  407. 'inception_next_small.sail_in1k': _cfg(
  408. hf_hub_id='timm/',
  409. # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth',
  410. ),
  411. 'inception_next_base.sail_in1k': _cfg(
  412. hf_hub_id='timm/',
  413. # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth',
  414. crop_pct=0.95,
  415. ),
  416. 'inception_next_base.sail_in1k_384': _cfg(
  417. hf_hub_id='timm/',
  418. # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth',
  419. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  420. ),
  421. })
  422. def _create_inception_next(variant, pretrained=False, **kwargs):
  423. model = build_model_with_cfg(
  424. MetaNeXt, variant, pretrained,
  425. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  426. **kwargs,
  427. )
  428. return model
  429. @register_model
  430. def inception_next_atto(pretrained=False, **kwargs):
  431. model_args = dict(
  432. depths=(2, 2, 6, 2), dims=(40, 80, 160, 320),
  433. token_mixers=partial(InceptionDWConv2d, band_kernel_size=9, branch_ratio=0.25)
  434. )
  435. return _create_inception_next('inception_next_atto', pretrained=pretrained, **dict(model_args, **kwargs))
  436. @register_model
  437. def inception_next_tiny(pretrained=False, **kwargs):
  438. model_args = dict(
  439. depths=(3, 3, 9, 3), dims=(96, 192, 384, 768),
  440. token_mixers=InceptionDWConv2d,
  441. )
  442. return _create_inception_next('inception_next_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
  443. @register_model
  444. def inception_next_small(pretrained=False, **kwargs):
  445. model_args = dict(
  446. depths=(3, 3, 27, 3), dims=(96, 192, 384, 768),
  447. token_mixers=InceptionDWConv2d,
  448. )
  449. return _create_inception_next('inception_next_small', pretrained=pretrained, **dict(model_args, **kwargs))
  450. @register_model
  451. def inception_next_base(pretrained=False, **kwargs):
  452. model_args = dict(
  453. depths=(3, 3, 27, 3), dims=(128, 256, 512, 1024),
  454. token_mixers=InceptionDWConv2d,
  455. )
  456. return _create_inception_next('inception_next_base', pretrained=pretrained, **dict(model_args, **kwargs))