mobilenetv5.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880
  1. from functools import partial
  2. from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
  7. from timm.layers import (
  8. SelectAdaptivePool2d,
  9. Linear,
  10. LayerType,
  11. RmsNorm2d,
  12. ConvNormAct,
  13. create_conv2d,
  14. get_norm_layer,
  15. get_norm_act_layer,
  16. to_2tuple,
  17. )
  18. from ._builder import build_model_with_cfg
  19. from ._efficientnet_blocks import SqueezeExcite, UniversalInvertedResidual
  20. from ._efficientnet_builder import (
  21. BlockArgs,
  22. EfficientNetBuilder,
  23. decode_arch_def,
  24. efficientnet_init_weights,
  25. round_channels,
  26. )
  27. from ._features import feature_take_indices
  28. from ._features_fx import register_notrace_module
  29. from ._manipulate import checkpoint_seq
  30. from ._registry import generate_default_cfgs, register_model
  31. __all__ = ['MobileNetV5', 'MobileNetV5Encoder']
  32. _GELU = partial(nn.GELU, approximate='tanh')
  33. @register_notrace_module
  34. class MobileNetV5MultiScaleFusionAdapter(nn.Module):
  35. """Multi-layer fusion token adapter.
  36. Args:
  37. in_chs: List of input channel counts for each feature scale.
  38. out_chs: The number of output channels.
  39. output_resolution: The output resolution.
  40. expansion_ratio: The FFN expansion ratio.
  41. interpolation_mode: The upsampling interpolation mode.
  42. layer_scale_init_value: The initial value of the layer scale, no layer scale if None.
  43. """
  44. def __init__(
  45. self,
  46. in_chs: Union[int, List[int]],
  47. out_chs: int,
  48. output_resolution: int,
  49. expansion_ratio: float = 2.0,
  50. interpolation_mode: str = "nearest",
  51. layer_scale_init_value: Optional[float] = None,
  52. noskip: bool = True,
  53. act_layer: Optional[LayerType] = None,
  54. norm_layer: Optional[LayerType] = None,
  55. device=None,
  56. dtype=None,
  57. ):
  58. dd = {'device': device, 'dtype': dtype}
  59. super().__init__()
  60. self.in_channels = sum(in_chs) if isinstance(in_chs, Sequence) else in_chs
  61. self.out_channels = out_chs
  62. self.output_resolution = to_2tuple(output_resolution)
  63. self.expansion_ratio = expansion_ratio
  64. self.interpolation_mode = interpolation_mode
  65. self.layer_scale_init_value = layer_scale_init_value
  66. self.noskip = noskip
  67. act_layer = act_layer or _GELU
  68. norm_layer = norm_layer or RmsNorm2d
  69. self.ffn = UniversalInvertedResidual(
  70. in_chs=self.in_channels,
  71. out_chs=self.out_channels,
  72. dw_kernel_size_mid=0,
  73. exp_ratio=self.expansion_ratio,
  74. act_layer=act_layer,
  75. norm_layer=norm_layer,
  76. noskip=self.noskip,
  77. layer_scale_init_value=self.layer_scale_init_value,
  78. **dd,
  79. )
  80. self.norm = norm_layer(self.out_channels, **dd)
  81. def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
  82. # Inputs list of [B, C, H, W] tensors
  83. high_resolution = inputs[0].shape[-2:] # Assuming the first input is the highest resolution.
  84. resized_inputs = []
  85. for _, img in enumerate(inputs):
  86. feat_size = img.shape[-2:]
  87. if feat_size[0] < high_resolution[0] or feat_size[1] < high_resolution[1]:
  88. img = F.interpolate(img, size=high_resolution, mode=self.interpolation_mode)
  89. resized_inputs.append(img)
  90. channel_cat_imgs = torch.cat(resized_inputs, dim=1) # Cat on channel dim, must equal self.in_channels
  91. img = self.ffn(channel_cat_imgs)
  92. if high_resolution[0] != self.output_resolution[0] or high_resolution[1] != self.output_resolution[1]:
  93. # Interpolate / pool to target output_resolution if highest feature resolution differs
  94. if (
  95. high_resolution[0] % self.output_resolution[0] != 0 or
  96. high_resolution[1] % self.output_resolution[1] != 0
  97. ):
  98. img = F.interpolate(img, size=self.output_resolution, mode="bilinear")
  99. else:
  100. h_strides = high_resolution[0] // self.output_resolution[0]
  101. w_strides = high_resolution[1] // self.output_resolution[1]
  102. img = F.avg_pool2d(
  103. img,
  104. kernel_size=(h_strides, w_strides),
  105. stride=(h_strides, w_strides),
  106. )
  107. img = self.norm(img)
  108. return img
  109. class MobileNetV5(nn.Module):
  110. """ MobiletNet-V5
  111. """
  112. def __init__(
  113. self,
  114. block_args: BlockArgs,
  115. num_classes: int = 1000,
  116. in_chans: int = 3,
  117. stem_size: int = 16,
  118. stem_bias: bool = True,
  119. fix_stem: bool = False,
  120. num_features: int = 2048,
  121. pad_type: str = '',
  122. use_msfa: bool = True,
  123. msfa_indices: List[int] = (-2, -1),
  124. msfa_output_resolution: int = 16,
  125. act_layer: Optional[LayerType] = None,
  126. norm_layer: Optional[LayerType] = None,
  127. aa_layer: Optional[LayerType] = None,
  128. se_layer: Optional[LayerType] = None,
  129. se_from_exp: bool = True,
  130. round_chs_fn: Callable = round_channels,
  131. drop_rate: float = 0.,
  132. drop_path_rate: float = 0.,
  133. layer_scale_init_value: Optional[float] = None,
  134. global_pool: str = 'avg',
  135. device=None,
  136. dtype=None,
  137. ):
  138. """
  139. Args:
  140. block_args: Arguments for blocks of the network.
  141. num_classes: Number of classes for classification head.
  142. in_chans: Number of input image channels.
  143. stem_size: Number of output channels of the initial stem convolution.
  144. fix_stem: If True, don't scale stem by round_chs_fn.
  145. num_features: Number of output channels of the conv head layer.
  146. head_bias: If True, add a learnable bias to the conv head layer.
  147. pad_type: Type of padding to use for convolution layers.
  148. act_layer: Type of activation layer.
  149. norm_layer: Type of normalization layer.
  150. aa_layer: Type of anti-aliasing layer.
  151. se_layer: Type of Squeeze-and-Excite layer.
  152. se_from_exp: If True, calculate SE channel reduction from expanded mid channels.
  153. round_chs_fn: Callable to round number of filters based on depth multiplier.
  154. drop_rate: Dropout rate.
  155. drop_path_rate: Stochastic depth rate.
  156. layer_scale_init_value: Enable layer scale on compatible blocks if not None.
  157. global_pool: Type of pooling to use for global pooling features of the FC head.
  158. """
  159. super().__init__()
  160. dd = {'device': device, 'dtype': dtype}
  161. act_layer = act_layer or _GELU
  162. norm_layer = get_norm_layer(norm_layer) or RmsNorm2d
  163. norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
  164. se_layer = se_layer or SqueezeExcite
  165. self.num_classes = num_classes
  166. self.in_chans = in_chans
  167. self.drop_rate = drop_rate
  168. self.grad_checkpointing = False
  169. self.msfa_indices = msfa_indices
  170. self.msfa_output_resolution = msfa_output_resolution
  171. # Stem
  172. if not fix_stem:
  173. stem_size = round_chs_fn(stem_size)
  174. self.conv_stem = ConvNormAct(
  175. in_chans,
  176. stem_size,
  177. kernel_size=3,
  178. stride=2,
  179. padding=pad_type,
  180. bias=stem_bias,
  181. norm_layer=norm_layer,
  182. act_layer=act_layer,
  183. **dd,
  184. )
  185. # Middle stages (IR/ER/DS Blocks)
  186. builder = EfficientNetBuilder(
  187. output_stride=32,
  188. pad_type=pad_type,
  189. round_chs_fn=round_chs_fn,
  190. se_from_exp=se_from_exp,
  191. act_layer=act_layer,
  192. norm_layer=norm_layer,
  193. aa_layer=aa_layer,
  194. se_layer=se_layer,
  195. drop_path_rate=drop_path_rate,
  196. layer_scale_init_value=layer_scale_init_value,
  197. **dd,
  198. )
  199. self.blocks = nn.Sequential(*builder(stem_size, block_args))
  200. self.feature_info = builder.features
  201. self.stage_ends = [f['stage'] for f in self.feature_info]
  202. self.num_features = builder.in_chs # features of last stage, output of forward_features()
  203. # Neck (aggregation) + Head + Pooling
  204. if use_msfa:
  205. self.num_features = self.head_hidden_size = num_features # output of msfa is output of forward_features()
  206. # Map msfa indices to feature info and calculate sum of feature channels
  207. self.msfa_indices = feature_take_indices(len(self.feature_info), self.msfa_indices)[0]
  208. self.msfa_in_chs = sum([self.feature_info[mi]['num_chs'] for mi in self.msfa_indices])
  209. self.msfa = MobileNetV5MultiScaleFusionAdapter(
  210. in_chs=self.msfa_in_chs,
  211. out_chs=num_features,
  212. output_resolution=self.msfa_output_resolution,
  213. norm_layer=norm_layer,
  214. act_layer=act_layer,
  215. **dd,
  216. )
  217. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  218. self.conv_head = None
  219. self.norm_head = None
  220. else:
  221. self.num_features = builder.in_chs # features of last stage, output of forward_features()
  222. self.head_hidden_size = num_features
  223. self.msfa = None
  224. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  225. num_pooled_chs = self.num_features * self.global_pool.feat_mult()
  226. # mobilenet-v4 style post-pooling PW conv is followed by a norm+act layer
  227. self.conv_head = create_conv2d(num_pooled_chs, self.head_hidden_size, 1, padding=pad_type, **dd)
  228. self.norm_head = norm_act_layer(self.head_hidden_size, **dd)
  229. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  230. self.classifier = Linear(self.head_hidden_size, num_classes, **dd) if num_classes > 0 else nn.Identity()
  231. efficientnet_init_weights(self)
  232. def as_sequential(self):
  233. layers = [self.conv_stem, self.bn1]
  234. layers.extend(self.blocks)
  235. layers.append(self.global_pool)
  236. if self.conv_head is not None:
  237. layers.append(self.conv_head)
  238. if self.norm_head is not None:
  239. layers.append(self.norm_head)
  240. layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
  241. return nn.Sequential(*layers)
  242. @torch.jit.ignore
  243. def group_matcher(self, coarse: bool = False):
  244. return dict(
  245. stem=r'^conv_stem|bn1',
  246. blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)'
  247. )
  248. @torch.jit.ignore
  249. def set_grad_checkpointing(self, enable: bool = True):
  250. self.grad_checkpointing = enable
  251. @torch.jit.ignore
  252. def get_classifier(self) -> nn.Module:
  253. return self.classifier
  254. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  255. self.num_classes = num_classes
  256. # NOTE: cannot meaningfully change pooling of efficient head after creation
  257. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  258. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  259. self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
  260. def forward_intermediates(
  261. self,
  262. x: torch.Tensor,
  263. indices: Optional[Union[int, List[int]]] = None,
  264. norm: bool = False,
  265. stop_early: bool = False,
  266. output_fmt: str = 'NCHW',
  267. intermediates_only: bool = False,
  268. extra_blocks: bool = False,
  269. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  270. """ Forward features that returns intermediates.
  271. Args:
  272. x: Input image tensor
  273. indices: Take last n blocks if int, all if None, select matching indices if sequence
  274. norm: Apply norm layer to compatible intermediates
  275. stop_early: Stop iterating over blocks when last desired intermediate hit
  276. output_fmt: Shape of intermediate feature outputs
  277. intermediates_only: Only return intermediate features
  278. extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info
  279. Returns:
  280. """
  281. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  282. if stop_early:
  283. assert intermediates_only, 'Must use intermediates_only for early stopping.'
  284. intermediates = []
  285. if extra_blocks:
  286. take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
  287. else:
  288. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  289. take_indices = [self.stage_ends[i] for i in take_indices]
  290. max_index = self.stage_ends[max_index]
  291. # FIXME MFSA and forward_intermediates overlap, they both take indices from specific features
  292. # When a user wants to grab specific feature maps for a downstream task AND have the msfa output
  293. # what should we do? Accumulate two intermediates? One for msfa and one for take_indices?
  294. # forward pass
  295. feat_idx = 0 # stem is index 0
  296. x = self.conv_stem(x)
  297. if feat_idx in take_indices:
  298. intermediates.append(x)
  299. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  300. blocks = self.blocks
  301. else:
  302. blocks = self.blocks[:max_index]
  303. for blk in blocks:
  304. feat_idx += 1
  305. x = blk(x)
  306. if feat_idx in take_indices:
  307. intermediates.append(x)
  308. if intermediates_only:
  309. return intermediates
  310. # FIXME see note above
  311. # self.msfa(msfa_intermediatse)
  312. return x, intermediates
  313. def prune_intermediate_layers(
  314. self,
  315. indices: Union[int, List[int]] = 1,
  316. prune_norm: bool = False,
  317. prune_head: bool = True,
  318. extra_blocks: bool = False,
  319. ):
  320. """ Prune layers not required for specified intermediates.
  321. """
  322. if extra_blocks:
  323. take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
  324. else:
  325. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  326. max_index = self.stage_ends[max_index]
  327. self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0
  328. if max_index < len(self.blocks):
  329. self.conv_head = None
  330. self.norm_head = None
  331. if prune_head:
  332. self.conv_head = None
  333. self.norm_head = None
  334. self.reset_classifier(0, '')
  335. return take_indices
  336. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  337. if self.msfa is not None:
  338. # When MSFA aggregation layer is present, we gather intermediates as is forward_intermediates
  339. feat_idx = 0 # offset by one from blocks index due to stem feature
  340. intermediates = []
  341. x = self.conv_stem(x)
  342. if feat_idx in self.msfa_indices:
  343. intermediates.append(x)
  344. for blk in self.blocks:
  345. feat_idx += 1
  346. # FIXME fix grad checkpointing
  347. x = blk(x)
  348. if feat_idx in self.msfa_indices:
  349. intermediates.append(x)
  350. x = self.msfa(intermediates)
  351. else:
  352. x = self.conv_stem(x)
  353. if self.grad_checkpointing and not torch.jit.is_scripting():
  354. x = checkpoint_seq(self.blocks, x, flatten=True)
  355. else:
  356. x = self.blocks(x)
  357. return x
  358. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  359. x = self.global_pool(x)
  360. if self.conv_head is not None:
  361. x = self.conv_head(x)
  362. if self.norm_head is not None:
  363. x = self.norm_head(x)
  364. x = self.flatten(x)
  365. if self.drop_rate > 0.:
  366. x = F.dropout(x, p=self.drop_rate, training=self.training)
  367. if pre_logits:
  368. return x
  369. return self.classifier(x)
  370. def forward(self, x: torch.Tensor) -> torch.Tensor:
  371. x = self.forward_features(x)
  372. x = self.forward_head(x)
  373. return x
  374. class MobileNetV5Encoder(nn.Module):
  375. """MobileNetV5 Vision Encoder"""
  376. def __init__(
  377. self,
  378. block_args: BlockArgs,
  379. in_chans: int = 3,
  380. stem_size: int = 64,
  381. stem_bias: bool = True,
  382. fix_stem: bool = False,
  383. pad_type: str = '',
  384. msfa_indices: Sequence[int] = (-2, -1),
  385. msfa_output_resolution: int = 16,
  386. act_layer: Optional[LayerType] = None,
  387. norm_layer: Optional[LayerType] = None,
  388. aa_layer: Optional[LayerType] = None,
  389. se_layer: Optional[LayerType] = None,
  390. se_from_exp: bool = True,
  391. round_chs_fn: Callable = round_channels,
  392. drop_rate: float = 0.,
  393. drop_path_rate: float = 0.,
  394. layer_scale_init_value: Optional[float] = None,
  395. device=None,
  396. dtype=None,
  397. ):
  398. super().__init__()
  399. dd = {'device': device, 'dtype': dtype}
  400. act_layer = act_layer or _GELU
  401. norm_layer = get_norm_layer(norm_layer) or RmsNorm2d
  402. se_layer = se_layer or SqueezeExcite
  403. self.num_classes = 0 # Exists to satisfy ._hub module APIs.
  404. self.in_chans = in_chans
  405. self.drop_rate = drop_rate
  406. self.grad_checkpointing = False
  407. # Stem
  408. if not fix_stem:
  409. stem_size = round_chs_fn(stem_size)
  410. self.conv_stem = ConvNormAct(
  411. in_chans,
  412. stem_size,
  413. kernel_size=3,
  414. stride=2,
  415. padding=pad_type,
  416. bias=stem_bias,
  417. norm_layer=norm_layer,
  418. act_layer=act_layer,
  419. **dd,
  420. )
  421. builder = EfficientNetBuilder(
  422. output_stride=32,
  423. pad_type=pad_type,
  424. round_chs_fn=round_chs_fn,
  425. se_from_exp=se_from_exp,
  426. act_layer=act_layer,
  427. norm_layer=norm_layer,
  428. aa_layer=aa_layer,
  429. se_layer=se_layer,
  430. drop_path_rate=drop_path_rate,
  431. layer_scale_init_value=layer_scale_init_value,
  432. **dd,
  433. )
  434. self.blocks = nn.Sequential(*builder(stem_size, block_args))
  435. self.feature_info = builder.features
  436. self.stage_ends = [f['stage'] for f in self.feature_info]
  437. self.num_features = self.head_hidden_size = 2048 # output of msfa is output of forward_features()
  438. # Map msfa indices to feature info and calculate sum of feature channels
  439. self.msfa_indices = feature_take_indices(len(self.feature_info), msfa_indices)[0]
  440. self.msfa_in_chs = sum([self.feature_info[mi]['num_chs'] for mi in self.msfa_indices])
  441. self.msfa_output_resolution = msfa_output_resolution
  442. self.msfa = MobileNetV5MultiScaleFusionAdapter(
  443. in_chs=self.msfa_in_chs,
  444. out_chs=self.num_features,
  445. output_resolution=self.msfa_output_resolution,
  446. norm_layer=norm_layer,
  447. act_layer=act_layer,
  448. **dd,
  449. )
  450. efficientnet_init_weights(self)
  451. def forward_intermediates(
  452. self,
  453. x: torch.Tensor,
  454. indices: Optional[Union[int, List[int]]] = None,
  455. norm: bool = False,
  456. stop_early: bool = False,
  457. output_fmt: str = 'NCHW',
  458. intermediates_only: bool = False,
  459. extra_blocks: bool = False,
  460. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  461. """ Forward features that returns intermediates.
  462. Args:
  463. x: Input image tensor
  464. indices: Take last n blocks if int, all if None, select matching indices if sequence
  465. norm: (Unused) Applies norm layer to compatible intermediates
  466. stop_early: Stop iterating over blocks when last desired intermediate hit
  467. output_fmt: Shape of intermediate feature outputs
  468. intermediates_only: Only return intermediate features
  469. extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info
  470. Returns:
  471. """
  472. del norm
  473. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  474. if stop_early:
  475. assert intermediates_only, 'Must use intermediates_only for early stopping.'
  476. # MobileNet v5's MultiScaleFusionAdapter takes intermediates from specific feature indicies and uses them in
  477. # its computation. These MSFA indices are not guaranteed to be captured by the `indices` parameter passed to
  478. # this function, so we accumulate two sets of indices, one that aligns with the `indices` parameter and one
  479. # that is required by the MSFA block.
  480. intermediates = []
  481. msfa_intermediates = []
  482. if extra_blocks:
  483. take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
  484. else:
  485. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  486. take_indices = [self.stage_ends[i] for i in take_indices]
  487. max_index = self.stage_ends[max_index]
  488. # forward pass
  489. feat_idx = 0 # stem is index 0
  490. x = self.conv_stem(x)
  491. if feat_idx in take_indices:
  492. intermediates.append(x)
  493. if feat_idx in self.msfa_indices:
  494. msfa_intermediates.append(x)
  495. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  496. blocks = self.blocks
  497. else:
  498. blocks = self.blocks[:max_index]
  499. for blk in blocks:
  500. feat_idx += 1
  501. x = blk(x)
  502. if feat_idx in take_indices:
  503. intermediates.append(x)
  504. if feat_idx in self.msfa_indices:
  505. msfa_intermediates.append(x)
  506. if intermediates_only:
  507. return intermediates
  508. return self.msfa(msfa_intermediates), intermediates
  509. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  510. feat_idx = 0 # offset by one from blocks index due to stem feature
  511. intermediates = []
  512. x = self.conv_stem(x)
  513. if feat_idx in self.msfa_indices:
  514. intermediates.append(x)
  515. for blk in self.blocks:
  516. feat_idx += 1
  517. # FIXME fix grad checkpointing
  518. x = blk(x)
  519. if feat_idx in self.msfa_indices:
  520. intermediates.append(x)
  521. return self.msfa(intermediates)
  522. def forward_head(self, x: torch.Tensor) -> torch.Tensor:
  523. raise NotImplementedError("MobileNetV5Encoder does not support classification use cases.")
  524. def forward(self, x: torch.Tensor) -> torch.Tensor:
  525. return self.forward_features(x)
  526. def checkpoint_filter_fn(
  527. state_dict: Dict[str, torch.Tensor],
  528. model,
  529. ) -> Dict[str, torch.Tensor]:
  530. """ convert weights from gemma encoders """
  531. state_dict = state_dict.get('model', state_dict)
  532. state_dict = state_dict.get('state_dict', state_dict)
  533. if 'model.vision_tower.timm_model.conv_stem.conv.weight' in state_dict:
  534. prefix = 'model.vision_tower.timm_model.'
  535. state_dict = {k.replace(prefix, ''): v for k, v in state_dict.items() if prefix in k}
  536. return state_dict
  537. def _create_mnv5_encoder(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV5Encoder:
  538. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
  539. feature_cfg = dict(out_indices=out_indices, feature_cls='getter')
  540. kwargs_filter = (
  541. 'num_classes',
  542. 'num_features',
  543. 'head_conv',
  544. 'head_bias',
  545. 'head_norm',
  546. 'global_pool',
  547. )
  548. model = build_model_with_cfg(
  549. MobileNetV5Encoder,
  550. variant,
  551. pretrained,
  552. pretrained_strict=False,
  553. pretrained_filter_fn=checkpoint_filter_fn,
  554. feature_cfg=feature_cfg,
  555. kwargs_filter=kwargs_filter,
  556. **kwargs,
  557. )
  558. return model
  559. def _create_mnv5(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV5:
  560. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
  561. feature_cfg = dict(out_indices=out_indices, feature_cls='getter')
  562. model = build_model_with_cfg(
  563. MobileNetV5,
  564. variant,
  565. pretrained,
  566. pretrained_filter_fn=checkpoint_filter_fn,
  567. feature_cfg=feature_cfg,
  568. **kwargs,
  569. )
  570. return model
  571. def _gen_mobilenet_v5(
  572. variant: str,
  573. channel_multiplier: float = 1.0,
  574. group_size=None,
  575. pretrained: bool = False,
  576. encoder: bool = False,
  577. **kwargs,
  578. ) -> MobileNetV5Encoder:
  579. if 'mobilenetv5_base' in variant:
  580. arch_def: list[list[str]] = [
  581. # Stage 0: 128x128 in
  582. [
  583. 'er_r1_k3_s2_e4_c128',
  584. 'er_r1_k3_s1_e4_c128',
  585. 'er_r1_k3_s1_e4_c128',
  586. ],
  587. # Stage 1: 256x256 in
  588. [
  589. 'uir_r1_a3_k5_s2_e6_c256',
  590. 'uir_r1_a5_k0_s1_e4_c256',
  591. 'uir_r1_a3_k0_s1_e4_c256',
  592. 'uir_r1_a5_k0_s1_e4_c256',
  593. 'uir_r1_a3_k0_s1_e4_c256',
  594. ],
  595. # Stage 2: 640x640 in
  596. [
  597. "uir_r1_a5_k5_s2_e6_c512",
  598. "uir_r1_a5_k0_s1_e4_c512",
  599. "uir_r1_a5_k0_s1_e4_c512",
  600. "uir_r1_a0_k0_s1_e1_c512",
  601. 'mqa_r1_k3_h8_s2_d64_c512',
  602. "uir_r1_a0_k0_s1_e2_c512",
  603. 'mqa_r1_k3_h8_s2_d64_c512',
  604. "uir_r1_a0_k0_s1_e2_c512",
  605. 'mqa_r1_k3_h8_s2_d64_c512',
  606. "uir_r1_a0_k0_s1_e2_c512",
  607. 'mqa_r1_k3_h8_s2_d64_c512',
  608. "uir_r1_a0_k0_s1_e2_c512",
  609. 'mqa_r1_k3_h8_s2_d64_c512',
  610. "uir_r1_a0_k0_s1_e2_c512",
  611. 'mqa_r1_k3_h8_s2_d64_c512',
  612. "uir_r1_a0_k0_s1_e2_c512",
  613. ],
  614. # Stage 3: 1280x1280 in
  615. [
  616. "uir_r1_a5_k5_s2_e6_c1024",
  617. 'mqa_r1_k3_h16_s1_d64_c1024',
  618. "uir_r1_a0_k0_s1_e2_c1024",
  619. 'mqa_r1_k3_h16_s1_d64_c1024',
  620. "uir_r1_a0_k0_s1_e2_c1024",
  621. 'mqa_r1_k3_h16_s1_d64_c1024',
  622. "uir_r1_a0_k0_s1_e2_c1024",
  623. 'mqa_r1_k3_h16_s1_d64_c1024',
  624. "uir_r1_a0_k0_s1_e2_c1024",
  625. 'mqa_r1_k3_h16_s1_d64_c1024',
  626. "uir_r1_a0_k0_s1_e2_c1024",
  627. 'mqa_r1_k3_h16_s1_d64_c1024',
  628. "uir_r1_a0_k0_s1_e2_c1024",
  629. 'mqa_r1_k3_h16_s1_d64_c1024',
  630. "uir_r1_a0_k0_s1_e2_c1024",
  631. ],
  632. ]
  633. else:
  634. arch_def: list[list[str]] = [
  635. # Stage 0: 128x128 in
  636. [
  637. 'er_r1_k3_s2_e4_c128',
  638. 'er_r1_k3_s1_e4_c128',
  639. 'er_r1_k3_s1_e4_c128',
  640. ],
  641. # Stage 1: 256x256 in
  642. [
  643. 'uir_r1_a3_k5_s2_e6_c256',
  644. 'uir_r1_a5_k0_s1_e4_c256',
  645. 'uir_r1_a3_k0_s1_e4_c256',
  646. 'uir_r1_a5_k0_s1_e4_c256',
  647. 'uir_r1_a3_k0_s1_e4_c256',
  648. ],
  649. # Stage 2: 640x640 in
  650. [
  651. "uir_r1_a5_k5_s2_e6_c640",
  652. "uir_r1_a5_k0_s1_e4_c640",
  653. "uir_r1_a5_k0_s1_e4_c640",
  654. "uir_r1_a5_k0_s1_e4_c640",
  655. "uir_r1_a5_k0_s1_e4_c640",
  656. "uir_r1_a5_k0_s1_e4_c640",
  657. "uir_r1_a5_k0_s1_e4_c640",
  658. "uir_r1_a5_k0_s1_e4_c640",
  659. "uir_r1_a0_k0_s1_e1_c640",
  660. "mqa_r1_k3_h12_v2_s1_d64_c640",
  661. "uir_r1_a0_k0_s1_e2_c640",
  662. "mqa_r1_k3_h12_v2_s1_d64_c640",
  663. "uir_r1_a0_k0_s1_e2_c640",
  664. "mqa_r1_k3_h12_v2_s1_d64_c640",
  665. "uir_r1_a0_k0_s1_e2_c640",
  666. "mqa_r1_k3_h12_v2_s1_d64_c640",
  667. "uir_r1_a0_k0_s1_e2_c640",
  668. "mqa_r1_k3_h12_v2_s1_d64_c640",
  669. "uir_r1_a0_k0_s1_e2_c640",
  670. "mqa_r1_k3_h12_v2_s1_d64_c640",
  671. "uir_r1_a0_k0_s1_e2_c640",
  672. "mqa_r1_k3_h12_v2_s1_d64_c640",
  673. "uir_r1_a0_k0_s1_e2_c640",
  674. "mqa_r1_k3_h12_v2_s1_d64_c640",
  675. "uir_r1_a0_k0_s1_e2_c640",
  676. "mqa_r1_k3_h12_v2_s1_d64_c640",
  677. "uir_r1_a0_k0_s1_e2_c640",
  678. "mqa_r1_k3_h12_v2_s1_d64_c640",
  679. "uir_r1_a0_k0_s1_e2_c640",
  680. "mqa_r1_k3_h12_v2_s1_d64_c640",
  681. "uir_r1_a0_k0_s1_e2_c640",
  682. "mqa_r1_k3_h12_v2_s1_d64_c640",
  683. "uir_r1_a0_k0_s1_e2_c640",
  684. "mqa_r1_k3_h12_v2_s1_d64_c640",
  685. "uir_r1_a0_k0_s1_e2_c640",
  686. "mqa_r1_k3_h12_v2_s1_d64_c640",
  687. "uir_r1_a0_k0_s1_e2_c640",
  688. ],
  689. # Stage 3: 1280x1280 in
  690. [
  691. "uir_r1_a5_k5_s2_e6_c1280",
  692. "mqa_r1_k3_h16_s1_d96_c1280",
  693. "uir_r1_a0_k0_s1_e2_c1280",
  694. "mqa_r1_k3_h16_s1_d96_c1280",
  695. "uir_r1_a0_k0_s1_e2_c1280",
  696. "mqa_r1_k3_h16_s1_d96_c1280",
  697. "uir_r1_a0_k0_s1_e2_c1280",
  698. "mqa_r1_k3_h16_s1_d96_c1280",
  699. "uir_r1_a0_k0_s1_e2_c1280",
  700. "mqa_r1_k3_h16_s1_d96_c1280",
  701. "uir_r1_a0_k0_s1_e2_c1280",
  702. "mqa_r1_k3_h16_s1_d96_c1280",
  703. "uir_r1_a0_k0_s1_e2_c1280",
  704. "mqa_r1_k3_h16_s1_d96_c1280",
  705. "uir_r1_a0_k0_s1_e2_c1280",
  706. "mqa_r1_k3_h16_s1_d96_c1280",
  707. "uir_r1_a0_k0_s1_e2_c1280",
  708. "mqa_r1_k3_h16_s1_d96_c1280",
  709. "uir_r1_a0_k0_s1_e2_c1280",
  710. "mqa_r1_k3_h16_s1_d96_c1280",
  711. "uir_r1_a0_k0_s1_e2_c1280",
  712. "mqa_r1_k3_h16_s1_d96_c1280",
  713. "uir_r1_a0_k0_s1_e2_c1280",
  714. "mqa_r1_k3_h16_s1_d96_c1280",
  715. "uir_r1_a0_k0_s1_e2_c1280",
  716. "mqa_r1_k3_h16_s1_d96_c1280",
  717. "uir_r1_a0_k0_s1_e2_c1280",
  718. "mqa_r1_k3_h16_s1_d96_c1280",
  719. "uir_r1_a0_k0_s1_e2_c1280",
  720. "mqa_r1_k3_h16_s1_d96_c1280",
  721. "uir_r1_a0_k0_s1_e2_c1280",
  722. "mqa_r1_k3_h16_s1_d96_c1280",
  723. "uir_r1_a0_k0_s1_e2_c1280",
  724. "mqa_r1_k3_h16_s1_d96_c1280",
  725. "uir_r1_a0_k0_s1_e2_c1280",
  726. "mqa_r1_k3_h16_s1_d96_c1280",
  727. "uir_r1_a0_k0_s1_e2_c1280",
  728. "mqa_r1_k3_h16_s1_d96_c1280",
  729. "uir_r1_a0_k0_s1_e2_c1280",
  730. ],
  731. ]
  732. model_kwargs = dict(
  733. block_args=decode_arch_def(arch_def, group_size=group_size),
  734. stem_size=64,
  735. fix_stem=channel_multiplier < 1.0,
  736. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  737. norm_layer=RmsNorm2d,
  738. act_layer=_GELU,
  739. layer_scale_init_value=1e-5,
  740. )
  741. model_kwargs = dict(model_kwargs, **kwargs)
  742. if encoder:
  743. model = _create_mnv5_encoder(variant, pretrained, **model_kwargs)
  744. else:
  745. model = _create_mnv5(variant, pretrained, **model_kwargs)
  746. return model
  747. def _cfg(url: str = '', **kwargs):
  748. return {
  749. 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (16, 16),
  750. 'crop_pct': 1.0, 'interpolation': 'bicubic',
  751. 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
  752. 'first_conv': 'conv_stem.conv', 'classifier': 'classifier',
  753. **kwargs
  754. }
  755. default_cfgs = generate_default_cfgs({
  756. # Encoder-only config for Gemma 3n Transformers integration
  757. 'mobilenetv5_300m_enc': _cfg(
  758. mean=(0., 0., 0.), std=(1., 1., 1.),
  759. input_size=(3, 768, 768),
  760. num_classes=0),
  761. # Gemma 3n encoder weights for timm use / fine-tune
  762. 'mobilenetv5_300m.gemma3n': _cfg(
  763. hf_hub_id='timm/',
  764. mean=(0., 0., 0.), std=(1., 1., 1.),
  765. input_size=(3, 768, 768),
  766. num_classes=0,
  767. license='gemma'),
  768. # WIP classification configs for testing
  769. 'mobilenetv5_base.untrained': _cfg(
  770. # hf_hub_id='timm/',
  771. num_classes=1000)
  772. })
  773. @register_model
  774. def mobilenetv5_300m_enc(pretrained: bool = False, **kwargs) -> MobileNetV5Encoder:
  775. """MobileNet V5 Vision Encoder"""
  776. pad_type = kwargs.pop('pad_type', 'same')
  777. model = _gen_mobilenet_v5(
  778. 'mobilenetv5_300m_enc',
  779. pretrained=pretrained,
  780. encoder=True,
  781. pad_type=pad_type,
  782. **kwargs,
  783. )
  784. return model
  785. @register_model
  786. def mobilenetv5_300m(pretrained: bool = False, **kwargs) -> MobileNetV5:
  787. model = _gen_mobilenet_v5('mobilenetv5_300m', pretrained=pretrained, **kwargs)
  788. return model
  789. @register_model
  790. def mobilenetv5_base(pretrained: bool = False, **kwargs) -> MobileNetV5:
  791. model = _gen_mobilenet_v5('mobilenetv5_base', pretrained=pretrained, **kwargs)
  792. return model