ghostnet.py 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016
  1. """
  2. An implementation of GhostNet & GhostNetV2 Models as defined in:
  3. GhostNet: More Features from Cheap Operations. https://arxiv.org/abs/1911.11907
  4. GhostNetV2: Enhance Cheap Operation with Long-Range Attention. https://proceedings.neurips.cc/paper_files/paper/2022/file/40b60852a4abdaa696b5a1a78da34635-Paper-Conference.pdf
  5. GhostNetV3: Exploring the Training Strategies for Compact Models. https://arxiv.org/abs/2404.11202
  6. The train script & code of models at:
  7. Original model: https://github.com/huawei-noah/CV-backbones/tree/master/ghostnet_pytorch
  8. Original model: https://github.com/huawei-noah/Efficient-AI-Backbones/blob/master/ghostnetv2_pytorch/model/ghostnetv2_torch.py
  9. Original model: https://github.com/huawei-noah/Efficient-AI-Backbones/blob/master/ghostnetv3_pytorch/ghostnetv3.py
  10. """
  11. import math
  12. from functools import partial
  13. from typing import Any, Dict, List, Set, Optional, Tuple, Union, Type
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  18. from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
  19. from timm.utils.model import reparameterize_model
  20. from ._builder import build_model_with_cfg
  21. from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
  22. from ._features import feature_take_indices
  23. from ._manipulate import checkpoint_seq
  24. from ._registry import register_model, generate_default_cfgs
  25. __all__ = ['GhostNet']
  26. _SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4))
  27. class GhostModule(nn.Module):
  28. def __init__(
  29. self,
  30. in_chs: int,
  31. out_chs: int,
  32. kernel_size: int = 1,
  33. ratio: int = 2,
  34. dw_size: int = 3,
  35. stride: int = 1,
  36. act_layer: Type[nn.Module] = nn.ReLU,
  37. device=None,
  38. dtype=None,
  39. ):
  40. dd = {'device': device, 'dtype': dtype}
  41. super().__init__()
  42. self.out_chs = out_chs
  43. init_chs = math.ceil(out_chs / ratio)
  44. new_chs = init_chs * (ratio - 1)
  45. self.primary_conv = nn.Sequential(
  46. nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False, **dd),
  47. nn.BatchNorm2d(init_chs, **dd),
  48. act_layer(inplace=True),
  49. )
  50. self.cheap_operation = nn.Sequential(
  51. nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False, **dd),
  52. nn.BatchNorm2d(new_chs, **dd),
  53. act_layer(inplace=True),
  54. )
  55. def forward(self, x: torch.Tensor) -> torch.Tensor:
  56. x1 = self.primary_conv(x)
  57. x2 = self.cheap_operation(x1)
  58. out = torch.cat([x1, x2], dim=1)
  59. return out[:, :self.out_chs, :, :]
  60. class GhostModuleV2(nn.Module):
  61. def __init__(
  62. self,
  63. in_chs: int,
  64. out_chs: int,
  65. kernel_size: int = 1,
  66. ratio: int = 2,
  67. dw_size: int = 3,
  68. stride: int = 1,
  69. act_layer: Type[nn.Module] = nn.ReLU,
  70. device=None,
  71. dtype=None,
  72. ):
  73. dd = {'device': device, 'dtype': dtype}
  74. super().__init__()
  75. self.gate_fn = nn.Sigmoid()
  76. self.out_chs = out_chs
  77. init_chs = math.ceil(out_chs / ratio)
  78. new_chs = init_chs * (ratio - 1)
  79. self.primary_conv = nn.Sequential(
  80. nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False, **dd),
  81. nn.BatchNorm2d(init_chs, **dd),
  82. act_layer(inplace=True),
  83. )
  84. self.cheap_operation = nn.Sequential(
  85. nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size // 2, groups=init_chs, bias=False, **dd),
  86. nn.BatchNorm2d(new_chs, **dd),
  87. act_layer(inplace=True),
  88. )
  89. self.short_conv = nn.Sequential(
  90. nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False, **dd),
  91. nn.BatchNorm2d(out_chs, **dd),
  92. nn.Conv2d(out_chs, out_chs, kernel_size=(1, 5), stride=1, padding=(0, 2), groups=out_chs, bias=False, **dd),
  93. nn.BatchNorm2d(out_chs, **dd),
  94. nn.Conv2d(out_chs, out_chs, kernel_size=(5, 1), stride=1, padding=(2, 0), groups=out_chs, bias=False, **dd),
  95. nn.BatchNorm2d(out_chs, **dd),
  96. )
  97. def forward(self, x: torch.Tensor) -> torch.Tensor:
  98. res = self.short_conv(F.avg_pool2d(x, kernel_size=2, stride=2))
  99. x1 = self.primary_conv(x)
  100. x2 = self.cheap_operation(x1)
  101. out = torch.cat([x1, x2], dim=1)
  102. return out[:, :self.out_chs, :, :] * F.interpolate(
  103. self.gate_fn(res), size=(out.shape[-2], out.shape[-1]), mode='nearest')
  104. class GhostModuleV3(nn.Module):
  105. def __init__(
  106. self,
  107. in_chs: int,
  108. out_chs: int,
  109. kernel_size: int = 1,
  110. ratio: int = 2,
  111. dw_size: int = 3,
  112. stride: int = 1,
  113. act_layer: Type[nn.Module] = nn.ReLU,
  114. mode: str = 'original',
  115. device=None,
  116. dtype=None,
  117. ):
  118. dd = {'device': device, 'dtype': dtype}
  119. super().__init__()
  120. self.gate_fn = nn.Sigmoid()
  121. self.out_chs = out_chs
  122. init_chs = math.ceil(out_chs / ratio)
  123. new_chs = init_chs * (ratio - 1)
  124. self.mode = mode
  125. self.num_conv_branches = 3
  126. self.infer_mode = False
  127. if not self.infer_mode:
  128. self.primary_conv = nn.Identity()
  129. self.cheap_operation = nn.Identity()
  130. self.primary_rpr_skip = None
  131. self.primary_rpr_scale = None
  132. self.primary_rpr_conv = nn.ModuleList([
  133. ConvBnAct(
  134. in_chs,
  135. init_chs,
  136. kernel_size,
  137. stride,
  138. pad_type=kernel_size // 2,
  139. act_layer=None,
  140. **dd,
  141. ) for _ in range(self.num_conv_branches)
  142. ])
  143. # Re-parameterizable scale branch
  144. self.primary_activation = act_layer(inplace=True)
  145. self.cheap_rpr_skip = nn.BatchNorm2d(init_chs, **dd)
  146. self.cheap_rpr_conv = nn.ModuleList([
  147. ConvBnAct(
  148. init_chs,
  149. new_chs,
  150. dw_size,
  151. 1,
  152. pad_type=dw_size // 2,
  153. group_size=1,
  154. act_layer=None,
  155. **dd,
  156. ) for _ in range(self.num_conv_branches)
  157. ])
  158. # Re-parameterizable scale branch
  159. self.cheap_rpr_scale = ConvBnAct(init_chs, new_chs, 1, 1, pad_type=0, group_size=1, act_layer=None, **dd)
  160. self.cheap_activation = act_layer(inplace=True)
  161. self.short_conv = nn.Sequential(
  162. nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False, **dd),
  163. nn.BatchNorm2d(out_chs, **dd),
  164. nn.Conv2d(out_chs, out_chs, kernel_size=(1,5), stride=1, padding=(0,2), groups=out_chs, bias=False, **dd),
  165. nn.BatchNorm2d(out_chs, **dd),
  166. nn.Conv2d(out_chs, out_chs, kernel_size=(5,1), stride=1, padding=(2,0), groups=out_chs, bias=False, **dd),
  167. nn.BatchNorm2d(out_chs, **dd),
  168. ) if self.mode in ['shortcut'] else nn.Identity()
  169. self.in_channels = init_chs
  170. self.groups = init_chs
  171. self.kernel_size = dw_size
  172. def forward(self, x):
  173. if self.infer_mode:
  174. x1 = self.primary_conv(x)
  175. x2 = self.cheap_operation(x1)
  176. else:
  177. x1 = 0
  178. for primary_rpr_conv in self.primary_rpr_conv:
  179. x1 += primary_rpr_conv(x)
  180. x1 = self.primary_activation(x1)
  181. x2 = self.cheap_rpr_scale(x1) + self.cheap_rpr_skip(x1)
  182. for cheap_rpr_conv in self.cheap_rpr_conv:
  183. x2 += cheap_rpr_conv(x1)
  184. x2 = self.cheap_activation(x2)
  185. out = torch.cat([x1,x2], dim=1)
  186. if self.mode not in ['shortcut']:
  187. return out
  188. else:
  189. res = self.short_conv(F.avg_pool2d(x, kernel_size=2, stride=2))
  190. return out[:,:self.out_chs,:,:] * F.interpolate(
  191. self.gate_fn(res), size=(out.shape[-2], out.shape[-1]), mode='nearest')
  192. def _get_kernel_bias_primary(self):
  193. kernel_scale = 0
  194. bias_scale = 0
  195. if self.primary_rpr_scale is not None:
  196. kernel_scale, bias_scale = self._fuse_bn_tensor(self.primary_rpr_scale)
  197. pad = self.kernel_size // 2
  198. kernel_scale = F.pad(kernel_scale, [pad, pad, pad, pad])
  199. kernel_identity = 0
  200. bias_identity = 0
  201. if self.primary_rpr_skip is not None:
  202. kernel_identity, bias_identity = self._fuse_bn_tensor(self.primary_rpr_skip)
  203. kernel_conv = 0
  204. bias_conv = 0
  205. for ix in range(self.num_conv_branches):
  206. _kernel, _bias = self._fuse_bn_tensor(self.primary_rpr_conv[ix])
  207. kernel_conv += _kernel
  208. bias_conv += _bias
  209. kernel_final = kernel_conv + kernel_scale + kernel_identity
  210. bias_final = bias_conv + bias_scale + bias_identity
  211. return kernel_final, bias_final
  212. def _get_kernel_bias_cheap(self):
  213. kernel_scale = 0
  214. bias_scale = 0
  215. if self.cheap_rpr_scale is not None:
  216. kernel_scale, bias_scale = self._fuse_bn_tensor(self.cheap_rpr_scale)
  217. pad = self.kernel_size // 2
  218. kernel_scale = F.pad(kernel_scale, [pad, pad, pad, pad])
  219. kernel_identity = 0
  220. bias_identity = 0
  221. if self.cheap_rpr_skip is not None:
  222. kernel_identity, bias_identity = self._fuse_bn_tensor(self.cheap_rpr_skip)
  223. kernel_conv = 0
  224. bias_conv = 0
  225. for ix in range(self.num_conv_branches):
  226. _kernel, _bias = self._fuse_bn_tensor(self.cheap_rpr_conv[ix])
  227. kernel_conv += _kernel
  228. bias_conv += _bias
  229. kernel_final = kernel_conv + kernel_scale + kernel_identity
  230. bias_final = bias_conv + bias_scale + bias_identity
  231. return kernel_final, bias_final
  232. def _fuse_bn_tensor(self, branch):
  233. if isinstance(branch, ConvBnAct):
  234. kernel = branch.conv.weight
  235. running_mean = branch.bn1.running_mean
  236. running_var = branch.bn1.running_var
  237. gamma = branch.bn1.weight
  238. beta = branch.bn1.bias
  239. eps = branch.bn1.eps
  240. else:
  241. assert isinstance(branch, nn.BatchNorm2d)
  242. if not hasattr(self, 'id_tensor'):
  243. input_dim = self.in_channels // self.groups
  244. kernel_value = torch.zeros(
  245. (self.in_channels, input_dim, self.kernel_size, self.kernel_size),
  246. dtype=branch.weight.dtype,
  247. device=branch.weight.device
  248. )
  249. for i in range(self.in_channels):
  250. kernel_value[i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2] = 1
  251. self.id_tensor = kernel_value
  252. kernel = self.id_tensor
  253. running_mean = branch.running_mean
  254. running_var = branch.running_var
  255. gamma = branch.weight
  256. beta = branch.bias
  257. eps = branch.eps
  258. std = (running_var + eps).sqrt()
  259. t = (gamma / std).reshape(-1, 1, 1, 1)
  260. return kernel * t, beta - running_mean * gamma / std
  261. def switch_to_deploy(self):
  262. if self.infer_mode:
  263. return
  264. primary_kernel, primary_bias = self._get_kernel_bias_primary()
  265. self.primary_conv = nn.Conv2d(
  266. in_channels=self.primary_rpr_conv[0].conv.in_channels,
  267. out_channels=self.primary_rpr_conv[0].conv.out_channels,
  268. kernel_size=self.primary_rpr_conv[0].conv.kernel_size,
  269. stride=self.primary_rpr_conv[0].conv.stride,
  270. padding=self.primary_rpr_conv[0].conv.padding,
  271. dilation=self.primary_rpr_conv[0].conv.dilation,
  272. groups=self.primary_rpr_conv[0].conv.groups,
  273. bias=True
  274. )
  275. self.primary_conv.weight.data = primary_kernel
  276. self.primary_conv.bias.data = primary_bias
  277. self.primary_conv = nn.Sequential(
  278. self.primary_conv,
  279. self.primary_activation if self.primary_activation is not None else nn.Sequential()
  280. )
  281. cheap_kernel, cheap_bias = self._get_kernel_bias_cheap()
  282. self.cheap_operation = nn.Conv2d(
  283. in_channels=self.cheap_rpr_conv[0].conv.in_channels,
  284. out_channels=self.cheap_rpr_conv[0].conv.out_channels,
  285. kernel_size=self.cheap_rpr_conv[0].conv.kernel_size,
  286. stride=self.cheap_rpr_conv[0].conv.stride,
  287. padding=self.cheap_rpr_conv[0].conv.padding,
  288. dilation=self.cheap_rpr_conv[0].conv.dilation,
  289. groups=self.cheap_rpr_conv[0].conv.groups,
  290. bias=True
  291. )
  292. self.cheap_operation.weight.data = cheap_kernel
  293. self.cheap_operation.bias.data = cheap_bias
  294. self.cheap_operation = nn.Sequential(
  295. self.cheap_operation,
  296. self.cheap_activation if self.cheap_activation is not None else nn.Sequential()
  297. )
  298. # Delete un-used branches
  299. for para in self.parameters():
  300. para.detach_()
  301. if hasattr(self, 'primary_rpr_conv'):
  302. self.__delattr__('primary_rpr_conv')
  303. if hasattr(self, 'primary_rpr_scale'):
  304. self.__delattr__('primary_rpr_scale')
  305. if hasattr(self, 'primary_rpr_skip'):
  306. self.__delattr__('primary_rpr_skip')
  307. if hasattr(self, 'cheap_rpr_conv'):
  308. self.__delattr__('cheap_rpr_conv')
  309. if hasattr(self, 'cheap_rpr_scale'):
  310. self.__delattr__('cheap_rpr_scale')
  311. if hasattr(self, 'cheap_rpr_skip'):
  312. self.__delattr__('cheap_rpr_skip')
  313. self.infer_mode = True
  314. def reparameterize(self):
  315. self.switch_to_deploy()
  316. class GhostBottleneck(nn.Module):
  317. """ GhostV1/V2 bottleneck w/ optional SE"""
  318. def __init__(
  319. self,
  320. in_chs: int,
  321. mid_chs: int,
  322. out_chs: int,
  323. dw_kernel_size: int = 3,
  324. stride: int = 1,
  325. act_layer: Type[nn.Module] = nn.ReLU,
  326. se_ratio: float = 0.,
  327. mode: str = 'original',
  328. device=None,
  329. dtype=None,
  330. ):
  331. dd = {'device': device, 'dtype': dtype}
  332. super().__init__()
  333. has_se = se_ratio is not None and se_ratio > 0.
  334. self.stride = stride
  335. # Point-wise expansion
  336. if mode == 'original':
  337. self.ghost1 = GhostModule(in_chs, mid_chs, act_layer=act_layer, **dd)
  338. else:
  339. self.ghost1 = GhostModuleV2(in_chs, mid_chs, act_layer=act_layer, **dd)
  340. # Depth-wise convolution
  341. if self.stride > 1:
  342. self.conv_dw = nn.Conv2d(
  343. mid_chs,
  344. mid_chs,
  345. dw_kernel_size,
  346. stride=stride,
  347. padding=(dw_kernel_size-1)//2,
  348. groups=mid_chs,
  349. bias=False,
  350. **dd,
  351. )
  352. self.bn_dw = nn.BatchNorm2d(mid_chs, **dd)
  353. else:
  354. self.conv_dw = None
  355. self.bn_dw = None
  356. # Squeeze-and-excitation
  357. self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio, **dd) if has_se else None
  358. # Point-wise linear projection
  359. self.ghost2 = GhostModule(mid_chs, out_chs, act_layer=nn.Identity, **dd)
  360. # shortcut
  361. if in_chs == out_chs and self.stride == 1:
  362. self.shortcut = nn.Sequential()
  363. else:
  364. self.shortcut = nn.Sequential(
  365. nn.Conv2d(
  366. in_chs,
  367. in_chs,
  368. dw_kernel_size,
  369. stride=stride,
  370. padding=(dw_kernel_size-1)//2,
  371. groups=in_chs,
  372. bias=False,
  373. **dd,
  374. ),
  375. nn.BatchNorm2d(in_chs, **dd),
  376. nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False, **dd),
  377. nn.BatchNorm2d(out_chs, **dd),
  378. )
  379. def forward(self, x: torch.Tensor) -> torch.Tensor:
  380. shortcut = x
  381. # 1st ghost bottleneck
  382. x = self.ghost1(x)
  383. # Depth-wise convolution
  384. if self.conv_dw is not None:
  385. x = self.conv_dw(x)
  386. x = self.bn_dw(x)
  387. # Squeeze-and-excitation
  388. if self.se is not None:
  389. x = self.se(x)
  390. # 2nd ghost bottleneck
  391. x = self.ghost2(x)
  392. x += self.shortcut(shortcut)
  393. return x
  394. class GhostBottleneckV3(nn.Module):
  395. """ GhostV3 bottleneck w/ optional SE"""
  396. def __init__(
  397. self,
  398. in_chs: int,
  399. mid_chs: int,
  400. out_chs: int,
  401. dw_kernel_size: int = 3,
  402. stride: int = 1,
  403. act_layer: Type[nn.Module] = nn.ReLU,
  404. se_ratio: float = 0.,
  405. mode: str = 'original',
  406. device=None,
  407. dtype=None,
  408. ):
  409. dd = {'device': device, 'dtype': dtype}
  410. super().__init__()
  411. has_se = se_ratio is not None and se_ratio > 0.
  412. self.stride = stride
  413. self.num_conv_branches = 3
  414. self.infer_mode = False
  415. if not self.infer_mode:
  416. self.conv_dw = nn.Identity()
  417. self.bn_dw = nn.Identity()
  418. # Point-wise expansion
  419. self.ghost1 = GhostModuleV3(in_chs, mid_chs, act_layer=act_layer, mode=mode, **dd)
  420. # Depth-wise convolution
  421. if self.stride > 1:
  422. self.dw_rpr_conv = nn.ModuleList([ConvBnAct(
  423. mid_chs,
  424. mid_chs,
  425. dw_kernel_size,
  426. stride,
  427. pad_type=(dw_kernel_size - 1) // 2,
  428. group_size=1,
  429. act_layer=None,
  430. **dd,
  431. ) for _ in range(self.num_conv_branches)
  432. ])
  433. # Re-parameterizable scale branch
  434. self.dw_rpr_scale = ConvBnAct(mid_chs, mid_chs, 1, 2, pad_type=0, group_size=1, act_layer=None, **dd)
  435. self.kernel_size = dw_kernel_size
  436. self.in_channels = mid_chs
  437. else:
  438. self.dw_rpr_conv = nn.ModuleList()
  439. self.dw_rpr_scale = nn.Identity()
  440. self.dw_rpr_skip = None
  441. # Squeeze-and-excitation
  442. self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio, **dd) if has_se else nn.Identity()
  443. # Point-wise linear projection
  444. self.ghost2 = GhostModuleV3(mid_chs, out_chs, act_layer=nn.Identity, mode='original', **dd)
  445. # shortcut
  446. if in_chs == out_chs and self.stride == 1:
  447. self.shortcut = nn.Identity()
  448. else:
  449. self.shortcut = nn.Sequential(
  450. nn.Conv2d(
  451. in_chs,
  452. in_chs,
  453. dw_kernel_size,
  454. stride=stride,
  455. padding=(dw_kernel_size-1)//2,
  456. groups=in_chs,
  457. bias=False,
  458. **dd,
  459. ),
  460. nn.BatchNorm2d(in_chs, **dd),
  461. nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False, **dd),
  462. nn.BatchNorm2d(out_chs, **dd),
  463. )
  464. def forward(self, x: torch.Tensor) -> torch.Tensor:
  465. shortcut = x
  466. # 1st ghost bottleneck
  467. x = self.ghost1(x)
  468. # Depth-wise convolution
  469. if self.stride > 1:
  470. if self.infer_mode:
  471. x = self.conv_dw(x)
  472. x = self.bn_dw(x)
  473. else:
  474. x1 = self.dw_rpr_scale(x)
  475. for dw_rpr_conv in self.dw_rpr_conv:
  476. x1 += dw_rpr_conv(x)
  477. x = x1
  478. # Squeeze-and-excitation
  479. x = self.se(x)
  480. # 2nd ghost bottleneck
  481. x = self.ghost2(x)
  482. x += self.shortcut(shortcut)
  483. return x
  484. def _get_kernel_bias_dw(self):
  485. kernel_scale = 0
  486. bias_scale = 0
  487. if self.dw_rpr_scale is not None:
  488. kernel_scale, bias_scale = self._fuse_bn_tensor(self.dw_rpr_scale)
  489. pad = self.kernel_size // 2
  490. kernel_scale = F.pad(kernel_scale, [pad, pad, pad, pad])
  491. kernel_identity = 0
  492. bias_identity = 0
  493. if self.dw_rpr_skip is not None:
  494. kernel_identity, bias_identity = self._fuse_bn_tensor(self.dw_rpr_skip)
  495. kernel_conv = 0
  496. bias_conv = 0
  497. for ix in range(self.num_conv_branches):
  498. _kernel, _bias = self._fuse_bn_tensor(self.dw_rpr_conv[ix])
  499. kernel_conv += _kernel
  500. bias_conv += _bias
  501. kernel_final = kernel_conv + kernel_scale + kernel_identity
  502. bias_final = bias_conv + bias_scale + bias_identity
  503. return kernel_final, bias_final
  504. def _fuse_bn_tensor(self, branch):
  505. if isinstance(branch, ConvBnAct):
  506. kernel = branch.conv.weight
  507. running_mean = branch.bn1.running_mean
  508. running_var = branch.bn1.running_var
  509. gamma = branch.bn1.weight
  510. beta = branch.bn1.bias
  511. eps = branch.bn1.eps
  512. else:
  513. assert isinstance(branch, nn.BatchNorm2d)
  514. if not hasattr(self, 'id_tensor'):
  515. input_dim = self.in_channels // self.groups
  516. kernel_value = torch.zeros(
  517. (self.in_channels, input_dim, self.kernel_size, self.kernel_size),
  518. dtype=branch.weight.dtype,
  519. device=branch.weight.device
  520. )
  521. for i in range(self.in_channels):
  522. kernel_value[i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2] = 1
  523. self.id_tensor = kernel_value
  524. kernel = self.id_tensor
  525. running_mean = branch.running_mean
  526. running_var = branch.running_var
  527. gamma = branch.weight
  528. beta = branch.bias
  529. eps = branch.eps
  530. std = (running_var + eps).sqrt()
  531. t = (gamma / std).reshape(-1, 1, 1, 1)
  532. return kernel * t, beta - running_mean * gamma / std
  533. def switch_to_deploy(self):
  534. if self.infer_mode or self.stride == 1:
  535. return
  536. dw_kernel, dw_bias = self._get_kernel_bias_dw()
  537. self.conv_dw = nn.Conv2d(
  538. in_channels=self.dw_rpr_conv[0].conv.in_channels,
  539. out_channels=self.dw_rpr_conv[0].conv.out_channels,
  540. kernel_size=self.dw_rpr_conv[0].conv.kernel_size,
  541. stride=self.dw_rpr_conv[0].conv.stride,
  542. padding=self.dw_rpr_conv[0].conv.padding,
  543. dilation=self.dw_rpr_conv[0].conv.dilation,
  544. groups=self.dw_rpr_conv[0].conv.groups,
  545. bias=True
  546. )
  547. self.conv_dw.weight.data = dw_kernel
  548. self.conv_dw.bias.data = dw_bias
  549. self.bn_dw = nn.Identity()
  550. # Delete un-used branches
  551. for para in self.parameters():
  552. para.detach_()
  553. if hasattr(self, 'dw_rpr_conv'):
  554. self.__delattr__('dw_rpr_conv')
  555. if hasattr(self, 'dw_rpr_scale'):
  556. self.__delattr__('dw_rpr_scale')
  557. if hasattr(self, 'dw_rpr_skip'):
  558. self.__delattr__('dw_rpr_skip')
  559. self.infer_mode = True
  560. def reparameterize(self):
  561. self.switch_to_deploy()
  562. class GhostNet(nn.Module):
  563. def __init__(
  564. self,
  565. cfgs: List[List[List[Union[int, float]]]],
  566. num_classes: int = 1000,
  567. width: float = 1.0,
  568. in_chans: int = 3,
  569. output_stride: int = 32,
  570. global_pool: str = 'avg',
  571. drop_rate: float = 0.2,
  572. version: str = 'v1',
  573. device=None,
  574. dtype=None,
  575. ):
  576. super().__init__()
  577. dd = {'device': device, 'dtype': dtype}
  578. # setting of inverted residual blocks
  579. assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
  580. self.cfgs = cfgs
  581. self.num_classes = num_classes
  582. self.in_chans = in_chans
  583. self.drop_rate = drop_rate
  584. self.grad_checkpointing = False
  585. self.feature_info = []
  586. Bottleneck = GhostBottleneckV3 if version == 'v3' else GhostBottleneck
  587. # building first layer
  588. stem_chs = make_divisible(16 * width, 4)
  589. self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False, **dd)
  590. self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem'))
  591. self.bn1 = nn.BatchNorm2d(stem_chs, **dd)
  592. self.act1 = nn.ReLU(inplace=True)
  593. prev_chs = stem_chs
  594. # building inverted residual blocks
  595. stages = nn.ModuleList([])
  596. stage_idx = 0
  597. layer_idx = 0
  598. net_stride = 2
  599. for cfg in self.cfgs:
  600. layers = []
  601. s = 1
  602. for k, exp_size, c, se_ratio, s in cfg:
  603. out_chs = make_divisible(c * width, 4)
  604. mid_chs = make_divisible(exp_size * width, 4)
  605. layer_kwargs = dict(**dd)
  606. if version == 'v2' and layer_idx > 1:
  607. layer_kwargs['mode'] = 'attn'
  608. if version == 'v3' and layer_idx > 1:
  609. layer_kwargs['mode'] = 'shortcut'
  610. layers.append(Bottleneck(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio, **layer_kwargs))
  611. prev_chs = out_chs
  612. layer_idx += 1
  613. if s > 1:
  614. net_stride *= 2
  615. self.feature_info.append(dict(
  616. num_chs=prev_chs, reduction=net_stride, module=f'blocks.{stage_idx}'))
  617. stages.append(nn.Sequential(*layers))
  618. stage_idx += 1
  619. out_chs = make_divisible(exp_size * width, 4)
  620. stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1, **dd)))
  621. self.pool_dim = prev_chs = out_chs
  622. self.blocks = nn.Sequential(*stages)
  623. # building last several layers
  624. self.num_features = prev_chs
  625. self.head_hidden_size = out_chs = 1280
  626. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  627. self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True, **dd)
  628. self.act2 = nn.ReLU(inplace=True)
  629. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  630. self.classifier = Linear(out_chs, num_classes, **dd) if num_classes > 0 else nn.Identity()
  631. # FIXME init
  632. @torch.jit.ignore
  633. def no_weight_decay(self) -> Set:
  634. return set()
  635. @torch.jit.ignore
  636. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  637. matcher = dict(
  638. stem=r'^conv_stem|bn1',
  639. blocks=[
  640. (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
  641. (r'conv_head', (99999,))
  642. ]
  643. )
  644. return matcher
  645. @torch.jit.ignore
  646. def set_grad_checkpointing(self, enable: bool = True):
  647. self.grad_checkpointing = enable
  648. @torch.jit.ignore
  649. def get_classifier(self) -> nn.Module:
  650. return self.classifier
  651. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  652. self.num_classes = num_classes
  653. # cannot meaningfully change pooling of efficient head after creation
  654. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  655. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  656. self.classifier = Linear(
  657. self.head_hidden_size, num_classes,
  658. device=self.conv_head.weight.device, dtype=self.conv_head.weight.dtype
  659. ) if num_classes > 0 else nn.Identity()
  660. def forward_intermediates(
  661. self,
  662. x: torch.Tensor,
  663. indices: Optional[Union[int, List[int]]] = None,
  664. norm: bool = False,
  665. stop_early: bool = False,
  666. output_fmt: str = 'NCHW',
  667. intermediates_only: bool = False,
  668. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  669. """ Forward features that returns intermediates.
  670. Args:
  671. x: Input image tensor
  672. indices: Take last n blocks if int, all if None, select matching indices if sequence
  673. norm: Apply norm layer to compatible intermediates
  674. stop_early: Stop iterating over blocks when last desired intermediate hit
  675. output_fmt: Shape of intermediate feature outputs
  676. intermediates_only: Only return intermediate features
  677. Returns:
  678. """
  679. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  680. intermediates = []
  681. stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
  682. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  683. take_indices = [stage_ends[i]+1 for i in take_indices]
  684. max_index = stage_ends[max_index]
  685. # forward pass
  686. feat_idx = 0
  687. x = self.conv_stem(x)
  688. if feat_idx in take_indices:
  689. intermediates.append(x)
  690. x = self.bn1(x)
  691. x = self.act1(x)
  692. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  693. stages = self.blocks
  694. else:
  695. stages = self.blocks[:max_index + 1]
  696. for feat_idx, stage in enumerate(stages, start=1):
  697. if self.grad_checkpointing and not torch.jit.is_scripting():
  698. x = checkpoint_seq(stage, x)
  699. else:
  700. x = stage(x)
  701. if feat_idx in take_indices:
  702. intermediates.append(x)
  703. if intermediates_only:
  704. return intermediates
  705. return x, intermediates
  706. def prune_intermediate_layers(
  707. self,
  708. indices: Union[int, List[int]] = 1,
  709. prune_norm: bool = False,
  710. prune_head: bool = True,
  711. ):
  712. """ Prune layers not required for specified intermediates.
  713. """
  714. stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
  715. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  716. max_index = stage_ends[max_index]
  717. self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0
  718. if prune_head:
  719. self.reset_classifier(0, '')
  720. return take_indices
  721. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  722. x = self.conv_stem(x)
  723. x = self.bn1(x)
  724. x = self.act1(x)
  725. if self.grad_checkpointing and not torch.jit.is_scripting():
  726. x = checkpoint_seq(self.blocks, x, flatten=True)
  727. else:
  728. x = self.blocks(x)
  729. return x
  730. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  731. x = self.global_pool(x)
  732. x = self.conv_head(x)
  733. x = self.act2(x)
  734. x = self.flatten(x)
  735. if self.drop_rate > 0.:
  736. x = F.dropout(x, p=self.drop_rate, training=self.training)
  737. return x if pre_logits else self.classifier(x)
  738. def forward(self, x: torch.Tensor) -> torch.Tensor:
  739. x = self.forward_features(x)
  740. x = self.forward_head(x)
  741. return x
  742. def convert_to_deploy(self):
  743. reparameterize_model(self, inplace=False)
  744. def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
  745. if 'state_dict' in state_dict:
  746. state_dict = state_dict['state_dict']
  747. out_dict = {}
  748. for k, v in state_dict.items():
  749. if 'bn.' in k and '.ghost' in k:
  750. k = k.replace('bn.', 'bn1.')
  751. if 'bn.' in k and '.dw_rpr_' in k:
  752. k = k.replace('bn.', 'bn1.')
  753. if 'total' in k:
  754. continue
  755. out_dict[k] = v
  756. return out_dict
  757. def _create_ghostnet(variant: str, width: float = 1.0, pretrained: bool = False, **kwargs: Any) -> GhostNet:
  758. """
  759. Constructs a GhostNet model
  760. """
  761. cfgs = [
  762. # k, t, c, SE, s
  763. # stage1
  764. [[3, 16, 16, 0, 1]],
  765. # stage2
  766. [[3, 48, 24, 0, 2]],
  767. [[3, 72, 24, 0, 1]],
  768. # stage3
  769. [[5, 72, 40, 0.25, 2]],
  770. [[5, 120, 40, 0.25, 1]],
  771. # stage4
  772. [[3, 240, 80, 0, 2]],
  773. [[3, 200, 80, 0, 1],
  774. [3, 184, 80, 0, 1],
  775. [3, 184, 80, 0, 1],
  776. [3, 480, 112, 0.25, 1],
  777. [3, 672, 112, 0.25, 1]
  778. ],
  779. # stage5
  780. [[5, 672, 160, 0.25, 2]],
  781. [[5, 960, 160, 0, 1],
  782. [5, 960, 160, 0.25, 1],
  783. [5, 960, 160, 0, 1],
  784. [5, 960, 160, 0.25, 1]
  785. ]
  786. ]
  787. model_kwargs = dict(
  788. cfgs=cfgs,
  789. width=width,
  790. **kwargs,
  791. )
  792. return build_model_with_cfg(
  793. GhostNet,
  794. variant,
  795. pretrained,
  796. pretrained_filter_fn=checkpoint_filter_fn,
  797. feature_cfg=dict(flatten_sequential=True),
  798. **model_kwargs,
  799. )
  800. def _cfg(url='', **kwargs):
  801. return {
  802. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  803. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  804. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  805. 'first_conv': 'conv_stem', 'classifier': 'classifier',
  806. 'license': 'apache-2.0',
  807. **kwargs
  808. }
  809. default_cfgs = generate_default_cfgs({
  810. 'ghostnet_050.untrained': _cfg(),
  811. 'ghostnet_100.in1k': _cfg(
  812. hf_hub_id='timm/',
  813. # url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'
  814. ),
  815. 'ghostnet_130.untrained': _cfg(),
  816. 'ghostnetv2_100.in1k': _cfg(
  817. hf_hub_id='timm/',
  818. # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_10.pth.tar'
  819. ),
  820. 'ghostnetv2_130.in1k': _cfg(
  821. hf_hub_id='timm/',
  822. # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_13.pth.tar'
  823. ),
  824. 'ghostnetv2_160.in1k': _cfg(
  825. hf_hub_id='timm/',
  826. # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_16.pth.tar'
  827. ),
  828. 'ghostnetv3_050.untrained': _cfg(),
  829. 'ghostnetv3_100.in1k': _cfg(
  830. hf_hub_id='timm/',
  831. #url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV3/ghostnetv3-1.0.pth.tar'
  832. ),
  833. 'ghostnetv3_130.untrained': _cfg(),
  834. 'ghostnetv3_160.untrained': _cfg(),
  835. })
  836. @register_model
  837. def ghostnet_050(pretrained=False, **kwargs) -> GhostNet:
  838. """ GhostNet-0.5x """
  839. model = _create_ghostnet('ghostnet_050', width=0.5, pretrained=pretrained, **kwargs)
  840. return model
  841. @register_model
  842. def ghostnet_100(pretrained=False, **kwargs) -> GhostNet:
  843. """ GhostNet-1.0x """
  844. model = _create_ghostnet('ghostnet_100', width=1.0, pretrained=pretrained, **kwargs)
  845. return model
  846. @register_model
  847. def ghostnet_130(pretrained=False, **kwargs) -> GhostNet:
  848. """ GhostNet-1.3x """
  849. model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs)
  850. return model
  851. @register_model
  852. def ghostnetv2_100(pretrained=False, **kwargs) -> GhostNet:
  853. """ GhostNetV2-1.0x """
  854. model = _create_ghostnet('ghostnetv2_100', width=1.0, pretrained=pretrained, version='v2', **kwargs)
  855. return model
  856. @register_model
  857. def ghostnetv2_130(pretrained=False, **kwargs) -> GhostNet:
  858. """ GhostNetV2-1.3x """
  859. model = _create_ghostnet('ghostnetv2_130', width=1.3, pretrained=pretrained, version='v2', **kwargs)
  860. return model
  861. @register_model
  862. def ghostnetv2_160(pretrained=False, **kwargs) -> GhostNet:
  863. """ GhostNetV2-1.6x """
  864. model = _create_ghostnet('ghostnetv2_160', width=1.6, pretrained=pretrained, version='v2', **kwargs)
  865. return model
  866. @register_model
  867. def ghostnetv3_050(pretrained: bool = False, **kwargs: Any) -> GhostNet:
  868. """ GhostNetV3-0.5x """
  869. model = _create_ghostnet('ghostnetv3_050', width=0.5, pretrained=pretrained, version='v3', **kwargs)
  870. return model
  871. @register_model
  872. def ghostnetv3_100(pretrained: bool = False, **kwargs: Any) -> GhostNet:
  873. """ GhostNetV3-1.0x """
  874. model = _create_ghostnet('ghostnetv3_100', width=1.0, pretrained=pretrained, version='v3', **kwargs)
  875. return model
  876. @register_model
  877. def ghostnetv3_130(pretrained: bool = False, **kwargs: Any) -> GhostNet:
  878. """ GhostNetV3-1.3x """
  879. model = _create_ghostnet('ghostnetv3_130', width=1.3, pretrained=pretrained, version='v3', **kwargs)
  880. return model
  881. @register_model
  882. def ghostnetv3_160(pretrained: bool = False, **kwargs: Any) -> GhostNet:
  883. """ GhostNetV3-1.6x """
  884. model = _create_ghostnet('ghostnetv3_160', width=1.6, pretrained=pretrained, version='v3', **kwargs)
  885. return model