hgnet.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852
  1. """ PP-HGNet (V1 & V2)
  2. Reference:
  3. https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/models/ImageNet1k/PP-HGNetV2.md
  4. The Paddle Implement of PP-HGNet (https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/docs/en/models/PP-HGNet_en.md)
  5. PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py
  6. PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py
  7. """
  8. from typing import Dict, List, Optional, Tuple, Type, Union
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  13. from timm.layers import SelectAdaptivePool2d, DropPath, calculate_drop_path_rates, create_conv2d
  14. from ._builder import build_model_with_cfg
  15. from ._features import feature_take_indices
  16. from ._registry import register_model, generate_default_cfgs
  17. from ._manipulate import checkpoint_seq
  18. __all__ = ['HighPerfGpuNet']
  19. class LearnableAffineBlock(nn.Module):
  20. def __init__(
  21. self,
  22. scale_value: float = 1.0,
  23. bias_value: float = 0.0,
  24. device=None,
  25. dtype=None,
  26. ):
  27. dd = {'device': device, 'dtype': dtype}
  28. super().__init__()
  29. self.scale = nn.Parameter(torch.tensor([scale_value], **dd), requires_grad=True)
  30. self.bias = nn.Parameter(torch.tensor([bias_value], **dd), requires_grad=True)
  31. def forward(self, x):
  32. return self.scale * x + self.bias
  33. class ConvBNAct(nn.Module):
  34. def __init__(
  35. self,
  36. in_chs: int,
  37. out_chs: int,
  38. kernel_size: int,
  39. stride: int = 1,
  40. groups: int = 1,
  41. padding: str = '',
  42. use_act: bool = True,
  43. use_lab: bool = False,
  44. device=None,
  45. dtype=None,
  46. ):
  47. dd = {'device': device, 'dtype': dtype}
  48. super().__init__()
  49. self.use_act = use_act
  50. self.use_lab = use_lab
  51. self.conv = create_conv2d(
  52. in_chs,
  53. out_chs,
  54. kernel_size,
  55. stride=stride,
  56. padding=padding,
  57. groups=groups,
  58. **dd,
  59. )
  60. self.bn = nn.BatchNorm2d(out_chs, **dd)
  61. if self.use_act:
  62. self.act = nn.ReLU()
  63. else:
  64. self.act = nn.Identity()
  65. if self.use_act and self.use_lab:
  66. self.lab = LearnableAffineBlock(**dd)
  67. else:
  68. self.lab = nn.Identity()
  69. def forward(self, x):
  70. x = self.conv(x)
  71. x = self.bn(x)
  72. x = self.act(x)
  73. x = self.lab(x)
  74. return x
  75. class LightConvBNAct(nn.Module):
  76. def __init__(
  77. self,
  78. in_chs: int,
  79. out_chs: int,
  80. kernel_size: int,
  81. groups: int = 1,
  82. use_lab: bool = False,
  83. device=None,
  84. dtype=None,
  85. ):
  86. dd = {'device': device, 'dtype': dtype}
  87. super().__init__()
  88. self.conv1 = ConvBNAct(
  89. in_chs,
  90. out_chs,
  91. kernel_size=1,
  92. use_act=False,
  93. use_lab=use_lab,
  94. **dd,
  95. )
  96. self.conv2 = ConvBNAct(
  97. out_chs,
  98. out_chs,
  99. kernel_size=kernel_size,
  100. groups=out_chs,
  101. use_act=True,
  102. use_lab=use_lab,
  103. **dd,
  104. )
  105. def forward(self, x):
  106. x = self.conv1(x)
  107. x = self.conv2(x)
  108. return x
  109. class EseModule(nn.Module):
  110. def __init__(self, chs: int, device=None, dtype=None):
  111. dd = {'device': device, 'dtype': dtype}
  112. super().__init__()
  113. self.conv = nn.Conv2d(
  114. chs,
  115. chs,
  116. kernel_size=1,
  117. stride=1,
  118. padding=0,
  119. **dd,
  120. )
  121. self.sigmoid = nn.Sigmoid()
  122. def forward(self, x):
  123. identity = x
  124. x = x.mean((2, 3), keepdim=True)
  125. x = self.conv(x)
  126. x = self.sigmoid(x)
  127. return torch.mul(identity, x)
  128. class StemV1(nn.Module):
  129. # for PP-HGNet
  130. def __init__(self, stem_chs: List[int], device=None, dtype=None):
  131. dd = {'device': device, 'dtype': dtype}
  132. super().__init__()
  133. self.stem = nn.Sequential(*[
  134. ConvBNAct(
  135. stem_chs[i],
  136. stem_chs[i + 1],
  137. kernel_size=3,
  138. stride=2 if i == 0 else 1,
  139. **dd) for i in range(
  140. len(stem_chs) - 1)
  141. ])
  142. self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  143. def forward(self, x):
  144. x = self.stem(x)
  145. x = self.pool(x)
  146. return x
  147. class StemV2(nn.Module):
  148. # for PP-HGNetv2
  149. def __init__(
  150. self,
  151. in_chs: int,
  152. mid_chs: int,
  153. out_chs: int,
  154. use_lab: bool = False,
  155. device=None,
  156. dtype=None,
  157. ):
  158. dd = {'device': device, 'dtype': dtype}
  159. super().__init__()
  160. self.stem1 = ConvBNAct(
  161. in_chs,
  162. mid_chs,
  163. kernel_size=3,
  164. stride=2,
  165. use_lab=use_lab,
  166. **dd,
  167. )
  168. self.stem2a = ConvBNAct(
  169. mid_chs,
  170. mid_chs // 2,
  171. kernel_size=2,
  172. stride=1,
  173. use_lab=use_lab,
  174. **dd,
  175. )
  176. self.stem2b = ConvBNAct(
  177. mid_chs // 2,
  178. mid_chs,
  179. kernel_size=2,
  180. stride=1,
  181. use_lab=use_lab,
  182. **dd,
  183. )
  184. self.stem3 = ConvBNAct(
  185. mid_chs * 2,
  186. mid_chs,
  187. kernel_size=3,
  188. stride=2,
  189. use_lab=use_lab,
  190. **dd,
  191. )
  192. self.stem4 = ConvBNAct(
  193. mid_chs,
  194. out_chs,
  195. kernel_size=1,
  196. stride=1,
  197. use_lab=use_lab,
  198. **dd,
  199. )
  200. self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True)
  201. def forward(self, x):
  202. x = self.stem1(x)
  203. x = F.pad(x, (0, 1, 0, 1))
  204. x2 = self.stem2a(x)
  205. x2 = F.pad(x2, (0, 1, 0, 1))
  206. x2 = self.stem2b(x2)
  207. x1 = self.pool(x)
  208. x = torch.cat([x1, x2], dim=1)
  209. x = self.stem3(x)
  210. x = self.stem4(x)
  211. return x
  212. class HighPerfGpuBlock(nn.Module):
  213. def __init__(
  214. self,
  215. in_chs: int,
  216. mid_chs: int,
  217. out_chs: int,
  218. layer_num: int,
  219. kernel_size: int = 3,
  220. residual: bool = False,
  221. light_block: bool = False,
  222. use_lab: bool = False,
  223. agg: str = 'ese',
  224. drop_path: Union[List[float], float] = 0.,
  225. device=None,
  226. dtype=None,
  227. ):
  228. dd = {'device': device, 'dtype': dtype}
  229. super().__init__()
  230. self.residual = residual
  231. self.layers = nn.ModuleList()
  232. for i in range(layer_num):
  233. if light_block:
  234. self.layers.append(
  235. LightConvBNAct(
  236. in_chs if i == 0 else mid_chs,
  237. mid_chs,
  238. kernel_size=kernel_size,
  239. use_lab=use_lab,
  240. **dd,
  241. )
  242. )
  243. else:
  244. self.layers.append(
  245. ConvBNAct(
  246. in_chs if i == 0 else mid_chs,
  247. mid_chs,
  248. kernel_size=kernel_size,
  249. stride=1,
  250. use_lab=use_lab,
  251. **dd,
  252. )
  253. )
  254. # feature aggregation
  255. total_chs = in_chs + layer_num * mid_chs
  256. if agg == 'se':
  257. aggregation_squeeze_conv = ConvBNAct(
  258. total_chs,
  259. out_chs // 2,
  260. kernel_size=1,
  261. stride=1,
  262. use_lab=use_lab,
  263. **dd,
  264. )
  265. aggregation_excitation_conv = ConvBNAct(
  266. out_chs // 2,
  267. out_chs,
  268. kernel_size=1,
  269. stride=1,
  270. use_lab=use_lab,
  271. **dd,
  272. )
  273. self.aggregation = nn.Sequential(
  274. aggregation_squeeze_conv,
  275. aggregation_excitation_conv,
  276. )
  277. else:
  278. aggregation_conv = ConvBNAct(
  279. total_chs,
  280. out_chs,
  281. kernel_size=1,
  282. stride=1,
  283. use_lab=use_lab,
  284. **dd,
  285. )
  286. att = EseModule(out_chs, **dd)
  287. self.aggregation = nn.Sequential(
  288. aggregation_conv,
  289. att,
  290. )
  291. self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
  292. def forward(self, x):
  293. identity = x
  294. output = [x]
  295. for layer in self.layers:
  296. x = layer(x)
  297. output.append(x)
  298. x = torch.cat(output, dim=1)
  299. x = self.aggregation(x)
  300. if self.residual:
  301. x = self.drop_path(x) + identity
  302. return x
  303. class HighPerfGpuStage(nn.Module):
  304. def __init__(
  305. self,
  306. in_chs: int,
  307. mid_chs: int,
  308. out_chs: int,
  309. block_num: int,
  310. layer_num: int,
  311. downsample: bool = True,
  312. stride: int = 2,
  313. light_block: bool = False,
  314. kernel_size: int = 3,
  315. use_lab: bool = False,
  316. agg: str = 'ese',
  317. drop_path: Union[List[float], float] = 0.,
  318. device=None,
  319. dtype=None,
  320. ):
  321. dd = {'device': device, 'dtype': dtype}
  322. super().__init__()
  323. self.downsample = downsample
  324. if downsample:
  325. self.downsample = ConvBNAct(
  326. in_chs,
  327. in_chs,
  328. kernel_size=3,
  329. stride=stride,
  330. groups=in_chs,
  331. use_act=False,
  332. use_lab=use_lab,
  333. **dd,
  334. )
  335. else:
  336. self.downsample = nn.Identity()
  337. blocks_list = []
  338. for i in range(block_num):
  339. blocks_list.append(
  340. HighPerfGpuBlock(
  341. in_chs if i == 0 else out_chs,
  342. mid_chs,
  343. out_chs,
  344. layer_num,
  345. residual=False if i == 0 else True,
  346. kernel_size=kernel_size,
  347. light_block=light_block,
  348. use_lab=use_lab,
  349. agg=agg,
  350. drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path,
  351. **dd,
  352. )
  353. )
  354. self.blocks = nn.Sequential(*blocks_list)
  355. self.grad_checkpointing= False
  356. def forward(self, x):
  357. x = self.downsample(x)
  358. if self.grad_checkpointing and not torch.jit.is_scripting():
  359. x = checkpoint_seq(self.blocks, x)
  360. else:
  361. x = self.blocks(x)
  362. return x
  363. class ClassifierHead(nn.Module):
  364. def __init__(
  365. self,
  366. in_features: int,
  367. num_classes: int,
  368. pool_type: str = 'avg',
  369. drop_rate: float = 0.,
  370. hidden_size: Optional[int] = 2048,
  371. use_lab: bool = False,
  372. device=None,
  373. dtype=None,
  374. ):
  375. dd = {'device': device, 'dtype': dtype}
  376. super().__init__()
  377. self.num_features = in_features
  378. if pool_type is not None:
  379. if not pool_type:
  380. assert num_classes == 0, 'Classifier head must be removed if pooling is disabled'
  381. self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
  382. if hidden_size is not None:
  383. self.num_features = hidden_size
  384. last_conv = nn.Conv2d(
  385. in_features,
  386. hidden_size,
  387. kernel_size=1,
  388. stride=1,
  389. padding=0,
  390. bias=False,
  391. **dd,
  392. )
  393. act = nn.ReLU()
  394. if use_lab:
  395. lab = LearnableAffineBlock(**dd)
  396. self.last_conv = nn.Sequential(last_conv, act, lab)
  397. else:
  398. self.last_conv = nn.Sequential(last_conv, act)
  399. else:
  400. self.last_conv = nn.Identity()
  401. self.dropout = nn.Dropout(drop_rate)
  402. self.flatten = nn.Flatten(1) if pool_type else nn.Identity() # don't flatten if pooling disabled
  403. self.fc = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  404. def reset(self, num_classes: int, pool_type: Optional[str] = None, device=None, dtype=None):
  405. dd = {'device': device, 'dtype': dtype}
  406. if pool_type is not None:
  407. if not pool_type:
  408. assert num_classes == 0, 'Classifier head must be removed if pooling is disabled'
  409. self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
  410. self.flatten = nn.Flatten(1) if pool_type else nn.Identity() # don't flatten if pooling disabled
  411. self.fc = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  412. def forward(self, x, pre_logits: bool = False):
  413. x = self.global_pool(x)
  414. x = self.last_conv(x)
  415. x = self.dropout(x)
  416. x = self.flatten(x)
  417. if pre_logits:
  418. return x
  419. x = self.fc(x)
  420. return x
  421. class HighPerfGpuNet(nn.Module):
  422. def __init__(
  423. self,
  424. cfg: Dict,
  425. in_chans: int = 3,
  426. num_classes: int = 1000,
  427. global_pool: str = 'avg',
  428. head_hidden_size: Optional[int] = 2048,
  429. drop_rate: float = 0.,
  430. drop_path_rate: float = 0.,
  431. use_lab: bool = False,
  432. device=None,
  433. dtype=None,
  434. **kwargs,
  435. ):
  436. super().__init__()
  437. dd = {'device': device, 'dtype': dtype}
  438. stem_type = cfg["stem_type"]
  439. stem_chs = cfg["stem_chs"]
  440. stages_cfg = [cfg["stage1"], cfg["stage2"], cfg["stage3"], cfg["stage4"]]
  441. self.num_classes = num_classes
  442. self.in_chans = in_chans
  443. self.drop_rate = drop_rate
  444. self.use_lab = use_lab
  445. assert stem_type in ['v1', 'v2']
  446. if stem_type == 'v2':
  447. self.stem = StemV2(
  448. in_chs=in_chans,
  449. mid_chs=stem_chs[0],
  450. out_chs=stem_chs[1],
  451. use_lab=use_lab,
  452. **dd,
  453. )
  454. else:
  455. self.stem = StemV1([in_chans] + stem_chs, **dd)
  456. current_stride = 4
  457. stages = []
  458. self.feature_info = []
  459. block_depths = [c[3] for c in stages_cfg]
  460. dpr = calculate_drop_path_rates(drop_path_rate, block_depths, stagewise=True)
  461. for i, stage_config in enumerate(stages_cfg):
  462. in_chs, mid_chs, out_chs, block_num, downsample, light_block, kernel_size, layer_num = stage_config
  463. stages += [HighPerfGpuStage(
  464. in_chs=in_chs,
  465. mid_chs=mid_chs,
  466. out_chs=out_chs,
  467. block_num=block_num,
  468. layer_num=layer_num,
  469. downsample=downsample,
  470. light_block=light_block,
  471. kernel_size=kernel_size,
  472. use_lab=use_lab,
  473. agg='ese' if stem_type == 'v1' else 'se',
  474. drop_path=dpr[i],
  475. **dd,
  476. )]
  477. self.num_features = out_chs
  478. if downsample:
  479. current_stride *= 2
  480. self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
  481. self.stages = nn.Sequential(*stages)
  482. self.head = ClassifierHead(
  483. self.num_features,
  484. num_classes=num_classes,
  485. pool_type=global_pool,
  486. drop_rate=drop_rate,
  487. hidden_size=head_hidden_size,
  488. use_lab=use_lab,
  489. **dd,
  490. )
  491. self.head_hidden_size = self.head.num_features
  492. for n, m in self.named_modules():
  493. if isinstance(m, nn.Conv2d):
  494. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  495. elif isinstance(m, nn.BatchNorm2d):
  496. nn.init.ones_(m.weight)
  497. nn.init.zeros_(m.bias)
  498. elif isinstance(m, nn.Linear):
  499. nn.init.zeros_(m.bias)
  500. @torch.jit.ignore
  501. def group_matcher(self, coarse=False):
  502. return dict(
  503. stem=r'^stem',
  504. blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).blocks\.(\d+)',
  505. )
  506. @torch.jit.ignore
  507. def set_grad_checkpointing(self, enable=True):
  508. for s in self.stages:
  509. s.grad_checkpointing = enable
  510. @torch.jit.ignore
  511. def get_classifier(self) -> nn.Module:
  512. return self.head.fc
  513. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, device=None, dtype=None):
  514. self.num_classes = num_classes
  515. self.head.reset(num_classes, global_pool, device=device, dtype=dtype)
  516. def forward_intermediates(
  517. self,
  518. x: torch.Tensor,
  519. indices: Optional[Union[int, List[int]]] = None,
  520. norm: bool = False,
  521. stop_early: bool = False,
  522. output_fmt: str = 'NCHW',
  523. intermediates_only: bool = False,
  524. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  525. """ Forward features that returns intermediates.
  526. Args:
  527. x: Input image tensor
  528. indices: Take last n blocks if int, all if None, select matching indices if sequence
  529. norm: Apply norm layer to compatible intermediates
  530. stop_early: Stop iterating over blocks when last desired intermediate hit
  531. output_fmt: Shape of intermediate feature outputs
  532. intermediates_only: Only return intermediate features
  533. Returns:
  534. """
  535. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  536. intermediates = []
  537. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  538. # forward pass
  539. x = self.stem(x)
  540. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  541. stages = self.stages
  542. else:
  543. stages = self.stages[:max_index + 1]
  544. for feat_idx, stage in enumerate(stages):
  545. x = stage(x)
  546. if feat_idx in take_indices:
  547. intermediates.append(x)
  548. if intermediates_only:
  549. return intermediates
  550. return x, intermediates
  551. def prune_intermediate_layers(
  552. self,
  553. indices: Union[int, List[int]] = 1,
  554. prune_norm: bool = False,
  555. prune_head: bool = True,
  556. ):
  557. """ Prune layers not required for specified intermediates.
  558. """
  559. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  560. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  561. if prune_head:
  562. self.reset_classifier(0, 'avg')
  563. return take_indices
  564. def forward_features(self, x):
  565. x = self.stem(x)
  566. return self.stages(x)
  567. def forward_head(self, x, pre_logits: bool = False):
  568. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  569. def forward(self, x):
  570. x = self.forward_features(x)
  571. x = self.forward_head(x)
  572. return x
  573. model_cfgs = dict(
  574. # PP-HGNet
  575. hgnet_tiny={
  576. "stem_type": 'v1',
  577. "stem_chs": [48, 48, 96],
  578. # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
  579. "stage1": [96, 96, 224, 1, False, False, 3, 5],
  580. "stage2": [224, 128, 448, 1, True, False, 3, 5],
  581. "stage3": [448, 160, 512, 2, True, False, 3, 5],
  582. "stage4": [512, 192, 768, 1, True, False, 3, 5],
  583. },
  584. hgnet_small={
  585. "stem_type": 'v1',
  586. "stem_chs": [64, 64, 128],
  587. # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
  588. "stage1": [128, 128, 256, 1, False, False, 3, 6],
  589. "stage2": [256, 160, 512, 1, True, False, 3, 6],
  590. "stage3": [512, 192, 768, 2, True, False, 3, 6],
  591. "stage4": [768, 224, 1024, 1, True, False, 3, 6],
  592. },
  593. hgnet_base={
  594. "stem_type": 'v1',
  595. "stem_chs": [96, 96, 160],
  596. # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
  597. "stage1": [160, 192, 320, 1, False, False, 3, 7],
  598. "stage2": [320, 224, 640, 2, True, False, 3, 7],
  599. "stage3": [640, 256, 960, 3, True, False, 3, 7],
  600. "stage4": [960, 288, 1280, 2, True, False, 3, 7],
  601. },
  602. # PP-HGNetv2
  603. hgnetv2_b0={
  604. "stem_type": 'v2',
  605. "stem_chs": [16, 16],
  606. # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
  607. "stage1": [16, 16, 64, 1, False, False, 3, 3],
  608. "stage2": [64, 32, 256, 1, True, False, 3, 3],
  609. "stage3": [256, 64, 512, 2, True, True, 5, 3],
  610. "stage4": [512, 128, 1024, 1, True, True, 5, 3],
  611. },
  612. hgnetv2_b1={
  613. "stem_type": 'v2',
  614. "stem_chs": [24, 32],
  615. # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
  616. "stage1": [32, 32, 64, 1, False, False, 3, 3],
  617. "stage2": [64, 48, 256, 1, True, False, 3, 3],
  618. "stage3": [256, 96, 512, 2, True, True, 5, 3],
  619. "stage4": [512, 192, 1024, 1, True, True, 5, 3],
  620. },
  621. hgnetv2_b2={
  622. "stem_type": 'v2',
  623. "stem_chs": [24, 32],
  624. # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
  625. "stage1": [32, 32, 96, 1, False, False, 3, 4],
  626. "stage2": [96, 64, 384, 1, True, False, 3, 4],
  627. "stage3": [384, 128, 768, 3, True, True, 5, 4],
  628. "stage4": [768, 256, 1536, 1, True, True, 5, 4],
  629. },
  630. hgnetv2_b3={
  631. "stem_type": 'v2',
  632. "stem_chs": [24, 32],
  633. # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
  634. "stage1": [32, 32, 128, 1, False, False, 3, 5],
  635. "stage2": [128, 64, 512, 1, True, False, 3, 5],
  636. "stage3": [512, 128, 1024, 3, True, True, 5, 5],
  637. "stage4": [1024, 256, 2048, 1, True, True, 5, 5],
  638. },
  639. hgnetv2_b4={
  640. "stem_type": 'v2',
  641. "stem_chs": [32, 48],
  642. # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
  643. "stage1": [48, 48, 128, 1, False, False, 3, 6],
  644. "stage2": [128, 96, 512, 1, True, False, 3, 6],
  645. "stage3": [512, 192, 1024, 3, True, True, 5, 6],
  646. "stage4": [1024, 384, 2048, 1, True, True, 5, 6],
  647. },
  648. hgnetv2_b5={
  649. "stem_type": 'v2',
  650. "stem_chs": [32, 64],
  651. # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
  652. "stage1": [64, 64, 128, 1, False, False, 3, 6],
  653. "stage2": [128, 128, 512, 2, True, False, 3, 6],
  654. "stage3": [512, 256, 1024, 5, True, True, 5, 6],
  655. "stage4": [1024, 512, 2048, 2, True, True, 5, 6],
  656. },
  657. hgnetv2_b6={
  658. "stem_type": 'v2',
  659. "stem_chs": [48, 96],
  660. # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
  661. "stage1": [96, 96, 192, 2, False, False, 3, 6],
  662. "stage2": [192, 192, 512, 3, True, False, 3, 6],
  663. "stage3": [512, 384, 1024, 6, True, True, 5, 6],
  664. "stage4": [1024, 768, 2048, 3, True, True, 5, 6],
  665. },
  666. )
  667. def _create_hgnet(variant, pretrained=False, **kwargs):
  668. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
  669. return build_model_with_cfg(
  670. HighPerfGpuNet,
  671. variant,
  672. pretrained,
  673. model_cfg=model_cfgs[variant],
  674. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  675. **kwargs,
  676. )
  677. def _cfg(url='', **kwargs):
  678. return {
  679. 'url': url,
  680. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  681. 'crop_pct': 0.965, 'interpolation': 'bicubic',
  682. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  683. 'classifier': 'head.fc', 'first_conv': 'stem.stem1.conv',
  684. 'test_crop_pct': 1.0, 'test_input_size': (3, 288, 288),
  685. 'license': 'apache-2.0',
  686. **kwargs,
  687. }
  688. default_cfgs = generate_default_cfgs({
  689. 'hgnet_tiny.paddle_in1k': _cfg(
  690. first_conv='stem.stem.0.conv',
  691. hf_hub_id='timm/'),
  692. 'hgnet_tiny.ssld_in1k': _cfg(
  693. first_conv='stem.stem.0.conv',
  694. hf_hub_id='timm/'),
  695. 'hgnet_small.paddle_in1k': _cfg(
  696. first_conv='stem.stem.0.conv',
  697. hf_hub_id='timm/'),
  698. 'hgnet_small.ssld_in1k': _cfg(
  699. first_conv='stem.stem.0.conv',
  700. hf_hub_id='timm/'),
  701. 'hgnet_base.ssld_in1k': _cfg(
  702. first_conv='stem.stem.0.conv',
  703. hf_hub_id='timm/'),
  704. 'hgnetv2_b0.ssld_stage2_ft_in1k': _cfg(
  705. hf_hub_id='timm/'),
  706. 'hgnetv2_b0.ssld_stage1_in22k_in1k': _cfg(
  707. hf_hub_id='timm/'),
  708. 'hgnetv2_b1.ssld_stage2_ft_in1k': _cfg(
  709. hf_hub_id='timm/'),
  710. 'hgnetv2_b1.ssld_stage1_in22k_in1k': _cfg(
  711. hf_hub_id='timm/'),
  712. 'hgnetv2_b2.ssld_stage2_ft_in1k': _cfg(
  713. hf_hub_id='timm/'),
  714. 'hgnetv2_b2.ssld_stage1_in22k_in1k': _cfg(
  715. hf_hub_id='timm/'),
  716. 'hgnetv2_b3.ssld_stage2_ft_in1k': _cfg(
  717. hf_hub_id='timm/'),
  718. 'hgnetv2_b3.ssld_stage1_in22k_in1k': _cfg(
  719. hf_hub_id='timm/'),
  720. 'hgnetv2_b4.ssld_stage2_ft_in1k': _cfg(
  721. hf_hub_id='timm/'),
  722. 'hgnetv2_b4.ssld_stage1_in22k_in1k': _cfg(
  723. hf_hub_id='timm/'),
  724. 'hgnetv2_b5.ssld_stage2_ft_in1k': _cfg(
  725. hf_hub_id='timm/'),
  726. 'hgnetv2_b5.ssld_stage1_in22k_in1k': _cfg(
  727. hf_hub_id='timm/'),
  728. 'hgnetv2_b6.ssld_stage2_ft_in1k': _cfg(
  729. hf_hub_id='timm/'),
  730. 'hgnetv2_b6.ssld_stage1_in22k_in1k': _cfg(
  731. hf_hub_id='timm/'),
  732. })
  733. @register_model
  734. def hgnet_tiny(pretrained=False, **kwargs) -> HighPerfGpuNet:
  735. return _create_hgnet('hgnet_tiny', pretrained=pretrained, **kwargs)
  736. @register_model
  737. def hgnet_small(pretrained=False, **kwargs) -> HighPerfGpuNet:
  738. return _create_hgnet('hgnet_small', pretrained=pretrained, **kwargs)
  739. @register_model
  740. def hgnet_base(pretrained=False, **kwargs) -> HighPerfGpuNet:
  741. return _create_hgnet('hgnet_base', pretrained=pretrained, **kwargs)
  742. @register_model
  743. def hgnetv2_b0(pretrained=False, **kwargs) -> HighPerfGpuNet:
  744. return _create_hgnet('hgnetv2_b0', pretrained=pretrained, use_lab=True, **kwargs)
  745. @register_model
  746. def hgnetv2_b1(pretrained=False, **kwargs) -> HighPerfGpuNet:
  747. return _create_hgnet('hgnetv2_b1', pretrained=pretrained, use_lab=True, **kwargs)
  748. @register_model
  749. def hgnetv2_b2(pretrained=False, **kwargs) -> HighPerfGpuNet:
  750. return _create_hgnet('hgnetv2_b2', pretrained=pretrained, use_lab=True, **kwargs)
  751. @register_model
  752. def hgnetv2_b3(pretrained=False, **kwargs) -> HighPerfGpuNet:
  753. return _create_hgnet('hgnetv2_b3', pretrained=pretrained, use_lab=True, **kwargs)
  754. @register_model
  755. def hgnetv2_b4(pretrained=False, **kwargs) -> HighPerfGpuNet:
  756. return _create_hgnet('hgnetv2_b4', pretrained=pretrained, **kwargs)
  757. @register_model
  758. def hgnetv2_b5(pretrained=False, **kwargs) -> HighPerfGpuNet:
  759. return _create_hgnet('hgnetv2_b5', pretrained=pretrained, **kwargs)
  760. @register_model
  761. def hgnetv2_b6(pretrained=False, **kwargs) -> HighPerfGpuNet:
  762. return _create_hgnet('hgnetv2_b6', pretrained=pretrained, **kwargs)