rdnet.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. """
  2. RDNet
  3. Copyright (c) 2024-present NAVER Cloud Corp.
  4. Apache-2.0
  5. """
  6. from functools import partial
  7. from typing import List, Optional, Tuple, Union, Callable, Type
  8. import torch
  9. import torch.nn as nn
  10. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  11. from timm.layers import DropPath, calculate_drop_path_rates, NormMlpClassifierHead, ClassifierHead, EffectiveSEModule, \
  12. make_divisible, get_act_layer, get_norm_layer
  13. from ._builder import build_model_with_cfg
  14. from ._features import feature_take_indices
  15. from ._manipulate import named_apply
  16. from ._registry import register_model, generate_default_cfgs
  17. __all__ = ["RDNet"]
  18. class Block(nn.Module):
  19. def __init__(
  20. self,
  21. in_chs: int,
  22. inter_chs: int,
  23. out_chs: int,
  24. norm_layer: Type[nn.Module],
  25. act_layer: Type[nn.Module],
  26. device=None,
  27. dtype=None,
  28. ):
  29. dd = {'device': device, 'dtype': dtype}
  30. super().__init__()
  31. self.layers = nn.Sequential(
  32. nn.Conv2d(in_chs, in_chs, groups=in_chs, kernel_size=7, stride=1, padding=3, **dd),
  33. norm_layer(in_chs, **dd),
  34. nn.Conv2d(in_chs, inter_chs, kernel_size=1, stride=1, padding=0, **dd),
  35. act_layer(),
  36. nn.Conv2d(inter_chs, out_chs, kernel_size=1, stride=1, padding=0, **dd),
  37. )
  38. def forward(self, x):
  39. return self.layers(x)
  40. class BlockESE(nn.Module):
  41. def __init__(
  42. self,
  43. in_chs: int,
  44. inter_chs: int,
  45. out_chs: int,
  46. norm_layer: Type[nn.Module],
  47. act_layer: Type[nn.Module],
  48. device=None,
  49. dtype=None,
  50. ):
  51. dd = {'device': device, 'dtype': dtype}
  52. super().__init__()
  53. self.layers = nn.Sequential(
  54. nn.Conv2d(in_chs, in_chs, groups=in_chs, kernel_size=7, stride=1, padding=3, **dd),
  55. norm_layer(in_chs, **dd),
  56. nn.Conv2d(in_chs, inter_chs, kernel_size=1, stride=1, padding=0, **dd),
  57. act_layer(),
  58. nn.Conv2d(inter_chs, out_chs, kernel_size=1, stride=1, padding=0, **dd),
  59. EffectiveSEModule(out_chs, **dd),
  60. )
  61. def forward(self, x):
  62. return self.layers(x)
  63. def _get_block_type(block: str):
  64. block = block.lower().strip()
  65. if block == "block":
  66. return Block
  67. elif block == "blockese":
  68. return BlockESE
  69. else:
  70. assert False, f"Unknown block type ({block})."
  71. class DenseBlock(nn.Module):
  72. def __init__(
  73. self,
  74. num_input_features: int = 64,
  75. growth_rate: int = 64,
  76. bottleneck_width_ratio: float = 4.0,
  77. drop_path_rate: float = 0.0,
  78. drop_rate: float = 0.0,
  79. rand_gather_step_prob: float = 0.0,
  80. block_idx: int = 0,
  81. block_type: str = "Block",
  82. ls_init_value: float = 1e-6,
  83. norm_layer: Type[nn.Module] = nn.LayerNorm,
  84. act_layer: Type[nn.Module] = nn.GELU,
  85. device=None,
  86. dtype=None,
  87. ):
  88. dd = {'device': device, 'dtype': dtype}
  89. super().__init__()
  90. self.drop_rate = drop_rate
  91. self.drop_path_rate = drop_path_rate
  92. self.rand_gather_step_prob = rand_gather_step_prob
  93. self.block_idx = block_idx
  94. self.growth_rate = growth_rate
  95. self.gamma = nn.Parameter(ls_init_value * torch.ones(growth_rate, **dd)) if ls_init_value > 0 else None
  96. growth_rate = int(growth_rate)
  97. inter_chs = int(num_input_features * bottleneck_width_ratio / 8) * 8
  98. self.drop_path = DropPath(drop_path_rate)
  99. self.layers = _get_block_type(block_type)(
  100. in_chs=num_input_features,
  101. inter_chs=inter_chs,
  102. out_chs=growth_rate,
  103. norm_layer=norm_layer,
  104. act_layer=act_layer,
  105. **dd,
  106. )
  107. def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
  108. x = torch.cat(x, 1)
  109. x = self.layers(x)
  110. if self.gamma is not None:
  111. x = x.mul(self.gamma.reshape(1, -1, 1, 1))
  112. x = self.drop_path(x)
  113. return x
  114. class DenseStage(nn.Sequential):
  115. def __init__(
  116. self,
  117. num_block: int,
  118. num_input_features: int,
  119. drop_path_rates: List[float],
  120. growth_rate: int,
  121. device=None,
  122. dtype=None,
  123. **kwargs,
  124. ):
  125. dd = {'device': device, 'dtype': dtype}
  126. super().__init__()
  127. for i in range(num_block):
  128. layer = DenseBlock(
  129. num_input_features=num_input_features,
  130. growth_rate=growth_rate,
  131. drop_path_rate=drop_path_rates[i],
  132. block_idx=i,
  133. **dd,
  134. **kwargs,
  135. )
  136. num_input_features += growth_rate
  137. self.add_module(f"dense_block{i}", layer)
  138. self.num_out_features = num_input_features
  139. def forward(self, init_feature: torch.Tensor) -> torch.Tensor:
  140. features = [init_feature]
  141. for module in self:
  142. new_feature = module(features)
  143. features.append(new_feature)
  144. return torch.cat(features, 1)
  145. class RDNet(nn.Module):
  146. def __init__(
  147. self,
  148. in_chans: int = 3, # timm option [--in-chans]
  149. num_classes: int = 1000, # timm option [--num-classes]
  150. global_pool: str = 'avg', # timm option [--gp]
  151. growth_rates: Union[List[int], Tuple[int]] = (64, 104, 128, 128, 128, 128, 224),
  152. num_blocks_list: Union[List[int], Tuple[int]] = (3, 3, 3, 3, 3, 3, 3),
  153. block_type: Union[List[int], Tuple[int]] = ("Block",) * 2 + ("BlockESE",) * 5,
  154. is_downsample_block: Union[List[bool], Tuple[bool]] = (None, True, True, False, False, False, True),
  155. bottleneck_width_ratio: float = 4.0,
  156. transition_compression_ratio: float = 0.5,
  157. ls_init_value: float = 1e-6,
  158. stem_type: str = 'patch',
  159. patch_size: int = 4,
  160. num_init_features: int = 64,
  161. head_init_scale: float = 1.,
  162. head_norm_first: bool = False,
  163. conv_bias: bool = True,
  164. act_layer: Union[str, Callable] = 'gelu',
  165. norm_layer: str = "layernorm2d",
  166. norm_eps: Optional[float] = None,
  167. drop_rate: float = 0.0, # timm option [--drop: dropout ratio]
  168. drop_path_rate: float = 0.0, # timm option [--drop-path: drop-path ratio]
  169. device=None,
  170. dtype=None,
  171. ):
  172. """
  173. Args:
  174. in_chans: Number of input image channels.
  175. num_classes: Number of classes for classification head.
  176. global_pool: Global pooling type.
  177. growth_rates: Growth rate at each stage.
  178. num_blocks_list: Number of blocks at each stage.
  179. is_downsample_block: Whether to downsample at each stage.
  180. bottleneck_width_ratio: Bottleneck width ratio (similar to mlp expansion ratio).
  181. transition_compression_ratio: Channel compression ratio of transition layers.
  182. ls_init_value: Init value for Layer Scale, disabled if None.
  183. stem_type: Type of stem.
  184. patch_size: Stem patch size for patch stem.
  185. num_init_features: Number of features of stem.
  186. head_init_scale: Init scaling value for classifier weights and biases.
  187. head_norm_first: Apply normalization before global pool + head.
  188. conv_bias: Use bias layers w/ all convolutions.
  189. act_layer: Activation layer type.
  190. norm_layer: Normalization layer type.
  191. norm_eps: Small value to avoid division by zero in normalization.
  192. drop_rate: Head pre-classifier dropout rate.
  193. drop_path_rate: Stochastic depth drop rate.
  194. """
  195. super().__init__()
  196. dd = {'device': device, 'dtype': dtype}
  197. assert len(growth_rates) == len(num_blocks_list) == len(is_downsample_block)
  198. act_layer = get_act_layer(act_layer)
  199. norm_layer = get_norm_layer(norm_layer)
  200. if norm_eps is not None:
  201. norm_layer = partial(norm_layer, eps=norm_eps)
  202. self.num_classes = num_classes
  203. self.in_chans = in_chans
  204. self.drop_rate = drop_rate
  205. # stem
  206. assert stem_type in ('patch', 'overlap', 'overlap_tiered')
  207. if stem_type == 'patch':
  208. # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
  209. self.stem = nn.Sequential(
  210. nn.Conv2d(in_chans, num_init_features, kernel_size=patch_size, stride=patch_size, bias=conv_bias, **dd),
  211. norm_layer(num_init_features, **dd),
  212. )
  213. stem_stride = patch_size
  214. else:
  215. mid_chs = make_divisible(num_init_features // 2) if 'tiered' in stem_type else num_init_features
  216. self.stem = nn.Sequential(
  217. nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd),
  218. nn.Conv2d(mid_chs, num_init_features, kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd),
  219. norm_layer(num_init_features, **dd),
  220. )
  221. stem_stride = 4
  222. # features
  223. self.feature_info = []
  224. self.num_stages = len(growth_rates)
  225. curr_stride = stem_stride
  226. num_features = num_init_features
  227. dp_rates = calculate_drop_path_rates(drop_path_rate, num_blocks_list, stagewise=True)
  228. dense_stages = []
  229. for i in range(self.num_stages):
  230. dense_stage_layers = []
  231. if i != 0:
  232. compressed_num_features = int(num_features * transition_compression_ratio / 8) * 8
  233. k_size = stride = 1
  234. if is_downsample_block[i]:
  235. curr_stride *= 2
  236. k_size = stride = 2
  237. dense_stage_layers.append(norm_layer(num_features, **dd))
  238. dense_stage_layers.append(nn.Conv2d(
  239. num_features,
  240. compressed_num_features,
  241. kernel_size=k_size,
  242. stride=stride,
  243. padding=0,
  244. **dd,
  245. ))
  246. num_features = compressed_num_features
  247. stage = DenseStage(
  248. num_block=num_blocks_list[i],
  249. num_input_features=num_features,
  250. growth_rate=growth_rates[i],
  251. bottleneck_width_ratio=bottleneck_width_ratio,
  252. drop_rate=drop_rate,
  253. drop_path_rates=dp_rates[i],
  254. ls_init_value=ls_init_value,
  255. block_type=block_type[i],
  256. norm_layer=norm_layer,
  257. act_layer=act_layer,
  258. **dd,
  259. )
  260. dense_stage_layers.append(stage)
  261. num_features += num_blocks_list[i] * growth_rates[i]
  262. if i + 1 == self.num_stages or (i + 1 != self.num_stages and is_downsample_block[i + 1]):
  263. self.feature_info += [
  264. dict(
  265. num_chs=num_features,
  266. reduction=curr_stride,
  267. module=f'dense_stages.{i}',
  268. growth_rate=growth_rates[i],
  269. )
  270. ]
  271. dense_stages.append(nn.Sequential(*dense_stage_layers))
  272. self.dense_stages = nn.Sequential(*dense_stages)
  273. self.num_features = self.head_hidden_size = num_features
  274. # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
  275. # otherwise pool -> norm -> fc, the default RDNet ordering (pretrained NV weights)
  276. if head_norm_first:
  277. self.norm_pre = norm_layer(self.num_features, **dd)
  278. self.head = ClassifierHead(
  279. self.num_features,
  280. num_classes,
  281. pool_type=global_pool,
  282. drop_rate=self.drop_rate,
  283. **dd,
  284. )
  285. else:
  286. self.norm_pre = nn.Identity()
  287. self.head = NormMlpClassifierHead(
  288. self.num_features,
  289. num_classes,
  290. pool_type=global_pool,
  291. drop_rate=self.drop_rate,
  292. norm_layer=norm_layer,
  293. **dd,
  294. )
  295. named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
  296. @torch.jit.ignore
  297. def group_matcher(self, coarse=False):
  298. assert not coarse, "coarse grouping is not implemented for RDNet"
  299. return dict(
  300. stem=r'^stem',
  301. blocks=r'^dense_stages\.(\d+)',
  302. )
  303. @torch.jit.ignore
  304. def set_grad_checkpointing(self, enable=True):
  305. for s in self.dense_stages:
  306. s.grad_checkpointing = enable
  307. @torch.jit.ignore
  308. def get_classifier(self) -> nn.Module:
  309. return self.head.fc
  310. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  311. self.num_classes = num_classes
  312. self.head.reset(num_classes, global_pool)
  313. def forward_intermediates(
  314. self,
  315. x: torch.Tensor,
  316. indices: Optional[Union[int, List[int]]] = None,
  317. norm: bool = False,
  318. stop_early: bool = False,
  319. output_fmt: str = 'NCHW',
  320. intermediates_only: bool = False,
  321. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  322. """ Forward features that returns intermediates.
  323. Args:
  324. x: Input image tensor
  325. indices: Take last n blocks if int, all if None, select matching indices if sequence
  326. norm: Apply norm layer to compatible intermediates
  327. stop_early: Stop iterating over blocks when last desired intermediate hit
  328. output_fmt: Shape of intermediate feature outputs
  329. intermediates_only: Only return intermediate features
  330. """
  331. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  332. intermediates = []
  333. stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
  334. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  335. take_indices = [stage_ends[i] for i in take_indices]
  336. max_index = stage_ends[max_index]
  337. # forward pass
  338. x = self.stem(x)
  339. last_idx = len(self.dense_stages) - 1
  340. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  341. dense_stages = self.dense_stages
  342. else:
  343. dense_stages = self.dense_stages[:max_index + 1]
  344. for feat_idx, stage in enumerate(dense_stages):
  345. x = stage(x)
  346. if feat_idx in take_indices:
  347. if norm and feat_idx == last_idx:
  348. x_inter = self.norm_pre(x) # applying final norm to last intermediate
  349. else:
  350. x_inter = x
  351. intermediates.append(x_inter)
  352. if intermediates_only:
  353. return intermediates
  354. if feat_idx == last_idx:
  355. x = self.norm_pre(x)
  356. return x, intermediates
  357. def prune_intermediate_layers(
  358. self,
  359. indices: Union[int, List[int]] = 1,
  360. prune_norm: bool = False,
  361. prune_head: bool = True,
  362. ):
  363. """ Prune layers not required for specified intermediates.
  364. """
  365. stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
  366. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  367. max_index = stage_ends[max_index]
  368. self.dense_stages = self.dense_stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  369. if prune_norm:
  370. self.norm_pre = nn.Identity()
  371. if prune_head:
  372. self.reset_classifier(0, '')
  373. return take_indices
  374. def forward_features(self, x):
  375. x = self.stem(x)
  376. x = self.dense_stages(x)
  377. x = self.norm_pre(x)
  378. return x
  379. def forward_head(self, x, pre_logits: bool = False):
  380. return self.head(x, pre_logits=True) if pre_logits else self.head(x)
  381. def forward(self, x):
  382. x = self.forward_features(x)
  383. x = self.forward_head(x)
  384. return x
  385. def _init_weights(module, name=None, head_init_scale=1.0):
  386. if isinstance(module, nn.Conv2d):
  387. nn.init.kaiming_normal_(module.weight)
  388. elif isinstance(module, nn.BatchNorm2d):
  389. nn.init.constant_(module.weight, 1)
  390. nn.init.constant_(module.bias, 0)
  391. elif isinstance(module, nn.Linear):
  392. nn.init.constant_(module.bias, 0)
  393. if name and 'head.' in name:
  394. module.weight.data.mul_(head_init_scale)
  395. module.bias.data.mul_(head_init_scale)
  396. def checkpoint_filter_fn(state_dict, model):
  397. """ Remap NV checkpoints -> timm """
  398. if 'stem.0.weight' in state_dict:
  399. return state_dict # non-NV checkpoint
  400. if 'model' in state_dict:
  401. state_dict = state_dict['model']
  402. out_dict = {}
  403. for k, v in state_dict.items():
  404. k = k.replace('stem.stem.', 'stem.')
  405. out_dict[k] = v
  406. return out_dict
  407. def _create_rdnet(variant, pretrained=False, **kwargs):
  408. model = build_model_with_cfg(
  409. RDNet, variant, pretrained,
  410. pretrained_filter_fn=checkpoint_filter_fn,
  411. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  412. **kwargs)
  413. return model
  414. def _cfg(url='', **kwargs):
  415. return {
  416. "url": url,
  417. "num_classes": 1000, "input_size": (3, 224, 224), "pool_size": (7, 7),
  418. "crop_pct": 0.9, "interpolation": "bicubic",
  419. "mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD,
  420. "first_conv": "stem.0", "classifier": "head.fc",
  421. "paper_ids": "arXiv:2403.19588",
  422. "paper_name": "DenseNets Reloaded: Paradigm Shift Beyond ResNets and ViTs",
  423. "origin_url": "https://github.com/naver-ai/rdnet",
  424. "license": "apache-2.0",
  425. **kwargs,
  426. }
  427. default_cfgs = generate_default_cfgs({
  428. 'rdnet_tiny.nv_in1k': _cfg(
  429. hf_hub_id='naver-ai/rdnet_tiny.nv_in1k'),
  430. 'rdnet_small.nv_in1k': _cfg(
  431. hf_hub_id='naver-ai/rdnet_small.nv_in1k'),
  432. 'rdnet_base.nv_in1k': _cfg(
  433. hf_hub_id='naver-ai/rdnet_base.nv_in1k'),
  434. 'rdnet_large.nv_in1k': _cfg(
  435. hf_hub_id='naver-ai/rdnet_large.nv_in1k'),
  436. 'rdnet_large.nv_in1k_ft_in1k_384': _cfg(
  437. hf_hub_id='naver-ai/rdnet_large.nv_in1k_ft_in1k_384',
  438. input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
  439. })
  440. @register_model
  441. def rdnet_tiny(pretrained=False, **kwargs):
  442. n_layer = 7
  443. model_args = {
  444. "num_init_features": 64,
  445. "growth_rates": [64] + [104] + [128] * 4 + [224],
  446. "num_blocks_list": [3] * n_layer,
  447. "is_downsample_block": (None, True, True, False, False, False, True),
  448. "transition_compression_ratio": 0.5,
  449. "block_type": ["Block"] + ["Block"] + ["BlockESE"] * 4 + ["BlockESE"],
  450. }
  451. model = _create_rdnet("rdnet_tiny", pretrained=pretrained, **dict(model_args, **kwargs))
  452. return model
  453. @register_model
  454. def rdnet_small(pretrained=False, **kwargs):
  455. n_layer = 11
  456. model_args = {
  457. "num_init_features": 72,
  458. "growth_rates": [64] + [128] + [128] * (n_layer - 4) + [240] * 2,
  459. "num_blocks_list": [3] * n_layer,
  460. "is_downsample_block": (None, True, True, False, False, False, False, False, False, True, False),
  461. "transition_compression_ratio": 0.5,
  462. "block_type": ["Block"] + ["Block"] + ["BlockESE"] * (n_layer - 4) + ["BlockESE"] * 2,
  463. }
  464. model = _create_rdnet("rdnet_small", pretrained=pretrained, **dict(model_args, **kwargs))
  465. return model
  466. @register_model
  467. def rdnet_base(pretrained=False, **kwargs):
  468. n_layer = 11
  469. model_args = {
  470. "num_init_features": 120,
  471. "growth_rates": [96] + [128] + [168] * (n_layer - 4) + [336] * 2,
  472. "num_blocks_list": [3] * n_layer,
  473. "is_downsample_block": (None, True, True, False, False, False, False, False, False, True, False),
  474. "transition_compression_ratio": 0.5,
  475. "block_type": ["Block"] + ["Block"] + ["BlockESE"] * (n_layer - 4) + ["BlockESE"] * 2,
  476. }
  477. model = _create_rdnet("rdnet_base", pretrained=pretrained, **dict(model_args, **kwargs))
  478. return model
  479. @register_model
  480. def rdnet_large(pretrained=False, **kwargs):
  481. n_layer = 12
  482. model_args = {
  483. "num_init_features": 144,
  484. "growth_rates": [128] + [192] + [256] * (n_layer - 4) + [360] * 2,
  485. "num_blocks_list": [3] * n_layer,
  486. "is_downsample_block": (None, True, True, False, False, False, False, False, False, False, True, False),
  487. "transition_compression_ratio": 0.5,
  488. "block_type": ["Block"] + ["Block"] + ["BlockESE"] * (n_layer - 4) + ["BlockESE"] * 2,
  489. }
  490. model = _create_rdnet("rdnet_large", pretrained=pretrained, **dict(model_args, **kwargs))
  491. return model