mobilenetv3.py 60 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526
  1. """ MobileNet V3
  2. A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
  3. Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
  4. Hacked together by / Copyright 2019, Ross Wightman
  5. """
  6. from functools import partial
  7. from typing import Any, Dict, Callable, List, Optional, Tuple, Union
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
  12. from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer
  13. from ._builder import build_model_with_cfg, pretrained_cfg_for_features
  14. from ._efficientnet_blocks import SqueezeExcite
  15. from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
  16. round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
  17. from ._features import FeatureInfo, FeatureHooks, feature_take_indices
  18. from ._manipulate import checkpoint_seq, checkpoint
  19. from ._registry import generate_default_cfgs, register_model, register_model_deprecations
  20. __all__ = ['MobileNetV3', 'MobileNetV3Features']
  21. class MobileNetV3(nn.Module):
  22. """MobileNetV3.
  23. Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific
  24. 'efficient head', where global pooling is done before the head convolution without a final batch-norm
  25. layer before the classifier.
  26. Paper: `Searching for MobileNetV3` - https://arxiv.org/abs/1905.02244
  27. Other architectures utilizing MobileNet-V3 efficient head that are supported by this impl include:
  28. * HardCoRe-NAS - https://arxiv.org/abs/2102.11646 (defn in hardcorenas.py uses this class)
  29. * FBNet-V3 - https://arxiv.org/abs/2006.02049
  30. * LCNet - https://arxiv.org/abs/2109.15099
  31. * MobileNet-V4 - https://arxiv.org/abs/2404.10518
  32. """
  33. def __init__(
  34. self,
  35. block_args: BlockArgs,
  36. num_classes: int = 1000,
  37. in_chans: int = 3,
  38. stem_size: int = 16,
  39. fix_stem: bool = False,
  40. num_features: int = 1280,
  41. head_bias: bool = True,
  42. head_norm: bool = False,
  43. pad_type: str = '',
  44. act_layer: Optional[LayerType] = None,
  45. norm_layer: Optional[LayerType] = None,
  46. aa_layer: Optional[LayerType] = None,
  47. se_layer: Optional[LayerType] = None,
  48. se_from_exp: bool = True,
  49. round_chs_fn: Callable = round_channels,
  50. drop_rate: float = 0.,
  51. drop_path_rate: float = 0.,
  52. layer_scale_init_value: Optional[float] = None,
  53. global_pool: str = 'avg',
  54. device=None,
  55. dtype=None,
  56. ):
  57. """Initialize MobileNetV3.
  58. Args:
  59. block_args: Arguments for blocks of the network.
  60. num_classes: Number of classes for classification head.
  61. in_chans: Number of input image channels.
  62. stem_size: Number of output channels of the initial stem convolution.
  63. fix_stem: If True, don't scale stem by round_chs_fn.
  64. num_features: Number of output channels of the conv head layer.
  65. head_bias: If True, add a learnable bias to the conv head layer.
  66. head_norm: If True, add normalization to the head layer.
  67. pad_type: Type of padding to use for convolution layers.
  68. act_layer: Type of activation layer.
  69. norm_layer: Type of normalization layer.
  70. aa_layer: Type of anti-aliasing layer.
  71. se_layer: Type of Squeeze-and-Excite layer.
  72. se_from_exp: If True, calculate SE channel reduction from expanded mid channels.
  73. round_chs_fn: Callable to round number of filters based on depth multiplier.
  74. drop_rate: Dropout rate.
  75. drop_path_rate: Stochastic depth rate.
  76. layer_scale_init_value: Enable layer scale on compatible blocks if not None.
  77. global_pool: Type of pooling to use for global pooling features of the FC head.
  78. """
  79. super().__init__()
  80. dd = {'device': device, 'dtype': dtype}
  81. act_layer = act_layer or nn.ReLU
  82. norm_layer = norm_layer or nn.BatchNorm2d
  83. norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
  84. se_layer = se_layer or SqueezeExcite
  85. self.num_classes = num_classes
  86. self.in_chans = in_chans
  87. self.drop_rate = drop_rate
  88. self.grad_checkpointing = False
  89. # Stem
  90. if not fix_stem:
  91. stem_size = round_chs_fn(stem_size)
  92. self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type, **dd)
  93. self.bn1 = norm_act_layer(stem_size, inplace=True, **dd)
  94. # Middle stages (IR/ER/DS Blocks)
  95. builder = EfficientNetBuilder(
  96. output_stride=32,
  97. pad_type=pad_type,
  98. round_chs_fn=round_chs_fn,
  99. se_from_exp=se_from_exp,
  100. act_layer=act_layer,
  101. norm_layer=norm_layer,
  102. aa_layer=aa_layer,
  103. se_layer=se_layer,
  104. drop_path_rate=drop_path_rate,
  105. layer_scale_init_value=layer_scale_init_value,
  106. **dd,
  107. )
  108. self.blocks = nn.Sequential(*builder(stem_size, block_args))
  109. self.feature_info = builder.features
  110. self.stage_ends = [f['stage'] for f in self.feature_info]
  111. self.num_features = builder.in_chs # features of last stage, output of forward_features()
  112. self.head_hidden_size = num_features # features of conv_head, pre_logits output
  113. # Head + Pooling
  114. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  115. num_pooled_chs = self.num_features * self.global_pool.feat_mult()
  116. if head_norm:
  117. # mobilenet-v4 post-pooling PW conv is followed by a norm+act layer
  118. self.conv_head = create_conv2d(
  119. num_pooled_chs,
  120. self.head_hidden_size,
  121. 1,
  122. padding=pad_type,
  123. bias=False, # never a bias
  124. **dd,
  125. )
  126. self.norm_head = norm_act_layer(self.head_hidden_size, **dd)
  127. self.act2 = nn.Identity()
  128. else:
  129. # mobilenet-v3 and others only have an activation after final PW conv
  130. self.conv_head = create_conv2d(
  131. num_pooled_chs,
  132. self.head_hidden_size,
  133. 1,
  134. padding=pad_type,
  135. bias=head_bias,
  136. **dd,
  137. )
  138. self.norm_head = nn.Identity()
  139. self.act2 = act_layer(inplace=True)
  140. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  141. self.classifier = Linear(self.head_hidden_size, num_classes, **dd) if num_classes > 0 else nn.Identity()
  142. efficientnet_init_weights(self)
  143. def as_sequential(self) -> nn.Sequential:
  144. """Convert model to sequential form.
  145. Returns:
  146. Sequential module containing all layers.
  147. """
  148. layers = [self.conv_stem, self.bn1]
  149. layers.extend(self.blocks)
  150. layers.extend([self.global_pool, self.conv_head, self.norm_head, self.act2])
  151. layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
  152. return nn.Sequential(*layers)
  153. @torch.jit.ignore
  154. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  155. """Group parameters for optimization."""
  156. return dict(
  157. stem=r'^conv_stem|bn1',
  158. blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)'
  159. )
  160. @torch.jit.ignore
  161. def set_grad_checkpointing(self, enable: bool = True) -> None:
  162. """Enable or disable gradient checkpointing."""
  163. self.grad_checkpointing = enable
  164. @torch.jit.ignore
  165. def get_classifier(self) -> nn.Module:
  166. """Get the classifier head."""
  167. return self.classifier
  168. def reset_classifier(self, num_classes: int, global_pool: str = 'avg') -> None:
  169. """Reset the classifier head.
  170. Args:
  171. num_classes: Number of classes for new classifier.
  172. global_pool: Global pooling type.
  173. """
  174. self.num_classes = num_classes
  175. # NOTE: cannot meaningfully change pooling of efficient head after creation
  176. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  177. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  178. self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
  179. def forward_intermediates(
  180. self,
  181. x: torch.Tensor,
  182. indices: Optional[Union[int, List[int]]] = None,
  183. norm: bool = False,
  184. stop_early: bool = False,
  185. output_fmt: str = 'NCHW',
  186. intermediates_only: bool = False,
  187. extra_blocks: bool = False,
  188. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  189. """ Forward features that returns intermediates.
  190. Args:
  191. x: Input image tensor
  192. indices: Take last n blocks if int, all if None, select matching indices if sequence
  193. norm: Apply norm layer to compatible intermediates
  194. stop_early: Stop iterating over blocks when last desired intermediate hit
  195. output_fmt: Shape of intermediate feature outputs
  196. intermediates_only: Only return intermediate features
  197. extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info
  198. Returns:
  199. """
  200. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  201. if stop_early:
  202. assert intermediates_only, 'Must use intermediates_only for early stopping.'
  203. intermediates = []
  204. if extra_blocks:
  205. take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
  206. else:
  207. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  208. take_indices = [self.stage_ends[i] for i in take_indices]
  209. max_index = self.stage_ends[max_index]
  210. # forward pass
  211. feat_idx = 0 # stem is index 0
  212. x = self.conv_stem(x)
  213. x = self.bn1(x)
  214. if feat_idx in take_indices:
  215. intermediates.append(x)
  216. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  217. blocks = self.blocks
  218. else:
  219. blocks = self.blocks[:max_index]
  220. for feat_idx, blk in enumerate(blocks, start=1):
  221. if self.grad_checkpointing and not torch.jit.is_scripting():
  222. x = checkpoint_seq(blk, x)
  223. else:
  224. x = blk(x)
  225. if feat_idx in take_indices:
  226. intermediates.append(x)
  227. if intermediates_only:
  228. return intermediates
  229. return x, intermediates
  230. def prune_intermediate_layers(
  231. self,
  232. indices: Union[int, List[int]] = 1,
  233. prune_norm: bool = False,
  234. prune_head: bool = True,
  235. extra_blocks: bool = False,
  236. ) -> List[int]:
  237. """Prune layers not required for specified intermediates.
  238. Args:
  239. indices: Indices of intermediate layers to keep.
  240. prune_norm: Whether to prune normalization layer.
  241. prune_head: Whether to prune the classifier head.
  242. extra_blocks: Include outputs of all blocks.
  243. Returns:
  244. List of indices that were kept.
  245. """
  246. if extra_blocks:
  247. take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
  248. else:
  249. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  250. max_index = self.stage_ends[max_index]
  251. self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0
  252. if max_index < len(self.blocks):
  253. self.conv_head = nn.Identity()
  254. self.norm_head = nn.Identity()
  255. if prune_head:
  256. self.conv_head = nn.Identity()
  257. self.norm_head = nn.Identity()
  258. self.reset_classifier(0, '')
  259. return take_indices
  260. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  261. """Forward pass through feature extraction layers.
  262. Args:
  263. x: Input tensor.
  264. Returns:
  265. Feature tensor.
  266. """
  267. x = self.conv_stem(x)
  268. x = self.bn1(x)
  269. if self.grad_checkpointing and not torch.jit.is_scripting():
  270. x = checkpoint_seq(self.blocks, x, flatten=True)
  271. else:
  272. x = self.blocks(x)
  273. return x
  274. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  275. """Forward pass through classifier head.
  276. Args:
  277. x: Input features.
  278. pre_logits: Return features before final linear layer.
  279. Returns:
  280. Classification logits or features.
  281. """
  282. x = self.global_pool(x)
  283. x = self.conv_head(x)
  284. x = self.norm_head(x)
  285. x = self.act2(x)
  286. x = self.flatten(x)
  287. if self.drop_rate > 0.:
  288. x = F.dropout(x, p=self.drop_rate, training=self.training)
  289. if pre_logits:
  290. return x
  291. return self.classifier(x)
  292. def forward(self, x: torch.Tensor) -> torch.Tensor:
  293. """Forward pass.
  294. Args:
  295. x: Input tensor.
  296. Returns:
  297. Output logits.
  298. """
  299. x = self.forward_features(x)
  300. x = self.forward_head(x)
  301. return x
  302. class MobileNetV3Features(nn.Module):
  303. """MobileNetV3 Feature Extractor.
  304. A work-in-progress feature extraction module for MobileNet-V3 to use as a backbone for segmentation
  305. and object detection models.
  306. """
  307. def __init__(
  308. self,
  309. block_args: BlockArgs,
  310. out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
  311. feature_location: str = 'bottleneck',
  312. in_chans: int = 3,
  313. stem_size: int = 16,
  314. fix_stem: bool = False,
  315. output_stride: int = 32,
  316. pad_type: PadType = '',
  317. round_chs_fn: Callable = round_channels,
  318. se_from_exp: bool = True,
  319. act_layer: Optional[LayerType] = None,
  320. norm_layer: Optional[LayerType] = None,
  321. aa_layer: Optional[LayerType] = None,
  322. se_layer: Optional[LayerType] = None,
  323. drop_rate: float = 0.,
  324. drop_path_rate: float = 0.,
  325. layer_scale_init_value: Optional[float] = None,
  326. device=None,
  327. dtype=None,
  328. ):
  329. """Initialize MobileNetV3Features.
  330. Args:
  331. block_args: Arguments for blocks of the network.
  332. out_indices: Output from stages at indices.
  333. feature_location: Location of feature before/after each block, must be in ['bottleneck', 'expansion'].
  334. in_chans: Number of input image channels.
  335. stem_size: Number of output channels of the initial stem convolution.
  336. fix_stem: If True, don't scale stem by round_chs_fn.
  337. output_stride: Output stride of the network.
  338. pad_type: Type of padding to use for convolution layers.
  339. round_chs_fn: Callable to round number of filters based on depth multiplier.
  340. se_from_exp: If True, calculate SE channel reduction from expanded mid channels.
  341. act_layer: Type of activation layer.
  342. norm_layer: Type of normalization layer.
  343. aa_layer: Type of anti-aliasing layer.
  344. se_layer: Type of Squeeze-and-Excite layer.
  345. drop_rate: Dropout rate.
  346. drop_path_rate: Stochastic depth rate.
  347. layer_scale_init_value: Enable layer scale on compatible blocks if not None.
  348. """
  349. super().__init__()
  350. dd = {'device': device, 'dtype': dtype}
  351. act_layer = act_layer or nn.ReLU
  352. norm_layer = norm_layer or nn.BatchNorm2d
  353. se_layer = se_layer or SqueezeExcite
  354. self.in_chans = in_chans
  355. self.drop_rate = drop_rate
  356. self.grad_checkpointing = False
  357. # Stem
  358. if not fix_stem:
  359. stem_size = round_chs_fn(stem_size)
  360. self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type, **dd)
  361. self.bn1 = norm_layer(stem_size, **dd)
  362. self.act1 = act_layer(inplace=True)
  363. # Middle stages (IR/ER/DS Blocks)
  364. builder = EfficientNetBuilder(
  365. output_stride=output_stride,
  366. pad_type=pad_type,
  367. round_chs_fn=round_chs_fn,
  368. se_from_exp=se_from_exp,
  369. act_layer=act_layer,
  370. norm_layer=norm_layer,
  371. aa_layer=aa_layer,
  372. se_layer=se_layer,
  373. drop_path_rate=drop_path_rate,
  374. layer_scale_init_value=layer_scale_init_value,
  375. feature_location=feature_location,
  376. **dd,
  377. )
  378. self.blocks = nn.Sequential(*builder(stem_size, block_args))
  379. self.feature_info = FeatureInfo(builder.features, out_indices)
  380. self._stage_out_idx = {f['stage']: f['index'] for f in self.feature_info.get_dicts()}
  381. efficientnet_init_weights(self)
  382. # Register feature extraction hooks with FeatureHooks helper
  383. self.feature_hooks = None
  384. if feature_location != 'bottleneck':
  385. hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
  386. self.feature_hooks = FeatureHooks(hooks, self.named_modules())
  387. @torch.jit.ignore
  388. def set_grad_checkpointing(self, enable: bool = True) -> None:
  389. """Enable or disable gradient checkpointing."""
  390. self.grad_checkpointing = enable
  391. def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
  392. """Forward pass through feature extraction.
  393. Args:
  394. x: Input tensor.
  395. Returns:
  396. List of feature tensors.
  397. """
  398. x = self.conv_stem(x)
  399. x = self.bn1(x)
  400. x = self.act1(x)
  401. if self.feature_hooks is None:
  402. features = []
  403. if 0 in self._stage_out_idx:
  404. features.append(x) # add stem out
  405. for i, b in enumerate(self.blocks):
  406. if self.grad_checkpointing and not torch.jit.is_scripting():
  407. x = checkpoint(b, x)
  408. else:
  409. x = b(x)
  410. if i + 1 in self._stage_out_idx:
  411. features.append(x)
  412. return features
  413. else:
  414. self.blocks(x)
  415. out = self.feature_hooks.get_output(x.device)
  416. return list(out.values())
  417. def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV3:
  418. """Create a MobileNetV3 model.
  419. Args:
  420. variant: Model variant name.
  421. pretrained: Load pretrained weights.
  422. **kwargs: Additional model arguments.
  423. Returns:
  424. MobileNetV3 model instance.
  425. """
  426. features_mode = ''
  427. model_cls = MobileNetV3
  428. kwargs_filter = None
  429. if kwargs.pop('features_only', False):
  430. if 'feature_cfg' in kwargs or 'feature_cls' in kwargs:
  431. features_mode = 'cfg'
  432. else:
  433. kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'head_norm', 'global_pool')
  434. model_cls = MobileNetV3Features
  435. features_mode = 'cls'
  436. model = build_model_with_cfg(
  437. model_cls,
  438. variant,
  439. pretrained,
  440. features_only=features_mode == 'cfg',
  441. pretrained_strict=features_mode != 'cls',
  442. kwargs_filter=kwargs_filter,
  443. **kwargs,
  444. )
  445. if features_mode == 'cls':
  446. model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
  447. return model
  448. def _gen_mobilenet_v3_rw(
  449. variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs
  450. ) -> MobileNetV3:
  451. """Creates a MobileNet-V3 model.
  452. Ref impl: ?
  453. Paper: https://arxiv.org/abs/1905.02244
  454. Args:
  455. variant: Model variant name.
  456. channel_multiplier: Multiplier to number of channels per layer.
  457. pretrained: Load pretrained weights.
  458. **kwargs: Additional model arguments.
  459. Returns:
  460. MobileNetV3 model instance.
  461. """
  462. arch_def = [
  463. # stage 0, 112x112 in
  464. ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
  465. # stage 1, 112x112 in
  466. ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
  467. # stage 2, 56x56 in
  468. ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
  469. # stage 3, 28x28 in
  470. ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
  471. # stage 4, 14x14in
  472. ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
  473. # stage 5, 14x14in
  474. ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
  475. # stage 6, 7x7 in
  476. ['cn_r1_k1_s1_c960'], # hard-swish
  477. ]
  478. model_kwargs = dict(
  479. block_args=decode_arch_def(arch_def),
  480. head_bias=False,
  481. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  482. norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  483. act_layer=resolve_act_layer(kwargs, 'hard_swish'),
  484. se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid'),
  485. **kwargs,
  486. )
  487. model = _create_mnv3(variant, pretrained, **model_kwargs)
  488. return model
  489. def _gen_mobilenet_v3(
  490. variant: str,
  491. channel_multiplier: float = 1.0,
  492. depth_multiplier: float = 1.0,
  493. group_size: Optional[int] = None,
  494. pretrained: bool = False,
  495. **kwargs
  496. ) -> MobileNetV3:
  497. """Creates a MobileNet-V3 model.
  498. Ref impl: ?
  499. Paper: https://arxiv.org/abs/1905.02244
  500. Args:
  501. variant: Model variant name.
  502. channel_multiplier: Multiplier to number of channels per layer.
  503. depth_multiplier: Depth multiplier for model scaling.
  504. group_size: Group size for grouped convolutions.
  505. pretrained: Load pretrained weights.
  506. **kwargs: Additional model arguments.
  507. Returns:
  508. MobileNetV3 model instance.
  509. """
  510. if 'small' in variant:
  511. num_features = 1024
  512. if 'minimal' in variant:
  513. act_layer = resolve_act_layer(kwargs, 'relu')
  514. arch_def = [
  515. # stage 0, 112x112 in
  516. ['ds_r1_k3_s2_e1_c16'],
  517. # stage 1, 56x56 in
  518. ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
  519. # stage 2, 28x28 in
  520. ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
  521. # stage 3, 14x14 in
  522. ['ir_r2_k3_s1_e3_c48'],
  523. # stage 4, 14x14in
  524. ['ir_r3_k3_s2_e6_c96'],
  525. # stage 6, 7x7 in
  526. ['cn_r1_k1_s1_c576'],
  527. ]
  528. else:
  529. act_layer = resolve_act_layer(kwargs, 'hard_swish')
  530. arch_def = [
  531. # stage 0, 112x112 in
  532. ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
  533. # stage 1, 56x56 in
  534. ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
  535. # stage 2, 28x28 in
  536. ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish
  537. # stage 3, 14x14 in
  538. ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish
  539. # stage 4, 14x14in
  540. ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish
  541. # stage 6, 7x7 in
  542. ['cn_r1_k1_s1_c576'], # hard-swish
  543. ]
  544. else:
  545. num_features = 1280
  546. if 'minimal' in variant:
  547. act_layer = resolve_act_layer(kwargs, 'relu')
  548. arch_def = [
  549. # stage 0, 112x112 in
  550. ['ds_r1_k3_s1_e1_c16'],
  551. # stage 1, 112x112 in
  552. ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
  553. # stage 2, 56x56 in
  554. ['ir_r3_k3_s2_e3_c40'],
  555. # stage 3, 28x28 in
  556. ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
  557. # stage 4, 14x14in
  558. ['ir_r2_k3_s1_e6_c112'],
  559. # stage 5, 14x14in
  560. ['ir_r3_k3_s2_e6_c160'],
  561. # stage 6, 7x7 in
  562. ['cn_r1_k1_s1_c960'],
  563. ]
  564. else:
  565. act_layer = resolve_act_layer(kwargs, 'hard_swish')
  566. arch_def = [
  567. # stage 0, 112x112 in
  568. ['ds_r1_k3_s1_e1_c16_nre'], # relu
  569. # stage 1, 112x112 in
  570. ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
  571. # stage 2, 56x56 in
  572. ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
  573. # stage 3, 28x28 in
  574. ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
  575. # stage 4, 14x14in
  576. ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
  577. # stage 5, 14x14in
  578. ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
  579. # stage 6, 7x7 in
  580. ['cn_r1_k1_s1_c960'], # hard-swish
  581. ]
  582. se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
  583. model_kwargs = dict(
  584. block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, group_size=group_size),
  585. num_features=num_features,
  586. stem_size=16,
  587. fix_stem=channel_multiplier < 0.75,
  588. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  589. norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  590. act_layer=act_layer,
  591. se_layer=se_layer,
  592. **kwargs,
  593. )
  594. model = _create_mnv3(variant, pretrained, **model_kwargs)
  595. return model
  596. def _gen_fbnetv3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
  597. """FBNetV3 model generator.
  598. Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining`
  599. - https://arxiv.org/abs/2006.02049
  600. FIXME untested, this is a preliminary impl of some FBNet-V3 variants.
  601. Args:
  602. variant: Model variant name.
  603. channel_multiplier: Channel width multiplier.
  604. pretrained: Load pretrained weights.
  605. **kwargs: Additional model arguments.
  606. Returns:
  607. MobileNetV3 model instance.
  608. """
  609. vl = variant.split('_')[-1]
  610. if vl in ('a', 'b'):
  611. stem_size = 16
  612. arch_def = [
  613. ['ds_r2_k3_s1_e1_c16'],
  614. ['ir_r1_k5_s2_e4_c24', 'ir_r3_k5_s1_e2_c24'],
  615. ['ir_r1_k5_s2_e5_c40_se0.25', 'ir_r4_k5_s1_e3_c40_se0.25'],
  616. ['ir_r1_k5_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],
  617. ['ir_r1_k3_s1_e5_c120_se0.25', 'ir_r5_k5_s1_e3_c120_se0.25'],
  618. ['ir_r1_k3_s2_e6_c184_se0.25', 'ir_r5_k5_s1_e4_c184_se0.25', 'ir_r1_k5_s1_e6_c224_se0.25'],
  619. ['cn_r1_k1_s1_c1344'],
  620. ]
  621. elif vl == 'd':
  622. stem_size = 24
  623. arch_def = [
  624. ['ds_r2_k3_s1_e1_c16'],
  625. ['ir_r1_k3_s2_e5_c24', 'ir_r5_k3_s1_e2_c24'],
  626. ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r4_k3_s1_e3_c40_se0.25'],
  627. ['ir_r1_k3_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],
  628. ['ir_r1_k3_s1_e5_c128_se0.25', 'ir_r6_k5_s1_e3_c128_se0.25'],
  629. ['ir_r1_k3_s2_e6_c208_se0.25', 'ir_r5_k5_s1_e5_c208_se0.25', 'ir_r1_k5_s1_e6_c240_se0.25'],
  630. ['cn_r1_k1_s1_c1440'],
  631. ]
  632. elif vl == 'g':
  633. stem_size = 32
  634. arch_def = [
  635. ['ds_r3_k3_s1_e1_c24'],
  636. ['ir_r1_k5_s2_e4_c40', 'ir_r4_k5_s1_e2_c40'],
  637. ['ir_r1_k5_s2_e4_c56_se0.25', 'ir_r4_k5_s1_e3_c56_se0.25'],
  638. ['ir_r1_k5_s2_e5_c104', 'ir_r4_k3_s1_e3_c104'],
  639. ['ir_r1_k3_s1_e5_c160_se0.25', 'ir_r8_k5_s1_e3_c160_se0.25'],
  640. ['ir_r1_k3_s2_e6_c264_se0.25', 'ir_r6_k5_s1_e5_c264_se0.25', 'ir_r2_k5_s1_e6_c288_se0.25'],
  641. ['cn_r1_k1_s1_c1728'],
  642. ]
  643. else:
  644. raise NotImplemented
  645. round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.95)
  646. se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=round_chs_fn)
  647. act_layer = resolve_act_layer(kwargs, 'hard_swish')
  648. model_kwargs = dict(
  649. block_args=decode_arch_def(arch_def),
  650. num_features=1984,
  651. head_bias=False,
  652. stem_size=stem_size,
  653. round_chs_fn=round_chs_fn,
  654. se_from_exp=False,
  655. norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  656. act_layer=act_layer,
  657. se_layer=se_layer,
  658. **kwargs,
  659. )
  660. model = _create_mnv3(variant, pretrained, **model_kwargs)
  661. return model
  662. def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
  663. """LCNet model generator.
  664. Essentially a MobileNet-V3 crossed with a MobileNet-V1
  665. Paper: `PP-LCNet: A Lightweight CPU Convolutional Neural Network` - https://arxiv.org/abs/2109.15099
  666. Args:
  667. variant: Model variant name.
  668. channel_multiplier: Multiplier to number of channels per layer.
  669. pretrained: Load pretrained weights.
  670. **kwargs: Additional model arguments.
  671. Returns:
  672. MobileNetV3 model instance.
  673. """
  674. arch_def = [
  675. # stage 0, 112x112 in
  676. ['dsa_r1_k3_s1_c32'],
  677. # stage 1, 112x112 in
  678. ['dsa_r2_k3_s2_c64'],
  679. # stage 2, 56x56 in
  680. ['dsa_r2_k3_s2_c128'],
  681. # stage 3, 28x28 in
  682. ['dsa_r1_k3_s2_c256', 'dsa_r1_k5_s1_c256'],
  683. # stage 4, 14x14in
  684. ['dsa_r4_k5_s1_c256'],
  685. # stage 5, 14x14in
  686. ['dsa_r2_k5_s2_c512_se0.25'],
  687. # 7x7
  688. ]
  689. model_kwargs = dict(
  690. block_args=decode_arch_def(arch_def),
  691. stem_size=16,
  692. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  693. norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  694. act_layer=resolve_act_layer(kwargs, 'hard_swish'),
  695. se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU),
  696. num_features=1280,
  697. **kwargs,
  698. )
  699. model = _create_mnv3(variant, pretrained, **model_kwargs)
  700. return model
  701. def _gen_mobilenet_v4(
  702. variant: str,
  703. channel_multiplier: float = 1.0,
  704. group_size: Optional[int] = None,
  705. pretrained: bool = False,
  706. **kwargs,
  707. ) -> MobileNetV3:
  708. """Creates a MobileNet-V4 model.
  709. Paper: https://arxiv.org/abs/2404.10518
  710. Args:
  711. variant: Model variant name.
  712. channel_multiplier: Multiplier to number of channels per layer.
  713. group_size: Group size for grouped convolutions.
  714. pretrained: Load pretrained weights.
  715. **kwargs: Additional model arguments.
  716. Returns:
  717. MobileNetV3 model instance.
  718. """
  719. num_features = 1280
  720. if 'hybrid' in variant:
  721. layer_scale_init_value = 1e-5
  722. if 'medium' in variant:
  723. stem_size = 32
  724. act_layer = resolve_act_layer(kwargs, 'relu')
  725. arch_def = [
  726. # stage 0, 112x112 in
  727. [
  728. 'er_r1_k3_s2_e4_c48' # FusedIB (EdgeResidual)
  729. ],
  730. # stage 1, 56x56 in
  731. [
  732. 'uir_r1_a3_k5_s2_e4_c80', # ExtraDW
  733. 'uir_r1_a3_k3_s1_e2_c80', # ExtraDW
  734. ],
  735. # stage 2, 28x28 in
  736. [
  737. 'uir_r1_a3_k5_s2_e6_c160', # ExtraDW
  738. 'uir_r1_a0_k0_s1_e2_c160', # FFN
  739. 'uir_r1_a3_k3_s1_e4_c160', # ExtraDW
  740. 'uir_r1_a3_k5_s1_e4_c160', # ExtraDW
  741. 'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample
  742. 'uir_r1_a3_k3_s1_e4_c160', # ExtraDW
  743. 'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample
  744. 'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt
  745. 'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample
  746. 'uir_r1_a3_k3_s1_e4_c160', # ExtraDW
  747. 'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample
  748. 'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt
  749. ],
  750. # stage 3, 14x14in
  751. [
  752. 'uir_r1_a5_k5_s2_e6_c256', # ExtraDW
  753. 'uir_r1_a5_k5_s1_e4_c256', # ExtraDW
  754. 'uir_r2_a3_k5_s1_e4_c256', # ExtraDW
  755. 'uir_r1_a0_k0_s1_e2_c256', # FFN
  756. 'uir_r1_a3_k5_s1_e2_c256', # ExtraDW
  757. 'uir_r1_a0_k0_s1_e2_c256', # FFN
  758. 'uir_r1_a0_k0_s1_e4_c256', # FFN
  759. 'mqa_r1_k3_h4_s1_d64_c256', # MQA
  760. 'uir_r1_a3_k0_s1_e4_c256', # ConvNeXt
  761. 'mqa_r1_k3_h4_s1_d64_c256', # MQA
  762. 'uir_r1_a5_k5_s1_e4_c256', # ExtraDW
  763. 'mqa_r1_k3_h4_s1_d64_c256', # MQA
  764. 'uir_r1_a5_k0_s1_e4_c256', # ConvNeXt
  765. 'mqa_r1_k3_h4_s1_d64_c256', # MQA
  766. 'uir_r1_a5_k0_s1_e4_c256', # ConvNeXt
  767. ],
  768. # stage 4, 7x7 in
  769. [
  770. 'cn_r1_k1_s1_c960' # Conv
  771. ],
  772. ]
  773. elif 'large' in variant:
  774. stem_size = 24
  775. act_layer = resolve_act_layer(kwargs, 'gelu')
  776. arch_def = [
  777. # stage 0, 112x112 in
  778. [
  779. 'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual)
  780. ],
  781. # stage 1, 56x56 in
  782. [
  783. 'uir_r1_a3_k5_s2_e4_c96', # ExtraDW
  784. 'uir_r1_a3_k3_s1_e4_c96', # ExtraDW
  785. ],
  786. # stage 2, 28x28 in
  787. [
  788. 'uir_r1_a3_k5_s2_e4_c192', # ExtraDW
  789. 'uir_r3_a3_k3_s1_e4_c192', # ExtraDW
  790. 'uir_r1_a3_k5_s1_e4_c192', # ExtraDW
  791. 'uir_r2_a5_k3_s1_e4_c192', # ExtraDW
  792. 'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample
  793. 'uir_r1_a5_k3_s1_e4_c192', # ExtraDW
  794. 'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample
  795. 'uir_r1_a5_k3_s1_e4_c192', # ExtraDW
  796. 'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample
  797. 'uir_r1_a5_k3_s1_e4_c192', # ExtraDW
  798. 'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample
  799. 'uir_r1_a3_k0_s1_e4_c192', # ConvNeXt
  800. ],
  801. # stage 3, 14x14in
  802. [
  803. 'uir_r4_a5_k5_s2_e4_c512', # ExtraDW
  804. 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
  805. 'uir_r1_a5_k3_s1_e4_c512', # ExtraDW
  806. 'uir_r2_a5_k0_s1_e4_c512', # ConvNeXt
  807. 'uir_r1_a5_k3_s1_e4_c512', # ExtraDW
  808. 'uir_r1_a5_k5_s1_e4_c512', # ExtraDW
  809. 'mqa_r1_k3_h8_s1_d64_c512', # MQA
  810. 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
  811. 'mqa_r1_k3_h8_s1_d64_c512', # MQA
  812. 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
  813. 'mqa_r1_k3_h8_s1_d64_c512', # MQA
  814. 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
  815. 'mqa_r1_k3_h8_s1_d64_c512', # MQA
  816. 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
  817. ],
  818. # stage 4, 7x7 in
  819. [
  820. 'cn_r1_k1_s1_c960', # Conv
  821. ],
  822. ]
  823. else:
  824. assert False, f'Unknown variant {variant}.'
  825. else:
  826. layer_scale_init_value = None
  827. if 'small' in variant:
  828. stem_size = 32
  829. act_layer = resolve_act_layer(kwargs, 'relu')
  830. arch_def = [
  831. # stage 0, 112x112 in
  832. [
  833. 'cn_r1_k3_s2_e1_c32', # Conv
  834. 'cn_r1_k1_s1_e1_c32', # Conv
  835. ],
  836. # stage 1, 56x56 in
  837. [
  838. 'cn_r1_k3_s2_e1_c96', # Conv
  839. 'cn_r1_k1_s1_e1_c64', # Conv
  840. ],
  841. # stage 2, 28x28 in
  842. [
  843. 'uir_r1_a5_k5_s2_e3_c96', # ExtraDW
  844. 'uir_r4_a0_k3_s1_e2_c96', # IR
  845. 'uir_r1_a3_k0_s1_e4_c96', # ConvNeXt
  846. ],
  847. # stage 3, 14x14 in
  848. [
  849. 'uir_r1_a3_k3_s2_e6_c128', # ExtraDW
  850. 'uir_r1_a5_k5_s1_e4_c128', # ExtraDW
  851. 'uir_r1_a0_k5_s1_e4_c128', # IR
  852. 'uir_r1_a0_k5_s1_e3_c128', # IR
  853. 'uir_r2_a0_k3_s1_e4_c128', # IR
  854. ],
  855. # stage 4, 7x7 in
  856. [
  857. 'cn_r1_k1_s1_c960', # Conv
  858. ],
  859. ]
  860. elif 'medium' in variant:
  861. stem_size = 32
  862. act_layer = resolve_act_layer(kwargs, 'relu')
  863. arch_def = [
  864. # stage 0, 112x112 in
  865. [
  866. 'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual)
  867. ],
  868. # stage 1, 56x56 in
  869. [
  870. 'uir_r1_a3_k5_s2_e4_c80', # ExtraDW
  871. 'uir_r1_a3_k3_s1_e2_c80', # ExtraDW
  872. ],
  873. # stage 2, 28x28 in
  874. [
  875. 'uir_r1_a3_k5_s2_e6_c160', # ExtraDW
  876. 'uir_r2_a3_k3_s1_e4_c160', # ExtraDW
  877. 'uir_r1_a3_k5_s1_e4_c160', # ExtraDW
  878. 'uir_r1_a3_k3_s1_e4_c160', # ExtraDW
  879. 'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt
  880. 'uir_r1_a0_k0_s1_e2_c160', # ExtraDW
  881. 'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt
  882. ],
  883. # stage 3, 14x14in
  884. [
  885. 'uir_r1_a5_k5_s2_e6_c256', # ExtraDW
  886. 'uir_r1_a5_k5_s1_e4_c256', # ExtraDW
  887. 'uir_r2_a3_k5_s1_e4_c256', # ExtraDW
  888. 'uir_r1_a0_k0_s1_e4_c256', # FFN
  889. 'uir_r1_a3_k0_s1_e4_c256', # ConvNeXt
  890. 'uir_r1_a3_k5_s1_e2_c256', # ExtraDW
  891. 'uir_r1_a5_k5_s1_e4_c256', # ExtraDW
  892. 'uir_r2_a0_k0_s1_e4_c256', # FFN
  893. 'uir_r1_a5_k0_s1_e2_c256', # ConvNeXt
  894. ],
  895. # stage 4, 7x7 in
  896. [
  897. 'cn_r1_k1_s1_c960', # Conv
  898. ],
  899. ]
  900. elif 'large' in variant:
  901. stem_size = 24
  902. act_layer = resolve_act_layer(kwargs, 'relu')
  903. arch_def = [
  904. # stage 0, 112x112 in
  905. [
  906. 'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual)
  907. ],
  908. # stage 1, 56x56 in
  909. [
  910. 'uir_r1_a3_k5_s2_e4_c96', # ExtraDW
  911. 'uir_r1_a3_k3_s1_e4_c96', # ExtraDW
  912. ],
  913. # stage 2, 28x28 in
  914. [
  915. 'uir_r1_a3_k5_s2_e4_c192', # ExtraDW
  916. 'uir_r3_a3_k3_s1_e4_c192', # ExtraDW
  917. 'uir_r1_a3_k5_s1_e4_c192', # ExtraDW
  918. 'uir_r5_a5_k3_s1_e4_c192', # ExtraDW
  919. 'uir_r1_a3_k0_s1_e4_c192', # ConvNeXt
  920. ],
  921. # stage 3, 14x14in
  922. [
  923. 'uir_r4_a5_k5_s2_e4_c512', # ExtraDW
  924. 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
  925. 'uir_r1_a5_k3_s1_e4_c512', # ExtraDW
  926. 'uir_r2_a5_k0_s1_e4_c512', # ConvNeXt
  927. 'uir_r1_a5_k3_s1_e4_c512', # ExtraDW
  928. 'uir_r1_a5_k5_s1_e4_c512', # ExtraDW
  929. 'uir_r3_a5_k0_s1_e4_c512', # ConvNeXt
  930. ],
  931. # stage 4, 7x7 in
  932. [
  933. 'cn_r1_k1_s1_c960', # Conv
  934. ],
  935. ]
  936. else:
  937. assert False, f'Unknown variant {variant}.'
  938. model_kwargs = dict(
  939. block_args=decode_arch_def(arch_def, group_size=group_size),
  940. head_bias=False,
  941. head_norm=True,
  942. num_features=num_features,
  943. stem_size=stem_size,
  944. fix_stem=channel_multiplier < 1.0,
  945. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  946. norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  947. act_layer=act_layer,
  948. layer_scale_init_value=layer_scale_init_value,
  949. **kwargs,
  950. )
  951. model = _create_mnv3(variant, pretrained, **model_kwargs)
  952. return model
  953. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  954. """Create default configuration dictionary.
  955. Args:
  956. url: Model weight URL.
  957. **kwargs: Additional configuration options.
  958. Returns:
  959. Configuration dictionary.
  960. """
  961. return {
  962. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  963. 'crop_pct': 0.875, 'interpolation': 'bilinear',
  964. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  965. 'first_conv': 'conv_stem', 'classifier': 'classifier',
  966. 'license': 'apache-2.0', **kwargs
  967. }
  968. default_cfgs = generate_default_cfgs({
  969. 'mobilenetv3_large_075.untrained': _cfg(url=''),
  970. 'mobilenetv3_large_100.ra_in1k': _cfg(
  971. interpolation='bicubic',
  972. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
  973. hf_hub_id='timm/'),
  974. 'mobilenetv3_large_100.ra4_e3600_r224_in1k': _cfg(
  975. hf_hub_id='timm/',
  976. interpolation='bicubic', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  977. crop_pct=0.95, test_input_size=(3, 256, 256), test_crop_pct=1.0),
  978. 'mobilenetv3_large_100.miil_in21k_ft_in1k': _cfg(
  979. interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.),
  980. origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
  981. paper_ids='arXiv:2104.10972v4',
  982. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth',
  983. hf_hub_id='timm/'),
  984. 'mobilenetv3_large_100.miil_in21k': _cfg(
  985. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth',
  986. hf_hub_id='timm/',
  987. origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
  988. paper_ids='arXiv:2104.10972v4',
  989. interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),
  990. 'mobilenetv3_large_150d.ra4_e3600_r256_in1k': _cfg(
  991. hf_hub_id='timm/',
  992. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  993. input_size=(3, 256, 256), crop_pct=0.95, pool_size=(8, 8), test_input_size=(3, 320, 320), test_crop_pct=1.0),
  994. 'mobilenetv3_small_050.lamb_in1k': _cfg(
  995. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',
  996. hf_hub_id='timm/',
  997. interpolation='bicubic'),
  998. 'mobilenetv3_small_075.lamb_in1k': _cfg(
  999. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth',
  1000. hf_hub_id='timm/',
  1001. interpolation='bicubic'),
  1002. 'mobilenetv3_small_100.lamb_in1k': _cfg(
  1003. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth',
  1004. hf_hub_id='timm/',
  1005. interpolation='bicubic'),
  1006. 'mobilenetv3_rw.rmsp_in1k': _cfg(
  1007. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
  1008. hf_hub_id='timm/',
  1009. interpolation='bicubic'),
  1010. 'tf_mobilenetv3_large_075.in1k': _cfg(
  1011. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
  1012. hf_hub_id='timm/',
  1013. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
  1014. 'tf_mobilenetv3_large_100.in1k': _cfg(
  1015. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
  1016. hf_hub_id='timm/',
  1017. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
  1018. 'tf_mobilenetv3_large_minimal_100.in1k': _cfg(
  1019. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
  1020. hf_hub_id='timm/',
  1021. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
  1022. 'tf_mobilenetv3_small_075.in1k': _cfg(
  1023. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
  1024. hf_hub_id='timm/',
  1025. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
  1026. 'tf_mobilenetv3_small_100.in1k': _cfg(
  1027. url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
  1028. hf_hub_id='timm/',
  1029. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
  1030. 'tf_mobilenetv3_small_minimal_100.in1k': _cfg(
  1031. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
  1032. hf_hub_id='timm/',
  1033. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
  1034. 'fbnetv3_b.ra2_in1k': _cfg(
  1035. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth',
  1036. hf_hub_id='timm/',
  1037. test_input_size=(3, 256, 256), crop_pct=0.95),
  1038. 'fbnetv3_d.ra2_in1k': _cfg(
  1039. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth',
  1040. hf_hub_id='timm/',
  1041. test_input_size=(3, 256, 256), crop_pct=0.95),
  1042. 'fbnetv3_g.ra2_in1k': _cfg(
  1043. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth',
  1044. hf_hub_id='timm/',
  1045. input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)),
  1046. "lcnet_035.untrained": _cfg(),
  1047. "lcnet_050.ra2_in1k": _cfg(
  1048. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth',
  1049. hf_hub_id='timm/',
  1050. interpolation='bicubic',
  1051. ),
  1052. "lcnet_075.ra2_in1k": _cfg(
  1053. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth',
  1054. hf_hub_id='timm/',
  1055. interpolation='bicubic',
  1056. ),
  1057. "lcnet_100.ra2_in1k": _cfg(
  1058. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth',
  1059. hf_hub_id='timm/',
  1060. interpolation='bicubic',
  1061. ),
  1062. "lcnet_150.untrained": _cfg(),
  1063. 'mobilenetv4_conv_small_035.untrained': _cfg(
  1064. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1065. test_input_size=(3, 256, 256), test_crop_pct=0.95, interpolation='bicubic'),
  1066. 'mobilenetv4_conv_small_050.e3000_r224_in1k': _cfg(
  1067. hf_hub_id='timm/',
  1068. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1069. test_input_size=(3, 256, 256), test_crop_pct=0.95, interpolation='bicubic'),
  1070. 'mobilenetv4_conv_small.e2400_r224_in1k': _cfg(
  1071. hf_hub_id='timm/',
  1072. test_input_size=(3, 256, 256), test_crop_pct=0.95, interpolation='bicubic'),
  1073. 'mobilenetv4_conv_small.e1200_r224_in1k': _cfg(
  1074. hf_hub_id='timm/',
  1075. test_input_size=(3, 256, 256), test_crop_pct=0.95, interpolation='bicubic'),
  1076. 'mobilenetv4_conv_small.e3600_r256_in1k': _cfg(
  1077. hf_hub_id='timm/',
  1078. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1079. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95,
  1080. test_input_size=(3, 320, 320), test_crop_pct=1.0, interpolation='bicubic'),
  1081. 'mobilenetv4_conv_medium.e500_r256_in1k': _cfg(
  1082. hf_hub_id='timm/',
  1083. input_size=(3, 256, 256), pool_size=(8, 8),
  1084. crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, interpolation='bicubic'),
  1085. 'mobilenetv4_conv_medium.e500_r224_in1k': _cfg(
  1086. hf_hub_id='timm/',
  1087. crop_pct=0.95, test_input_size=(3, 256, 256), test_crop_pct=1.0, interpolation='bicubic'),
  1088. 'mobilenetv4_conv_medium.e250_r384_in12k_ft_in1k': _cfg(
  1089. hf_hub_id='timm/',
  1090. input_size=(3, 384, 384), pool_size=(12, 12),
  1091. crop_pct=0.95, interpolation='bicubic'),
  1092. 'mobilenetv4_conv_medium.e180_r384_in12k': _cfg(
  1093. hf_hub_id='timm/',
  1094. num_classes=11821,
  1095. input_size=(3, 384, 384), pool_size=(12, 12),
  1096. crop_pct=1.0, interpolation='bicubic'),
  1097. 'mobilenetv4_conv_medium.e180_ad_r384_in12k': _cfg(
  1098. hf_hub_id='timm/',
  1099. num_classes=11821,
  1100. input_size=(3, 384, 384), pool_size=(12, 12),
  1101. crop_pct=1.0, interpolation='bicubic'),
  1102. 'mobilenetv4_conv_medium.e250_r384_in12k': _cfg(
  1103. hf_hub_id='timm/',
  1104. num_classes=11821,
  1105. input_size=(3, 384, 384), pool_size=(12, 12),
  1106. crop_pct=1.0, interpolation='bicubic'),
  1107. 'mobilenetv4_conv_large.e600_r384_in1k': _cfg(
  1108. hf_hub_id='timm/',
  1109. input_size=(3, 384, 384), pool_size=(12, 12),
  1110. crop_pct=0.95, test_input_size=(3, 448, 448), test_crop_pct=1.0, interpolation='bicubic'),
  1111. 'mobilenetv4_conv_large.e500_r256_in1k': _cfg(
  1112. hf_hub_id='timm/',
  1113. input_size=(3, 256, 256), pool_size=(8, 8),
  1114. crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, interpolation='bicubic'),
  1115. 'mobilenetv4_hybrid_medium.e200_r256_in12k_ft_in1k': _cfg(
  1116. hf_hub_id='timm/',
  1117. input_size=(3, 256, 256), pool_size=(8, 8),
  1118. crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, interpolation='bicubic'),
  1119. 'mobilenetv4_hybrid_medium.ix_e550_r256_in1k': _cfg(
  1120. hf_hub_id='timm/',
  1121. input_size=(3, 256, 256), pool_size=(8, 8),
  1122. crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, interpolation='bicubic'),
  1123. 'mobilenetv4_hybrid_medium.ix_e550_r384_in1k': _cfg(
  1124. hf_hub_id='timm/',
  1125. input_size=(3, 384, 384), pool_size=(12, 12),
  1126. crop_pct=0.95, test_input_size=(3, 448, 448), test_crop_pct=1.0, interpolation='bicubic'),
  1127. 'mobilenetv4_hybrid_medium.e500_r224_in1k': _cfg(
  1128. hf_hub_id='timm/',
  1129. crop_pct=0.95, test_input_size=(3, 256, 256), test_crop_pct=1.0, interpolation='bicubic'),
  1130. 'mobilenetv4_hybrid_medium.e200_r256_in12k': _cfg(
  1131. hf_hub_id='timm/',
  1132. num_classes=11821,
  1133. input_size=(3, 256, 256), pool_size=(8, 8),
  1134. crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, interpolation='bicubic'),
  1135. 'mobilenetv4_hybrid_large.ix_e600_r384_in1k': _cfg(
  1136. hf_hub_id='timm/',
  1137. input_size=(3, 384, 384), pool_size=(12, 12),
  1138. crop_pct=0.95, test_input_size=(3, 448, 448), test_crop_pct=1.0, interpolation='bicubic'),
  1139. 'mobilenetv4_hybrid_large.e600_r384_in1k': _cfg(
  1140. hf_hub_id='timm/',
  1141. input_size=(3, 384, 384), pool_size=(12, 12),
  1142. crop_pct=0.95, test_input_size=(3, 448, 448), test_crop_pct=1.0, interpolation='bicubic'),
  1143. # experimental
  1144. 'mobilenetv4_conv_aa_medium.untrained': _cfg(
  1145. # hf_hub_id='timm/',
  1146. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
  1147. 'mobilenetv4_conv_blur_medium.e500_r224_in1k': _cfg(
  1148. hf_hub_id='timm/',
  1149. crop_pct=0.95, test_input_size=(3, 256, 256), test_crop_pct=1.0, interpolation='bicubic'),
  1150. 'mobilenetv4_conv_aa_large.e230_r448_in12k_ft_in1k': _cfg(
  1151. hf_hub_id='timm/',
  1152. input_size=(3, 448, 448), pool_size=(14, 14),
  1153. crop_pct=0.95, test_input_size=(3, 544, 544), test_crop_pct=1.0, interpolation='bicubic'),
  1154. 'mobilenetv4_conv_aa_large.e230_r384_in12k_ft_in1k': _cfg(
  1155. hf_hub_id='timm/',
  1156. input_size=(3, 384, 384), pool_size=(12, 12),
  1157. crop_pct=0.95, test_input_size=(3, 480, 480), test_crop_pct=1.0, interpolation='bicubic'),
  1158. 'mobilenetv4_conv_aa_large.e600_r384_in1k': _cfg(
  1159. hf_hub_id='timm/',
  1160. input_size=(3, 384, 384), pool_size=(12, 12),
  1161. crop_pct=0.95, test_input_size=(3, 480, 480), test_crop_pct=1.0, interpolation='bicubic'),
  1162. 'mobilenetv4_conv_aa_large.e230_r384_in12k': _cfg(
  1163. hf_hub_id='timm/',
  1164. num_classes=11821,
  1165. input_size=(3, 384, 384), pool_size=(12, 12),
  1166. crop_pct=0.95, test_input_size=(3, 448, 448), test_crop_pct=1.0, interpolation='bicubic'),
  1167. 'mobilenetv4_hybrid_medium_075.untrained': _cfg(
  1168. # hf_hub_id='timm/',
  1169. crop_pct=0.95, interpolation='bicubic'),
  1170. 'mobilenetv4_hybrid_large_075.untrained': _cfg(
  1171. # hf_hub_id='timm/',
  1172. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
  1173. })
  1174. @register_model
  1175. def mobilenetv3_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1176. """ MobileNet V3 """
  1177. model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
  1178. return model
  1179. @register_model
  1180. def mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1181. """ MobileNet V3 """
  1182. model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
  1183. return model
  1184. @register_model
  1185. def mobilenetv3_large_150d(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1186. """ MobileNet V3 """
  1187. model = _gen_mobilenet_v3('mobilenetv3_large_150d', 1.5, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
  1188. return model
  1189. @register_model
  1190. def mobilenetv3_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1191. """ MobileNet V3 """
  1192. model = _gen_mobilenet_v3('mobilenetv3_small_050', 0.50, pretrained=pretrained, **kwargs)
  1193. return model
  1194. @register_model
  1195. def mobilenetv3_small_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1196. """ MobileNet V3 """
  1197. model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
  1198. return model
  1199. @register_model
  1200. def mobilenetv3_small_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1201. """ MobileNet V3 """
  1202. model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
  1203. return model
  1204. @register_model
  1205. def mobilenetv3_rw(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1206. """ MobileNet V3 """
  1207. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  1208. model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs)
  1209. return model
  1210. @register_model
  1211. def tf_mobilenetv3_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1212. """ MobileNet V3 """
  1213. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  1214. kwargs.setdefault('pad_type', 'same')
  1215. model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
  1216. return model
  1217. @register_model
  1218. def tf_mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1219. """ MobileNet V3 """
  1220. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  1221. kwargs.setdefault('pad_type', 'same')
  1222. model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
  1223. return model
  1224. @register_model
  1225. def tf_mobilenetv3_large_minimal_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1226. """ MobileNet V3 """
  1227. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  1228. kwargs.setdefault('pad_type', 'same')
  1229. model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
  1230. return model
  1231. @register_model
  1232. def tf_mobilenetv3_small_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1233. """ MobileNet V3 """
  1234. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  1235. kwargs.setdefault('pad_type', 'same')
  1236. model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
  1237. return model
  1238. @register_model
  1239. def tf_mobilenetv3_small_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1240. """ MobileNet V3 """
  1241. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  1242. kwargs.setdefault('pad_type', 'same')
  1243. model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
  1244. return model
  1245. @register_model
  1246. def tf_mobilenetv3_small_minimal_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1247. """ MobileNet V3 """
  1248. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  1249. kwargs.setdefault('pad_type', 'same')
  1250. model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
  1251. return model
  1252. @register_model
  1253. def fbnetv3_b(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1254. """ FBNetV3-B """
  1255. model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs)
  1256. return model
  1257. @register_model
  1258. def fbnetv3_d(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1259. """ FBNetV3-D """
  1260. model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs)
  1261. return model
  1262. @register_model
  1263. def fbnetv3_g(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1264. """ FBNetV3-G """
  1265. model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs)
  1266. return model
  1267. @register_model
  1268. def lcnet_035(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1269. """ PP-LCNet 0.35"""
  1270. model = _gen_lcnet('lcnet_035', 0.35, pretrained=pretrained, **kwargs)
  1271. return model
  1272. @register_model
  1273. def lcnet_050(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1274. """ PP-LCNet 0.5"""
  1275. model = _gen_lcnet('lcnet_050', 0.5, pretrained=pretrained, **kwargs)
  1276. return model
  1277. @register_model
  1278. def lcnet_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1279. """ PP-LCNet 1.0"""
  1280. model = _gen_lcnet('lcnet_075', 0.75, pretrained=pretrained, **kwargs)
  1281. return model
  1282. @register_model
  1283. def lcnet_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1284. """ PP-LCNet 1.0"""
  1285. model = _gen_lcnet('lcnet_100', 1.0, pretrained=pretrained, **kwargs)
  1286. return model
  1287. @register_model
  1288. def lcnet_150(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1289. """ PP-LCNet 1.5"""
  1290. model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs)
  1291. return model
  1292. @register_model
  1293. def mobilenetv4_conv_small_035(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1294. """ MobileNet V4 """
  1295. model = _gen_mobilenet_v4('mobilenetv4_conv_small_035', 0.35, pretrained=pretrained, **kwargs)
  1296. return model
  1297. @register_model
  1298. def mobilenetv4_conv_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1299. """ MobileNet V4 """
  1300. model = _gen_mobilenet_v4('mobilenetv4_conv_small_050', 0.50, pretrained=pretrained, **kwargs)
  1301. return model
  1302. @register_model
  1303. def mobilenetv4_conv_small(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1304. """ MobileNet V4 """
  1305. model = _gen_mobilenet_v4('mobilenetv4_conv_small', 1.0, pretrained=pretrained, **kwargs)
  1306. return model
  1307. @register_model
  1308. def mobilenetv4_conv_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1309. """ MobileNet V4 """
  1310. model = _gen_mobilenet_v4('mobilenetv4_conv_medium', 1.0, pretrained=pretrained, **kwargs)
  1311. return model
  1312. @register_model
  1313. def mobilenetv4_conv_large(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1314. """ MobileNet V4 """
  1315. model = _gen_mobilenet_v4('mobilenetv4_conv_large', 1.0, pretrained=pretrained, **kwargs)
  1316. return model
  1317. @register_model
  1318. def mobilenetv4_hybrid_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1319. """ MobileNet V4 Hybrid """
  1320. model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium', 1.0, pretrained=pretrained, **kwargs)
  1321. return model
  1322. @register_model
  1323. def mobilenetv4_hybrid_large(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1324. """ MobileNet V4 Hybrid"""
  1325. model = _gen_mobilenet_v4('mobilenetv4_hybrid_large', 1.0, pretrained=pretrained, **kwargs)
  1326. return model
  1327. @register_model
  1328. def mobilenetv4_conv_aa_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1329. """ MobileNet V4 w/ AvgPool AA """
  1330. model = _gen_mobilenet_v4('mobilenetv4_conv_aa_medium', 1.0, pretrained=pretrained, aa_layer='avg', **kwargs)
  1331. return model
  1332. @register_model
  1333. def mobilenetv4_conv_blur_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1334. """ MobileNet V4 Conv w/ Blur AA """
  1335. model = _gen_mobilenet_v4('mobilenetv4_conv_blur_medium', 1.0, pretrained=pretrained, aa_layer='blurpc', **kwargs)
  1336. return model
  1337. @register_model
  1338. def mobilenetv4_conv_aa_large(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1339. """ MobileNet V4 w/ AvgPool AA """
  1340. model = _gen_mobilenet_v4('mobilenetv4_conv_aa_large', 1.0, pretrained=pretrained, aa_layer='avg', **kwargs)
  1341. return model
  1342. @register_model
  1343. def mobilenetv4_hybrid_medium_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1344. """ MobileNet V4 Hybrid """
  1345. model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium_075', 0.75, pretrained=pretrained, **kwargs)
  1346. return model
  1347. @register_model
  1348. def mobilenetv4_hybrid_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
  1349. """ MobileNet V4 Hybrid"""
  1350. model = _gen_mobilenet_v4('mobilenetv4_hybrid_large_075', 0.75, pretrained=pretrained, **kwargs)
  1351. return model
  1352. register_model_deprecations(__name__, {
  1353. 'mobilenetv3_large_100_miil': 'mobilenetv3_large_100.miil_in21k_ft_in1k',
  1354. 'mobilenetv3_large_100_miil_in21k': 'mobilenetv3_large_100.miil_in21k',
  1355. })