visformer.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. """ Visformer
  2. Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533
  3. From original at https://github.com/danczs/Visformer
  4. Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
  5. """
  6. from typing import Optional, Union, Type, Any
  7. import torch
  8. import torch.nn as nn
  9. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  10. from timm.layers import to_2tuple, trunc_normal_, DropPath, calculate_drop_path_rates, PatchEmbed, LayerNorm2d, create_classifier, use_fused_attn
  11. from ._builder import build_model_with_cfg
  12. from ._manipulate import checkpoint_seq
  13. from ._registry import register_model, generate_default_cfgs
  14. __all__ = ['Visformer']
  15. class SpatialMlp(nn.Module):
  16. def __init__(
  17. self,
  18. in_features: int,
  19. hidden_features: Optional[int] = None,
  20. out_features: Optional[int] = None,
  21. act_layer: Type[nn.Module] = nn.GELU,
  22. drop: float = 0.,
  23. group: int = 8,
  24. spatial_conv: bool = False,
  25. device=None,
  26. dtype=None,
  27. ):
  28. dd = {'device': device, 'dtype': dtype}
  29. super().__init__()
  30. out_features = out_features or in_features
  31. hidden_features = hidden_features or in_features
  32. drop_probs = to_2tuple(drop)
  33. self.in_features = in_features
  34. self.out_features = out_features
  35. self.spatial_conv = spatial_conv
  36. if self.spatial_conv:
  37. if group < 2: # net setting
  38. hidden_features = in_features * 5 // 6
  39. else:
  40. hidden_features = in_features * 2
  41. self.hidden_features = hidden_features
  42. self.group = group
  43. self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False, **dd)
  44. self.act1 = act_layer()
  45. self.drop1 = nn.Dropout(drop_probs[0])
  46. if self.spatial_conv:
  47. self.conv2 = nn.Conv2d(
  48. hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False, **dd)
  49. self.act2 = act_layer()
  50. else:
  51. self.conv2 = None
  52. self.act2 = None
  53. self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False, **dd)
  54. self.drop3 = nn.Dropout(drop_probs[1])
  55. def forward(self, x):
  56. x = self.conv1(x)
  57. x = self.act1(x)
  58. x = self.drop1(x)
  59. if self.conv2 is not None:
  60. x = self.conv2(x)
  61. x = self.act2(x)
  62. x = self.conv3(x)
  63. x = self.drop3(x)
  64. return x
  65. class Attention(nn.Module):
  66. fused_attn: torch.jit.Final[bool]
  67. def __init__(
  68. self,
  69. dim: int,
  70. num_heads: int = 8,
  71. head_dim_ratio: float = 1.,
  72. attn_drop: float = 0.,
  73. proj_drop: float = 0.,
  74. device=None,
  75. dtype=None,
  76. ):
  77. dd = {'device': device, 'dtype': dtype}
  78. super().__init__()
  79. self.dim = dim
  80. self.num_heads = num_heads
  81. head_dim = round(dim // num_heads * head_dim_ratio)
  82. self.head_dim = head_dim
  83. self.scale = head_dim ** -0.5
  84. self.fused_attn = use_fused_attn(experimental=True)
  85. self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False, **dd)
  86. self.attn_drop = nn.Dropout(attn_drop)
  87. self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False, **dd)
  88. self.proj_drop = nn.Dropout(proj_drop)
  89. def forward(self, x):
  90. B, C, H, W = x.shape
  91. x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
  92. q, k, v = x.unbind(0)
  93. if self.fused_attn:
  94. x = torch.nn.functional.scaled_dot_product_attention(
  95. q.contiguous(), k.contiguous(), v.contiguous(),
  96. dropout_p=self.attn_drop.p if self.training else 0.,
  97. )
  98. else:
  99. attn = (q @ k.transpose(-2, -1)) * self.scale
  100. attn = attn.softmax(dim=-1)
  101. attn = self.attn_drop(attn)
  102. x = attn @ v
  103. x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W)
  104. x = self.proj(x)
  105. x = self.proj_drop(x)
  106. return x
  107. class Block(nn.Module):
  108. def __init__(
  109. self,
  110. dim: int,
  111. num_heads: int,
  112. head_dim_ratio: float = 1.,
  113. mlp_ratio: float = 4.,
  114. proj_drop: float = 0.,
  115. attn_drop: float = 0.,
  116. drop_path: float = 0.,
  117. act_layer: Type[nn.Module] = nn.GELU,
  118. norm_layer: Type[nn.Module] = LayerNorm2d,
  119. group: int = 8,
  120. attn_disabled: bool = False,
  121. spatial_conv: bool = False,
  122. device=None,
  123. dtype=None,
  124. ):
  125. dd = {'device': device, 'dtype': dtype}
  126. super().__init__()
  127. self.spatial_conv = spatial_conv
  128. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  129. if attn_disabled:
  130. self.norm1 = None
  131. self.attn = None
  132. else:
  133. self.norm1 = norm_layer(dim, **dd)
  134. self.attn = Attention(
  135. dim,
  136. num_heads=num_heads,
  137. head_dim_ratio=head_dim_ratio,
  138. attn_drop=attn_drop,
  139. proj_drop=proj_drop,
  140. **dd,
  141. )
  142. self.norm2 = norm_layer(dim, **dd)
  143. self.mlp = SpatialMlp(
  144. in_features=dim,
  145. hidden_features=int(dim * mlp_ratio),
  146. act_layer=act_layer,
  147. drop=proj_drop,
  148. group=group,
  149. spatial_conv=spatial_conv,
  150. **dd,
  151. )
  152. def forward(self, x):
  153. if self.attn is not None:
  154. x = x + self.drop_path(self.attn(self.norm1(x)))
  155. x = x + self.drop_path(self.mlp(self.norm2(x)))
  156. return x
  157. class Visformer(nn.Module):
  158. def __init__(
  159. self,
  160. img_size: int = 224,
  161. patch_size: int = 16,
  162. in_chans: int = 3,
  163. num_classes: int = 1000,
  164. init_channels: Optional[int] = 32,
  165. embed_dim: int = 384,
  166. depth: Union[int, tuple] = 12,
  167. num_heads: int = 6,
  168. mlp_ratio: float = 4.,
  169. drop_rate: float = 0.,
  170. pos_drop_rate: float = 0.,
  171. proj_drop_rate: float = 0.,
  172. attn_drop_rate: float = 0.,
  173. drop_path_rate: float = 0.,
  174. norm_layer: Type[nn.Module] = LayerNorm2d,
  175. attn_stage: str = '111',
  176. use_pos_embed: bool = True,
  177. spatial_conv: str = '111',
  178. vit_stem: bool = False,
  179. group: int = 8,
  180. global_pool: str = 'avg',
  181. conv_init: bool = False,
  182. embed_norm: Optional[Type[nn.Module]] = None,
  183. device=None,
  184. dtype=None,
  185. ):
  186. super().__init__()
  187. dd = {'device': device, 'dtype': dtype}
  188. img_size = to_2tuple(img_size)
  189. self.num_classes = num_classes
  190. self.in_chans = in_chans
  191. self.embed_dim = embed_dim
  192. self.init_channels = init_channels
  193. self.img_size = img_size
  194. self.vit_stem = vit_stem
  195. self.conv_init = conv_init
  196. if isinstance(depth, (list, tuple)):
  197. self.stage_num1, self.stage_num2, self.stage_num3 = depth
  198. depth = sum(depth)
  199. else:
  200. self.stage_num1 = self.stage_num3 = depth // 3
  201. self.stage_num2 = depth - self.stage_num1 - self.stage_num3
  202. self.use_pos_embed = use_pos_embed
  203. self.grad_checkpointing = False
  204. dpr = calculate_drop_path_rates(drop_path_rate, depth)
  205. # stage 1
  206. if self.vit_stem:
  207. self.stem = None
  208. self.patch_embed1 = PatchEmbed(
  209. img_size=img_size,
  210. patch_size=patch_size,
  211. in_chans=in_chans,
  212. embed_dim=embed_dim,
  213. norm_layer=embed_norm,
  214. flatten=False,
  215. **dd,
  216. )
  217. img_size = [x // patch_size for x in img_size]
  218. else:
  219. if self.init_channels is None:
  220. self.stem = None
  221. self.patch_embed1 = PatchEmbed(
  222. img_size=img_size,
  223. patch_size=patch_size // 2,
  224. in_chans=in_chans,
  225. embed_dim=embed_dim // 2,
  226. norm_layer=embed_norm,
  227. flatten=False,
  228. **dd,
  229. )
  230. img_size = [x // (patch_size // 2) for x in img_size]
  231. else:
  232. self.stem = nn.Sequential(
  233. nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False, **dd),
  234. nn.BatchNorm2d(self.init_channels, **dd),
  235. nn.ReLU(inplace=True)
  236. )
  237. img_size = [x // 2 for x in img_size]
  238. self.patch_embed1 = PatchEmbed(
  239. img_size=img_size,
  240. patch_size=patch_size // 4,
  241. in_chans=self.init_channels,
  242. embed_dim=embed_dim // 2,
  243. norm_layer=embed_norm,
  244. flatten=False,
  245. **dd,
  246. )
  247. img_size = [x // (patch_size // 4) for x in img_size]
  248. if self.use_pos_embed:
  249. if self.vit_stem:
  250. self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size, **dd))
  251. else:
  252. self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size, **dd))
  253. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  254. else:
  255. self.pos_embed1 = None
  256. self.stage1 = nn.Sequential(*[
  257. Block(
  258. dim=embed_dim//2,
  259. num_heads=num_heads,
  260. head_dim_ratio=0.5,
  261. mlp_ratio=mlp_ratio,
  262. proj_drop=proj_drop_rate,
  263. attn_drop=attn_drop_rate,
  264. drop_path=dpr[i],
  265. norm_layer=norm_layer,
  266. group=group,
  267. attn_disabled=(attn_stage[0] == '0'),
  268. spatial_conv=(spatial_conv[0] == '1'),
  269. **dd,
  270. )
  271. for i in range(self.stage_num1)
  272. ])
  273. # stage2
  274. if not self.vit_stem:
  275. self.patch_embed2 = PatchEmbed(
  276. img_size=img_size,
  277. patch_size=patch_size // 8,
  278. in_chans=embed_dim // 2,
  279. embed_dim=embed_dim,
  280. norm_layer=embed_norm,
  281. flatten=False,
  282. **dd,
  283. )
  284. img_size = [x // (patch_size // 8) for x in img_size]
  285. if self.use_pos_embed:
  286. self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size, **dd))
  287. else:
  288. self.pos_embed2 = None
  289. else:
  290. self.patch_embed2 = None
  291. self.stage2 = nn.Sequential(*[
  292. Block(
  293. dim=embed_dim,
  294. num_heads=num_heads,
  295. head_dim_ratio=1.0,
  296. mlp_ratio=mlp_ratio,
  297. proj_drop=proj_drop_rate,
  298. attn_drop=attn_drop_rate,
  299. drop_path=dpr[i],
  300. norm_layer=norm_layer,
  301. group=group,
  302. attn_disabled=(attn_stage[1] == '0'),
  303. spatial_conv=(spatial_conv[1] == '1'),
  304. **dd,
  305. )
  306. for i in range(self.stage_num1, self.stage_num1+self.stage_num2)
  307. ])
  308. # stage 3
  309. if not self.vit_stem:
  310. self.patch_embed3 = PatchEmbed(
  311. img_size=img_size,
  312. patch_size=patch_size // 8,
  313. in_chans=embed_dim,
  314. embed_dim=embed_dim * 2,
  315. norm_layer=embed_norm,
  316. flatten=False,
  317. **dd,
  318. )
  319. img_size = [x // (patch_size // 8) for x in img_size]
  320. if self.use_pos_embed:
  321. self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size, **dd))
  322. else:
  323. self.pos_embed3 = None
  324. else:
  325. self.patch_embed3 = None
  326. self.stage3 = nn.Sequential(*[
  327. Block(
  328. dim=embed_dim * 2,
  329. num_heads=num_heads,
  330. head_dim_ratio=1.0,
  331. mlp_ratio=mlp_ratio,
  332. proj_drop=proj_drop_rate,
  333. attn_drop=attn_drop_rate,
  334. drop_path=dpr[i],
  335. norm_layer=norm_layer,
  336. group=group,
  337. attn_disabled=(attn_stage[2] == '0'),
  338. spatial_conv=(spatial_conv[2] == '1'),
  339. **dd,
  340. )
  341. for i in range(self.stage_num1+self.stage_num2, depth)
  342. ])
  343. self.num_features = self.head_hidden_size = embed_dim if self.vit_stem else embed_dim * 2
  344. self.norm = norm_layer(self.num_features, **dd)
  345. # head
  346. global_pool, head = create_classifier(
  347. self.num_features,
  348. self.num_classes,
  349. pool_type=global_pool,
  350. device=device,
  351. dtype=dtype,
  352. )
  353. self.global_pool = global_pool
  354. self.head_drop = nn.Dropout(drop_rate)
  355. self.head = head
  356. # weights init
  357. if self.use_pos_embed:
  358. trunc_normal_(self.pos_embed1, std=0.02)
  359. if not self.vit_stem:
  360. trunc_normal_(self.pos_embed2, std=0.02)
  361. trunc_normal_(self.pos_embed3, std=0.02)
  362. self.apply(self._init_weights)
  363. def _init_weights(self, m):
  364. if isinstance(m, nn.Linear):
  365. trunc_normal_(m.weight, std=0.02)
  366. if m.bias is not None:
  367. nn.init.constant_(m.bias, 0)
  368. elif isinstance(m, nn.Conv2d):
  369. if self.conv_init:
  370. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  371. else:
  372. trunc_normal_(m.weight, std=0.02)
  373. if m.bias is not None:
  374. nn.init.constant_(m.bias, 0.)
  375. @torch.jit.ignore
  376. def group_matcher(self, coarse=False):
  377. return dict(
  378. stem=r'^patch_embed1|pos_embed1|stem', # stem and embed
  379. blocks=[
  380. (r'^stage(\d+)\.(\d+)' if coarse else r'^stage(\d+)\.(\d+)', None),
  381. (r'^(?:patch_embed|pos_embed)(\d+)', (0,)),
  382. (r'^norm', (99999,))
  383. ]
  384. )
  385. @torch.jit.ignore
  386. def set_grad_checkpointing(self, enable=True):
  387. self.grad_checkpointing = enable
  388. @torch.jit.ignore
  389. def get_classifier(self) -> nn.Module:
  390. return self.head
  391. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  392. self.num_classes = num_classes
  393. device = self.head.weight.device if hasattr(self.head, 'weight') else None
  394. dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None
  395. self.global_pool, self.head = create_classifier(
  396. self.num_features, self.num_classes, pool_type=global_pool, device=device, dtype=dtype)
  397. def forward_features(self, x):
  398. if self.stem is not None:
  399. x = self.stem(x)
  400. # stage 1
  401. x = self.patch_embed1(x)
  402. if self.pos_embed1 is not None:
  403. x = self.pos_drop(x + self.pos_embed1)
  404. if self.grad_checkpointing and not torch.jit.is_scripting():
  405. x = checkpoint_seq(self.stage1, x)
  406. else:
  407. x = self.stage1(x)
  408. # stage 2
  409. if self.patch_embed2 is not None:
  410. x = self.patch_embed2(x)
  411. if self.pos_embed2 is not None:
  412. x = self.pos_drop(x + self.pos_embed2)
  413. if self.grad_checkpointing and not torch.jit.is_scripting():
  414. x = checkpoint_seq(self.stage2, x)
  415. else:
  416. x = self.stage2(x)
  417. # stage3
  418. if self.patch_embed3 is not None:
  419. x = self.patch_embed3(x)
  420. if self.pos_embed3 is not None:
  421. x = self.pos_drop(x + self.pos_embed3)
  422. if self.grad_checkpointing and not torch.jit.is_scripting():
  423. x = checkpoint_seq(self.stage3, x)
  424. else:
  425. x = self.stage3(x)
  426. x = self.norm(x)
  427. return x
  428. def forward_head(self, x, pre_logits: bool = False):
  429. x = self.global_pool(x)
  430. x = self.head_drop(x)
  431. return x if pre_logits else self.head(x)
  432. def forward(self, x):
  433. x = self.forward_features(x)
  434. x = self.forward_head(x)
  435. return x
  436. def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs):
  437. if kwargs.get('features_only', None):
  438. raise RuntimeError('features_only not implemented for Vision Transformer models.')
  439. model = build_model_with_cfg(Visformer, variant, pretrained, **kwargs)
  440. return model
  441. def _cfg(url='', **kwargs):
  442. return {
  443. 'url': url,
  444. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  445. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  446. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  447. 'first_conv': 'stem.0', 'classifier': 'head',
  448. 'license': 'apache-2.0',
  449. **kwargs
  450. }
  451. default_cfgs = generate_default_cfgs({
  452. 'visformer_tiny.in1k': _cfg(hf_hub_id='timm/'),
  453. 'visformer_small.in1k': _cfg(hf_hub_id='timm/'),
  454. })
  455. @register_model
  456. def visformer_tiny(pretrained=False, **kwargs) -> Visformer:
  457. model_cfg = dict(
  458. init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
  459. attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
  460. embed_norm=nn.BatchNorm2d)
  461. model = _create_visformer('visformer_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
  462. return model
  463. @register_model
  464. def visformer_small(pretrained=False, **kwargs) -> Visformer:
  465. model_cfg = dict(
  466. init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
  467. attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
  468. embed_norm=nn.BatchNorm2d)
  469. model = _create_visformer('visformer_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
  470. return model
  471. # @register_model
  472. # def visformer_net1(pretrained=False, **kwargs):
  473. # model = Visformer(
  474. # init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
  475. # spatial_conv='000', vit_stem=True, conv_init=True, **kwargs)
  476. # model.default_cfg = _cfg()
  477. # return model
  478. #
  479. #
  480. # @register_model
  481. # def visformer_net2(pretrained=False, **kwargs):
  482. # model = Visformer(
  483. # init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
  484. # spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
  485. # model.default_cfg = _cfg()
  486. # return model
  487. #
  488. #
  489. # @register_model
  490. # def visformer_net3(pretrained=False, **kwargs):
  491. # model = Visformer(
  492. # init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
  493. # spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
  494. # model.default_cfg = _cfg()
  495. # return model
  496. #
  497. #
  498. # @register_model
  499. # def visformer_net4(pretrained=False, **kwargs):
  500. # model = Visformer(
  501. # init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
  502. # spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
  503. # model.default_cfg = _cfg()
  504. # return model
  505. #
  506. #
  507. # @register_model
  508. # def visformer_net5(pretrained=False, **kwargs):
  509. # model = Visformer(
  510. # init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
  511. # spatial_conv='111', vit_stem=False, conv_init=True, **kwargs)
  512. # model.default_cfg = _cfg()
  513. # return model
  514. #
  515. #
  516. # @register_model
  517. # def visformer_net6(pretrained=False, **kwargs):
  518. # model = Visformer(
  519. # init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
  520. # pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
  521. # model.default_cfg = _cfg()
  522. # return model
  523. #
  524. #
  525. # @register_model
  526. # def visformer_net7(pretrained=False, **kwargs):
  527. # model = Visformer(
  528. # init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000',
  529. # pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
  530. # model.default_cfg = _cfg()
  531. # return model