davit.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953
  1. """ DaViT: Dual Attention Vision Transformers
  2. As described in https://arxiv.org/abs/2204.03645
  3. Input size invariant transformer architecture that combines channel and spacial
  4. attention in each block. The attention mechanisms used are linear in complexity.
  5. DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below
  6. """
  7. # Copyright (c) 2022 Mingyu Ding
  8. # All rights reserved.
  9. # This source code is licensed under the MIT license
  10. from functools import partial
  11. from typing import List, Optional, Tuple, Type, Union
  12. import torch
  13. import torch.nn as nn
  14. import torch.nn.functional as F
  15. from torch import Tensor
  16. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  17. from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
  18. from timm.layers import NormMlpClassifierHead, ClassifierHead
  19. from ._builder import build_model_with_cfg
  20. from ._features import feature_take_indices
  21. from ._features_fx import register_notrace_function
  22. from ._manipulate import checkpoint, checkpoint_seq
  23. from ._registry import generate_default_cfgs, register_model
  24. __all__ = ['DaVit']
  25. class ConvPosEnc(nn.Module):
  26. def __init__(
  27. self,
  28. dim: int,
  29. k: int = 3,
  30. act: bool = False,
  31. device=None,
  32. dtype=None,
  33. ):
  34. dd = {'device': device, 'dtype': dtype}
  35. super().__init__()
  36. self.proj = nn.Conv2d(
  37. dim,
  38. dim,
  39. kernel_size=k,
  40. stride=1,
  41. padding=k // 2,
  42. groups=dim,
  43. **dd,
  44. )
  45. self.act = nn.GELU() if act else nn.Identity()
  46. def forward(self, x: Tensor):
  47. feat = self.proj(x)
  48. x = x + self.act(feat)
  49. return x
  50. class Stem(nn.Module):
  51. """ Size-agnostic implementation of 2D image to patch embedding,
  52. allowing input size to be adjusted during model forward operation
  53. """
  54. def __init__(
  55. self,
  56. in_chs: int = 3,
  57. out_chs: int = 96,
  58. stride: int = 4,
  59. norm_layer: Type[nn.Module] = LayerNorm2d,
  60. device=None,
  61. dtype=None,
  62. ):
  63. dd = {'device': device, 'dtype': dtype}
  64. super().__init__()
  65. stride = to_2tuple(stride)
  66. self.stride = stride
  67. self.in_chs = in_chs
  68. self.out_chs = out_chs
  69. assert stride[0] == 4 # only setup for stride==4
  70. self.conv = nn.Conv2d(
  71. in_chs,
  72. out_chs,
  73. kernel_size=7,
  74. stride=stride,
  75. padding=3,
  76. **dd,
  77. )
  78. self.norm = norm_layer(out_chs, **dd)
  79. def forward(self, x: Tensor):
  80. B, C, H, W = x.shape
  81. pad_r = (self.stride[1] - W % self.stride[1]) % self.stride[1]
  82. pad_b = (self.stride[0] - H % self.stride[0]) % self.stride[0]
  83. x = F.pad(x, (0, pad_r, 0, pad_b))
  84. x = self.conv(x)
  85. x = self.norm(x)
  86. return x
  87. class Downsample(nn.Module):
  88. def __init__(
  89. self,
  90. in_chs: int,
  91. out_chs: int,
  92. kernel_size: int = 3,
  93. norm_layer: Type[nn.Module] = LayerNorm2d,
  94. device=None,
  95. dtype=None,
  96. ):
  97. dd = {'device': device, 'dtype': dtype}
  98. super().__init__()
  99. self.in_chs = in_chs
  100. self.out_chs = out_chs
  101. self.norm = norm_layer(in_chs, **dd)
  102. self.even_k = kernel_size % 2 == 0
  103. self.conv = nn.Conv2d(
  104. in_chs,
  105. out_chs,
  106. kernel_size=kernel_size,
  107. stride=2,
  108. padding=0 if self.even_k else kernel_size // 2,
  109. **dd,
  110. )
  111. def forward(self, x: Tensor):
  112. B, C, H, W = x.shape
  113. x = self.norm(x)
  114. if self.even_k:
  115. k_h, k_w = self.conv.kernel_size
  116. pad_r = (k_w - W % k_w) % k_w
  117. pad_b = (k_h - H % k_h) % k_h
  118. x = F.pad(x, (0, pad_r , 0, pad_b))
  119. x = self.conv(x)
  120. return x
  121. class ChannelAttentionV2(nn.Module):
  122. def __init__(
  123. self,
  124. dim: int,
  125. num_heads: int = 8,
  126. qkv_bias: bool = True,
  127. dynamic_scale: bool = True,
  128. device=None,
  129. dtype=None,
  130. ):
  131. dd = {'device': device, 'dtype': dtype}
  132. super().__init__()
  133. self.groups = num_heads
  134. self.head_dim = dim // num_heads
  135. self.dynamic_scale = dynamic_scale
  136. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  137. self.proj = nn.Linear(dim, dim, **dd)
  138. def forward(self, x):
  139. B, N, C = x.shape
  140. qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
  141. q, k, v = qkv.unbind(0)
  142. if self.dynamic_scale:
  143. q = q * N ** -0.5
  144. else:
  145. q = q * self.head_dim ** -0.5
  146. attn = q.transpose(-1, -2) @ k
  147. attn = attn.softmax(dim=-1)
  148. x = (attn @ v.transpose(-1, -2)).transpose(-1, -2)
  149. x = x.transpose(1, 2).reshape(B, N, C)
  150. x = self.proj(x)
  151. return x
  152. class ChannelAttention(nn.Module):
  153. def __init__(
  154. self,
  155. dim: int,
  156. num_heads: int = 8,
  157. qkv_bias: bool = False,
  158. device=None,
  159. dtype=None,
  160. ):
  161. dd = {'device': device, 'dtype': dtype}
  162. super().__init__()
  163. self.num_heads = num_heads
  164. head_dim = dim // num_heads
  165. self.scale = head_dim ** -0.5
  166. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  167. self.proj = nn.Linear(dim, dim, **dd)
  168. def forward(self, x: Tensor):
  169. B, N, C = x.shape
  170. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  171. q, k, v = qkv.unbind(0)
  172. k = k * self.scale
  173. attn = k.transpose(-1, -2) @ v
  174. attn = attn.softmax(dim=-1)
  175. x = (attn @ q.transpose(-1, -2)).transpose(-1, -2)
  176. x = x.transpose(1, 2).reshape(B, N, C)
  177. x = self.proj(x)
  178. return x
  179. class ChannelBlock(nn.Module):
  180. def __init__(
  181. self,
  182. dim: int,
  183. num_heads: int,
  184. mlp_ratio: float = 4.,
  185. qkv_bias: bool = False,
  186. drop_path: float = 0.,
  187. act_layer: Type[nn.Module] = nn.GELU,
  188. norm_layer: Type[nn.Module] = nn.LayerNorm,
  189. ffn: bool = True,
  190. cpe_act: bool = False,
  191. v2: bool = False,
  192. device=None,
  193. dtype=None,
  194. ):
  195. dd = {'device': device, 'dtype': dtype}
  196. super().__init__()
  197. self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd)
  198. self.ffn = ffn
  199. self.norm1 = norm_layer(dim, **dd)
  200. attn_layer = ChannelAttentionV2 if v2 else ChannelAttention
  201. self.attn = attn_layer(
  202. dim,
  203. num_heads=num_heads,
  204. qkv_bias=qkv_bias,
  205. **dd,
  206. )
  207. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  208. self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd)
  209. if self.ffn:
  210. self.norm2 = norm_layer(dim, **dd)
  211. self.mlp = Mlp(
  212. in_features=dim,
  213. hidden_features=int(dim * mlp_ratio),
  214. act_layer=act_layer,
  215. **dd,
  216. )
  217. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  218. else:
  219. self.norm2 = None
  220. self.mlp = None
  221. self.drop_path2 = None
  222. def forward(self, x: Tensor):
  223. B, C, H, W = x.shape
  224. x = self.cpe1(x).flatten(2).transpose(1, 2)
  225. cur = self.norm1(x)
  226. cur = self.attn(cur)
  227. x = x + self.drop_path1(cur)
  228. x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
  229. if self.mlp is not None:
  230. x = x.flatten(2).transpose(1, 2)
  231. x = x + self.drop_path2(self.mlp(self.norm2(x)))
  232. x = x.transpose(1, 2).view(B, C, H, W)
  233. return x
  234. def window_partition(x: Tensor, window_size: Tuple[int, int]):
  235. """
  236. Args:
  237. x: (B, H, W, C)
  238. window_size (int): window size
  239. Returns:
  240. windows: (num_windows*B, window_size, window_size, C)
  241. """
  242. B, H, W, C = x.shape
  243. x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
  244. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
  245. return windows
  246. @register_notrace_function # reason: int argument is a Proxy
  247. def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int):
  248. """
  249. Args:
  250. windows: (num_windows*B, window_size, window_size, C)
  251. window_size (int): Window size
  252. H (int): Height of image
  253. W (int): Width of image
  254. Returns:
  255. x: (B, H, W, C)
  256. """
  257. C = windows.shape[-1]
  258. x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
  259. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
  260. return x
  261. class WindowAttention(nn.Module):
  262. r""" Window based multi-head self attention (W-MSA) module with relative position bias.
  263. It supports both of shifted and non-shifted window.
  264. Args:
  265. dim (int): Number of input channels.
  266. window_size (tuple[int]): The height and width of the window.
  267. num_heads (int): Number of attention heads.
  268. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  269. """
  270. fused_attn: torch.jit.Final[bool]
  271. def __init__(
  272. self,
  273. dim: int,
  274. window_size: Tuple[int, int],
  275. num_heads: int,
  276. qkv_bias: bool = True,
  277. device=None,
  278. dtype=None,
  279. ):
  280. dd = {'device': device, 'dtype': dtype}
  281. super().__init__()
  282. self.dim = dim
  283. self.window_size = window_size
  284. self.num_heads = num_heads
  285. head_dim = dim // num_heads
  286. self.scale = head_dim ** -0.5
  287. self.fused_attn = use_fused_attn()
  288. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  289. self.proj = nn.Linear(dim, dim, **dd)
  290. self.softmax = nn.Softmax(dim=-1)
  291. def forward(self, x: Tensor):
  292. B_, N, C = x.shape
  293. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  294. q, k, v = qkv.unbind(0)
  295. if self.fused_attn:
  296. x = F.scaled_dot_product_attention(q, k, v)
  297. else:
  298. q = q * self.scale
  299. attn = (q @ k.transpose(-2, -1))
  300. attn = self.softmax(attn)
  301. x = attn @ v
  302. x = x.transpose(1, 2).reshape(B_, N, C)
  303. x = self.proj(x)
  304. return x
  305. class SpatialBlock(nn.Module):
  306. r""" Windows Block.
  307. Args:
  308. dim (int): Number of input channels.
  309. num_heads (int): Number of attention heads.
  310. window_size (int): Window size.
  311. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  312. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  313. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  314. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
  315. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  316. """
  317. def __init__(
  318. self,
  319. dim: int,
  320. num_heads: int,
  321. window_size: int = 7,
  322. mlp_ratio: float = 4.,
  323. qkv_bias: bool = True,
  324. drop_path: float = 0.,
  325. act_layer: Type[nn.Module] = nn.GELU,
  326. norm_layer: Type[nn.Module] = nn.LayerNorm,
  327. ffn: bool = True,
  328. cpe_act: bool = False,
  329. device=None,
  330. dtype=None,
  331. ):
  332. dd = {'device': device, 'dtype': dtype}
  333. super().__init__()
  334. self.dim = dim
  335. self.ffn = ffn
  336. self.num_heads = num_heads
  337. self.window_size = to_2tuple(window_size)
  338. self.mlp_ratio = mlp_ratio
  339. self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd)
  340. self.norm1 = norm_layer(dim, **dd)
  341. self.attn = WindowAttention(
  342. dim,
  343. self.window_size,
  344. num_heads=num_heads,
  345. qkv_bias=qkv_bias,
  346. **dd,
  347. )
  348. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  349. self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd)
  350. if self.ffn:
  351. self.norm2 = norm_layer(dim, **dd)
  352. mlp_hidden_dim = int(dim * mlp_ratio)
  353. self.mlp = Mlp(
  354. in_features=dim,
  355. hidden_features=mlp_hidden_dim,
  356. act_layer=act_layer,
  357. **dd,
  358. )
  359. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  360. else:
  361. self.norm2 = None
  362. self.mlp = None
  363. self.drop_path1 = None
  364. def forward(self, x: Tensor):
  365. B, C, H, W = x.shape
  366. shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
  367. x = self.norm1(shortcut)
  368. x = x.view(B, H, W, C)
  369. pad_l = pad_t = 0
  370. pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
  371. pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
  372. x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
  373. _, Hp, Wp, _ = x.shape
  374. x_windows = window_partition(x, self.window_size)
  375. x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
  376. # W-MSA/SW-MSA
  377. attn_windows = self.attn(x_windows)
  378. # merge windows
  379. attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
  380. x = window_reverse(attn_windows, self.window_size, Hp, Wp)
  381. # if pad_r > 0 or pad_b > 0:
  382. x = x[:, :H, :W, :].contiguous()
  383. x = x.view(B, H * W, C)
  384. x = shortcut + self.drop_path1(x)
  385. x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
  386. if self.mlp is not None:
  387. x = x.flatten(2).transpose(1, 2)
  388. x = x + self.drop_path2(self.mlp(self.norm2(x)))
  389. x = x.transpose(1, 2).view(B, C, H, W)
  390. return x
  391. class DaVitStage(nn.Module):
  392. def __init__(
  393. self,
  394. in_chs: int,
  395. out_chs: int,
  396. depth:int = 1,
  397. downsample: bool = True,
  398. attn_types: Tuple[str, ...] = ('spatial', 'channel'),
  399. num_heads: int = 3,
  400. window_size: int = 7,
  401. mlp_ratio: float = 4.,
  402. qkv_bias: bool = True,
  403. drop_path_rates: Tuple[float, ...] = (0, 0),
  404. norm_layer: Type[nn.Module] = LayerNorm2d,
  405. norm_layer_cl: Type[nn.Module] = nn.LayerNorm,
  406. ffn: bool = True,
  407. cpe_act: bool = False,
  408. down_kernel_size: int = 2,
  409. named_blocks: bool = False,
  410. channel_attn_v2: bool = False,
  411. device=None,
  412. dtype=None,
  413. ):
  414. dd = {'device': device, 'dtype': dtype}
  415. super().__init__()
  416. self.grad_checkpointing = False
  417. # downsample embedding layer at the beginning of each stage
  418. if downsample:
  419. self.downsample = Downsample(in_chs, out_chs, kernel_size=down_kernel_size, norm_layer=norm_layer, **dd)
  420. else:
  421. self.downsample = nn.Identity()
  422. '''
  423. repeating alternating attention blocks in each stage
  424. default: (spatial -> channel) x depth
  425. potential opportunity to integrate with a more general version of ByobNet/ByoaNet
  426. since the logic is similar
  427. '''
  428. stage_blocks = []
  429. for block_idx in range(depth):
  430. from collections import OrderedDict
  431. dual_attention_block = []
  432. for attn_idx, attn_type in enumerate(attn_types):
  433. if attn_type == 'spatial':
  434. dual_attention_block.append(('spatial_block', SpatialBlock(
  435. dim=out_chs,
  436. num_heads=num_heads,
  437. mlp_ratio=mlp_ratio,
  438. qkv_bias=qkv_bias,
  439. drop_path=drop_path_rates[block_idx],
  440. norm_layer=norm_layer_cl,
  441. ffn=ffn,
  442. cpe_act=cpe_act,
  443. window_size=window_size,
  444. **dd,
  445. )))
  446. elif attn_type == 'channel':
  447. dual_attention_block.append(('channel_block', ChannelBlock(
  448. dim=out_chs,
  449. num_heads=num_heads,
  450. mlp_ratio=mlp_ratio,
  451. qkv_bias=qkv_bias,
  452. drop_path=drop_path_rates[block_idx],
  453. norm_layer=norm_layer_cl,
  454. ffn=ffn,
  455. cpe_act=cpe_act,
  456. v2=channel_attn_v2,
  457. **dd,
  458. )))
  459. if named_blocks:
  460. stage_blocks.append(nn.Sequential(OrderedDict(dual_attention_block)))
  461. else:
  462. stage_blocks.append(nn.Sequential(*[b[1] for b in dual_attention_block]))
  463. self.blocks = nn.Sequential(*stage_blocks)
  464. @torch.jit.ignore
  465. def set_grad_checkpointing(self, enable=True):
  466. self.grad_checkpointing = enable
  467. def forward(self, x: Tensor):
  468. x = self.downsample(x)
  469. if self.grad_checkpointing and not torch.jit.is_scripting():
  470. x = checkpoint_seq(self.blocks, x)
  471. else:
  472. x = self.blocks(x)
  473. return x
  474. class DaVit(nn.Module):
  475. r""" DaViT
  476. A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645
  477. Supports arbitrary input sizes and pyramid feature extraction
  478. Args:
  479. in_chans (int): Number of input image channels. Default: 3
  480. num_classes (int): Number of classes for classification head. Default: 1000
  481. depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1)
  482. embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768)
  483. num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24)
  484. window_size (int): Window size. Default: 7
  485. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  486. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  487. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  488. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  489. """
  490. def __init__(
  491. self,
  492. in_chans: int = 3,
  493. depths: Tuple[int, ...] = (1, 1, 3, 1),
  494. embed_dims: Tuple[int, ...] = (96, 192, 384, 768),
  495. num_heads: Tuple[int, ...] = (3, 6, 12, 24),
  496. window_size: int = 7,
  497. mlp_ratio: float = 4,
  498. qkv_bias: bool = True,
  499. norm_layer: str = 'layernorm2d',
  500. norm_layer_cl: str = 'layernorm',
  501. norm_eps: float = 1e-5,
  502. attn_types: Tuple[str, ...] = ('spatial', 'channel'),
  503. ffn: bool = True,
  504. cpe_act: bool = False,
  505. down_kernel_size: int = 2,
  506. channel_attn_v2: bool = False,
  507. named_blocks: bool = False,
  508. drop_rate: float = 0.,
  509. drop_path_rate: float = 0.,
  510. num_classes: int = 1000,
  511. global_pool: str = 'avg',
  512. head_norm_first: bool = False,
  513. device=None,
  514. dtype=None,
  515. ):
  516. super().__init__()
  517. dd = {'device': device, 'dtype': dtype}
  518. num_stages = len(embed_dims)
  519. assert num_stages == len(num_heads) == len(depths)
  520. norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
  521. norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
  522. self.num_classes = num_classes
  523. self.in_chans = in_chans
  524. self.num_features = self.head_hidden_size = embed_dims[-1]
  525. self.drop_rate = drop_rate
  526. self.grad_checkpointing = False
  527. self.feature_info = []
  528. self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer, **dd)
  529. in_chs = embed_dims[0]
  530. dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  531. stages = []
  532. for i in range(num_stages):
  533. out_chs = embed_dims[i]
  534. stage = DaVitStage(
  535. in_chs,
  536. out_chs,
  537. depth=depths[i],
  538. downsample=i > 0,
  539. attn_types=attn_types,
  540. num_heads=num_heads[i],
  541. window_size=window_size,
  542. mlp_ratio=mlp_ratio,
  543. qkv_bias=qkv_bias,
  544. drop_path_rates=dpr[i],
  545. norm_layer=norm_layer,
  546. norm_layer_cl=norm_layer_cl,
  547. ffn=ffn,
  548. cpe_act=cpe_act,
  549. down_kernel_size=down_kernel_size,
  550. channel_attn_v2=channel_attn_v2,
  551. named_blocks=named_blocks,
  552. **dd,
  553. )
  554. in_chs = out_chs
  555. stages.append(stage)
  556. self.feature_info += [dict(num_chs=out_chs, reduction=2**(i+2), module=f'stages.{i}')]
  557. self.stages = nn.Sequential(*stages)
  558. # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
  559. # otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt
  560. # FIXME generalize this structure to ClassifierHead
  561. if head_norm_first:
  562. self.norm_pre = norm_layer(self.num_features, **dd)
  563. self.head = ClassifierHead(
  564. self.num_features,
  565. num_classes,
  566. pool_type=global_pool,
  567. drop_rate=self.drop_rate,
  568. **dd,
  569. )
  570. else:
  571. self.norm_pre = nn.Identity()
  572. self.head = NormMlpClassifierHead(
  573. self.num_features,
  574. num_classes,
  575. pool_type=global_pool,
  576. drop_rate=self.drop_rate,
  577. norm_layer=norm_layer,
  578. **dd,
  579. )
  580. self.apply(self._init_weights)
  581. def _init_weights(self, m):
  582. if isinstance(m, nn.Linear):
  583. trunc_normal_(m.weight, std=.02)
  584. if isinstance(m, nn.Linear) and m.bias is not None:
  585. nn.init.constant_(m.bias, 0)
  586. @torch.jit.ignore
  587. def group_matcher(self, coarse=False):
  588. return dict(
  589. stem=r'^stem', # stem and embed
  590. blocks=r'^stages\.(\d+)' if coarse else [
  591. (r'^stages\.(\d+).downsample', (0,)),
  592. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  593. (r'^norm_pre', (99999,)),
  594. ]
  595. )
  596. @torch.jit.ignore
  597. def set_grad_checkpointing(self, enable=True):
  598. self.grad_checkpointing = enable
  599. for stage in self.stages:
  600. stage.set_grad_checkpointing(enable=enable)
  601. @torch.jit.ignore
  602. def get_classifier(self) -> nn.Module:
  603. return self.head.fc
  604. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  605. self.num_classes = num_classes
  606. self.head.reset(num_classes, global_pool)
  607. def forward_intermediates(
  608. self,
  609. x: torch.Tensor,
  610. indices: Optional[Union[int, List[int]]] = None,
  611. norm: bool = False,
  612. stop_early: bool = False,
  613. output_fmt: str = 'NCHW',
  614. intermediates_only: bool = False,
  615. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  616. """ Forward features that returns intermediates.
  617. Args:
  618. x: Input image tensor
  619. indices: Take last n blocks if int, all if None, select matching indices if sequence
  620. norm: Apply norm layer to compatible intermediates
  621. stop_early: Stop iterating over blocks when last desired intermediate hit
  622. output_fmt: Shape of intermediate feature outputs
  623. intermediates_only: Only return intermediate features
  624. Returns:
  625. """
  626. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  627. intermediates = []
  628. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  629. # forward pass
  630. x = self.stem(x)
  631. last_idx = len(self.stages) - 1
  632. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  633. stages = self.stages
  634. else:
  635. stages = self.stages[:max_index + 1]
  636. for feat_idx, stage in enumerate(stages):
  637. if self.grad_checkpointing and not torch.jit.is_scripting():
  638. x = checkpoint(stage, x)
  639. else:
  640. x = stage(x)
  641. if feat_idx in take_indices:
  642. if norm and feat_idx == last_idx:
  643. x_inter = self.norm_pre(x) # applying final norm to last intermediate
  644. else:
  645. x_inter = x
  646. intermediates.append(x_inter)
  647. if intermediates_only:
  648. return intermediates
  649. if feat_idx == last_idx:
  650. x = self.norm_pre(x)
  651. return x, intermediates
  652. def prune_intermediate_layers(
  653. self,
  654. indices: Union[int, List[int]] = 1,
  655. prune_norm: bool = False,
  656. prune_head: bool = True,
  657. ):
  658. """ Prune layers not required for specified intermediates.
  659. """
  660. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  661. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  662. if prune_norm:
  663. self.norm_pre = nn.Identity()
  664. if prune_head:
  665. self.reset_classifier(0, '')
  666. return take_indices
  667. def forward_features(self, x):
  668. x = self.stem(x)
  669. if self.grad_checkpointing and not torch.jit.is_scripting():
  670. x = checkpoint_seq(self.stages, x)
  671. else:
  672. x = self.stages(x)
  673. x = self.norm_pre(x)
  674. return x
  675. def forward_head(self, x, pre_logits: bool = False):
  676. return self.head(x, pre_logits=True) if pre_logits else self.head(x)
  677. def forward(self, x):
  678. x = self.forward_features(x)
  679. x = self.forward_head(x)
  680. return x
  681. def _convert_florence2(state_dict, model, prefix='vision_tower.'):
  682. import re
  683. out_dict = {}
  684. for k, v in state_dict.items():
  685. if k.startswith(prefix):
  686. k = k.replace(prefix, '')
  687. else:
  688. continue
  689. k = re.sub(r'convs.([0-9]+)', r'stages.\1.downsample', k)
  690. k = re.sub(r'blocks.([0-9]+)', r'stages.\1.blocks', k)
  691. k = k.replace('downsample.proj', 'downsample.conv')
  692. k = k.replace('stages.0.downsample', 'stem')
  693. #k = k.replace('head.', 'head.fc.')
  694. #k = k.replace('norms.', 'head.norm.')
  695. k = k.replace('window_attn.norm.', 'norm1.')
  696. k = k.replace('window_attn.fn.', 'attn.')
  697. k = k.replace('channel_attn.norm.', 'norm1.')
  698. k = k.replace('channel_attn.fn.', 'attn.')
  699. k = k.replace('ffn.norm.', 'norm2.')
  700. k = k.replace('ffn.fn.net.', 'mlp.')
  701. k = k.replace('conv1.fn.dw', 'cpe1.proj')
  702. k = k.replace('conv2.fn.dw', 'cpe2.proj')
  703. out_dict[k] = v
  704. return out_dict
  705. def checkpoint_filter_fn(state_dict, model):
  706. """ Remap MSFT checkpoints -> timm """
  707. if 'head.fc.weight' in state_dict:
  708. return state_dict # non-MSFT checkpoint
  709. if 'state_dict' in state_dict:
  710. state_dict = state_dict['state_dict']
  711. if 'vision_tower.convs.0.proj.weight' in state_dict:
  712. return _convert_florence2(state_dict, model)
  713. import re
  714. out_dict = {}
  715. for k, v in state_dict.items():
  716. k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.downsample', k)
  717. k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
  718. k = k.replace('downsample.proj', 'downsample.conv')
  719. k = k.replace('stages.0.downsample', 'stem')
  720. k = k.replace('head.', 'head.fc.')
  721. k = k.replace('norms.', 'head.norm.')
  722. k = k.replace('cpe.0', 'cpe1')
  723. k = k.replace('cpe.1', 'cpe2')
  724. out_dict[k] = v
  725. return out_dict
  726. def _create_davit(variant, pretrained=False, **kwargs):
  727. default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
  728. out_indices = kwargs.pop('out_indices', default_out_indices)
  729. strict = kwargs.pop('pretrained_strict', True)
  730. if variant.endswith('_fl'):
  731. # FIXME cleaner approach to missing head norm?
  732. strict = False
  733. model = build_model_with_cfg(
  734. DaVit,
  735. variant,
  736. pretrained,
  737. pretrained_filter_fn=checkpoint_filter_fn,
  738. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  739. pretrained_strict=strict,
  740. **kwargs)
  741. return model
  742. def _cfg(url='', **kwargs):
  743. return {
  744. 'url': url,
  745. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  746. 'crop_pct': 0.95, 'interpolation': 'bicubic',
  747. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  748. 'first_conv': 'stem.conv', 'classifier': 'head.fc',
  749. 'license': 'apache-2.0',
  750. **kwargs
  751. }
  752. # TODO contact authors to get larger pretrained models
  753. default_cfgs = generate_default_cfgs({
  754. # official microsoft weights from https://github.com/dingmyu/davit
  755. 'davit_tiny.msft_in1k': _cfg(
  756. hf_hub_id='timm/'),
  757. 'davit_small.msft_in1k': _cfg(
  758. hf_hub_id='timm/'),
  759. 'davit_base.msft_in1k': _cfg(
  760. hf_hub_id='timm/'),
  761. 'davit_large': _cfg(),
  762. 'davit_huge': _cfg(),
  763. 'davit_giant': _cfg(),
  764. 'davit_base_fl.msft_florence2': _cfg(
  765. hf_hub_id='microsoft/Florence-2-base',
  766. num_classes=0, input_size=(3, 768, 768)),
  767. 'davit_huge_fl.msft_florence2': _cfg(
  768. hf_hub_id='microsoft/Florence-2-large',
  769. num_classes=0, input_size=(3, 768, 768)),
  770. })
  771. @register_model
  772. def davit_tiny(pretrained=False, **kwargs) -> DaVit:
  773. model_args = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24))
  774. return _create_davit('davit_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
  775. @register_model
  776. def davit_small(pretrained=False, **kwargs) -> DaVit:
  777. model_args = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24))
  778. return _create_davit('davit_small', pretrained=pretrained, **dict(model_args, **kwargs))
  779. @register_model
  780. def davit_base(pretrained=False, **kwargs) -> DaVit:
  781. model_args = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32))
  782. return _create_davit('davit_base', pretrained=pretrained, **dict(model_args, **kwargs))
  783. @register_model
  784. def davit_large(pretrained=False, **kwargs) -> DaVit:
  785. model_args = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48))
  786. return _create_davit('davit_large', pretrained=pretrained, **dict(model_args, **kwargs))
  787. @register_model
  788. def davit_huge(pretrained=False, **kwargs) -> DaVit:
  789. model_args = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64))
  790. return _create_davit('davit_huge', pretrained=pretrained, **dict(model_args, **kwargs))
  791. @register_model
  792. def davit_giant(pretrained=False, **kwargs) -> DaVit:
  793. model_args = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96))
  794. return _create_davit('davit_giant', pretrained=pretrained, **dict(model_args, **kwargs))
  795. @register_model
  796. def davit_base_fl(pretrained=False, **kwargs) -> DaVit:
  797. model_args = dict(
  798. depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32),
  799. window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
  800. )
  801. return _create_davit('davit_base_fl', pretrained=pretrained, **dict(model_args, **kwargs))
  802. @register_model
  803. def davit_huge_fl(pretrained=False, **kwargs) -> DaVit:
  804. # NOTE: huge image tower used in 'large' Florence2 model
  805. model_args = dict(
  806. depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64),
  807. window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
  808. )
  809. return _create_davit('davit_huge_fl', pretrained=pretrained, **dict(model_args, **kwargs))