efficientvit_msra.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826
  1. """ EfficientViT (by MSRA)
  2. Paper: `EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention`
  3. - https://arxiv.org/abs/2305.07027
  4. Adapted from official impl at https://github.com/microsoft/Cream/tree/main/EfficientViT
  5. """
  6. __all__ = ['EfficientVitMsra']
  7. import itertools
  8. from collections import OrderedDict
  9. from functools import partial
  10. from typing import Dict, List, Optional, Tuple, Type, Union
  11. import torch
  12. import torch.nn as nn
  13. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  14. from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
  15. from ._builder import build_model_with_cfg
  16. from ._features import feature_take_indices
  17. from ._manipulate import checkpoint, checkpoint_seq
  18. from ._registry import register_model, generate_default_cfgs
  19. class ConvNorm(torch.nn.Sequential):
  20. def __init__(
  21. self,
  22. in_chs: int,
  23. out_chs: int,
  24. ks: int = 1,
  25. stride: int = 1,
  26. pad: int = 0,
  27. dilation: int = 1,
  28. groups: int = 1,
  29. bn_weight_init: float = 1,
  30. device=None,
  31. dtype=None,
  32. ):
  33. dd = {'device': device, 'dtype': dtype}
  34. super().__init__()
  35. self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False, **dd)
  36. self.bn = nn.BatchNorm2d(out_chs, **dd)
  37. torch.nn.init.constant_(self.bn.weight, bn_weight_init)
  38. @torch.no_grad()
  39. def fuse(self):
  40. c, bn = self.conv, self.bn
  41. w = bn.weight / (bn.running_var + bn.eps)**0.5
  42. w = c.weight * w[:, None, None, None]
  43. b = bn.bias - bn.running_mean * bn.weight / \
  44. (bn.running_var + bn.eps)**0.5
  45. m = torch.nn.Conv2d(
  46. w.size(1) * self.conv.groups, w.size(0), w.shape[2:],
  47. stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups)
  48. m.weight.data.copy_(w)
  49. m.bias.data.copy_(b)
  50. return m
  51. class NormLinear(torch.nn.Sequential):
  52. def __init__(
  53. self,
  54. in_features: int,
  55. out_features: int,
  56. bias: bool = True,
  57. std: float = 0.02,
  58. drop: float = 0.,
  59. device=None,
  60. dtype=None,
  61. ):
  62. dd = {'device': device, 'dtype': dtype}
  63. super().__init__()
  64. self.bn = nn.BatchNorm1d(in_features, **dd)
  65. self.drop = nn.Dropout(drop)
  66. self.linear = nn.Linear(in_features, out_features, bias=bias, **dd)
  67. trunc_normal_(self.linear.weight, std=std)
  68. if self.linear.bias is not None:
  69. nn.init.constant_(self.linear.bias, 0)
  70. @torch.no_grad()
  71. def fuse(self):
  72. bn, linear = self.bn, self.linear
  73. w = bn.weight / (bn.running_var + bn.eps)**0.5
  74. b = bn.bias - self.bn.running_mean * \
  75. self.bn.weight / (bn.running_var + bn.eps)**0.5
  76. w = linear.weight * w[None, :]
  77. if linear.bias is None:
  78. b = b @ self.linear.weight.T
  79. else:
  80. b = (linear.weight @ b[:, None]).view(-1) + self.linear.bias
  81. m = torch.nn.Linear(w.size(1), w.size(0))
  82. m.weight.data.copy_(w)
  83. m.bias.data.copy_(b)
  84. return m
  85. class PatchMerging(torch.nn.Module):
  86. def __init__(
  87. self,
  88. dim: int,
  89. out_dim: int,
  90. device=None,
  91. dtype=None,
  92. ):
  93. dd = {'device': device, 'dtype': dtype}
  94. super().__init__()
  95. hid_dim = int(dim * 4)
  96. self.conv1 = ConvNorm(dim, hid_dim, 1, 1, 0, **dd)
  97. self.act = torch.nn.ReLU()
  98. self.conv2 = ConvNorm(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim, **dd)
  99. self.se = SqueezeExcite(hid_dim, .25, **dd)
  100. self.conv3 = ConvNorm(hid_dim, out_dim, 1, 1, 0, **dd)
  101. def forward(self, x):
  102. x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
  103. return x
  104. class ResidualDrop(torch.nn.Module):
  105. def __init__(self, m: nn.Module, drop: float = 0.):
  106. super().__init__()
  107. self.m = m
  108. self.drop = drop
  109. def forward(self, x):
  110. if self.training and self.drop > 0:
  111. return x + self.m(x) * torch.rand(
  112. x.size(0), 1, 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
  113. else:
  114. return x + self.m(x)
  115. class ConvMlp(torch.nn.Module):
  116. def __init__(
  117. self,
  118. ed: int,
  119. h: int,
  120. device=None,
  121. dtype=None,
  122. ):
  123. dd = {'device': device, 'dtype': dtype}
  124. super().__init__()
  125. self.pw1 = ConvNorm(ed, h, **dd)
  126. self.act = torch.nn.ReLU()
  127. self.pw2 = ConvNorm(h, ed, bn_weight_init=0, **dd)
  128. def forward(self, x):
  129. x = self.pw2(self.act(self.pw1(x)))
  130. return x
  131. class CascadedGroupAttention(torch.nn.Module):
  132. attention_bias_cache: Dict[str, torch.Tensor]
  133. r""" Cascaded Group Attention.
  134. Args:
  135. dim (int): Number of input channels.
  136. key_dim (int): The dimension for query and key.
  137. num_heads (int): Number of attention heads.
  138. attn_ratio (int): Multiplier for the query dim for value dimension.
  139. resolution (int): Input resolution, correspond to the window size.
  140. kernels (List[int]): The kernel size of the dw conv on query.
  141. """
  142. def __init__(
  143. self,
  144. dim: int,
  145. key_dim: int,
  146. num_heads: int = 8,
  147. attn_ratio: int = 4,
  148. resolution: int = 14,
  149. kernels: Tuple[int, ...] = (5, 5, 5, 5),
  150. device=None,
  151. dtype=None,
  152. ):
  153. dd = {'device': device, 'dtype': dtype}
  154. super().__init__()
  155. self.num_heads = num_heads
  156. self.scale = key_dim ** -0.5
  157. self.key_dim = key_dim
  158. self.val_dim = int(attn_ratio * key_dim)
  159. self.attn_ratio = attn_ratio
  160. qkvs = []
  161. dws = []
  162. for i in range(num_heads):
  163. qkvs.append(ConvNorm(dim // num_heads, self.key_dim * 2 + self.val_dim, **dd))
  164. dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim, **dd))
  165. self.qkvs = torch.nn.ModuleList(qkvs)
  166. self.dws = torch.nn.ModuleList(dws)
  167. self.proj = torch.nn.Sequential(
  168. torch.nn.ReLU(),
  169. ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0, **dd)
  170. )
  171. self.resolution = resolution
  172. N = resolution * resolution
  173. # Number of unique offsets: abs differences range from 0 to resolution-1 for each dim
  174. num_offsets = resolution * resolution
  175. self.attention_biases = torch.nn.Parameter(torch.empty(num_heads, num_offsets, **dd))
  176. self.register_buffer(
  177. 'attention_bias_idxs',
  178. torch.empty((N, N), device=device, dtype=torch.long),
  179. persistent=False,
  180. )
  181. self.attention_bias_cache = {}
  182. # TODO: skip init when on meta device when safe to do so
  183. self.reset_parameters()
  184. def reset_parameters(self) -> None:
  185. """Initialize parameters and buffers."""
  186. torch.nn.init.zeros_(self.attention_biases)
  187. self._init_buffers()
  188. def _init_buffers(self) -> None:
  189. """Compute and fill non-persistent buffer values."""
  190. points = list(itertools.product(range(self.resolution), range(self.resolution)))
  191. attention_offsets = {}
  192. idxs = []
  193. for p1 in points:
  194. for p2 in points:
  195. offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
  196. if offset not in attention_offsets:
  197. attention_offsets[offset] = len(attention_offsets)
  198. idxs.append(attention_offsets[offset])
  199. self.attention_bias_idxs.copy_(torch.tensor(idxs, dtype=torch.long).view(len(points), len(points)))
  200. def init_non_persistent_buffers(self) -> None:
  201. """Initialize non-persistent buffers."""
  202. self._init_buffers()
  203. @torch.no_grad()
  204. def train(self, mode=True):
  205. super().train(mode)
  206. if mode and self.attention_bias_cache:
  207. self.attention_bias_cache = {} # clear ab cache
  208. def get_attention_biases(self, device: torch.device) -> torch.Tensor:
  209. if torch.jit.is_tracing() or self.training:
  210. return self.attention_biases[:, self.attention_bias_idxs]
  211. else:
  212. device_key = str(device)
  213. if device_key not in self.attention_bias_cache:
  214. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  215. return self.attention_bias_cache[device_key]
  216. def forward(self, x):
  217. B, C, H, W = x.shape
  218. feats_in = x.chunk(len(self.qkvs), dim=1)
  219. feats_out = []
  220. feat = feats_in[0]
  221. attn_bias = self.get_attention_biases(x.device)
  222. for head_idx, (qkv, dws) in enumerate(zip(self.qkvs, self.dws)):
  223. if head_idx > 0:
  224. feat = feat + feats_in[head_idx]
  225. feat = qkv(feat)
  226. q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.val_dim], dim=1)
  227. q = dws(q)
  228. q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
  229. q = q * self.scale
  230. attn = q.transpose(-2, -1) @ k
  231. attn = attn + attn_bias[head_idx]
  232. attn = attn.softmax(dim=-1)
  233. feat = v @ attn.transpose(-2, -1)
  234. feat = feat.view(B, self.val_dim, H, W)
  235. feats_out.append(feat)
  236. x = self.proj(torch.cat(feats_out, 1))
  237. return x
  238. class LocalWindowAttention(torch.nn.Module):
  239. r""" Local Window Attention.
  240. Args:
  241. dim (int): Number of input channels.
  242. key_dim (int): The dimension for query and key.
  243. num_heads (int): Number of attention heads.
  244. attn_ratio (int): Multiplier for the query dim for value dimension.
  245. resolution (int): Input resolution.
  246. window_resolution (int): Local window resolution.
  247. kernels (List[int]): The kernel size of the dw conv on query.
  248. """
  249. def __init__(
  250. self,
  251. dim: int,
  252. key_dim: int,
  253. num_heads: int = 8,
  254. attn_ratio: int = 4,
  255. resolution: int = 14,
  256. window_resolution: int = 7,
  257. kernels: Tuple[int, ...] = (5, 5, 5, 5),
  258. device=None,
  259. dtype=None,
  260. ):
  261. dd = {'device': device, 'dtype': dtype}
  262. super().__init__()
  263. self.dim = dim
  264. self.num_heads = num_heads
  265. self.resolution = resolution
  266. assert window_resolution > 0, 'window_size must be greater than 0'
  267. self.window_resolution = window_resolution
  268. window_resolution = min(window_resolution, resolution)
  269. self.attn = CascadedGroupAttention(
  270. dim, key_dim, num_heads,
  271. attn_ratio=attn_ratio,
  272. resolution=window_resolution,
  273. kernels=kernels,
  274. **dd,
  275. )
  276. def forward(self, x):
  277. H = W = self.resolution
  278. B, C, H_, W_ = x.shape
  279. # Only check this for classification models
  280. _assert(H == H_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}')
  281. _assert(W == W_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}')
  282. if H <= self.window_resolution and W <= self.window_resolution:
  283. x = self.attn(x)
  284. else:
  285. x = x.permute(0, 2, 3, 1)
  286. pad_b = (self.window_resolution - H % self.window_resolution) % self.window_resolution
  287. pad_r = (self.window_resolution - W % self.window_resolution) % self.window_resolution
  288. x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b))
  289. pH, pW = H + pad_b, W + pad_r
  290. nH = pH // self.window_resolution
  291. nW = pW // self.window_resolution
  292. # window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw
  293. x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3)
  294. x = x.reshape(B * nH * nW, self.window_resolution, self.window_resolution, C).permute(0, 3, 1, 2)
  295. x = self.attn(x)
  296. # window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC
  297. x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution, C)
  298. x = x.transpose(2, 3).reshape(B, pH, pW, C)
  299. x = x[:, :H, :W].contiguous()
  300. x = x.permute(0, 3, 1, 2)
  301. return x
  302. class EfficientVitBlock(torch.nn.Module):
  303. """ A basic EfficientVit building block.
  304. Args:
  305. dim (int): Number of input channels.
  306. key_dim (int): Dimension for query and key in the token mixer.
  307. num_heads (int): Number of attention heads.
  308. attn_ratio (int): Multiplier for the query dim for value dimension.
  309. resolution (int): Input resolution.
  310. window_resolution (int): Local window resolution.
  311. kernels (List[int]): The kernel size of the dw conv on query.
  312. """
  313. def __init__(
  314. self,
  315. dim: int,
  316. key_dim: int,
  317. num_heads: int = 8,
  318. attn_ratio: int = 4,
  319. resolution: int = 14,
  320. window_resolution: int = 7,
  321. kernels: List[int] = [5, 5, 5, 5],
  322. device=None,
  323. dtype=None,
  324. ):
  325. dd = {'device': device, 'dtype': dtype}
  326. super().__init__()
  327. self.dw0 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0., **dd))
  328. self.ffn0 = ResidualDrop(ConvMlp(dim, int(dim * 2), **dd))
  329. self.mixer = ResidualDrop(
  330. LocalWindowAttention(
  331. dim, key_dim, num_heads,
  332. attn_ratio=attn_ratio,
  333. resolution=resolution,
  334. window_resolution=window_resolution,
  335. kernels=kernels,
  336. **dd,
  337. ),
  338. )
  339. self.dw1 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0., **dd))
  340. self.ffn1 = ResidualDrop(ConvMlp(dim, int(dim * 2), **dd))
  341. def forward(self, x):
  342. return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
  343. class EfficientVitStage(torch.nn.Module):
  344. def __init__(
  345. self,
  346. in_dim: int,
  347. out_dim: int,
  348. key_dim: int,
  349. downsample: Tuple[str, int] = ('', 1),
  350. num_heads: int = 8,
  351. attn_ratio: int = 4,
  352. resolution: int = 14,
  353. window_resolution: int = 7,
  354. kernels: List[int] = [5, 5, 5, 5],
  355. depth: int = 1,
  356. device=None,
  357. dtype=None,
  358. ):
  359. dd = {'device': device, 'dtype': dtype}
  360. super().__init__()
  361. if downsample[0] == 'subsample':
  362. self.resolution = (resolution - 1) // downsample[1] + 1
  363. down_blocks = []
  364. down_blocks.append((
  365. 'res1',
  366. torch.nn.Sequential(
  367. ResidualDrop(ConvNorm(in_dim, in_dim, 3, 1, 1, groups=in_dim, **dd)),
  368. ResidualDrop(ConvMlp(in_dim, int(in_dim * 2), **dd)),
  369. )
  370. ))
  371. down_blocks.append(('patchmerge', PatchMerging(in_dim, out_dim, **dd)))
  372. down_blocks.append((
  373. 'res2',
  374. torch.nn.Sequential(
  375. ResidualDrop(ConvNorm(out_dim, out_dim, 3, 1, 1, groups=out_dim, **dd)),
  376. ResidualDrop(ConvMlp(out_dim, int(out_dim * 2), **dd)),
  377. )
  378. ))
  379. self.downsample = nn.Sequential(OrderedDict(down_blocks))
  380. else:
  381. assert in_dim == out_dim
  382. self.downsample = nn.Identity()
  383. self.resolution = resolution
  384. blocks = []
  385. for d in range(depth):
  386. blocks.append(EfficientVitBlock(
  387. out_dim,
  388. key_dim,
  389. num_heads,
  390. attn_ratio,
  391. self.resolution,
  392. window_resolution,
  393. kernels,
  394. **dd,
  395. ))
  396. self.blocks = nn.Sequential(*blocks)
  397. def forward(self, x):
  398. x = self.downsample(x)
  399. x = self.blocks(x)
  400. return x
  401. class PatchEmbedding(torch.nn.Sequential):
  402. def __init__(
  403. self,
  404. in_chans: int,
  405. dim: int,
  406. device=None,
  407. dtype=None,
  408. ):
  409. super().__init__()
  410. dd = {'device': device, 'dtype': dtype}
  411. self.add_module('conv1', ConvNorm(in_chans, dim // 8, 3, 2, 1, **dd))
  412. self.add_module('relu1', torch.nn.ReLU())
  413. self.add_module('conv2', ConvNorm(dim // 8, dim // 4, 3, 2, 1, **dd))
  414. self.add_module('relu2', torch.nn.ReLU())
  415. self.add_module('conv3', ConvNorm(dim // 4, dim // 2, 3, 2, 1, **dd))
  416. self.add_module('relu3', torch.nn.ReLU())
  417. self.add_module('conv4', ConvNorm(dim // 2, dim, 3, 2, 1, **dd))
  418. self.patch_size = 16
  419. class EfficientVitMsra(nn.Module):
  420. def __init__(
  421. self,
  422. img_size: int = 224,
  423. in_chans: int = 3,
  424. num_classes: int = 1000,
  425. embed_dim: Tuple[int, ...] = (64, 128, 192),
  426. key_dim: Tuple[int, ...] = (16, 16, 16),
  427. depth: Tuple[int, ...] = (1, 2, 3),
  428. num_heads: Tuple[int, ...] = (4, 4, 4),
  429. window_size: Tuple[int, ...] = (7, 7, 7),
  430. kernels: Tuple[int, ...] = (5, 5, 5, 5),
  431. down_ops: Tuple[Tuple[str, int], ...] = (('', 1), ('subsample', 2), ('subsample', 2)),
  432. global_pool: str = 'avg',
  433. drop_rate: float = 0.,
  434. device=None,
  435. dtype=None,
  436. ):
  437. super().__init__()
  438. dd = {'device': device, 'dtype': dtype}
  439. self.grad_checkpointing = False
  440. self.num_classes = num_classes
  441. self.in_chans = in_chans
  442. self.drop_rate = drop_rate
  443. # Patch embedding
  444. self.patch_embed = PatchEmbedding(in_chans, embed_dim[0], **dd)
  445. stride = self.patch_embed.patch_size
  446. resolution = img_size // self.patch_embed.patch_size
  447. attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
  448. # Build EfficientVit blocks
  449. self.feature_info = []
  450. stages = []
  451. pre_ed = embed_dim[0]
  452. for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate(
  453. zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
  454. stage = EfficientVitStage(
  455. in_dim=pre_ed,
  456. out_dim=ed,
  457. key_dim=kd,
  458. downsample=do,
  459. num_heads=nh,
  460. attn_ratio=ar,
  461. resolution=resolution,
  462. window_resolution=wd,
  463. kernels=kernels,
  464. depth=dpth,
  465. **dd,
  466. )
  467. pre_ed = ed
  468. if do[0] == 'subsample' and i != 0:
  469. stride *= do[1]
  470. resolution = stage.resolution
  471. stages.append(stage)
  472. self.feature_info += [dict(num_chs=ed, reduction=stride, module=f'stages.{i}')]
  473. self.stages = nn.Sequential(*stages)
  474. if global_pool == 'avg':
  475. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
  476. else:
  477. assert num_classes == 0
  478. self.global_pool = nn.Identity()
  479. self.num_features = self.head_hidden_size = embed_dim[-1]
  480. self.head = NormLinear(
  481. self.num_features, num_classes, drop=self.drop_rate, **dd) if num_classes > 0 else torch.nn.Identity()
  482. # TODO: skip init when on meta device when safe to do so
  483. self.init_weights(needs_reset=False)
  484. def init_weights(self, needs_reset: bool = True):
  485. self.apply(partial(self._init_weights, needs_reset=needs_reset))
  486. def _init_weights(self, m: nn.Module, needs_reset: bool = True) -> None:
  487. if needs_reset and hasattr(m, 'reset_parameters'):
  488. m.reset_parameters()
  489. @torch.jit.ignore
  490. def no_weight_decay(self):
  491. return {x for x in self.state_dict().keys() if 'attention_biases' in x}
  492. @torch.jit.ignore
  493. def group_matcher(self, coarse=False):
  494. matcher = dict(
  495. stem=r'^patch_embed',
  496. blocks=r'^stages\.(\d+)' if coarse else [
  497. (r'^stages\.(\d+).downsample', (0,)),
  498. (r'^stages\.(\d+)\.\w+\.(\d+)', None),
  499. ]
  500. )
  501. return matcher
  502. @torch.jit.ignore
  503. def set_grad_checkpointing(self, enable=True):
  504. self.grad_checkpointing = enable
  505. @torch.jit.ignore
  506. def get_classifier(self) -> nn.Module:
  507. return self.head.linear
  508. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  509. self.num_classes = num_classes
  510. if global_pool is not None:
  511. if global_pool == 'avg':
  512. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
  513. else:
  514. assert num_classes == 0
  515. self.global_pool = nn.Identity()
  516. self.head = NormLinear(
  517. self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity()
  518. def forward_intermediates(
  519. self,
  520. x: torch.Tensor,
  521. indices: Optional[Union[int, List[int]]] = None,
  522. norm: bool = False,
  523. stop_early: bool = False,
  524. output_fmt: str = 'NCHW',
  525. intermediates_only: bool = False,
  526. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  527. """ Forward features that returns intermediates.
  528. Args:
  529. x: Input image tensor
  530. indices: Take last n blocks if int, all if None, select matching indices if sequence
  531. norm: Apply norm layer to compatible intermediates
  532. stop_early: Stop iterating over blocks when last desired intermediate hit
  533. output_fmt: Shape of intermediate feature outputs
  534. intermediates_only: Only return intermediate features
  535. Returns:
  536. """
  537. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  538. intermediates = []
  539. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  540. # forward pass
  541. x = self.patch_embed(x)
  542. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  543. stages = self.stages
  544. else:
  545. stages = self.stages[:max_index + 1]
  546. for feat_idx, stage in enumerate(stages):
  547. if self.grad_checkpointing and not torch.jit.is_scripting():
  548. x = checkpoint(stage, x)
  549. else:
  550. x = stage(x)
  551. if feat_idx in take_indices:
  552. intermediates.append(x)
  553. if intermediates_only:
  554. return intermediates
  555. return x, intermediates
  556. def prune_intermediate_layers(
  557. self,
  558. indices: Union[int, List[int]] = 1,
  559. prune_norm: bool = False,
  560. prune_head: bool = True,
  561. ):
  562. """ Prune layers not required for specified intermediates.
  563. """
  564. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  565. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  566. if prune_head:
  567. self.reset_classifier(0, '')
  568. return take_indices
  569. def forward_features(self, x):
  570. x = self.patch_embed(x)
  571. if self.grad_checkpointing and not torch.jit.is_scripting():
  572. x = checkpoint_seq(self.stages, x)
  573. else:
  574. x = self.stages(x)
  575. return x
  576. def forward_head(self, x, pre_logits: bool = False):
  577. x = self.global_pool(x)
  578. return x if pre_logits else self.head(x)
  579. def forward(self, x):
  580. x = self.forward_features(x)
  581. x = self.forward_head(x)
  582. return x
  583. # def checkpoint_filter_fn(state_dict, model):
  584. # if 'model' in state_dict.keys():
  585. # state_dict = state_dict['model']
  586. # tmp_dict = {}
  587. # out_dict = {}
  588. # target_keys = model.state_dict().keys()
  589. # target_keys = [k for k in target_keys if k.startswith('stages.')]
  590. #
  591. # for k, v in state_dict.items():
  592. # if 'attention_bias_idxs' in k:
  593. # continue
  594. # k = k.split('.')
  595. # if k[-2] == 'c':
  596. # k[-2] = 'conv'
  597. # if k[-2] == 'l':
  598. # k[-2] = 'linear'
  599. # k = '.'.join(k)
  600. # tmp_dict[k] = v
  601. #
  602. # for k, v in tmp_dict.items():
  603. # if k.startswith('patch_embed'):
  604. # k = k.split('.')
  605. # k[1] = 'conv' + str(int(k[1]) // 2 + 1)
  606. # k = '.'.join(k)
  607. # elif k.startswith('blocks'):
  608. # kw = '.'.join(k.split('.')[2:])
  609. # find_kw = [a for a in list(sorted(tmp_dict.keys())) if kw in a]
  610. # idx = find_kw.index(k)
  611. # k = [a for a in target_keys if kw in a][idx]
  612. # out_dict[k] = v
  613. #
  614. # return out_dict
  615. def _cfg(url='', **kwargs):
  616. return {
  617. 'url': url,
  618. 'num_classes': 1000,
  619. 'mean': IMAGENET_DEFAULT_MEAN,
  620. 'std': IMAGENET_DEFAULT_STD,
  621. 'first_conv': 'patch_embed.conv1.conv',
  622. 'classifier': 'head.linear',
  623. 'fixed_input_size': True,
  624. 'pool_size': (4, 4),
  625. 'license': 'mit',
  626. **kwargs,
  627. }
  628. default_cfgs = generate_default_cfgs({
  629. 'efficientvit_m0.r224_in1k': _cfg(
  630. hf_hub_id='timm/',
  631. #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth'
  632. ),
  633. 'efficientvit_m1.r224_in1k': _cfg(
  634. hf_hub_id='timm/',
  635. #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth'
  636. ),
  637. 'efficientvit_m2.r224_in1k': _cfg(
  638. hf_hub_id='timm/',
  639. #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth'
  640. ),
  641. 'efficientvit_m3.r224_in1k': _cfg(
  642. hf_hub_id='timm/',
  643. #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth'
  644. ),
  645. 'efficientvit_m4.r224_in1k': _cfg(
  646. hf_hub_id='timm/',
  647. #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth'
  648. ),
  649. 'efficientvit_m5.r224_in1k': _cfg(
  650. hf_hub_id='timm/',
  651. #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth'
  652. ),
  653. })
  654. def _create_efficientvit_msra(variant, pretrained=False, **kwargs):
  655. out_indices = kwargs.pop('out_indices', (0, 1, 2))
  656. model = build_model_with_cfg(
  657. EfficientVitMsra,
  658. variant,
  659. pretrained,
  660. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  661. **kwargs
  662. )
  663. return model
  664. @register_model
  665. def efficientvit_m0(pretrained=False, **kwargs):
  666. model_args = dict(
  667. img_size=224,
  668. embed_dim=[64, 128, 192],
  669. depth=[1, 2, 3],
  670. num_heads=[4, 4, 4],
  671. window_size=[7, 7, 7],
  672. kernels=[5, 5, 5, 5]
  673. )
  674. return _create_efficientvit_msra('efficientvit_m0', pretrained=pretrained, **dict(model_args, **kwargs))
  675. @register_model
  676. def efficientvit_m1(pretrained=False, **kwargs):
  677. model_args = dict(
  678. img_size=224,
  679. embed_dim=[128, 144, 192],
  680. depth=[1, 2, 3],
  681. num_heads=[2, 3, 3],
  682. window_size=[7, 7, 7],
  683. kernels=[7, 5, 3, 3]
  684. )
  685. return _create_efficientvit_msra('efficientvit_m1', pretrained=pretrained, **dict(model_args, **kwargs))
  686. @register_model
  687. def efficientvit_m2(pretrained=False, **kwargs):
  688. model_args = dict(
  689. img_size=224,
  690. embed_dim=[128, 192, 224],
  691. depth=[1, 2, 3],
  692. num_heads=[4, 3, 2],
  693. window_size=[7, 7, 7],
  694. kernels=[7, 5, 3, 3]
  695. )
  696. return _create_efficientvit_msra('efficientvit_m2', pretrained=pretrained, **dict(model_args, **kwargs))
  697. @register_model
  698. def efficientvit_m3(pretrained=False, **kwargs):
  699. model_args = dict(
  700. img_size=224,
  701. embed_dim=[128, 240, 320],
  702. depth=[1, 2, 3],
  703. num_heads=[4, 3, 4],
  704. window_size=[7, 7, 7],
  705. kernels=[5, 5, 5, 5]
  706. )
  707. return _create_efficientvit_msra('efficientvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))
  708. @register_model
  709. def efficientvit_m4(pretrained=False, **kwargs):
  710. model_args = dict(
  711. img_size=224,
  712. embed_dim=[128, 256, 384],
  713. depth=[1, 2, 3],
  714. num_heads=[4, 4, 4],
  715. window_size=[7, 7, 7],
  716. kernels=[7, 5, 3, 3]
  717. )
  718. return _create_efficientvit_msra('efficientvit_m4', pretrained=pretrained, **dict(model_args, **kwargs))
  719. @register_model
  720. def efficientvit_m5(pretrained=False, **kwargs):
  721. model_args = dict(
  722. img_size=224,
  723. embed_dim=[192, 288, 384],
  724. depth=[1, 3, 4],
  725. num_heads=[3, 3, 4],
  726. window_size=[7, 7, 7],
  727. kernels=[7, 5, 3, 3]
  728. )
  729. return _create_efficientvit_msra('efficientvit_m5', pretrained=pretrained, **dict(model_args, **kwargs))