vovnet.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559
  1. """ VoVNet (V1 & V2)
  2. Papers:
  3. * `An Energy and GPU-Computation Efficient Backbone Network` - https://arxiv.org/abs/1904.09730
  4. * `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
  5. Looked at https://github.com/youngwanLEE/vovnet-detectron2 &
  6. https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
  7. for some reference, rewrote most of the code.
  8. Hacked together by / Copyright 2020 Ross Wightman
  9. """
  10. from typing import List, Optional, Tuple, Union, Type
  11. import torch
  12. import torch.nn as nn
  13. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  14. from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \
  15. create_attn, create_norm_act_layer, calculate_drop_path_rates
  16. from ._builder import build_model_with_cfg
  17. from ._features import feature_take_indices
  18. from ._manipulate import checkpoint_seq
  19. from ._registry import register_model, generate_default_cfgs
  20. __all__ = ['VovNet'] # model_registry will add each entrypoint fn to this
  21. class SequentialAppendList(nn.Sequential):
  22. def __init__(self, *args, **kwargs):
  23. super().__init__(*args)
  24. def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor:
  25. for i, module in enumerate(self):
  26. if i == 0:
  27. concat_list.append(module(x))
  28. else:
  29. concat_list.append(module(concat_list[-1]))
  30. x = torch.cat(concat_list, dim=1)
  31. return x
  32. class OsaBlock(nn.Module):
  33. def __init__(
  34. self,
  35. in_chs: int,
  36. mid_chs: int,
  37. out_chs: int,
  38. layer_per_block: int,
  39. residual: bool = False,
  40. depthwise: bool = False,
  41. attn: str = '',
  42. norm_layer: Type[nn.Module] = BatchNormAct2d,
  43. act_layer: Type[nn.Module] = nn.ReLU,
  44. drop_path: Optional[nn.Module] = None,
  45. device=None,
  46. dtype=None,
  47. ):
  48. dd = {'device': device, 'dtype': dtype}
  49. super().__init__()
  50. self.residual = residual
  51. self.depthwise = depthwise
  52. conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer, **dd)
  53. next_in_chs = in_chs
  54. if self.depthwise and next_in_chs != mid_chs:
  55. assert not residual
  56. self.conv_reduction = ConvNormAct(next_in_chs, mid_chs, 1, **conv_kwargs)
  57. else:
  58. self.conv_reduction = None
  59. mid_convs = []
  60. for i in range(layer_per_block):
  61. if self.depthwise:
  62. conv = SeparableConvNormAct(mid_chs, mid_chs, **conv_kwargs)
  63. else:
  64. conv = ConvNormAct(next_in_chs, mid_chs, 3, **conv_kwargs)
  65. next_in_chs = mid_chs
  66. mid_convs.append(conv)
  67. self.conv_mid = SequentialAppendList(*mid_convs)
  68. # feature aggregation
  69. next_in_chs = in_chs + layer_per_block * mid_chs
  70. self.conv_concat = ConvNormAct(next_in_chs, out_chs, **conv_kwargs)
  71. self.attn = create_attn(attn, out_chs, **dd) if attn else None
  72. self.drop_path = drop_path
  73. def forward(self, x):
  74. output = [x]
  75. if self.conv_reduction is not None:
  76. x = self.conv_reduction(x)
  77. x = self.conv_mid(x, output)
  78. x = self.conv_concat(x)
  79. if self.attn is not None:
  80. x = self.attn(x)
  81. if self.drop_path is not None:
  82. x = self.drop_path(x)
  83. if self.residual:
  84. x = x + output[0]
  85. return x
  86. class OsaStage(nn.Module):
  87. def __init__(
  88. self,
  89. in_chs: int,
  90. mid_chs: int,
  91. out_chs: int,
  92. block_per_stage: int,
  93. layer_per_block: int,
  94. downsample: bool = True,
  95. residual: bool = True,
  96. depthwise: bool = False,
  97. attn: str = 'ese',
  98. norm_layer: Type[nn.Module] = BatchNormAct2d,
  99. act_layer: Type[nn.Module] = nn.ReLU,
  100. drop_path_rates: Optional[List[float]] = None,
  101. device=None,
  102. dtype=None,
  103. ):
  104. dd = {'device': device, 'dtype': dtype}
  105. super().__init__()
  106. self.grad_checkpointing = False
  107. if downsample:
  108. self.pool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
  109. else:
  110. self.pool = None
  111. blocks = []
  112. for i in range(block_per_stage):
  113. last_block = i == block_per_stage - 1
  114. if drop_path_rates is not None and drop_path_rates[i] > 0.:
  115. drop_path = DropPath(drop_path_rates[i])
  116. else:
  117. drop_path = None
  118. blocks += [OsaBlock(
  119. in_chs,
  120. mid_chs,
  121. out_chs,
  122. layer_per_block,
  123. residual=residual and i > 0,
  124. depthwise=depthwise,
  125. attn=attn if last_block else '',
  126. norm_layer=norm_layer,
  127. act_layer=act_layer,
  128. drop_path=drop_path,
  129. **dd,
  130. )]
  131. in_chs = out_chs
  132. self.blocks = nn.Sequential(*blocks)
  133. def forward(self, x):
  134. if self.pool is not None:
  135. x = self.pool(x)
  136. if self.grad_checkpointing and not torch.jit.is_scripting():
  137. x = checkpoint_seq(self.blocks, x)
  138. else:
  139. x = self.blocks(x)
  140. return x
  141. class VovNet(nn.Module):
  142. def __init__(
  143. self,
  144. cfg: dict,
  145. in_chans: int = 3,
  146. num_classes: int = 1000,
  147. global_pool: str = 'avg',
  148. output_stride: int = 32,
  149. norm_layer: Type[nn.Module] = BatchNormAct2d,
  150. act_layer: Type[nn.Module] = nn.ReLU,
  151. drop_rate: float = 0.,
  152. drop_path_rate: float = 0.,
  153. device=None,
  154. dtype=None,
  155. **kwargs,
  156. ):
  157. """
  158. Args:
  159. cfg (dict): Model architecture configuration
  160. in_chans (int): Number of input channels (default: 3)
  161. num_classes (int): Number of classifier classes (default: 1000)
  162. global_pool (str): Global pooling type (default: 'avg')
  163. output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
  164. norm_layer (Union[str, nn.Module]): normalization layer
  165. act_layer (Union[str, nn.Module]): activation layer
  166. drop_rate (float): Dropout rate (default: 0.)
  167. drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
  168. kwargs (dict): Extra kwargs overlayed onto cfg
  169. """
  170. super().__init__()
  171. dd = {'device': device, 'dtype': dtype}
  172. self.num_classes = num_classes
  173. self.in_chans = in_chans
  174. self.drop_rate = drop_rate
  175. assert output_stride == 32 # FIXME support dilation
  176. cfg = dict(cfg, **kwargs)
  177. stem_stride = cfg.get("stem_stride", 4)
  178. stem_chs = cfg["stem_chs"]
  179. stage_conv_chs = cfg["stage_conv_chs"]
  180. stage_out_chs = cfg["stage_out_chs"]
  181. block_per_stage = cfg["block_per_stage"]
  182. layer_per_block = cfg["layer_per_block"]
  183. conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer, **dd)
  184. # Stem module
  185. last_stem_stride = stem_stride // 2
  186. conv_type = SeparableConvNormAct if cfg["depthwise"] else ConvNormAct
  187. self.stem = nn.Sequential(*[
  188. ConvNormAct(in_chans, stem_chs[0], 3, stride=2, **conv_kwargs),
  189. conv_type(stem_chs[0], stem_chs[1], 3, stride=1, **conv_kwargs),
  190. conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, **conv_kwargs),
  191. ])
  192. self.feature_info = [dict(
  193. num_chs=stem_chs[1], reduction=2, module=f'stem.{1 if stem_stride == 4 else 2}')]
  194. current_stride = stem_stride
  195. # OSA stages
  196. stage_dpr = calculate_drop_path_rates(drop_path_rate, block_per_stage, stagewise=True)
  197. in_ch_list = stem_chs[-1:] + stage_out_chs[:-1]
  198. stage_args = dict(residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], **conv_kwargs)
  199. stages = []
  200. for i in range(4): # num_stages
  201. downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
  202. stages += [OsaStage(
  203. in_ch_list[i],
  204. stage_conv_chs[i],
  205. stage_out_chs[i],
  206. block_per_stage[i],
  207. layer_per_block,
  208. downsample=downsample,
  209. drop_path_rates=stage_dpr[i],
  210. **stage_args,
  211. )]
  212. self.num_features = stage_out_chs[i]
  213. current_stride *= 2 if downsample else 1
  214. self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
  215. self.stages = nn.Sequential(*stages)
  216. self.head_hidden_size = self.num_features
  217. self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate, **dd)
  218. for n, m in self.named_modules():
  219. if isinstance(m, nn.Conv2d):
  220. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  221. elif isinstance(m, nn.Linear):
  222. nn.init.zeros_(m.bias)
  223. @torch.jit.ignore
  224. def group_matcher(self, coarse=False):
  225. return dict(
  226. stem=r'^stem',
  227. blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).blocks\.(\d+)',
  228. )
  229. @torch.jit.ignore
  230. def set_grad_checkpointing(self, enable=True):
  231. for s in self.stages:
  232. s.grad_checkpointing = enable
  233. @torch.jit.ignore
  234. def get_classifier(self) -> nn.Module:
  235. return self.head.fc
  236. def reset_classifier(self, num_classes, global_pool: Optional[str] = None):
  237. self.num_classes = num_classes
  238. self.head.reset(num_classes, global_pool)
  239. def forward_intermediates(
  240. self,
  241. x: torch.Tensor,
  242. indices: Optional[Union[int, List[int]]] = None,
  243. norm: bool = False,
  244. stop_early: bool = False,
  245. output_fmt: str = 'NCHW',
  246. intermediates_only: bool = False,
  247. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  248. """ Forward features that returns intermediates.
  249. Args:
  250. x: Input image tensor
  251. indices: Take last n blocks if int, all if None, select matching indices if sequence
  252. norm: Apply norm layer to compatible intermediates
  253. stop_early: Stop iterating over blocks when last desired intermediate hit
  254. output_fmt: Shape of intermediate feature outputs
  255. intermediates_only: Only return intermediate features
  256. Returns:
  257. """
  258. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  259. intermediates = []
  260. take_indices, max_index = feature_take_indices(5, indices)
  261. # forward pass
  262. feat_idx = 0
  263. x = self.stem[:-1](x)
  264. if feat_idx in take_indices:
  265. intermediates.append(x)
  266. x = self.stem[-1](x)
  267. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  268. stages = self.stages
  269. else:
  270. stages = self.stages[:max_index]
  271. for feat_idx, stage in enumerate(stages, start=1):
  272. x = stage(x)
  273. if feat_idx in take_indices:
  274. intermediates.append(x)
  275. if intermediates_only:
  276. return intermediates
  277. return x, intermediates
  278. def prune_intermediate_layers(
  279. self,
  280. indices: Union[int, List[int]] = 1,
  281. prune_norm: bool = False,
  282. prune_head: bool = True,
  283. ):
  284. """ Prune layers not required for specified intermediates.
  285. """
  286. take_indices, max_index = feature_take_indices(5, indices)
  287. self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
  288. if prune_head:
  289. self.reset_classifier(0, '')
  290. return take_indices
  291. def forward_features(self, x):
  292. x = self.stem(x)
  293. return self.stages(x)
  294. def forward_head(self, x, pre_logits: bool = False):
  295. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  296. def forward(self, x):
  297. x = self.forward_features(x)
  298. x = self.forward_head(x)
  299. return x
  300. # model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 &
  301. # https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
  302. model_cfgs = dict(
  303. vovnet39a=dict(
  304. stem_chs=[64, 64, 128],
  305. stage_conv_chs=[128, 160, 192, 224],
  306. stage_out_chs=[256, 512, 768, 1024],
  307. layer_per_block=5,
  308. block_per_stage=[1, 1, 2, 2],
  309. residual=False,
  310. depthwise=False,
  311. attn='',
  312. ),
  313. vovnet57a=dict(
  314. stem_chs=[64, 64, 128],
  315. stage_conv_chs=[128, 160, 192, 224],
  316. stage_out_chs=[256, 512, 768, 1024],
  317. layer_per_block=5,
  318. block_per_stage=[1, 1, 4, 3],
  319. residual=False,
  320. depthwise=False,
  321. attn='',
  322. ),
  323. ese_vovnet19b_slim_dw=dict(
  324. stem_chs=[64, 64, 64],
  325. stage_conv_chs=[64, 80, 96, 112],
  326. stage_out_chs=[112, 256, 384, 512],
  327. layer_per_block=3,
  328. block_per_stage=[1, 1, 1, 1],
  329. residual=True,
  330. depthwise=True,
  331. attn='ese',
  332. ),
  333. ese_vovnet19b_dw=dict(
  334. stem_chs=[64, 64, 64],
  335. stage_conv_chs=[128, 160, 192, 224],
  336. stage_out_chs=[256, 512, 768, 1024],
  337. layer_per_block=3,
  338. block_per_stage=[1, 1, 1, 1],
  339. residual=True,
  340. depthwise=True,
  341. attn='ese',
  342. ),
  343. ese_vovnet19b_slim=dict(
  344. stem_chs=[64, 64, 128],
  345. stage_conv_chs=[64, 80, 96, 112],
  346. stage_out_chs=[112, 256, 384, 512],
  347. layer_per_block=3,
  348. block_per_stage=[1, 1, 1, 1],
  349. residual=True,
  350. depthwise=False,
  351. attn='ese',
  352. ),
  353. ese_vovnet19b=dict(
  354. stem_chs=[64, 64, 128],
  355. stage_conv_chs=[128, 160, 192, 224],
  356. stage_out_chs=[256, 512, 768, 1024],
  357. layer_per_block=3,
  358. block_per_stage=[1, 1, 1, 1],
  359. residual=True,
  360. depthwise=False,
  361. attn='ese',
  362. ),
  363. ese_vovnet39b=dict(
  364. stem_chs=[64, 64, 128],
  365. stage_conv_chs=[128, 160, 192, 224],
  366. stage_out_chs=[256, 512, 768, 1024],
  367. layer_per_block=5,
  368. block_per_stage=[1, 1, 2, 2],
  369. residual=True,
  370. depthwise=False,
  371. attn='ese',
  372. ),
  373. ese_vovnet57b=dict(
  374. stem_chs=[64, 64, 128],
  375. stage_conv_chs=[128, 160, 192, 224],
  376. stage_out_chs=[256, 512, 768, 1024],
  377. layer_per_block=5,
  378. block_per_stage=[1, 1, 4, 3],
  379. residual=True,
  380. depthwise=False,
  381. attn='ese',
  382. ),
  383. ese_vovnet99b=dict(
  384. stem_chs=[64, 64, 128],
  385. stage_conv_chs=[128, 160, 192, 224],
  386. stage_out_chs=[256, 512, 768, 1024],
  387. layer_per_block=5,
  388. block_per_stage=[1, 3, 9, 3],
  389. residual=True,
  390. depthwise=False,
  391. attn='ese',
  392. ),
  393. eca_vovnet39b=dict(
  394. stem_chs=[64, 64, 128],
  395. stage_conv_chs=[128, 160, 192, 224],
  396. stage_out_chs=[256, 512, 768, 1024],
  397. layer_per_block=5,
  398. block_per_stage=[1, 1, 2, 2],
  399. residual=True,
  400. depthwise=False,
  401. attn='eca',
  402. ),
  403. )
  404. model_cfgs['ese_vovnet39b_evos'] = model_cfgs['ese_vovnet39b']
  405. def _create_vovnet(variant, pretrained=False, **kwargs):
  406. return build_model_with_cfg(
  407. VovNet,
  408. variant,
  409. pretrained,
  410. model_cfg=model_cfgs[variant],
  411. feature_cfg=dict(flatten_sequential=True),
  412. **kwargs,
  413. )
  414. def _cfg(url='', **kwargs):
  415. return {
  416. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  417. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  418. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  419. 'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
  420. 'license': 'apache-2.0', **kwargs,
  421. }
  422. default_cfgs = generate_default_cfgs({
  423. 'vovnet39a.untrained': _cfg(url=''),
  424. 'vovnet57a.untrained': _cfg(url=''),
  425. 'ese_vovnet19b_slim_dw.untrained': _cfg(url=''),
  426. 'ese_vovnet19b_dw.ra_in1k': _cfg(
  427. hf_hub_id='timm/',
  428. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  429. 'ese_vovnet19b_slim.untrained': _cfg(url=''),
  430. 'ese_vovnet39b.ra_in1k': _cfg(
  431. hf_hub_id='timm/',
  432. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  433. 'ese_vovnet57b.ra4_e3600_r256_in1k': _cfg(
  434. hf_hub_id='timm/',
  435. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  436. crop_pct=0.95, input_size=(3, 256, 256), pool_size=(8, 8),
  437. test_input_size=(3, 320, 320), test_crop_pct=1.0
  438. ),
  439. 'ese_vovnet99b.untrained': _cfg(url=''),
  440. 'eca_vovnet39b.untrained': _cfg(url=''),
  441. 'ese_vovnet39b_evos.untrained': _cfg(url=''),
  442. })
  443. @register_model
  444. def vovnet39a(pretrained=False, **kwargs) -> VovNet:
  445. return _create_vovnet('vovnet39a', pretrained=pretrained, **kwargs)
  446. @register_model
  447. def vovnet57a(pretrained=False, **kwargs) -> VovNet:
  448. return _create_vovnet('vovnet57a', pretrained=pretrained, **kwargs)
  449. @register_model
  450. def ese_vovnet19b_slim_dw(pretrained=False, **kwargs) -> VovNet:
  451. return _create_vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs)
  452. @register_model
  453. def ese_vovnet19b_dw(pretrained=False, **kwargs) -> VovNet:
  454. return _create_vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs)
  455. @register_model
  456. def ese_vovnet19b_slim(pretrained=False, **kwargs) -> VovNet:
  457. return _create_vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs)
  458. @register_model
  459. def ese_vovnet39b(pretrained=False, **kwargs) -> VovNet:
  460. return _create_vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs)
  461. @register_model
  462. def ese_vovnet57b(pretrained=False, **kwargs) -> VovNet:
  463. return _create_vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs)
  464. @register_model
  465. def ese_vovnet99b(pretrained=False, **kwargs) -> VovNet:
  466. return _create_vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs)
  467. @register_model
  468. def eca_vovnet39b(pretrained=False, **kwargs) -> VovNet:
  469. return _create_vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs)
  470. # Experimental Models
  471. @register_model
  472. def ese_vovnet39b_evos(pretrained=False, **kwargs) -> VovNet:
  473. def norm_act_fn(num_features, **nkwargs):
  474. return create_norm_act_layer('evonorms0', num_features, jit=False, **nkwargs)
  475. return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)