levit.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152
  1. """ LeViT
  2. Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference`
  3. - https://arxiv.org/abs/2104.01136
  4. @article{graham2021levit,
  5. title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
  6. author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze},
  7. journal={arXiv preprint arXiv:22104.01136},
  8. year={2021}
  9. }
  10. Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow.
  11. This version combines both conv/linear models and fixes torchscript compatibility.
  12. Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
  13. """
  14. # Copyright (c) 2015-present, Facebook, Inc.
  15. # All rights reserved.
  16. # Modified from
  17. # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  18. # Copyright 2020 Ross Wightman, Apache-2.0 License
  19. from collections import OrderedDict
  20. from functools import partial
  21. from typing import Dict, List, Optional, Tuple, Type, Union
  22. import torch
  23. import torch.nn as nn
  24. from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
  25. from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid
  26. from ._builder import build_model_with_cfg
  27. from ._features import feature_take_indices
  28. from ._manipulate import checkpoint, checkpoint_seq
  29. from ._registry import generate_default_cfgs, register_model
  30. __all__ = ['Levit']
  31. class ConvNorm(nn.Module):
  32. def __init__(
  33. self,
  34. in_chs: int,
  35. out_chs: int,
  36. kernel_size: int = 1,
  37. stride: int = 1,
  38. padding: int = 0,
  39. dilation: int = 1,
  40. groups: int = 1,
  41. bn_weight_init: float = 1,
  42. device=None,
  43. dtype=None,
  44. ):
  45. dd = {'device': device, 'dtype': dtype}
  46. super().__init__()
  47. self.linear = nn.Conv2d(in_chs, out_chs, kernel_size, stride, padding, dilation, groups, bias=False, **dd)
  48. self.bn = nn.BatchNorm2d(out_chs, **dd)
  49. nn.init.constant_(self.bn.weight, bn_weight_init)
  50. @torch.no_grad()
  51. def fuse(self):
  52. c, bn = self.linear, self.bn
  53. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  54. w = c.weight * w[:, None, None, None]
  55. b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
  56. m = nn.Conv2d(
  57. w.size(1), w.size(0), w.shape[2:], stride=self.linear.stride,
  58. padding=self.linear.padding, dilation=self.linear.dilation, groups=self.linear.groups)
  59. m.weight.data.copy_(w)
  60. m.bias.data.copy_(b)
  61. return m
  62. def forward(self, x):
  63. return self.bn(self.linear(x))
  64. class LinearNorm(nn.Module):
  65. def __init__(
  66. self,
  67. in_features: int,
  68. out_features: int,
  69. bn_weight_init: float = 1,
  70. device=None,
  71. dtype=None,
  72. ):
  73. dd = {'device': device, 'dtype': dtype}
  74. super().__init__()
  75. self.linear = nn.Linear(in_features, out_features, bias=False, **dd)
  76. self.bn = nn.BatchNorm1d(out_features, **dd)
  77. nn.init.constant_(self.bn.weight, bn_weight_init)
  78. @torch.no_grad()
  79. def fuse(self):
  80. l, bn = self.linear, self.bn
  81. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  82. w = l.weight * w[:, None]
  83. b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
  84. m = nn.Linear(w.size(1), w.size(0))
  85. m.weight.data.copy_(w)
  86. m.bias.data.copy_(b)
  87. return m
  88. def forward(self, x):
  89. x = self.linear(x)
  90. return self.bn(x.flatten(0, 1)).reshape_as(x)
  91. class NormLinear(nn.Module):
  92. def __init__(
  93. self,
  94. in_features: int,
  95. out_features: int,
  96. bias: bool = True,
  97. std: float = 0.02,
  98. drop: float = 0.,
  99. device=None,
  100. dtype=None,
  101. ):
  102. dd = {'device': device, 'dtype': dtype}
  103. super().__init__()
  104. self.bn = nn.BatchNorm1d(in_features, **dd)
  105. self.drop = nn.Dropout(drop)
  106. self.linear = nn.Linear(in_features, out_features, bias=bias, **dd)
  107. trunc_normal_(self.linear.weight, std=std)
  108. if self.linear.bias is not None:
  109. nn.init.constant_(self.linear.bias, 0)
  110. @torch.no_grad()
  111. def fuse(self):
  112. bn, l = self.bn, self.linear
  113. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  114. b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
  115. w = l.weight * w[None, :]
  116. if l.bias is None:
  117. b = b @ self.linear.weight.T
  118. else:
  119. b = (l.weight @ b[:, None]).view(-1) + self.linear.bias
  120. m = nn.Linear(w.size(1), w.size(0))
  121. m.weight.data.copy_(w)
  122. m.bias.data.copy_(b)
  123. return m
  124. def forward(self, x):
  125. return self.linear(self.drop(self.bn(x)))
  126. class Stem8(nn.Sequential):
  127. def __init__(
  128. self,
  129. in_chs: int,
  130. out_chs: int,
  131. act_layer: Type[nn.Module],
  132. device=None,
  133. dtype=None,
  134. ):
  135. dd = {'device': device, 'dtype': dtype}
  136. super().__init__()
  137. self.stride = 8
  138. self.add_module('conv1', ConvNorm(in_chs, out_chs // 4, 3, stride=2, padding=1, **dd))
  139. self.add_module('act1', act_layer())
  140. self.add_module('conv2', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1, **dd))
  141. self.add_module('act2', act_layer())
  142. self.add_module('conv3', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1, **dd))
  143. class Stem16(nn.Sequential):
  144. def __init__(
  145. self,
  146. in_chs: int,
  147. out_chs: int,
  148. act_layer: Type[nn.Module],
  149. device=None,
  150. dtype=None,
  151. ):
  152. dd = {'device': device, 'dtype': dtype}
  153. super().__init__()
  154. self.stride = 16
  155. self.add_module('conv1', ConvNorm(in_chs, out_chs // 8, 3, stride=2, padding=1, **dd))
  156. self.add_module('act1', act_layer())
  157. self.add_module('conv2', ConvNorm(out_chs // 8, out_chs // 4, 3, stride=2, padding=1, **dd))
  158. self.add_module('act2', act_layer())
  159. self.add_module('conv3', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1, **dd))
  160. self.add_module('act3', act_layer())
  161. self.add_module('conv4', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1, **dd))
  162. class Downsample(nn.Module):
  163. def __init__(
  164. self,
  165. stride: int,
  166. resolution: Union[int, Tuple[int, int]],
  167. use_pool: bool = False,
  168. device=None,
  169. dtype=None,
  170. ):
  171. super().__init__()
  172. self.stride = stride
  173. self.resolution = to_2tuple(resolution)
  174. self.pool = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) if use_pool else None
  175. def forward(self, x):
  176. B, N, C = x.shape
  177. x = x.view(B, self.resolution[0], self.resolution[1], C)
  178. if self.pool is not None:
  179. x = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
  180. else:
  181. x = x[:, ::self.stride, ::self.stride]
  182. return x.reshape(B, -1, C)
  183. class Attention(nn.Module):
  184. attention_bias_cache: Dict[str, torch.Tensor]
  185. def __init__(
  186. self,
  187. dim: int,
  188. key_dim: int,
  189. num_heads: int = 8,
  190. attn_ratio: float = 4.,
  191. resolution: Union[int, Tuple[int, int]] = 14,
  192. use_conv: bool = False,
  193. act_layer: Type[nn.Module] = nn.SiLU,
  194. device=None,
  195. dtype=None,
  196. ):
  197. dd = {'device': device, 'dtype': dtype}
  198. super().__init__()
  199. ln_layer = ConvNorm if use_conv else LinearNorm
  200. resolution = to_2tuple(resolution)
  201. self.use_conv = use_conv
  202. self.num_heads = num_heads
  203. self.scale = key_dim ** -0.5
  204. self.key_dim = key_dim
  205. self.key_attn_dim = key_dim * num_heads
  206. self.val_dim = int(attn_ratio * key_dim)
  207. self.val_attn_dim = int(attn_ratio * key_dim) * num_heads
  208. self.resolution = resolution
  209. self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2, **dd)
  210. self.proj = nn.Sequential(OrderedDict([
  211. ('act', act_layer()),
  212. ('ln', ln_layer(self.val_attn_dim, dim, bn_weight_init=0, **dd))
  213. ]))
  214. N = resolution[0] * resolution[1]
  215. self.attention_biases = nn.Parameter(torch.empty(num_heads, N, **dd))
  216. self.register_buffer(
  217. 'attention_bias_idxs', torch.empty((N, N), device=device, dtype=torch.long), persistent=False)
  218. self.attention_bias_cache = {}
  219. # TODO: skip init when on meta device when safe to do so
  220. self.reset_parameters()
  221. @torch.no_grad()
  222. def train(self, mode=True):
  223. super().train(mode)
  224. if mode and self.attention_bias_cache:
  225. self.attention_bias_cache = {} # clear ab cache
  226. def reset_parameters(self) -> None:
  227. """Initialize parameters and buffers."""
  228. nn.init.zeros_(self.attention_biases)
  229. self._init_buffers()
  230. def _compute_attention_bias_idxs(self, device=None):
  231. """Compute relative position indices for attention bias."""
  232. pos = torch.stack(ndgrid(
  233. torch.arange(self.resolution[0], device=device, dtype=torch.long),
  234. torch.arange(self.resolution[1], device=device, dtype=torch.long),
  235. )).flatten(1)
  236. rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
  237. rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
  238. return rel_pos
  239. def _init_buffers(self) -> None:
  240. """Compute and fill non-persistent buffer values."""
  241. self.attention_bias_idxs.copy_(
  242. self._compute_attention_bias_idxs(device=self.attention_bias_idxs.device)
  243. )
  244. self.attention_bias_cache = {}
  245. def init_non_persistent_buffers(self) -> None:
  246. """Initialize non-persistent buffers."""
  247. self._init_buffers()
  248. def get_attention_biases(self, device: torch.device) -> torch.Tensor:
  249. if torch.jit.is_tracing() or self.training:
  250. return self.attention_biases[:, self.attention_bias_idxs]
  251. else:
  252. device_key = str(device)
  253. if device_key not in self.attention_bias_cache:
  254. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  255. return self.attention_bias_cache[device_key]
  256. def forward(self, x): # x (B,C,H,W)
  257. if self.use_conv:
  258. B, C, H, W = x.shape
  259. q, k, v = self.qkv(x).view(
  260. B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.val_dim], dim=2)
  261. attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
  262. attn = attn.softmax(dim=-1)
  263. x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
  264. else:
  265. B, N, C = x.shape
  266. q, k, v = self.qkv(x).view(
  267. B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
  268. q = q.permute(0, 2, 1, 3)
  269. k = k.permute(0, 2, 3, 1)
  270. v = v.permute(0, 2, 1, 3)
  271. attn = q @ k * self.scale + self.get_attention_biases(x.device)
  272. attn = attn.softmax(dim=-1)
  273. x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim)
  274. x = self.proj(x)
  275. return x
  276. class AttentionDownsample(nn.Module):
  277. attention_bias_cache: Dict[str, torch.Tensor]
  278. def __init__(
  279. self,
  280. in_dim: int,
  281. out_dim: int,
  282. key_dim: int,
  283. num_heads: int = 8,
  284. attn_ratio: float = 2.0,
  285. stride: int = 2,
  286. resolution: Union[int, Tuple[int, int]] = 14,
  287. use_conv: bool = False,
  288. use_pool: bool = False,
  289. act_layer: Type[nn.Module] = nn.SiLU,
  290. device=None,
  291. dtype=None,
  292. ):
  293. dd = {'device': device, 'dtype': dtype}
  294. super().__init__()
  295. resolution = to_2tuple(resolution)
  296. self.stride = stride
  297. self.resolution = resolution
  298. self.num_heads = num_heads
  299. self.key_dim = key_dim
  300. self.key_attn_dim = key_dim * num_heads
  301. self.val_dim = int(attn_ratio * key_dim)
  302. self.val_attn_dim = self.val_dim * self.num_heads
  303. self.scale = key_dim ** -0.5
  304. self.use_conv = use_conv
  305. if self.use_conv:
  306. ln_layer = ConvNorm
  307. sub_layer = partial(
  308. nn.AvgPool2d,
  309. kernel_size=3 if use_pool else 1, padding=1 if use_pool else 0, count_include_pad=False)
  310. else:
  311. ln_layer = LinearNorm
  312. sub_layer = partial(Downsample, resolution=resolution, use_pool=use_pool, **dd)
  313. self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim, **dd)
  314. self.q = nn.Sequential(OrderedDict([
  315. ('down', sub_layer(stride=stride)),
  316. ('ln', ln_layer(in_dim, self.key_attn_dim, **dd))
  317. ]))
  318. self.proj = nn.Sequential(OrderedDict([
  319. ('act', act_layer()),
  320. ('ln', ln_layer(self.val_attn_dim, out_dim, **dd))
  321. ]))
  322. N_k = resolution[0] * resolution[1]
  323. N_q = -(-resolution[0] // stride) * -(-resolution[1] // stride) # ceiling division
  324. self.attention_biases = nn.Parameter(torch.empty(num_heads, N_k, **dd))
  325. self.register_buffer('attention_bias_idxs', torch.empty((N_q, N_k), device=device, dtype=torch.long), persistent=False)
  326. self.attention_bias_cache = {}
  327. # TODO: skip init when on meta device when safe to do so
  328. self.reset_parameters()
  329. @torch.no_grad()
  330. def train(self, mode=True):
  331. super().train(mode)
  332. if mode and self.attention_bias_cache:
  333. self.attention_bias_cache = {} # clear ab cache
  334. def reset_parameters(self) -> None:
  335. """Initialize parameters and buffers."""
  336. nn.init.zeros_(self.attention_biases)
  337. self._init_buffers()
  338. def _compute_attention_bias_idxs(self, device=None):
  339. """Compute relative position indices for attention bias."""
  340. k_pos = torch.stack(ndgrid(
  341. torch.arange(self.resolution[0], device=device, dtype=torch.long),
  342. torch.arange(self.resolution[1], device=device, dtype=torch.long),
  343. )).flatten(1)
  344. q_pos = torch.stack(ndgrid(
  345. torch.arange(0, self.resolution[0], step=self.stride, device=device, dtype=torch.long),
  346. torch.arange(0, self.resolution[1], step=self.stride, device=device, dtype=torch.long),
  347. )).flatten(1)
  348. rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
  349. rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
  350. return rel_pos
  351. def _init_buffers(self) -> None:
  352. """Compute and fill non-persistent buffer values."""
  353. self.attention_bias_idxs.copy_(
  354. self._compute_attention_bias_idxs(device=self.attention_bias_idxs.device)
  355. )
  356. self.attention_bias_cache = {}
  357. def init_non_persistent_buffers(self) -> None:
  358. """Initialize non-persistent buffers."""
  359. self._init_buffers()
  360. def get_attention_biases(self, device: torch.device) -> torch.Tensor:
  361. if torch.jit.is_tracing() or self.training:
  362. return self.attention_biases[:, self.attention_bias_idxs]
  363. else:
  364. device_key = str(device)
  365. if device_key not in self.attention_bias_cache:
  366. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  367. return self.attention_bias_cache[device_key]
  368. def forward(self, x):
  369. if self.use_conv:
  370. B, C, H, W = x.shape
  371. HH, WW = (H - 1) // self.stride + 1, (W - 1) // self.stride + 1
  372. k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.val_dim], dim=2)
  373. q = self.q(x).view(B, self.num_heads, self.key_dim, -1)
  374. attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
  375. attn = attn.softmax(dim=-1)
  376. x = (v @ attn.transpose(-2, -1)).reshape(B, self.val_attn_dim, HH, WW)
  377. else:
  378. B, N, C = x.shape
  379. k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3)
  380. k = k.permute(0, 2, 3, 1) # BHCN
  381. v = v.permute(0, 2, 1, 3) # BHNC
  382. q = self.q(x).view(B, -1, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
  383. attn = q @ k * self.scale + self.get_attention_biases(x.device)
  384. attn = attn.softmax(dim=-1)
  385. x = (attn @ v).transpose(1, 2).reshape(B, -1, self.val_attn_dim)
  386. x = self.proj(x)
  387. return x
  388. class LevitMlp(nn.Module):
  389. """ MLP for Levit w/ normalization + ability to switch btw conv and linear
  390. """
  391. def __init__(
  392. self,
  393. in_features: int,
  394. hidden_features: Optional[int] = None,
  395. out_features: Optional[int] = None,
  396. use_conv: bool = False,
  397. act_layer: Type[nn.Module] = nn.SiLU,
  398. drop: float = 0.,
  399. device=None,
  400. dtype=None,
  401. ):
  402. dd = {'device': device, 'dtype': dtype}
  403. super().__init__()
  404. out_features = out_features or in_features
  405. hidden_features = hidden_features or in_features
  406. ln_layer = ConvNorm if use_conv else LinearNorm
  407. self.ln1 = ln_layer(in_features, hidden_features, **dd)
  408. self.act = act_layer()
  409. self.drop = nn.Dropout(drop)
  410. self.ln2 = ln_layer(hidden_features, out_features, bn_weight_init=0, **dd)
  411. def forward(self, x):
  412. x = self.ln1(x)
  413. x = self.act(x)
  414. x = self.drop(x)
  415. x = self.ln2(x)
  416. return x
  417. class LevitDownsample(nn.Module):
  418. def __init__(
  419. self,
  420. in_dim: int,
  421. out_dim: int,
  422. key_dim: int,
  423. num_heads: int = 8,
  424. attn_ratio: float = 4.,
  425. mlp_ratio: float = 2.,
  426. act_layer: Type[nn.Module] = nn.SiLU,
  427. attn_act_layer: Optional[Type[nn.Module]] = None,
  428. resolution: Union[int, Tuple[int, int]] = 14,
  429. use_conv: bool = False,
  430. use_pool: bool = False,
  431. drop_path: float = 0.,
  432. device=None,
  433. dtype=None,
  434. ):
  435. dd = {'device': device, 'dtype': dtype}
  436. super().__init__()
  437. attn_act_layer = attn_act_layer or act_layer
  438. self.attn_downsample = AttentionDownsample(
  439. in_dim=in_dim,
  440. out_dim=out_dim,
  441. key_dim=key_dim,
  442. num_heads=num_heads,
  443. attn_ratio=attn_ratio,
  444. act_layer=attn_act_layer,
  445. resolution=resolution,
  446. use_conv=use_conv,
  447. use_pool=use_pool,
  448. **dd,
  449. )
  450. self.mlp = LevitMlp(
  451. out_dim,
  452. int(out_dim * mlp_ratio),
  453. use_conv=use_conv,
  454. act_layer=act_layer,
  455. **dd,
  456. )
  457. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  458. def forward(self, x):
  459. x = self.attn_downsample(x)
  460. x = x + self.drop_path(self.mlp(x))
  461. return x
  462. class LevitBlock(nn.Module):
  463. def __init__(
  464. self,
  465. dim: int,
  466. key_dim: int,
  467. num_heads: int = 8,
  468. attn_ratio: float = 4.,
  469. mlp_ratio: float = 2.,
  470. resolution: Union[int, Tuple[int, int]] = 14,
  471. use_conv: bool = False,
  472. act_layer: Type[nn.Module] = nn.SiLU,
  473. attn_act_layer: Optional[Type[nn.Module]] = None,
  474. drop_path: float = 0.,
  475. device=None,
  476. dtype=None,
  477. ):
  478. dd = {'device': device, 'dtype': dtype}
  479. super().__init__()
  480. attn_act_layer = attn_act_layer or act_layer
  481. self.attn = Attention(
  482. dim=dim,
  483. key_dim=key_dim,
  484. num_heads=num_heads,
  485. attn_ratio=attn_ratio,
  486. resolution=resolution,
  487. use_conv=use_conv,
  488. act_layer=attn_act_layer,
  489. **dd,
  490. )
  491. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  492. self.mlp = LevitMlp(
  493. dim,
  494. int(dim * mlp_ratio),
  495. use_conv=use_conv,
  496. act_layer=act_layer,
  497. **dd,
  498. )
  499. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  500. def forward(self, x):
  501. x = x + self.drop_path1(self.attn(x))
  502. x = x + self.drop_path2(self.mlp(x))
  503. return x
  504. class LevitStage(nn.Module):
  505. def __init__(
  506. self,
  507. in_dim: int,
  508. out_dim: int,
  509. key_dim: int,
  510. depth: int = 4,
  511. num_heads: int = 8,
  512. attn_ratio: float = 4.0,
  513. mlp_ratio: float = 4.0,
  514. act_layer: Type[nn.Module] = nn.SiLU,
  515. attn_act_layer: Optional[Type[nn.Module]] = None,
  516. resolution: Union[int, Tuple[int, int]] = 14,
  517. downsample: str = '',
  518. use_conv: bool = False,
  519. drop_path: float = 0.,
  520. device=None,
  521. dtype=None,
  522. ):
  523. dd = {'device': device, 'dtype': dtype}
  524. super().__init__()
  525. resolution = to_2tuple(resolution)
  526. if downsample:
  527. self.downsample = LevitDownsample(
  528. in_dim,
  529. out_dim,
  530. key_dim=key_dim,
  531. num_heads=in_dim // key_dim,
  532. attn_ratio=4.,
  533. mlp_ratio=2.,
  534. act_layer=act_layer,
  535. attn_act_layer=attn_act_layer,
  536. resolution=resolution,
  537. use_conv=use_conv,
  538. drop_path=drop_path,
  539. **dd,
  540. )
  541. resolution = [(r - 1) // 2 + 1 for r in resolution]
  542. else:
  543. assert in_dim == out_dim
  544. self.downsample = nn.Identity()
  545. blocks = []
  546. for _ in range(depth):
  547. blocks += [LevitBlock(
  548. out_dim,
  549. key_dim,
  550. num_heads=num_heads,
  551. attn_ratio=attn_ratio,
  552. mlp_ratio=mlp_ratio,
  553. act_layer=act_layer,
  554. attn_act_layer=attn_act_layer,
  555. resolution=resolution,
  556. use_conv=use_conv,
  557. drop_path=drop_path,
  558. **dd,
  559. )]
  560. self.blocks = nn.Sequential(*blocks)
  561. def forward(self, x):
  562. x = self.downsample(x)
  563. x = self.blocks(x)
  564. return x
  565. class Levit(nn.Module):
  566. """ Vision Transformer with support for patch or hybrid CNN input stage
  567. NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems
  568. w/ train scripts that don't take tuple outputs,
  569. """
  570. def __init__(
  571. self,
  572. img_size: Union[int, Tuple[int, int]] = 224,
  573. in_chans: int = 3,
  574. num_classes: int = 1000,
  575. embed_dim: Tuple[int, ...] = (192,),
  576. key_dim: int = 64,
  577. depth: Tuple[int, ...] = (12,),
  578. num_heads: Union[int, Tuple[int, ...]] = (3,),
  579. attn_ratio: Union[float, Tuple[float, ...]] = 2.,
  580. mlp_ratio: Union[float, Tuple[float, ...]] = 2.,
  581. stem_backbone: Optional[nn.Module] = None,
  582. stem_stride: Optional[int] = None,
  583. stem_type: str = 's16',
  584. down_op: str = 'subsample',
  585. act_layer: str = 'hard_swish',
  586. attn_act_layer: Optional[str] = None,
  587. use_conv: bool = False,
  588. global_pool: str = 'avg',
  589. drop_rate: float = 0.,
  590. drop_path_rate: float = 0.,
  591. device=None,
  592. dtype=None,
  593. ):
  594. super().__init__()
  595. dd = {'device': device, 'dtype': dtype}
  596. act_layer = get_act_layer(act_layer)
  597. attn_act_layer = get_act_layer(attn_act_layer or act_layer)
  598. self.use_conv = use_conv
  599. self.num_classes = num_classes
  600. self.in_chans = in_chans
  601. self.global_pool = global_pool
  602. self.num_features = self.head_hidden_size = embed_dim[-1]
  603. self.embed_dim = embed_dim
  604. self.drop_rate = drop_rate
  605. self.grad_checkpointing = False
  606. self.feature_info = []
  607. num_stages = len(embed_dim)
  608. assert len(depth) == num_stages
  609. num_heads = to_ntuple(num_stages)(num_heads)
  610. attn_ratio = to_ntuple(num_stages)(attn_ratio)
  611. mlp_ratio = to_ntuple(num_stages)(mlp_ratio)
  612. if stem_backbone is not None:
  613. assert stem_stride >= 2
  614. self.stem = stem_backbone
  615. stride = stem_stride
  616. else:
  617. assert stem_type in ('s16', 's8')
  618. if stem_type == 's16':
  619. self.stem = Stem16(in_chans, embed_dim[0], act_layer=act_layer, **dd)
  620. else:
  621. self.stem = Stem8(in_chans, embed_dim[0], act_layer=act_layer, **dd)
  622. stride = self.stem.stride
  623. resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
  624. in_dim = embed_dim[0]
  625. stages = []
  626. for i in range(num_stages):
  627. stage_stride = 2 if i > 0 else 1
  628. stages += [LevitStage(
  629. in_dim,
  630. embed_dim[i],
  631. key_dim,
  632. depth=depth[i],
  633. num_heads=num_heads[i],
  634. attn_ratio=attn_ratio[i],
  635. mlp_ratio=mlp_ratio[i],
  636. act_layer=act_layer,
  637. attn_act_layer=attn_act_layer,
  638. resolution=resolution,
  639. use_conv=use_conv,
  640. downsample=down_op if stage_stride == 2 else '',
  641. drop_path=drop_path_rate,
  642. **dd,
  643. )]
  644. stride *= stage_stride
  645. resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution])
  646. self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
  647. in_dim = embed_dim[i]
  648. self.stages = nn.Sequential(*stages)
  649. # Classifier head
  650. self.head = NormLinear(embed_dim[-1], num_classes, drop=drop_rate, **dd) if num_classes > 0 else nn.Identity()
  651. # TODO: skip init when on meta device when safe to do so
  652. self.init_weights(needs_reset=False)
  653. def init_weights(self, needs_reset: bool = True):
  654. self.apply(partial(self._init_weights, needs_reset=needs_reset))
  655. def _init_weights(self, m: nn.Module, needs_reset: bool = True) -> None:
  656. if needs_reset and hasattr(m, 'reset_parameters'):
  657. m.reset_parameters()
  658. @torch.jit.ignore
  659. def no_weight_decay(self):
  660. return {x for x in self.state_dict().keys() if 'attention_biases' in x}
  661. @torch.jit.ignore
  662. def group_matcher(self, coarse=False):
  663. matcher = dict(
  664. stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
  665. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
  666. )
  667. return matcher
  668. @torch.jit.ignore
  669. def set_grad_checkpointing(self, enable=True):
  670. self.grad_checkpointing = enable
  671. @torch.jit.ignore
  672. def get_classifier(self) -> nn.Module:
  673. return self.head
  674. def reset_classifier(self, num_classes: int , global_pool: Optional[str] = None):
  675. self.num_classes = num_classes
  676. if global_pool is not None:
  677. self.global_pool = global_pool
  678. self.head = NormLinear(
  679. self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
  680. def forward_intermediates(
  681. self,
  682. x: torch.Tensor,
  683. indices: Optional[Union[int, List[int]]] = None,
  684. norm: bool = False,
  685. stop_early: bool = False,
  686. output_fmt: str = 'NCHW',
  687. intermediates_only: bool = False,
  688. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  689. """ Forward features that returns intermediates.
  690. Args:
  691. x: Input image tensor
  692. indices: Take last n blocks if int, all if None, select matching indices if sequence
  693. norm: Apply norm layer to compatible intermediates
  694. stop_early: Stop iterating over blocks when last desired intermediate hit
  695. output_fmt: Shape of intermediate feature outputs
  696. intermediates_only: Only return intermediate features
  697. Returns:
  698. """
  699. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  700. intermediates = []
  701. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  702. # forward pass
  703. x = self.stem(x)
  704. B, C, H, W = x.shape
  705. if not self.use_conv:
  706. x = x.flatten(2).transpose(1, 2)
  707. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  708. stages = self.stages
  709. else:
  710. stages = self.stages[:max_index + 1]
  711. for feat_idx, stage in enumerate(stages):
  712. if self.grad_checkpointing and not torch.jit.is_scripting():
  713. x = checkpoint(stage, x)
  714. else:
  715. x = stage(x)
  716. if feat_idx in take_indices:
  717. if self.use_conv:
  718. intermediates.append(x)
  719. else:
  720. intermediates.append(x.reshape(B, H, W, -1).permute(0, 3, 1, 2))
  721. H = (H + 2 - 1) // 2
  722. W = (W + 2 - 1) // 2
  723. if intermediates_only:
  724. return intermediates
  725. return x, intermediates
  726. def prune_intermediate_layers(
  727. self,
  728. indices: Union[int, List[int]] = 1,
  729. prune_norm: bool = False,
  730. prune_head: bool = True,
  731. ):
  732. """ Prune layers not required for specified intermediates.
  733. """
  734. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  735. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  736. if prune_head:
  737. self.reset_classifier(0, '')
  738. return take_indices
  739. def forward_features(self, x):
  740. x = self.stem(x)
  741. if not self.use_conv:
  742. x = x.flatten(2).transpose(1, 2)
  743. if self.grad_checkpointing and not torch.jit.is_scripting():
  744. x = checkpoint_seq(self.stages, x)
  745. else:
  746. x = self.stages(x)
  747. return x
  748. def forward_head(self, x, pre_logits: bool = False):
  749. if self.global_pool == 'avg':
  750. x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
  751. return x if pre_logits else self.head(x)
  752. def forward(self, x):
  753. x = self.forward_features(x)
  754. x = self.forward_head(x)
  755. return x
  756. class LevitDistilled(Levit):
  757. def __init__(self, *args, **kwargs):
  758. super().__init__(*args, **kwargs)
  759. dd = {'device': kwargs.get('device', None), 'dtype': kwargs.get('dtype', None)}
  760. self.head_dist = NormLinear(self.num_features, self.num_classes, **dd) if self.num_classes > 0 else nn.Identity()
  761. self.distilled_training = False # must set this True to train w/ distillation token
  762. @torch.jit.ignore
  763. def get_classifier(self) -> nn.Module:
  764. return self.head, self.head_dist
  765. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  766. self.num_classes = num_classes
  767. if global_pool is not None:
  768. self.global_pool = global_pool
  769. self.head = NormLinear(
  770. self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
  771. self.head_dist = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  772. @torch.jit.ignore
  773. def set_distilled_training(self, enable=True):
  774. self.distilled_training = enable
  775. def forward_head(self, x, pre_logits: bool = False):
  776. if self.global_pool == 'avg':
  777. x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
  778. if pre_logits:
  779. return x
  780. x, x_dist = self.head(x), self.head_dist(x)
  781. if self.distilled_training and self.training and not torch.jit.is_scripting():
  782. # only return separate classification predictions when training in distilled mode
  783. return x, x_dist
  784. else:
  785. # during standard train/finetune, inference average the classifier predictions
  786. return (x + x_dist) / 2
  787. def checkpoint_filter_fn(state_dict, model):
  788. if 'model' in state_dict:
  789. state_dict = state_dict['model']
  790. # filter out attn biases, should not have been persistent
  791. state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k}
  792. # NOTE: old weight conversion code, disabled
  793. # D = model.state_dict()
  794. # out_dict = {}
  795. # for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
  796. # if va.ndim == 4 and vb.ndim == 2:
  797. # vb = vb[:, :, None, None]
  798. # if va.shape != vb.shape:
  799. # # head or first-conv shapes may change for fine-tune
  800. # assert 'head' in ka or 'stem.conv1.linear' in ka
  801. # out_dict[ka] = vb
  802. return state_dict
  803. model_cfgs = dict(
  804. levit_128s=dict(
  805. embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),
  806. levit_128=dict(
  807. embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)),
  808. levit_192=dict(
  809. embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)),
  810. levit_256=dict(
  811. embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
  812. levit_384=dict(
  813. embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
  814. # stride-8 stem experiments
  815. levit_384_s8=dict(
  816. embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4),
  817. act_layer='silu', stem_type='s8'),
  818. levit_512_s8=dict(
  819. embed_dim=(512, 640, 896), key_dim=64, num_heads=(8, 10, 14), depth=(4, 4, 4),
  820. act_layer='silu', stem_type='s8'),
  821. # wider experiments
  822. levit_512=dict(
  823. embed_dim=(512, 768, 1024), key_dim=64, num_heads=(8, 12, 16), depth=(4, 4, 4), act_layer='silu'),
  824. # deeper experiments
  825. levit_256d=dict(
  826. embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 8, 6), act_layer='silu'),
  827. levit_512d=dict(
  828. embed_dim=(512, 640, 768), key_dim=64, num_heads=(8, 10, 12), depth=(4, 8, 6), act_layer='silu'),
  829. )
  830. def create_levit(variant, cfg_variant=None, pretrained=False, distilled=True, **kwargs):
  831. is_conv = '_conv' in variant
  832. out_indices = kwargs.pop('out_indices', (0, 1, 2))
  833. if kwargs.get('features_only', False) and not is_conv:
  834. kwargs.setdefault('feature_cls', 'getter')
  835. if cfg_variant is None:
  836. if variant in model_cfgs:
  837. cfg_variant = variant
  838. elif is_conv:
  839. cfg_variant = variant.replace('_conv', '')
  840. model_cfg = dict(model_cfgs[cfg_variant], **kwargs)
  841. model = build_model_with_cfg(
  842. LevitDistilled if distilled else Levit,
  843. variant,
  844. pretrained,
  845. pretrained_filter_fn=checkpoint_filter_fn,
  846. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  847. **model_cfg,
  848. )
  849. return model
  850. def _cfg(url='', **kwargs):
  851. return {
  852. 'url': url,
  853. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  854. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  855. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  856. 'first_conv': 'stem.conv1.linear', 'classifier': ('head.linear', 'head_dist.linear'),
  857. 'license': 'apache-2.0',
  858. **kwargs
  859. }
  860. default_cfgs = generate_default_cfgs({
  861. # weights in nn.Linear mode
  862. 'levit_128s.fb_dist_in1k': _cfg(
  863. hf_hub_id='timm/',
  864. ),
  865. 'levit_128.fb_dist_in1k': _cfg(
  866. hf_hub_id='timm/',
  867. ),
  868. 'levit_192.fb_dist_in1k': _cfg(
  869. hf_hub_id='timm/',
  870. ),
  871. 'levit_256.fb_dist_in1k': _cfg(
  872. hf_hub_id='timm/',
  873. ),
  874. 'levit_384.fb_dist_in1k': _cfg(
  875. hf_hub_id='timm/',
  876. ),
  877. # weights in nn.Conv2d mode
  878. 'levit_conv_128s.fb_dist_in1k': _cfg(
  879. hf_hub_id='timm/',
  880. pool_size=(4, 4),
  881. ),
  882. 'levit_conv_128.fb_dist_in1k': _cfg(
  883. hf_hub_id='timm/',
  884. pool_size=(4, 4),
  885. ),
  886. 'levit_conv_192.fb_dist_in1k': _cfg(
  887. hf_hub_id='timm/',
  888. pool_size=(4, 4),
  889. ),
  890. 'levit_conv_256.fb_dist_in1k': _cfg(
  891. hf_hub_id='timm/',
  892. pool_size=(4, 4),
  893. ),
  894. 'levit_conv_384.fb_dist_in1k': _cfg(
  895. hf_hub_id='timm/',
  896. pool_size=(4, 4),
  897. ),
  898. 'levit_384_s8.untrained': _cfg(classifier='head.linear'),
  899. 'levit_512_s8.untrained': _cfg(classifier='head.linear'),
  900. 'levit_512.untrained': _cfg(classifier='head.linear'),
  901. 'levit_256d.untrained': _cfg(classifier='head.linear'),
  902. 'levit_512d.untrained': _cfg(classifier='head.linear'),
  903. 'levit_conv_384_s8.untrained': _cfg(classifier='head.linear'),
  904. 'levit_conv_512_s8.untrained': _cfg(classifier='head.linear'),
  905. 'levit_conv_512.untrained': _cfg(classifier='head.linear'),
  906. 'levit_conv_256d.untrained': _cfg(classifier='head.linear'),
  907. 'levit_conv_512d.untrained': _cfg(classifier='head.linear'),
  908. })
  909. @register_model
  910. def levit_128s(pretrained=False, **kwargs) -> Levit:
  911. return create_levit('levit_128s', pretrained=pretrained, **kwargs)
  912. @register_model
  913. def levit_128(pretrained=False, **kwargs) -> Levit:
  914. return create_levit('levit_128', pretrained=pretrained, **kwargs)
  915. @register_model
  916. def levit_192(pretrained=False, **kwargs) -> Levit:
  917. return create_levit('levit_192', pretrained=pretrained, **kwargs)
  918. @register_model
  919. def levit_256(pretrained=False, **kwargs) -> Levit:
  920. return create_levit('levit_256', pretrained=pretrained, **kwargs)
  921. @register_model
  922. def levit_384(pretrained=False, **kwargs) -> Levit:
  923. return create_levit('levit_384', pretrained=pretrained, **kwargs)
  924. @register_model
  925. def levit_384_s8(pretrained=False, **kwargs) -> Levit:
  926. return create_levit('levit_384_s8', pretrained=pretrained, **kwargs)
  927. @register_model
  928. def levit_512_s8(pretrained=False, **kwargs) -> Levit:
  929. return create_levit('levit_512_s8', pretrained=pretrained, distilled=False, **kwargs)
  930. @register_model
  931. def levit_512(pretrained=False, **kwargs) -> Levit:
  932. return create_levit('levit_512', pretrained=pretrained, distilled=False, **kwargs)
  933. @register_model
  934. def levit_256d(pretrained=False, **kwargs) -> Levit:
  935. return create_levit('levit_256d', pretrained=pretrained, distilled=False, **kwargs)
  936. @register_model
  937. def levit_512d(pretrained=False, **kwargs) -> Levit:
  938. return create_levit('levit_512d', pretrained=pretrained, distilled=False, **kwargs)
  939. @register_model
  940. def levit_conv_128s(pretrained=False, **kwargs) -> Levit:
  941. return create_levit('levit_conv_128s', pretrained=pretrained, use_conv=True, **kwargs)
  942. @register_model
  943. def levit_conv_128(pretrained=False, **kwargs) -> Levit:
  944. return create_levit('levit_conv_128', pretrained=pretrained, use_conv=True, **kwargs)
  945. @register_model
  946. def levit_conv_192(pretrained=False, **kwargs) -> Levit:
  947. return create_levit('levit_conv_192', pretrained=pretrained, use_conv=True, **kwargs)
  948. @register_model
  949. def levit_conv_256(pretrained=False, **kwargs) -> Levit:
  950. return create_levit('levit_conv_256', pretrained=pretrained, use_conv=True, **kwargs)
  951. @register_model
  952. def levit_conv_384(pretrained=False, **kwargs) -> Levit:
  953. return create_levit('levit_conv_384', pretrained=pretrained, use_conv=True, **kwargs)
  954. @register_model
  955. def levit_conv_384_s8(pretrained=False, **kwargs) -> Levit:
  956. return create_levit('levit_conv_384_s8', pretrained=pretrained, use_conv=True, **kwargs)
  957. @register_model
  958. def levit_conv_512_s8(pretrained=False, **kwargs) -> Levit:
  959. return create_levit('levit_conv_512_s8', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
  960. @register_model
  961. def levit_conv_512(pretrained=False, **kwargs) -> Levit:
  962. return create_levit('levit_conv_512', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
  963. @register_model
  964. def levit_conv_256d(pretrained=False, **kwargs) -> Levit:
  965. return create_levit('levit_conv_256d', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
  966. @register_model
  967. def levit_conv_512d(pretrained=False, **kwargs) -> Levit:
  968. return create_levit('levit_conv_512d', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)