tiny_vit.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879
  1. """ TinyViT
  2. Paper: `TinyViT: Fast Pretraining Distillation for Small Vision Transformers`
  3. - https://arxiv.org/abs/2207.10666
  4. Adapted from official impl at https://github.com/microsoft/Cream/tree/main/TinyViT
  5. """
  6. __all__ = ['TinyVit']
  7. import itertools
  8. from functools import partial
  9. from typing import Dict, List, Optional, Tuple, Union, Type, Any
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  14. from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\
  15. trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn, calculate_drop_path_rates
  16. from ._builder import build_model_with_cfg
  17. from ._features import feature_take_indices
  18. from ._features_fx import register_notrace_module
  19. from ._manipulate import checkpoint, checkpoint_seq
  20. from ._registry import register_model, generate_default_cfgs
  21. class ConvNorm(torch.nn.Sequential):
  22. def __init__(
  23. self,
  24. in_chs: int,
  25. out_chs: int,
  26. ks: int = 1,
  27. stride: int = 1,
  28. pad: int = 0,
  29. dilation: int = 1,
  30. groups: int = 1,
  31. bn_weight_init: float = 1,
  32. device=None,
  33. dtype=None,
  34. ):
  35. dd = {'device': device, 'dtype': dtype}
  36. super().__init__()
  37. self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False, **dd)
  38. self.bn = nn.BatchNorm2d(out_chs, **dd)
  39. torch.nn.init.constant_(self.bn.weight, bn_weight_init)
  40. torch.nn.init.constant_(self.bn.bias, 0)
  41. @torch.no_grad()
  42. def fuse(self):
  43. c, bn = self.conv, self.bn
  44. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  45. w = c.weight * w[:, None, None, None]
  46. b = bn.bias - bn.running_mean * bn.weight / \
  47. (bn.running_var + bn.eps) ** 0.5
  48. m = torch.nn.Conv2d(
  49. w.size(1) * self.conv.groups, w.size(0), w.shape[2:],
  50. stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups)
  51. m.weight.data.copy_(w)
  52. m.bias.data.copy_(b)
  53. return m
  54. class PatchEmbed(nn.Module):
  55. def __init__(
  56. self,
  57. in_chs: int,
  58. out_chs: int,
  59. act_layer: Type[nn.Module],
  60. device=None,
  61. dtype=None,
  62. ):
  63. dd = {'device': device, 'dtype': dtype}
  64. super().__init__()
  65. self.stride = 4
  66. self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1, **dd)
  67. self.act = act_layer()
  68. self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1, **dd)
  69. def forward(self, x):
  70. x = self.conv1(x)
  71. x = self.act(x)
  72. x = self.conv2(x)
  73. return x
  74. class MBConv(nn.Module):
  75. def __init__(
  76. self,
  77. in_chs: int,
  78. out_chs: int,
  79. expand_ratio: float,
  80. act_layer: Type[nn.Module],
  81. drop_path: float,
  82. device=None,
  83. dtype=None,
  84. ):
  85. dd = {'device': device, 'dtype': dtype}
  86. super().__init__()
  87. mid_chs = int(in_chs * expand_ratio)
  88. self.conv1 = ConvNorm(in_chs, mid_chs, ks=1, **dd)
  89. self.act1 = act_layer()
  90. self.conv2 = ConvNorm(mid_chs, mid_chs, ks=3, stride=1, pad=1, groups=mid_chs, **dd)
  91. self.act2 = act_layer()
  92. self.conv3 = ConvNorm(mid_chs, out_chs, ks=1, bn_weight_init=0.0, **dd)
  93. self.act3 = act_layer()
  94. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  95. def forward(self, x):
  96. shortcut = x
  97. x = self.conv1(x)
  98. x = self.act1(x)
  99. x = self.conv2(x)
  100. x = self.act2(x)
  101. x = self.conv3(x)
  102. x = self.drop_path(x)
  103. x += shortcut
  104. x = self.act3(x)
  105. return x
  106. class PatchMerging(nn.Module):
  107. def __init__(
  108. self,
  109. dim: int,
  110. out_dim: int,
  111. act_layer: Type[nn.Module],
  112. device=None,
  113. dtype=None,
  114. ):
  115. dd = {'device': device, 'dtype': dtype}
  116. super().__init__()
  117. self.conv1 = ConvNorm(dim, out_dim, 1, 1, 0, **dd)
  118. self.act1 = act_layer()
  119. self.conv2 = ConvNorm(out_dim, out_dim, 3, 2, 1, groups=out_dim, **dd)
  120. self.act2 = act_layer()
  121. self.conv3 = ConvNorm(out_dim, out_dim, 1, 1, 0, **dd)
  122. def forward(self, x):
  123. x = self.conv1(x)
  124. x = self.act1(x)
  125. x = self.conv2(x)
  126. x = self.act2(x)
  127. x = self.conv3(x)
  128. return x
  129. class ConvLayer(nn.Module):
  130. def __init__(
  131. self,
  132. dim: int,
  133. depth: int,
  134. act_layer: Type[nn.Module],
  135. drop_path: Union[float, List[float]] = 0.,
  136. conv_expand_ratio: float = 4.,
  137. device=None,
  138. dtype=None,
  139. ):
  140. dd = {'device': device, 'dtype': dtype}
  141. super().__init__()
  142. self.dim = dim
  143. self.depth = depth
  144. self.blocks = nn.Sequential(*[
  145. MBConv(
  146. dim,
  147. dim,
  148. conv_expand_ratio,
  149. act_layer,
  150. drop_path[i] if isinstance(drop_path, list) else drop_path,
  151. **dd,
  152. )
  153. for i in range(depth)
  154. ])
  155. def forward(self, x):
  156. x = self.blocks(x)
  157. return x
  158. class NormMlp(nn.Module):
  159. def __init__(
  160. self,
  161. in_features: int,
  162. hidden_features: Optional[int] = None,
  163. out_features: Optional[int] = None,
  164. norm_layer: Type[nn.Module] = nn.LayerNorm,
  165. act_layer: Type[nn.Module] = nn.GELU,
  166. drop: float = 0.,
  167. device=None,
  168. dtype=None,
  169. ):
  170. dd = {'device': device, 'dtype': dtype}
  171. super().__init__()
  172. out_features = out_features or in_features
  173. hidden_features = hidden_features or in_features
  174. self.norm = norm_layer(in_features, **dd)
  175. self.fc1 = nn.Linear(in_features, hidden_features, **dd)
  176. self.act = act_layer()
  177. self.drop1 = nn.Dropout(drop)
  178. self.fc2 = nn.Linear(hidden_features, out_features, **dd)
  179. self.drop2 = nn.Dropout(drop)
  180. def forward(self, x):
  181. x = self.norm(x)
  182. x = self.fc1(x)
  183. x = self.act(x)
  184. x = self.drop1(x)
  185. x = self.fc2(x)
  186. x = self.drop2(x)
  187. return x
  188. class Attention(torch.nn.Module):
  189. fused_attn: torch.jit.Final[bool]
  190. attention_bias_cache: Dict[str, torch.Tensor]
  191. def __init__(
  192. self,
  193. dim: int,
  194. key_dim: int,
  195. num_heads: int = 8,
  196. attn_ratio: int = 4,
  197. resolution: Tuple[int, int] = (14, 14),
  198. device=None,
  199. dtype=None,
  200. ):
  201. dd = {'device': device, 'dtype': dtype}
  202. super().__init__()
  203. assert isinstance(resolution, tuple) and len(resolution) == 2
  204. self.num_heads = num_heads
  205. self.scale = key_dim ** -0.5
  206. self.key_dim = key_dim
  207. self.val_dim = int(attn_ratio * key_dim)
  208. self.out_dim = self.val_dim * num_heads
  209. self.attn_ratio = attn_ratio
  210. self.resolution = resolution
  211. self.fused_attn = use_fused_attn()
  212. self.norm = nn.LayerNorm(dim, **dd)
  213. self.qkv = nn.Linear(dim, num_heads * (self.val_dim + 2 * key_dim), **dd)
  214. self.proj = nn.Linear(self.out_dim, dim, **dd)
  215. N = resolution[0] * resolution[1]
  216. num_offsets = resolution[0] * resolution[1] # unique offset count
  217. self.attention_biases = torch.nn.Parameter(torch.empty(num_heads, num_offsets, **dd))
  218. self.register_buffer(
  219. 'attention_bias_idxs',
  220. torch.empty((N, N), device=device, dtype=torch.long),
  221. persistent=False,
  222. )
  223. self.attention_bias_cache = {}
  224. # TODO: skip init when on meta device when safe to do so
  225. self.reset_parameters()
  226. @torch.no_grad()
  227. def train(self, mode=True):
  228. super().train(mode)
  229. if mode and self.attention_bias_cache:
  230. self.attention_bias_cache = {} # clear ab cache
  231. def get_attention_biases(self, device: torch.device) -> torch.Tensor:
  232. if torch.jit.is_tracing() or self.training:
  233. return self.attention_biases[:, self.attention_bias_idxs]
  234. else:
  235. device_key = str(device)
  236. if device_key not in self.attention_bias_cache:
  237. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  238. return self.attention_bias_cache[device_key]
  239. def forward(self, x):
  240. attn_bias = self.get_attention_biases(x.device)
  241. B, N, _ = x.shape
  242. # Normalization
  243. x = self.norm(x)
  244. qkv = self.qkv(x)
  245. # (B, N, num_heads, d)
  246. q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
  247. # (B, num_heads, N, d)
  248. q = q.permute(0, 2, 1, 3)
  249. k = k.permute(0, 2, 1, 3)
  250. v = v.permute(0, 2, 1, 3)
  251. if self.fused_attn:
  252. x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
  253. else:
  254. q = q * self.scale
  255. attn = q @ k.transpose(-2, -1)
  256. attn = attn + attn_bias
  257. attn = attn.softmax(dim=-1)
  258. x = attn @ v
  259. x = x.transpose(1, 2).reshape(B, N, self.out_dim)
  260. x = self.proj(x)
  261. return x
  262. def reset_parameters(self) -> None:
  263. """Initialize parameters and buffers."""
  264. nn.init.zeros_(self.attention_biases)
  265. self._init_buffers()
  266. def _init_buffers(self) -> None:
  267. """Compute and fill non-persistent buffer values."""
  268. device = self.attention_bias_idxs.device
  269. points = list(itertools.product(range(self.resolution[0]), range(self.resolution[1])))
  270. N = len(points)
  271. attention_offsets = {}
  272. idxs = []
  273. for p1 in points:
  274. for p2 in points:
  275. offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
  276. if offset not in attention_offsets:
  277. attention_offsets[offset] = len(attention_offsets)
  278. idxs.append(attention_offsets[offset])
  279. self.attention_bias_idxs.copy_(torch.tensor(idxs, device=device, dtype=torch.long).view(N, N))
  280. self.attention_bias_cache = {}
  281. def init_non_persistent_buffers(self) -> None:
  282. """Initialize non-persistent buffers."""
  283. self._init_buffers()
  284. class TinyVitBlock(nn.Module):
  285. """ TinyViT Block.
  286. Args:
  287. dim (int): Number of input channels.
  288. num_heads (int): Number of attention heads.
  289. window_size (int): Window size.
  290. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  291. drop (float, optional): Dropout rate. Default: 0.0
  292. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  293. local_conv_size (int): the kernel size of the convolution between
  294. Attention and MLP. Default: 3
  295. act_layer: the activation function. Default: nn.GELU
  296. """
  297. def __init__(
  298. self,
  299. dim: int,
  300. num_heads: int,
  301. window_size: int = 7,
  302. mlp_ratio: float = 4.,
  303. drop: float = 0.,
  304. drop_path: float = 0.,
  305. local_conv_size: int = 3,
  306. act_layer: Type[nn.Module] = nn.GELU,
  307. device=None,
  308. dtype=None,
  309. ):
  310. dd = {'device': device, 'dtype': dtype}
  311. super().__init__()
  312. self.dim = dim
  313. self.num_heads = num_heads
  314. assert window_size > 0, 'window_size must be greater than 0'
  315. self.window_size = window_size
  316. self.mlp_ratio = mlp_ratio
  317. assert dim % num_heads == 0, 'dim must be divisible by num_heads'
  318. head_dim = dim // num_heads
  319. window_resolution = (window_size, window_size)
  320. self.attn = Attention(
  321. dim,
  322. head_dim,
  323. num_heads,
  324. attn_ratio=1,
  325. resolution=window_resolution,
  326. **dd,
  327. )
  328. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  329. self.mlp = NormMlp(
  330. in_features=dim,
  331. hidden_features=int(dim * mlp_ratio),
  332. act_layer=act_layer,
  333. drop=drop,
  334. **dd,
  335. )
  336. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  337. pad = local_conv_size // 2
  338. self.local_conv = ConvNorm(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim, **dd)
  339. def forward(self, x):
  340. B, H, W, C = x.shape
  341. L = H * W
  342. shortcut = x
  343. if H == self.window_size and W == self.window_size:
  344. x = x.reshape(B, L, C)
  345. x = self.attn(x)
  346. x = x.view(B, H, W, C)
  347. else:
  348. pad_b = (self.window_size - H % self.window_size) % self.window_size
  349. pad_r = (self.window_size - W % self.window_size) % self.window_size
  350. padding = pad_b > 0 or pad_r > 0
  351. if padding:
  352. x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
  353. # window partition
  354. pH, pW = H + pad_b, W + pad_r
  355. nH = pH // self.window_size
  356. nW = pW // self.window_size
  357. x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape(
  358. B * nH * nW, self.window_size * self.window_size, C
  359. )
  360. x = self.attn(x)
  361. # window reverse
  362. x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
  363. if padding:
  364. x = x[:, :H, :W].contiguous()
  365. x = shortcut + self.drop_path1(x)
  366. x = x.permute(0, 3, 1, 2)
  367. x = self.local_conv(x)
  368. x = x.reshape(B, C, L).transpose(1, 2)
  369. x = x + self.drop_path2(self.mlp(x))
  370. return x.view(B, H, W, C)
  371. def extra_repr(self) -> str:
  372. return f"dim={self.dim}, num_heads={self.num_heads}, " \
  373. f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
  374. register_notrace_module(TinyVitBlock)
  375. class TinyVitStage(nn.Module):
  376. """ A basic TinyViT layer for one stage.
  377. Args:
  378. dim (int): Number of input channels.
  379. out_dim: the output dimension of the layer
  380. depth (int): Number of blocks.
  381. num_heads (int): Number of attention heads.
  382. window_size (int): Local window size.
  383. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  384. drop (float, optional): Dropout rate. Default: 0.0
  385. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  386. downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
  387. local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
  388. act_layer: the activation function. Default: nn.GELU
  389. """
  390. def __init__(
  391. self,
  392. dim: int,
  393. out_dim: int,
  394. depth: int,
  395. num_heads: int,
  396. window_size: int,
  397. mlp_ratio: float = 4.,
  398. drop: float = 0.,
  399. drop_path: Union[float, List[float]] = 0.,
  400. downsample: Optional[Type[nn.Module]] = None,
  401. local_conv_size: int = 3,
  402. act_layer: Type[nn.Module] = nn.GELU,
  403. device=None,
  404. dtype=None,
  405. ):
  406. dd = {'device': device, 'dtype': dtype}
  407. super().__init__()
  408. self.depth = depth
  409. self.out_dim = out_dim
  410. # patch merging layer
  411. if downsample is not None:
  412. self.downsample = downsample(
  413. dim=dim,
  414. out_dim=out_dim,
  415. act_layer=act_layer,
  416. **dd,
  417. )
  418. else:
  419. self.downsample = nn.Identity()
  420. assert dim == out_dim
  421. # build blocks
  422. self.blocks = nn.Sequential(*[
  423. TinyVitBlock(
  424. dim=out_dim,
  425. num_heads=num_heads,
  426. window_size=window_size,
  427. mlp_ratio=mlp_ratio,
  428. drop=drop,
  429. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  430. local_conv_size=local_conv_size,
  431. act_layer=act_layer,
  432. **dd,
  433. )
  434. for i in range(depth)])
  435. def forward(self, x):
  436. x = self.downsample(x)
  437. x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
  438. x = self.blocks(x)
  439. x = x.permute(0, 3, 1, 2) # BHWC -> BCHW
  440. return x
  441. def extra_repr(self) -> str:
  442. return f"dim={self.out_dim}, depth={self.depth}"
  443. class TinyVit(nn.Module):
  444. def __init__(
  445. self,
  446. in_chans: int = 3,
  447. num_classes: int = 1000,
  448. global_pool: str = 'avg',
  449. embed_dims: Tuple[int, ...] = (96, 192, 384, 768),
  450. depths: Tuple[int, ...] = (2, 2, 6, 2),
  451. num_heads: Tuple[int, ...] = (3, 6, 12, 24),
  452. window_sizes: Tuple[int, ...] = (7, 7, 14, 7),
  453. mlp_ratio: float = 4.,
  454. drop_rate: float = 0.,
  455. drop_path_rate: float = 0.1,
  456. use_checkpoint: bool = False,
  457. mbconv_expand_ratio: float = 4.0,
  458. local_conv_size: int = 3,
  459. act_layer: Type[nn.Module] = nn.GELU,
  460. device=None,
  461. dtype=None,
  462. ):
  463. super().__init__()
  464. dd = {'device': device, 'dtype': dtype}
  465. self.num_classes = num_classes
  466. self.in_chans = in_chans
  467. self.depths = depths
  468. self.num_stages = len(depths)
  469. self.mlp_ratio = mlp_ratio
  470. self.grad_checkpointing = use_checkpoint
  471. self.patch_embed = PatchEmbed(
  472. in_chs=in_chans,
  473. out_chs=embed_dims[0],
  474. act_layer=act_layer,
  475. **dd,
  476. )
  477. # stochastic depth rate rule
  478. dpr = calculate_drop_path_rates(drop_path_rate, sum(depths))
  479. # build stages
  480. self.stages = nn.Sequential()
  481. stride = self.patch_embed.stride
  482. prev_dim = embed_dims[0]
  483. self.feature_info = []
  484. for stage_idx in range(self.num_stages):
  485. if stage_idx == 0:
  486. stage = ConvLayer(
  487. dim=prev_dim,
  488. depth=depths[stage_idx],
  489. act_layer=act_layer,
  490. drop_path=dpr[:depths[stage_idx]],
  491. conv_expand_ratio=mbconv_expand_ratio,
  492. **dd,
  493. )
  494. else:
  495. out_dim = embed_dims[stage_idx]
  496. drop_path_rate = dpr[sum(depths[:stage_idx]):sum(depths[:stage_idx + 1])]
  497. stage = TinyVitStage(
  498. dim=embed_dims[stage_idx - 1],
  499. out_dim=out_dim,
  500. depth=depths[stage_idx],
  501. num_heads=num_heads[stage_idx],
  502. window_size=window_sizes[stage_idx],
  503. mlp_ratio=self.mlp_ratio,
  504. drop=drop_rate,
  505. local_conv_size=local_conv_size,
  506. drop_path=drop_path_rate,
  507. downsample=PatchMerging,
  508. act_layer=act_layer,
  509. **dd,
  510. )
  511. prev_dim = out_dim
  512. stride *= 2
  513. self.stages.append(stage)
  514. self.feature_info += [dict(num_chs=prev_dim, reduction=stride, module=f'stages.{stage_idx}')]
  515. # Classifier head
  516. self.num_features = self.head_hidden_size = embed_dims[-1]
  517. norm_layer_cf = partial(LayerNorm2d, eps=1e-5)
  518. self.head = NormMlpClassifierHead(
  519. self.num_features,
  520. num_classes,
  521. pool_type=global_pool,
  522. norm_layer=norm_layer_cf,
  523. **dd,
  524. )
  525. # TODO: skip init when on meta device when safe to do so
  526. self.init_weights(needs_reset=False)
  527. def init_weights(self, needs_reset: bool = True):
  528. self.apply(partial(self._init_weights, needs_reset=needs_reset))
  529. def _init_weights(self, m: nn.Module, needs_reset: bool = True):
  530. if isinstance(m, nn.Linear):
  531. trunc_normal_(m.weight, std=.02)
  532. if m.bias is not None:
  533. nn.init.constant_(m.bias, 0)
  534. elif needs_reset and hasattr(m, 'reset_parameters'):
  535. m.reset_parameters()
  536. @torch.jit.ignore
  537. def no_weight_decay_keywords(self):
  538. return {'attention_biases'}
  539. @torch.jit.ignore
  540. def no_weight_decay(self):
  541. return {x for x in self.state_dict().keys() if 'attention_biases' in x}
  542. @torch.jit.ignore
  543. def group_matcher(self, coarse=False):
  544. matcher = dict(
  545. stem=r'^patch_embed',
  546. blocks=r'^stages\.(\d+)' if coarse else [
  547. (r'^stages\.(\d+).downsample', (0,)),
  548. (r'^stages\.(\d+)\.\w+\.(\d+)', None),
  549. ]
  550. )
  551. return matcher
  552. @torch.jit.ignore
  553. def set_grad_checkpointing(self, enable=True):
  554. self.grad_checkpointing = enable
  555. @torch.jit.ignore
  556. def get_classifier(self) -> nn.Module:
  557. return self.head.fc
  558. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  559. self.num_classes = num_classes
  560. self.head.reset(num_classes, pool_type=global_pool)
  561. def forward_intermediates(
  562. self,
  563. x: torch.Tensor,
  564. indices: Optional[Union[int, List[int]]] = None,
  565. norm: bool = False,
  566. stop_early: bool = False,
  567. output_fmt: str = 'NCHW',
  568. intermediates_only: bool = False,
  569. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  570. """ Forward features that returns intermediates.
  571. Args:
  572. x: Input image tensor
  573. indices: Take last n blocks if int, all if None, select matching indices if sequence
  574. norm: Apply norm layer to compatible intermediates
  575. stop_early: Stop iterating over blocks when last desired intermediate hit
  576. output_fmt: Shape of intermediate feature outputs
  577. intermediates_only: Only return intermediate features
  578. Returns:
  579. """
  580. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  581. intermediates = []
  582. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  583. # forward pass
  584. x = self.patch_embed(x)
  585. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  586. stages = self.stages
  587. else:
  588. stages = self.stages[:max_index + 1]
  589. for feat_idx, stage in enumerate(stages):
  590. if self.grad_checkpointing and not torch.jit.is_scripting():
  591. x = checkpoint(stage, x)
  592. else:
  593. x = stage(x)
  594. if feat_idx in take_indices:
  595. intermediates.append(x)
  596. if intermediates_only:
  597. return intermediates
  598. return x, intermediates
  599. def prune_intermediate_layers(
  600. self,
  601. indices: Union[int, List[int]] = 1,
  602. prune_norm: bool = False,
  603. prune_head: bool = True,
  604. ):
  605. """ Prune layers not required for specified intermediates.
  606. """
  607. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  608. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  609. if prune_head:
  610. self.reset_classifier(0, '')
  611. return take_indices
  612. def forward_features(self, x):
  613. x = self.patch_embed(x)
  614. if self.grad_checkpointing and not torch.jit.is_scripting():
  615. x = checkpoint_seq(self.stages, x)
  616. else:
  617. x = self.stages(x)
  618. return x
  619. def forward_head(self, x, pre_logits: bool = False):
  620. x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  621. return x
  622. def forward(self, x):
  623. x = self.forward_features(x)
  624. x = self.forward_head(x)
  625. return x
  626. def checkpoint_filter_fn(state_dict, model):
  627. if 'model' in state_dict.keys():
  628. state_dict = state_dict['model']
  629. target_sd = model.state_dict()
  630. out_dict = {}
  631. for k, v in state_dict.items():
  632. if k.endswith('attention_bias_idxs'):
  633. continue
  634. if 'attention_biases' in k:
  635. # TODO: whether move this func into model for dynamic input resolution? (high risk)
  636. v = resize_rel_pos_bias_table_levit(v.T, target_sd[k].shape[::-1]).T
  637. out_dict[k] = v
  638. return out_dict
  639. def _cfg(url='', **kwargs):
  640. return {
  641. 'url': url,
  642. 'num_classes': 1000,
  643. 'mean': IMAGENET_DEFAULT_MEAN,
  644. 'std': IMAGENET_DEFAULT_STD,
  645. 'first_conv': 'patch_embed.conv1.conv',
  646. 'classifier': 'head.fc',
  647. 'pool_size': (7, 7),
  648. 'input_size': (3, 224, 224),
  649. 'crop_pct': 0.95,
  650. 'license': 'apache-2.0',
  651. **kwargs,
  652. }
  653. default_cfgs = generate_default_cfgs({
  654. 'tiny_vit_5m_224.dist_in22k': _cfg(
  655. hf_hub_id='timm/',
  656. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22k_distill.pth',
  657. num_classes=21841
  658. ),
  659. 'tiny_vit_5m_224.dist_in22k_ft_in1k': _cfg(
  660. hf_hub_id='timm/',
  661. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22kto1k_distill.pth'
  662. ),
  663. 'tiny_vit_5m_224.in1k': _cfg(
  664. hf_hub_id='timm/',
  665. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_1k.pth'
  666. ),
  667. 'tiny_vit_11m_224.dist_in22k': _cfg(
  668. hf_hub_id='timm/',
  669. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22k_distill.pth',
  670. num_classes=21841
  671. ),
  672. 'tiny_vit_11m_224.dist_in22k_ft_in1k': _cfg(
  673. hf_hub_id='timm/',
  674. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22kto1k_distill.pth'
  675. ),
  676. 'tiny_vit_11m_224.in1k': _cfg(
  677. hf_hub_id='timm/',
  678. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_1k.pth'
  679. ),
  680. 'tiny_vit_21m_224.dist_in22k': _cfg(
  681. hf_hub_id='timm/',
  682. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22k_distill.pth',
  683. num_classes=21841
  684. ),
  685. 'tiny_vit_21m_224.dist_in22k_ft_in1k': _cfg(
  686. hf_hub_id='timm/',
  687. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_distill.pth'
  688. ),
  689. 'tiny_vit_21m_224.in1k': _cfg(
  690. hf_hub_id='timm/',
  691. #url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_1k.pth'
  692. ),
  693. 'tiny_vit_21m_384.dist_in22k_ft_in1k': _cfg(
  694. hf_hub_id='timm/',
  695. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_384_distill.pth',
  696. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  697. ),
  698. 'tiny_vit_21m_512.dist_in22k_ft_in1k': _cfg(
  699. hf_hub_id='timm/',
  700. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_512_distill.pth',
  701. input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash',
  702. ),
  703. })
  704. def _create_tiny_vit(variant, pretrained=False, **kwargs):
  705. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
  706. model = build_model_with_cfg(
  707. TinyVit,
  708. variant,
  709. pretrained,
  710. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  711. pretrained_filter_fn=checkpoint_filter_fn,
  712. **kwargs
  713. )
  714. return model
  715. @register_model
  716. def tiny_vit_5m_224(pretrained=False, **kwargs):
  717. model_kwargs = dict(
  718. embed_dims=[64, 128, 160, 320],
  719. depths=[2, 2, 6, 2],
  720. num_heads=[2, 4, 5, 10],
  721. window_sizes=[7, 7, 14, 7],
  722. drop_path_rate=0.0,
  723. )
  724. model_kwargs.update(kwargs)
  725. return _create_tiny_vit('tiny_vit_5m_224', pretrained, **model_kwargs)
  726. @register_model
  727. def tiny_vit_11m_224(pretrained=False, **kwargs):
  728. model_kwargs = dict(
  729. embed_dims=[64, 128, 256, 448],
  730. depths=[2, 2, 6, 2],
  731. num_heads=[2, 4, 8, 14],
  732. window_sizes=[7, 7, 14, 7],
  733. drop_path_rate=0.1,
  734. )
  735. model_kwargs.update(kwargs)
  736. return _create_tiny_vit('tiny_vit_11m_224', pretrained, **model_kwargs)
  737. @register_model
  738. def tiny_vit_21m_224(pretrained=False, **kwargs):
  739. model_kwargs = dict(
  740. embed_dims=[96, 192, 384, 576],
  741. depths=[2, 2, 6, 2],
  742. num_heads=[3, 6, 12, 18],
  743. window_sizes=[7, 7, 14, 7],
  744. drop_path_rate=0.2,
  745. )
  746. model_kwargs.update(kwargs)
  747. return _create_tiny_vit('tiny_vit_21m_224', pretrained, **model_kwargs)
  748. @register_model
  749. def tiny_vit_21m_384(pretrained=False, **kwargs):
  750. model_kwargs = dict(
  751. embed_dims=[96, 192, 384, 576],
  752. depths=[2, 2, 6, 2],
  753. num_heads=[3, 6, 12, 18],
  754. window_sizes=[12, 12, 24, 12],
  755. drop_path_rate=0.1,
  756. )
  757. model_kwargs.update(kwargs)
  758. return _create_tiny_vit('tiny_vit_21m_384', pretrained, **model_kwargs)
  759. @register_model
  760. def tiny_vit_21m_512(pretrained=False, **kwargs):
  761. model_kwargs = dict(
  762. embed_dims=[96, 192, 384, 576],
  763. depths=[2, 2, 6, 2],
  764. num_heads=[3, 6, 12, 18],
  765. window_sizes=[16, 16, 32, 16],
  766. drop_path_rate=0.1,
  767. )
  768. model_kwargs.update(kwargs)
  769. return _create_tiny_vit('tiny_vit_21m_512', pretrained, **model_kwargs)