mobilevit.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710
  1. """ MobileViT
  2. Paper:
  3. V1: `MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178
  4. V2: `Separable Self-attention for Mobile Vision Transformers` - https://arxiv.org/abs/2206.02680
  5. MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below)
  6. License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source)
  7. Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022, Ross Wightman
  8. """
  9. #
  10. # For licensing see accompanying LICENSE file.
  11. # Copyright (C) 2020 Apple Inc. All Rights Reserved.
  12. #
  13. import math
  14. from typing import Callable, Tuple, Optional, Type
  15. import torch
  16. import torch.nn.functional as F
  17. from torch import nn
  18. from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath, is_exportable
  19. from ._builder import build_model_with_cfg
  20. from ._features_fx import register_notrace_module
  21. from ._registry import register_model, generate_default_cfgs, register_model_deprecations
  22. from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups
  23. from .vision_transformer import Block as TransformerBlock
  24. __all__ = []
  25. def _inverted_residual_block(d, c, s, br=4.0):
  26. # inverted residual is a bottleneck block with bottle_ratio > 1 applied to in_chs, linear output, gs=1 (depthwise)
  27. return ByoBlockCfg(
  28. type='bottle', d=d, c=c, s=s, gs=1, br=br,
  29. block_kwargs=dict(bottle_in=True, linear_out=True))
  30. def _mobilevit_block(d, c, s, transformer_dim, transformer_depth, patch_size=4, br=4.0):
  31. # inverted residual + mobilevit blocks as per MobileViT network
  32. return (
  33. _inverted_residual_block(d=d, c=c, s=s, br=br),
  34. ByoBlockCfg(
  35. type='mobilevit', d=1, c=c, s=1,
  36. block_kwargs=dict(
  37. transformer_dim=transformer_dim,
  38. transformer_depth=transformer_depth,
  39. patch_size=patch_size)
  40. )
  41. )
  42. def _mobilevitv2_block(d, c, s, transformer_depth, patch_size=2, br=2.0, transformer_br=0.5):
  43. # inverted residual + mobilevit blocks as per MobileViT network
  44. return (
  45. _inverted_residual_block(d=d, c=c, s=s, br=br),
  46. ByoBlockCfg(
  47. type='mobilevit2', d=1, c=c, s=1, br=transformer_br, gs=1,
  48. block_kwargs=dict(
  49. transformer_depth=transformer_depth,
  50. patch_size=patch_size)
  51. )
  52. )
  53. def _mobilevitv2_cfg(multiplier=1.0):
  54. chs = (64, 128, 256, 384, 512)
  55. if multiplier != 1.0:
  56. chs = tuple([int(c * multiplier) for c in chs])
  57. cfg = ByoModelCfg(
  58. blocks=(
  59. _inverted_residual_block(d=1, c=chs[0], s=1, br=2.0),
  60. _inverted_residual_block(d=2, c=chs[1], s=2, br=2.0),
  61. _mobilevitv2_block(d=1, c=chs[2], s=2, transformer_depth=2),
  62. _mobilevitv2_block(d=1, c=chs[3], s=2, transformer_depth=4),
  63. _mobilevitv2_block(d=1, c=chs[4], s=2, transformer_depth=3),
  64. ),
  65. stem_chs=int(32 * multiplier),
  66. stem_type='3x3',
  67. stem_pool='',
  68. downsample='',
  69. act_layer='silu',
  70. )
  71. return cfg
  72. model_cfgs = dict(
  73. mobilevit_xxs=ByoModelCfg(
  74. blocks=(
  75. _inverted_residual_block(d=1, c=16, s=1, br=2.0),
  76. _inverted_residual_block(d=3, c=24, s=2, br=2.0),
  77. _mobilevit_block(d=1, c=48, s=2, transformer_dim=64, transformer_depth=2, patch_size=2, br=2.0),
  78. _mobilevit_block(d=1, c=64, s=2, transformer_dim=80, transformer_depth=4, patch_size=2, br=2.0),
  79. _mobilevit_block(d=1, c=80, s=2, transformer_dim=96, transformer_depth=3, patch_size=2, br=2.0),
  80. ),
  81. stem_chs=16,
  82. stem_type='3x3',
  83. stem_pool='',
  84. downsample='',
  85. act_layer='silu',
  86. num_features=320,
  87. ),
  88. mobilevit_xs=ByoModelCfg(
  89. blocks=(
  90. _inverted_residual_block(d=1, c=32, s=1),
  91. _inverted_residual_block(d=3, c=48, s=2),
  92. _mobilevit_block(d=1, c=64, s=2, transformer_dim=96, transformer_depth=2, patch_size=2),
  93. _mobilevit_block(d=1, c=80, s=2, transformer_dim=120, transformer_depth=4, patch_size=2),
  94. _mobilevit_block(d=1, c=96, s=2, transformer_dim=144, transformer_depth=3, patch_size=2),
  95. ),
  96. stem_chs=16,
  97. stem_type='3x3',
  98. stem_pool='',
  99. downsample='',
  100. act_layer='silu',
  101. num_features=384,
  102. ),
  103. mobilevit_s=ByoModelCfg(
  104. blocks=(
  105. _inverted_residual_block(d=1, c=32, s=1),
  106. _inverted_residual_block(d=3, c=64, s=2),
  107. _mobilevit_block(d=1, c=96, s=2, transformer_dim=144, transformer_depth=2, patch_size=2),
  108. _mobilevit_block(d=1, c=128, s=2, transformer_dim=192, transformer_depth=4, patch_size=2),
  109. _mobilevit_block(d=1, c=160, s=2, transformer_dim=240, transformer_depth=3, patch_size=2),
  110. ),
  111. stem_chs=16,
  112. stem_type='3x3',
  113. stem_pool='',
  114. downsample='',
  115. act_layer='silu',
  116. num_features=640,
  117. ),
  118. semobilevit_s=ByoModelCfg(
  119. blocks=(
  120. _inverted_residual_block(d=1, c=32, s=1),
  121. _inverted_residual_block(d=3, c=64, s=2),
  122. _mobilevit_block(d=1, c=96, s=2, transformer_dim=144, transformer_depth=2, patch_size=2),
  123. _mobilevit_block(d=1, c=128, s=2, transformer_dim=192, transformer_depth=4, patch_size=2),
  124. _mobilevit_block(d=1, c=160, s=2, transformer_dim=240, transformer_depth=3, patch_size=2),
  125. ),
  126. stem_chs=16,
  127. stem_type='3x3',
  128. stem_pool='',
  129. downsample='',
  130. attn_layer='se',
  131. attn_kwargs=dict(rd_ratio=1/8),
  132. num_features=640,
  133. ),
  134. mobilevitv2_050=_mobilevitv2_cfg(.50),
  135. mobilevitv2_075=_mobilevitv2_cfg(.75),
  136. mobilevitv2_125=_mobilevitv2_cfg(1.25),
  137. mobilevitv2_100=_mobilevitv2_cfg(1.0),
  138. mobilevitv2_150=_mobilevitv2_cfg(1.5),
  139. mobilevitv2_175=_mobilevitv2_cfg(1.75),
  140. mobilevitv2_200=_mobilevitv2_cfg(2.0),
  141. )
  142. @register_notrace_module
  143. class MobileVitBlock(nn.Module):
  144. """ MobileViT block
  145. Paper: https://arxiv.org/abs/2110.02178?context=cs.LG
  146. """
  147. def __init__(
  148. self,
  149. in_chs: int,
  150. out_chs: Optional[int] = None,
  151. kernel_size: int = 3,
  152. stride: int = 1,
  153. bottle_ratio: float = 1.0,
  154. group_size: Optional[int] = None,
  155. dilation: Tuple[int, int] = (1, 1),
  156. mlp_ratio: float = 2.0,
  157. transformer_dim: Optional[int] = None,
  158. transformer_depth: int = 2,
  159. patch_size: int = 8,
  160. num_heads: int = 4,
  161. attn_drop: float = 0.,
  162. drop: int = 0.,
  163. no_fusion: bool = False,
  164. drop_path_rate: float = 0.,
  165. layers: LayerFn = None,
  166. transformer_norm_layer: Type[nn.Module] = nn.LayerNorm,
  167. device=None,
  168. dtype=None,
  169. **kwargs, # eat unused args
  170. ):
  171. dd = {'device': device, 'dtype': dtype}
  172. super().__init__()
  173. layers = layers or LayerFn()
  174. groups = num_groups(group_size, in_chs)
  175. out_chs = out_chs or in_chs
  176. transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs)
  177. self.conv_kxk = layers.conv_norm_act(
  178. in_chs,
  179. in_chs,
  180. kernel_size=kernel_size,
  181. stride=stride,
  182. groups=groups,
  183. dilation=dilation[0],
  184. **dd,
  185. )
  186. self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False, **dd)
  187. self.transformer = nn.Sequential(*[
  188. TransformerBlock(
  189. transformer_dim,
  190. mlp_ratio=mlp_ratio,
  191. num_heads=num_heads,
  192. qkv_bias=True,
  193. attn_drop=attn_drop,
  194. proj_drop=drop,
  195. drop_path=drop_path_rate,
  196. act_layer=layers.act,
  197. norm_layer=transformer_norm_layer,
  198. **dd,
  199. )
  200. for _ in range(transformer_depth)
  201. ])
  202. self.norm = transformer_norm_layer(transformer_dim, **dd)
  203. self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, **dd)
  204. if no_fusion:
  205. self.conv_fusion = None
  206. else:
  207. self.conv_fusion = layers.conv_norm_act(in_chs + out_chs, out_chs, kernel_size=kernel_size, stride=1, **dd)
  208. self.patch_size = to_2tuple(patch_size)
  209. self.patch_area = self.patch_size[0] * self.patch_size[1]
  210. def forward(self, x: torch.Tensor) -> torch.Tensor:
  211. shortcut = x
  212. # Local representation
  213. x = self.conv_kxk(x)
  214. x = self.conv_1x1(x)
  215. # Unfold (feature map -> patches)
  216. patch_h, patch_w = self.patch_size
  217. B, C, H, W = x.shape
  218. new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
  219. num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
  220. num_patches = num_patch_h * num_patch_w # N
  221. interpolate = False
  222. if new_h != H or new_w != W:
  223. # Note: Padding can be done, but then it needs to be handled in attention function.
  224. x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
  225. interpolate = True
  226. # [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w]
  227. x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w, patch_w).transpose(1, 2)
  228. # [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w
  229. x = x.reshape(B, C, num_patches, self.patch_area).transpose(1, 3).reshape(B * self.patch_area, num_patches, -1)
  230. # Global representations
  231. x = self.transformer(x)
  232. x = self.norm(x)
  233. # Fold (patch -> feature map)
  234. # [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w]
  235. x = x.contiguous().view(B, self.patch_area, num_patches, -1)
  236. x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w, patch_h, patch_w)
  237. # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W]
  238. x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
  239. if interpolate:
  240. x = F.interpolate(x, size=(H, W), mode="bilinear", align_corners=False)
  241. x = self.conv_proj(x)
  242. if self.conv_fusion is not None:
  243. x = self.conv_fusion(torch.cat((shortcut, x), dim=1))
  244. return x
  245. class LinearSelfAttention(nn.Module):
  246. """
  247. This layer applies a self-attention with linear complexity, as described in `https://arxiv.org/abs/2206.02680`
  248. This layer can be used for self- as well as cross-attention.
  249. Args:
  250. embed_dim (int): :math:`C` from an expected input of size :math:`(N, C, H, W)`
  251. attn_drop (float): Dropout value for context scores. Default: 0.0
  252. bias (bool): Use bias in learnable layers. Default: True
  253. Shape:
  254. - Input: :math:`(N, C, P, N)` where :math:`N` is the batch size, :math:`C` is the input channels,
  255. :math:`P` is the number of pixels in the patch, and :math:`N` is the number of patches
  256. - Output: same as the input
  257. .. note::
  258. For MobileViTv2, we unfold the feature map [B, C, H, W] into [B, C, P, N] where P is the number of pixels
  259. in a patch and N is the number of patches. Because channel is the first dimension in this unfolded tensor,
  260. we use point-wise convolution (instead of a linear layer). This avoids a transpose operation (which may be
  261. expensive on resource-constrained devices) that may be required to convert the unfolded tensor from
  262. channel-first to channel-last format in case of a linear layer.
  263. """
  264. def __init__(
  265. self,
  266. embed_dim: int,
  267. attn_drop: float = 0.0,
  268. proj_drop: float = 0.0,
  269. bias: bool = True,
  270. device=None,
  271. dtype=None,
  272. ) -> None:
  273. dd = {'device': device, 'dtype': dtype}
  274. super().__init__()
  275. self.embed_dim = embed_dim
  276. self.qkv_proj = nn.Conv2d(
  277. in_channels=embed_dim,
  278. out_channels=1 + (2 * embed_dim),
  279. bias=bias,
  280. kernel_size=1,
  281. **dd,
  282. )
  283. self.attn_drop = nn.Dropout(attn_drop)
  284. self.out_proj = nn.Conv2d(
  285. in_channels=embed_dim,
  286. out_channels=embed_dim,
  287. bias=bias,
  288. kernel_size=1,
  289. **dd,
  290. )
  291. self.out_drop = nn.Dropout(proj_drop)
  292. def _forward_self_attn(self, x: torch.Tensor) -> torch.Tensor:
  293. # [B, C, P, N] --> [B, h + 2d, P, N]
  294. qkv = self.qkv_proj(x)
  295. # Project x into query, key and value
  296. # Query --> [B, 1, P, N]
  297. # value, key --> [B, d, P, N]
  298. query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1)
  299. # apply softmax along N dimension
  300. context_scores = F.softmax(query, dim=-1)
  301. context_scores = self.attn_drop(context_scores)
  302. # Compute context vector
  303. # [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] --> [B, d, P, 1]
  304. context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
  305. # combine context vector with values
  306. # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
  307. out = F.relu(value) * context_vector.expand_as(value)
  308. out = self.out_proj(out)
  309. out = self.out_drop(out)
  310. return out
  311. @torch.jit.ignore()
  312. def _forward_cross_attn(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
  313. # x --> [B, C, P, N]
  314. # x_prev = [B, C, P, M]
  315. batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape
  316. q_patch_area, q_num_patches = x.shape[-2:]
  317. assert (
  318. kv_patch_area == q_patch_area
  319. ), "The number of pixels in a patch for query and key_value should be the same"
  320. # compute query, key, and value
  321. # [B, C, P, M] --> [B, 1 + d, P, M]
  322. qk = F.conv2d(
  323. x_prev,
  324. weight=self.qkv_proj.weight[:self.embed_dim + 1],
  325. bias=self.qkv_proj.bias[:self.embed_dim + 1],
  326. )
  327. # [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M]
  328. query, key = qk.split([1, self.embed_dim], dim=1)
  329. # [B, C, P, N] --> [B, d, P, N]
  330. value = F.conv2d(
  331. x,
  332. weight=self.qkv_proj.weight[self.embed_dim + 1],
  333. bias=self.qkv_proj.bias[self.embed_dim + 1] if self.qkv_proj.bias is not None else None,
  334. )
  335. # apply softmax along M dimension
  336. context_scores = F.softmax(query, dim=-1)
  337. context_scores = self.attn_drop(context_scores)
  338. # compute context vector
  339. # [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M] --> [B, d, P, 1]
  340. context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
  341. # combine context vector with values
  342. # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
  343. out = F.relu(value) * context_vector.expand_as(value)
  344. out = self.out_proj(out)
  345. out = self.out_drop(out)
  346. return out
  347. def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
  348. if x_prev is None:
  349. return self._forward_self_attn(x)
  350. else:
  351. return self._forward_cross_attn(x, x_prev=x_prev)
  352. class LinearTransformerBlock(nn.Module):
  353. """
  354. This class defines the pre-norm transformer encoder with linear self-attention in `MobileViTv2 paper <>`_
  355. Args:
  356. embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, P, N)`
  357. mlp_ratio (float): Inner dimension ratio of the FFN relative to embed_dim
  358. drop (float): Dropout rate. Default: 0.0
  359. attn_drop (float): Dropout rate for attention in multi-head attention. Default: 0.0
  360. drop_path (float): Stochastic depth rate Default: 0.0
  361. norm_layer (Callable): Normalization layer. Default: layer_norm_2d
  362. Shape:
  363. - Input: :math:`(B, C_{in}, P, N)` where :math:`B` is batch size, :math:`C_{in}` is input embedding dim,
  364. :math:`P` is number of pixels in a patch, and :math:`N` is number of patches,
  365. - Output: same shape as the input
  366. """
  367. def __init__(
  368. self,
  369. embed_dim: int,
  370. mlp_ratio: float = 2.0,
  371. drop: float = 0.0,
  372. attn_drop: float = 0.0,
  373. drop_path: float = 0.0,
  374. act_layer: Optional[Type[nn.Module]] = None,
  375. norm_layer: Optional[Type[nn.Module]] = None,
  376. device=None,
  377. dtype=None,
  378. ) -> None:
  379. dd = {'device': device, 'dtype': dtype}
  380. super().__init__()
  381. act_layer = act_layer or nn.SiLU
  382. norm_layer = norm_layer or GroupNorm1
  383. self.norm1 = norm_layer(embed_dim, **dd)
  384. self.attn = LinearSelfAttention(embed_dim=embed_dim, attn_drop=attn_drop, proj_drop=drop, **dd)
  385. self.drop_path1 = DropPath(drop_path)
  386. self.norm2 = norm_layer(embed_dim, **dd)
  387. self.mlp = ConvMlp(
  388. in_features=embed_dim,
  389. hidden_features=int(embed_dim * mlp_ratio),
  390. act_layer=act_layer,
  391. drop=drop,
  392. **dd)
  393. self.drop_path2 = DropPath(drop_path)
  394. def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
  395. if x_prev is None:
  396. # self-attention
  397. x = x + self.drop_path1(self.attn(self.norm1(x)))
  398. else:
  399. # cross-attention
  400. res = x
  401. x = self.norm1(x) # norm
  402. x = self.attn(x, x_prev) # attn
  403. x = self.drop_path1(x) + res # residual
  404. # Feed forward network
  405. x = x + self.drop_path2(self.mlp(self.norm2(x)))
  406. return x
  407. @register_notrace_module
  408. class MobileVitV2Block(nn.Module):
  409. """
  410. This class defines the `MobileViTv2 block <>`_
  411. """
  412. def __init__(
  413. self,
  414. in_chs: int,
  415. out_chs: Optional[int] = None,
  416. kernel_size: int = 3,
  417. bottle_ratio: float = 1.0,
  418. group_size: Optional[int] = 1,
  419. dilation: Tuple[int, int] = (1, 1),
  420. mlp_ratio: float = 2.0,
  421. transformer_dim: Optional[int] = None,
  422. transformer_depth: int = 2,
  423. patch_size: int = 8,
  424. attn_drop: float = 0.,
  425. drop: int = 0.,
  426. drop_path_rate: float = 0.,
  427. layers: LayerFn = None,
  428. transformer_norm_layer: Type[nn.Module] = GroupNorm1,
  429. device=None,
  430. dtype=None,
  431. **kwargs, # eat unused args
  432. ):
  433. dd = {'device': device, 'dtype': dtype}
  434. super().__init__()
  435. layers = layers or LayerFn()
  436. groups = num_groups(group_size, in_chs)
  437. out_chs = out_chs or in_chs
  438. transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs)
  439. self.conv_kxk = layers.conv_norm_act(
  440. in_chs,
  441. in_chs,
  442. kernel_size=kernel_size,
  443. stride=1,
  444. groups=groups,
  445. dilation=dilation[0],
  446. **dd,
  447. )
  448. self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False, **dd)
  449. self.transformer = nn.Sequential(*[
  450. LinearTransformerBlock(
  451. transformer_dim,
  452. mlp_ratio=mlp_ratio,
  453. attn_drop=attn_drop,
  454. drop=drop,
  455. drop_path=drop_path_rate,
  456. act_layer=layers.act,
  457. norm_layer=transformer_norm_layer,
  458. **dd,
  459. )
  460. for _ in range(transformer_depth)
  461. ])
  462. self.norm = transformer_norm_layer(transformer_dim, **dd)
  463. self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, apply_act=False, **dd)
  464. self.patch_size = to_2tuple(patch_size)
  465. self.patch_area = self.patch_size[0] * self.patch_size[1]
  466. self.coreml_exportable = is_exportable()
  467. def forward(self, x: torch.Tensor) -> torch.Tensor:
  468. B, C, H, W = x.shape
  469. patch_h, patch_w = self.patch_size
  470. new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
  471. num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
  472. num_patches = num_patch_h * num_patch_w # N
  473. if new_h != H or new_w != W:
  474. x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=True)
  475. # Local representation
  476. x = self.conv_kxk(x)
  477. x = self.conv_1x1(x)
  478. # Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
  479. C = x.shape[1]
  480. if self.coreml_exportable:
  481. x = F.unfold(x, kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
  482. else:
  483. x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
  484. x = x.reshape(B, C, -1, num_patches)
  485. # Global representations
  486. x = self.transformer(x)
  487. x = self.norm(x)
  488. # Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
  489. if self.coreml_exportable:
  490. # adopted from https://github.com/apple/ml-cvnets/blob/main/cvnets/modules/mobilevit_block.py#L609-L624
  491. x = x.reshape(B, C * patch_h * patch_w, num_patch_h, num_patch_w)
  492. x = F.pixel_shuffle(x, upscale_factor=patch_h)
  493. else:
  494. x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
  495. x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
  496. x = self.conv_proj(x)
  497. return x
  498. register_block('mobilevit', MobileVitBlock)
  499. register_block('mobilevit2', MobileVitV2Block)
  500. def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs):
  501. return build_model_with_cfg(
  502. ByobNet, variant, pretrained,
  503. model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
  504. feature_cfg=dict(flatten_sequential=True),
  505. **kwargs)
  506. def _create_mobilevit2(variant, cfg_variant=None, pretrained=False, **kwargs):
  507. return build_model_with_cfg(
  508. ByobNet, variant, pretrained,
  509. model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
  510. feature_cfg=dict(flatten_sequential=True),
  511. **kwargs)
  512. def _cfg(url='', **kwargs):
  513. return {
  514. 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
  515. 'crop_pct': 0.9, 'interpolation': 'bicubic',
  516. 'mean': (0., 0., 0.), 'std': (1., 1., 1.),
  517. 'first_conv': 'stem.conv', 'classifier': 'head.fc',
  518. 'fixed_input_size': False,
  519. 'license': 'cvnets-license',
  520. **kwargs
  521. }
  522. default_cfgs = generate_default_cfgs({
  523. 'mobilevit_xxs.cvnets_in1k': _cfg(hf_hub_id='timm/'),
  524. 'mobilevit_xs.cvnets_in1k': _cfg(hf_hub_id='timm/'),
  525. 'mobilevit_s.cvnets_in1k': _cfg(hf_hub_id='timm/'),
  526. 'mobilevitv2_050.cvnets_in1k': _cfg(
  527. hf_hub_id='timm/',
  528. crop_pct=0.888),
  529. 'mobilevitv2_075.cvnets_in1k': _cfg(
  530. hf_hub_id='timm/',
  531. crop_pct=0.888),
  532. 'mobilevitv2_100.cvnets_in1k': _cfg(
  533. hf_hub_id='timm/',
  534. crop_pct=0.888),
  535. 'mobilevitv2_125.cvnets_in1k': _cfg(
  536. hf_hub_id='timm/',
  537. crop_pct=0.888),
  538. 'mobilevitv2_150.cvnets_in1k': _cfg(
  539. hf_hub_id='timm/',
  540. crop_pct=0.888),
  541. 'mobilevitv2_175.cvnets_in1k': _cfg(
  542. hf_hub_id='timm/',
  543. crop_pct=0.888),
  544. 'mobilevitv2_200.cvnets_in1k': _cfg(
  545. hf_hub_id='timm/',
  546. crop_pct=0.888),
  547. 'mobilevitv2_150.cvnets_in22k_ft_in1k': _cfg(
  548. hf_hub_id='timm/',
  549. crop_pct=0.888),
  550. 'mobilevitv2_175.cvnets_in22k_ft_in1k': _cfg(
  551. hf_hub_id='timm/',
  552. crop_pct=0.888),
  553. 'mobilevitv2_200.cvnets_in22k_ft_in1k': _cfg(
  554. hf_hub_id='timm/',
  555. crop_pct=0.888),
  556. 'mobilevitv2_150.cvnets_in22k_ft_in1k_384': _cfg(
  557. hf_hub_id='timm/',
  558. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  559. 'mobilevitv2_175.cvnets_in22k_ft_in1k_384': _cfg(
  560. hf_hub_id='timm/',
  561. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  562. 'mobilevitv2_200.cvnets_in22k_ft_in1k_384': _cfg(
  563. hf_hub_id='timm/',
  564. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  565. })
  566. @register_model
  567. def mobilevit_xxs(pretrained=False, **kwargs) -> ByobNet:
  568. return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs)
  569. @register_model
  570. def mobilevit_xs(pretrained=False, **kwargs) -> ByobNet:
  571. return _create_mobilevit('mobilevit_xs', pretrained=pretrained, **kwargs)
  572. @register_model
  573. def mobilevit_s(pretrained=False, **kwargs) -> ByobNet:
  574. return _create_mobilevit('mobilevit_s', pretrained=pretrained, **kwargs)
  575. @register_model
  576. def mobilevitv2_050(pretrained=False, **kwargs) -> ByobNet:
  577. return _create_mobilevit('mobilevitv2_050', pretrained=pretrained, **kwargs)
  578. @register_model
  579. def mobilevitv2_075(pretrained=False, **kwargs) -> ByobNet:
  580. return _create_mobilevit('mobilevitv2_075', pretrained=pretrained, **kwargs)
  581. @register_model
  582. def mobilevitv2_100(pretrained=False, **kwargs) -> ByobNet:
  583. return _create_mobilevit('mobilevitv2_100', pretrained=pretrained, **kwargs)
  584. @register_model
  585. def mobilevitv2_125(pretrained=False, **kwargs) -> ByobNet:
  586. return _create_mobilevit('mobilevitv2_125', pretrained=pretrained, **kwargs)
  587. @register_model
  588. def mobilevitv2_150(pretrained=False, **kwargs) -> ByobNet:
  589. return _create_mobilevit('mobilevitv2_150', pretrained=pretrained, **kwargs)
  590. @register_model
  591. def mobilevitv2_175(pretrained=False, **kwargs) -> ByobNet:
  592. return _create_mobilevit('mobilevitv2_175', pretrained=pretrained, **kwargs)
  593. @register_model
  594. def mobilevitv2_200(pretrained=False, **kwargs) -> ByobNet:
  595. return _create_mobilevit('mobilevitv2_200', pretrained=pretrained, **kwargs)
  596. register_model_deprecations(__name__, {
  597. 'mobilevitv2_150_in22ft1k': 'mobilevitv2_150.cvnets_in22k_ft_in1k',
  598. 'mobilevitv2_175_in22ft1k': 'mobilevitv2_175.cvnets_in22k_ft_in1k',
  599. 'mobilevitv2_200_in22ft1k': 'mobilevitv2_200.cvnets_in22k_ft_in1k',
  600. 'mobilevitv2_150_384_in22ft1k': 'mobilevitv2_150.cvnets_in22k_ft_in1k_384',
  601. 'mobilevitv2_175_384_in22ft1k': 'mobilevitv2_175.cvnets_in22k_ft_in1k_384',
  602. 'mobilevitv2_200_384_in22ft1k': 'mobilevitv2_200.cvnets_in22k_ft_in1k_384',
  603. })