repghost.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. """
  2. An implementation of RepGhostNet Model as defined in:
  3. RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization. https://arxiv.org/abs/2211.06088
  4. Original implementation: https://github.com/ChengpengChen/RepGhost
  5. """
  6. import copy
  7. from functools import partial
  8. from typing import List, Optional, Tuple, Union, Type
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  13. from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
  14. from ._builder import build_model_with_cfg
  15. from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
  16. from ._features import feature_take_indices
  17. from ._manipulate import checkpoint_seq
  18. from ._registry import register_model, generate_default_cfgs
  19. __all__ = ['RepGhostNet']
  20. _SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4))
  21. class RepGhostModule(nn.Module):
  22. def __init__(
  23. self,
  24. in_chs: int,
  25. out_chs: int,
  26. kernel_size: int = 1,
  27. dw_size: int = 3,
  28. stride: int = 1,
  29. relu: bool = True,
  30. reparam: bool = True,
  31. device=None,
  32. dtype=None,
  33. ):
  34. dd = {'device': device, 'dtype': dtype}
  35. super().__init__()
  36. self.out_chs = out_chs
  37. init_chs = out_chs
  38. new_chs = out_chs
  39. self.primary_conv = nn.Sequential(
  40. nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False, **dd),
  41. nn.BatchNorm2d(init_chs, **dd),
  42. nn.ReLU(inplace=True) if relu else nn.Identity(),
  43. )
  44. fusion_conv = []
  45. fusion_bn = []
  46. if reparam:
  47. fusion_conv.append(nn.Identity())
  48. fusion_bn.append(nn.BatchNorm2d(init_chs, **dd))
  49. self.fusion_conv = nn.Sequential(*fusion_conv)
  50. self.fusion_bn = nn.Sequential(*fusion_bn)
  51. self.cheap_operation = nn.Sequential(
  52. nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False, **dd),
  53. nn.BatchNorm2d(new_chs, **dd),
  54. # nn.ReLU(inplace=True) if relu else nn.Identity(),
  55. )
  56. self.relu = nn.ReLU(inplace=False) if relu else nn.Identity()
  57. def forward(self, x):
  58. x1 = self.primary_conv(x)
  59. x2 = self.cheap_operation(x1)
  60. for conv, bn in zip(self.fusion_conv, self.fusion_bn):
  61. x2 = x2 + bn(conv(x1))
  62. return self.relu(x2)
  63. def get_equivalent_kernel_bias(self):
  64. kernel3x3, bias3x3 = self._fuse_bn_tensor(self.cheap_operation[0], self.cheap_operation[1])
  65. for conv, bn in zip(self.fusion_conv, self.fusion_bn):
  66. kernel, bias = self._fuse_bn_tensor(conv, bn, kernel3x3.shape[0], kernel3x3.device)
  67. kernel3x3 += self._pad_1x1_to_3x3_tensor(kernel)
  68. bias3x3 += bias
  69. return kernel3x3, bias3x3
  70. @staticmethod
  71. def _pad_1x1_to_3x3_tensor(kernel1x1):
  72. if kernel1x1 is None:
  73. return 0
  74. else:
  75. return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  76. @staticmethod
  77. def _fuse_bn_tensor(conv, bn, in_channels=None, device=None):
  78. in_channels = in_channels if in_channels else bn.running_mean.shape[0]
  79. device = device if device else bn.weight.device
  80. if isinstance(conv, nn.Conv2d):
  81. kernel = conv.weight
  82. assert conv.bias is None
  83. else:
  84. assert isinstance(conv, nn.Identity)
  85. kernel = torch.ones(in_channels, 1, 1, 1, device=device)
  86. if isinstance(bn, nn.BatchNorm2d):
  87. running_mean = bn.running_mean
  88. running_var = bn.running_var
  89. gamma = bn.weight
  90. beta = bn.bias
  91. eps = bn.eps
  92. std = (running_var + eps).sqrt()
  93. t = (gamma / std).reshape(-1, 1, 1, 1)
  94. return kernel * t, beta - running_mean * gamma / std
  95. assert isinstance(bn, nn.Identity)
  96. return kernel, torch.zeros(in_channels).to(kernel.device)
  97. def switch_to_deploy(self):
  98. if len(self.fusion_conv) == 0 and len(self.fusion_bn) == 0:
  99. return
  100. kernel, bias = self.get_equivalent_kernel_bias()
  101. dd = {'device': kernel.device, 'dtype': kernel.dtype}
  102. self.cheap_operation = nn.Conv2d(
  103. in_channels=self.cheap_operation[0].in_channels,
  104. out_channels=self.cheap_operation[0].out_channels,
  105. kernel_size=self.cheap_operation[0].kernel_size,
  106. padding=self.cheap_operation[0].padding,
  107. dilation=self.cheap_operation[0].dilation,
  108. groups=self.cheap_operation[0].groups,
  109. bias=True,
  110. **dd)
  111. self.cheap_operation.weight.data = kernel
  112. self.cheap_operation.bias.data = bias
  113. self.__delattr__('fusion_conv')
  114. self.__delattr__('fusion_bn')
  115. self.fusion_conv = []
  116. self.fusion_bn = []
  117. def reparameterize(self):
  118. self.switch_to_deploy()
  119. class RepGhostBottleneck(nn.Module):
  120. """ RepGhost bottleneck w/ optional SE"""
  121. def __init__(
  122. self,
  123. in_chs: int,
  124. mid_chs: int,
  125. out_chs: int,
  126. dw_kernel_size: int = 3,
  127. stride: int = 1,
  128. act_layer: Type[nn.Module] = nn.ReLU,
  129. se_ratio: float = 0.,
  130. reparam: bool = True,
  131. device=None,
  132. dtype=None,
  133. ):
  134. dd = {'device': device, 'dtype': dtype}
  135. super().__init__()
  136. has_se = se_ratio is not None and se_ratio > 0.
  137. self.stride = stride
  138. # Point-wise expansion
  139. self.ghost1 = RepGhostModule(in_chs, mid_chs, relu=True, reparam=reparam, **dd)
  140. # Depth-wise convolution
  141. if self.stride > 1:
  142. self.conv_dw = nn.Conv2d(
  143. mid_chs,
  144. mid_chs,
  145. dw_kernel_size,
  146. stride=stride,
  147. padding=(dw_kernel_size-1)//2,
  148. groups=mid_chs,
  149. bias=False,
  150. **dd,
  151. )
  152. self.bn_dw = nn.BatchNorm2d(mid_chs, **dd)
  153. else:
  154. self.conv_dw = None
  155. self.bn_dw = None
  156. # Squeeze-and-excitation
  157. self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio, **dd) if has_se else None
  158. # Point-wise linear projection
  159. self.ghost2 = RepGhostModule(mid_chs, out_chs, relu=False, reparam=reparam, **dd)
  160. # shortcut
  161. if in_chs == out_chs and self.stride == 1:
  162. self.shortcut = nn.Sequential()
  163. else:
  164. self.shortcut = nn.Sequential(
  165. nn.Conv2d(
  166. in_chs,
  167. in_chs,
  168. dw_kernel_size,
  169. stride=stride,
  170. padding=(dw_kernel_size-1)//2,
  171. groups=in_chs,
  172. bias=False,
  173. **dd,
  174. ),
  175. nn.BatchNorm2d(in_chs, **dd),
  176. nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False, **dd),
  177. nn.BatchNorm2d(out_chs, **dd),
  178. )
  179. def forward(self, x):
  180. shortcut = x
  181. # 1st ghost bottleneck
  182. x = self.ghost1(x)
  183. # Depth-wise convolution
  184. if self.conv_dw is not None:
  185. x = self.conv_dw(x)
  186. x = self.bn_dw(x)
  187. # Squeeze-and-excitation
  188. if self.se is not None:
  189. x = self.se(x)
  190. # 2nd ghost bottleneck
  191. x = self.ghost2(x)
  192. x += self.shortcut(shortcut)
  193. return x
  194. class RepGhostNet(nn.Module):
  195. def __init__(
  196. self,
  197. cfgs: List[List[List]],
  198. num_classes: int = 1000,
  199. width: float = 1.0,
  200. in_chans: int = 3,
  201. output_stride: int = 32,
  202. global_pool: str = 'avg',
  203. drop_rate: float = 0.2,
  204. reparam: bool = True,
  205. device=None,
  206. dtype=None,
  207. ):
  208. super().__init__()
  209. dd = {'device': device, 'dtype': dtype}
  210. # setting of inverted residual blocks
  211. assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
  212. self.cfgs = cfgs
  213. self.num_classes = num_classes
  214. self.in_chans = in_chans
  215. self.drop_rate = drop_rate
  216. self.grad_checkpointing = False
  217. self.feature_info = []
  218. # building first layer
  219. stem_chs = make_divisible(16 * width, 4)
  220. self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False, **dd)
  221. self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem'))
  222. self.bn1 = nn.BatchNorm2d(stem_chs, **dd)
  223. self.act1 = nn.ReLU(inplace=True)
  224. prev_chs = stem_chs
  225. # building inverted residual blocks
  226. stages = nn.ModuleList([])
  227. block = RepGhostBottleneck
  228. stage_idx = 0
  229. net_stride = 2
  230. for cfg in self.cfgs:
  231. layers = []
  232. s = 1
  233. for k, exp_size, c, se_ratio, s in cfg:
  234. out_chs = make_divisible(c * width, 4)
  235. mid_chs = make_divisible(exp_size * width, 4)
  236. layers.append(block(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio, reparam=reparam, **dd))
  237. prev_chs = out_chs
  238. if s > 1:
  239. net_stride *= 2
  240. self.feature_info.append(dict(
  241. num_chs=prev_chs, reduction=net_stride, module=f'blocks.{stage_idx}'))
  242. stages.append(nn.Sequential(*layers))
  243. stage_idx += 1
  244. out_chs = make_divisible(exp_size * width * 2, 4)
  245. stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1, **dd)))
  246. self.pool_dim = prev_chs = out_chs
  247. self.blocks = nn.Sequential(*stages)
  248. # building last several layers
  249. self.num_features = prev_chs
  250. self.head_hidden_size = out_chs = 1280
  251. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  252. self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True, **dd)
  253. self.act2 = nn.ReLU(inplace=True)
  254. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  255. self.classifier = Linear(out_chs, num_classes, **dd) if num_classes > 0 else nn.Identity()
  256. @torch.jit.ignore
  257. def group_matcher(self, coarse=False):
  258. matcher = dict(
  259. stem=r'^conv_stem|bn1',
  260. blocks=[
  261. (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
  262. (r'conv_head', (99999,))
  263. ]
  264. )
  265. return matcher
  266. @torch.jit.ignore
  267. def set_grad_checkpointing(self, enable=True):
  268. self.grad_checkpointing = enable
  269. @torch.jit.ignore
  270. def get_classifier(self) -> nn.Module:
  271. return self.classifier
  272. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  273. self.num_classes = num_classes
  274. if global_pool is not None:
  275. # NOTE: cannot meaningfully change pooling of efficient head after creation
  276. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  277. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  278. if num_classes > 0:
  279. device = self.classifier.weight.device if hasattr(self.classifier, 'weight') else None
  280. dtype = self.classifier.weight.dtype if hasattr(self.classifier, 'weight') else None
  281. dd = {'device': device, 'dtype': dtype}
  282. self.classifier = Linear(self.head_hidden_size, num_classes, **dd)
  283. else:
  284. self.classifier = nn.Identity()
  285. def forward_intermediates(
  286. self,
  287. x: torch.Tensor,
  288. indices: Optional[Union[int, List[int]]] = None,
  289. norm: bool = False,
  290. stop_early: bool = False,
  291. output_fmt: str = 'NCHW',
  292. intermediates_only: bool = False,
  293. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  294. """ Forward features that returns intermediates.
  295. Args:
  296. x: Input image tensor
  297. indices: Take last n blocks if int, all if None, select matching indices if sequence
  298. norm: Apply norm layer to compatible intermediates
  299. stop_early: Stop iterating over blocks when last desired intermediate hit
  300. output_fmt: Shape of intermediate feature outputs
  301. intermediates_only: Only return intermediate features
  302. Returns:
  303. """
  304. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  305. intermediates = []
  306. stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
  307. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  308. take_indices = [stage_ends[i]+1 for i in take_indices]
  309. max_index = stage_ends[max_index]
  310. # forward pass
  311. feat_idx = 0
  312. x = self.conv_stem(x)
  313. if feat_idx in take_indices:
  314. intermediates.append(x)
  315. x = self.bn1(x)
  316. x = self.act1(x)
  317. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  318. stages = self.blocks
  319. else:
  320. stages = self.blocks[:max_index + 1]
  321. for feat_idx, stage in enumerate(stages, start=1):
  322. if self.grad_checkpointing and not torch.jit.is_scripting():
  323. x = checkpoint_seq(stage, x)
  324. else:
  325. x = stage(x)
  326. if feat_idx in take_indices:
  327. intermediates.append(x)
  328. if intermediates_only:
  329. return intermediates
  330. return x, intermediates
  331. def prune_intermediate_layers(
  332. self,
  333. indices: Union[int, List[int]] = 1,
  334. prune_norm: bool = False,
  335. prune_head: bool = True,
  336. ):
  337. """ Prune layers not required for specified intermediates.
  338. """
  339. stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
  340. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  341. max_index = stage_ends[max_index]
  342. self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0
  343. if prune_head:
  344. self.reset_classifier(0, '')
  345. return take_indices
  346. def forward_features(self, x):
  347. x = self.conv_stem(x)
  348. x = self.bn1(x)
  349. x = self.act1(x)
  350. if self.grad_checkpointing and not torch.jit.is_scripting():
  351. x = checkpoint_seq(self.blocks, x, flatten=True)
  352. else:
  353. x = self.blocks(x)
  354. return x
  355. def forward_head(self, x, pre_logits: bool = False):
  356. x = self.global_pool(x)
  357. x = self.conv_head(x)
  358. x = self.act2(x)
  359. x = self.flatten(x)
  360. if self.drop_rate > 0.:
  361. x = F.dropout(x, p=self.drop_rate, training=self.training)
  362. return x if pre_logits else self.classifier(x)
  363. def forward(self, x):
  364. x = self.forward_features(x)
  365. x = self.forward_head(x)
  366. return x
  367. def convert_to_deploy(self):
  368. repghost_model_convert(self, do_copy=False)
  369. def repghost_model_convert(model: torch.nn.Module, save_path=None, do_copy=True):
  370. """
  371. taken from from https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
  372. """
  373. if do_copy:
  374. model = copy.deepcopy(model)
  375. for module in model.modules():
  376. if hasattr(module, 'switch_to_deploy'):
  377. module.switch_to_deploy()
  378. if save_path is not None:
  379. torch.save(model.state_dict(), save_path)
  380. return model
  381. def _create_repghostnet(variant, width=1.0, pretrained=False, **kwargs):
  382. """
  383. Constructs a RepGhostNet model
  384. """
  385. cfgs = [
  386. # k, t, c, SE, s
  387. # stage1
  388. [[3, 8, 16, 0, 1]],
  389. # stage2
  390. [[3, 24, 24, 0, 2]],
  391. [[3, 36, 24, 0, 1]],
  392. # stage3
  393. [[5, 36, 40, 0.25, 2]],
  394. [[5, 60, 40, 0.25, 1]],
  395. # stage4
  396. [[3, 120, 80, 0, 2]],
  397. [[3, 100, 80, 0, 1],
  398. [3, 120, 80, 0, 1],
  399. [3, 120, 80, 0, 1],
  400. [3, 240, 112, 0.25, 1],
  401. [3, 336, 112, 0.25, 1]
  402. ],
  403. # stage5
  404. [[5, 336, 160, 0.25, 2]],
  405. [[5, 480, 160, 0, 1],
  406. [5, 480, 160, 0.25, 1],
  407. [5, 480, 160, 0, 1],
  408. [5, 480, 160, 0.25, 1]
  409. ]
  410. ]
  411. model_kwargs = dict(
  412. cfgs=cfgs,
  413. width=width,
  414. **kwargs,
  415. )
  416. return build_model_with_cfg(
  417. RepGhostNet,
  418. variant,
  419. pretrained,
  420. feature_cfg=dict(flatten_sequential=True),
  421. **model_kwargs,
  422. )
  423. def _cfg(url='', **kwargs):
  424. return {
  425. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  426. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  427. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  428. 'first_conv': 'conv_stem', 'classifier': 'classifier',
  429. 'license': 'mit',
  430. **kwargs
  431. }
  432. default_cfgs = generate_default_cfgs({
  433. 'repghostnet_050.in1k': _cfg(
  434. hf_hub_id='timm/',
  435. # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_5x_43M_66.95.pth.tar'
  436. ),
  437. 'repghostnet_058.in1k': _cfg(
  438. hf_hub_id='timm/',
  439. # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_58x_60M_68.94.pth.tar'
  440. ),
  441. 'repghostnet_080.in1k': _cfg(
  442. hf_hub_id='timm/',
  443. # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_8x_96M_72.24.pth.tar'
  444. ),
  445. 'repghostnet_100.in1k': _cfg(
  446. hf_hub_id='timm/',
  447. # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_0x_142M_74.22.pth.tar'
  448. ),
  449. 'repghostnet_111.in1k': _cfg(
  450. hf_hub_id='timm/',
  451. # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_11x_170M_75.07.pth.tar'
  452. ),
  453. 'repghostnet_130.in1k': _cfg(
  454. hf_hub_id='timm/',
  455. # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_3x_231M_76.37.pth.tar'
  456. ),
  457. 'repghostnet_150.in1k': _cfg(
  458. hf_hub_id='timm/',
  459. # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_5x_301M_77.45.pth.tar'
  460. ),
  461. 'repghostnet_200.in1k': _cfg(
  462. hf_hub_id='timm/',
  463. # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_2_0x_516M_78.81.pth.tar'
  464. ),
  465. })
  466. @register_model
  467. def repghostnet_050(pretrained=False, **kwargs) -> RepGhostNet:
  468. """ RepGhostNet-0.5x """
  469. model = _create_repghostnet('repghostnet_050', width=0.5, pretrained=pretrained, **kwargs)
  470. return model
  471. @register_model
  472. def repghostnet_058(pretrained=False, **kwargs) -> RepGhostNet:
  473. """ RepGhostNet-0.58x """
  474. model = _create_repghostnet('repghostnet_058', width=0.58, pretrained=pretrained, **kwargs)
  475. return model
  476. @register_model
  477. def repghostnet_080(pretrained=False, **kwargs) -> RepGhostNet:
  478. """ RepGhostNet-0.8x """
  479. model = _create_repghostnet('repghostnet_080', width=0.8, pretrained=pretrained, **kwargs)
  480. return model
  481. @register_model
  482. def repghostnet_100(pretrained=False, **kwargs) -> RepGhostNet:
  483. """ RepGhostNet-1.0x """
  484. model = _create_repghostnet('repghostnet_100', width=1.0, pretrained=pretrained, **kwargs)
  485. return model
  486. @register_model
  487. def repghostnet_111(pretrained=False, **kwargs) -> RepGhostNet:
  488. """ RepGhostNet-1.11x """
  489. model = _create_repghostnet('repghostnet_111', width=1.11, pretrained=pretrained, **kwargs)
  490. return model
  491. @register_model
  492. def repghostnet_130(pretrained=False, **kwargs) -> RepGhostNet:
  493. """ RepGhostNet-1.3x """
  494. model = _create_repghostnet('repghostnet_130', width=1.3, pretrained=pretrained, **kwargs)
  495. return model
  496. @register_model
  497. def repghostnet_150(pretrained=False, **kwargs) -> RepGhostNet:
  498. """ RepGhostNet-1.5x """
  499. model = _create_repghostnet('repghostnet_150', width=1.5, pretrained=pretrained, **kwargs)
  500. return model
  501. @register_model
  502. def repghostnet_200(pretrained=False, **kwargs) -> RepGhostNet:
  503. """ RepGhostNet-2.0x """
  504. model = _create_repghostnet('repghostnet_200', width=2.0, pretrained=pretrained, **kwargs)
  505. return model