convit.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. """ ConViT Model
  2. @article{d2021convit,
  3. title={ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases},
  4. author={d'Ascoli, St{\'e}phane and Touvron, Hugo and Leavitt, Matthew and Morcos, Ari and Biroli, Giulio and Sagun, Levent},
  5. journal={arXiv preprint arXiv:2103.10697},
  6. year={2021}
  7. }
  8. Paper link: https://arxiv.org/abs/2103.10697
  9. Original code: https://github.com/facebookresearch/convit, original copyright below
  10. Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
  11. """
  12. # Copyright (c) 2015-present, Facebook, Inc.
  13. # All rights reserved.
  14. #
  15. # This source code is licensed under the CC-by-NC license found in the
  16. # LICENSE file in the root directory of this source tree.
  17. #
  18. '''These modules are adapted from those of timm, see
  19. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  20. '''
  21. from typing import Optional, Union, Type, Any
  22. import torch
  23. import torch.nn as nn
  24. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  25. from timm.layers import DropPath, calculate_drop_path_rates, trunc_normal_, PatchEmbed, Mlp, LayerNorm, HybridEmbed
  26. from ._builder import build_model_with_cfg
  27. from ._features_fx import register_notrace_module
  28. from ._registry import register_model, generate_default_cfgs
  29. __all__ = ['ConVit']
  30. @register_notrace_module # reason: FX can't symbolically trace control flow in forward method
  31. class GPSA(nn.Module):
  32. def __init__(
  33. self,
  34. dim: int,
  35. num_heads: int = 8,
  36. qkv_bias: bool = False,
  37. attn_drop: float = 0.,
  38. proj_drop: float = 0.,
  39. locality_strength: float = 1.,
  40. device=None,
  41. dtype=None,
  42. ):
  43. dd = {'device': device, 'dtype': dtype}
  44. super().__init__()
  45. self.num_heads = num_heads
  46. self.dim = dim
  47. head_dim = dim // num_heads
  48. self.scale = head_dim ** -0.5
  49. self.locality_strength = locality_strength
  50. self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias, **dd)
  51. self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  52. self.attn_drop = nn.Dropout(attn_drop)
  53. self.proj = nn.Linear(dim, dim, **dd)
  54. self.pos_proj = nn.Linear(3, num_heads, **dd)
  55. self.proj_drop = nn.Dropout(proj_drop)
  56. self.gating_param = nn.Parameter(torch.ones(self.num_heads, **dd))
  57. self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3, **dd) # silly torchscript hack, won't work with None
  58. def forward(self, x):
  59. B, N, C = x.shape
  60. if self.rel_indices is None or self.rel_indices.shape[1] != N:
  61. self.rel_indices = self.get_rel_indices(N)
  62. attn = self.get_attention(x)
  63. v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  64. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  65. x = self.proj(x)
  66. x = self.proj_drop(x)
  67. return x
  68. def get_attention(self, x):
  69. B, N, C = x.shape
  70. qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  71. q, k = qk[0], qk[1]
  72. pos_score = self.rel_indices.expand(B, -1, -1, -1)
  73. pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
  74. patch_score = (q @ k.transpose(-2, -1)) * self.scale
  75. patch_score = patch_score.softmax(dim=-1)
  76. pos_score = pos_score.softmax(dim=-1)
  77. gating = self.gating_param.view(1, -1, 1, 1)
  78. attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score
  79. attn /= attn.sum(dim=-1).unsqueeze(-1)
  80. attn = self.attn_drop(attn)
  81. return attn
  82. def get_attention_map(self, x, return_map=False):
  83. attn_map = self.get_attention(x).mean(0) # average over batch
  84. distances = self.rel_indices.squeeze()[:, :, -1] ** .5
  85. dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / distances.size(0)
  86. if return_map:
  87. return dist, attn_map
  88. else:
  89. return dist
  90. def local_init(self):
  91. self.v.weight.data.copy_(torch.eye(self.dim))
  92. locality_distance = 1 # max(1,1/locality_strength**.5)
  93. kernel_size = int(self.num_heads ** .5)
  94. center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2
  95. for h1 in range(kernel_size):
  96. for h2 in range(kernel_size):
  97. position = h1 + kernel_size * h2
  98. self.pos_proj.weight.data[position, 2] = -1
  99. self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance
  100. self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance
  101. self.pos_proj.weight.data *= self.locality_strength
  102. def get_rel_indices(self, num_patches: int) -> torch.Tensor:
  103. img_size = int(num_patches ** .5)
  104. rel_indices = torch.zeros(1, num_patches, num_patches, 3)
  105. ind = (
  106. torch.arange(img_size, dtype=torch.float32).view(1, -1)
  107. - torch.arange(img_size, dtype=torch.float32).view(-1, 1)
  108. )
  109. indx = ind.repeat(img_size, img_size)
  110. indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
  111. indd = indx ** 2 + indy ** 2
  112. rel_indices[:, :, :, 2] = indd.unsqueeze(0)
  113. rel_indices[:, :, :, 1] = indy.unsqueeze(0)
  114. rel_indices[:, :, :, 0] = indx.unsqueeze(0)
  115. device = self.qk.weight.device
  116. dtype = self.qk.weight.dtype
  117. return rel_indices.to(device=device, dtype=dtype)
  118. class MHSA(nn.Module):
  119. def __init__(
  120. self,
  121. dim: int,
  122. num_heads: int = 8,
  123. qkv_bias: bool = False,
  124. attn_drop: float = 0.,
  125. proj_drop: float = 0.,
  126. device=None,
  127. dtype=None,
  128. ):
  129. dd = {'device': device, 'dtype': dtype}
  130. super().__init__()
  131. self.num_heads = num_heads
  132. head_dim = dim // num_heads
  133. self.scale = head_dim ** -0.5
  134. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  135. self.attn_drop = nn.Dropout(attn_drop)
  136. self.proj = nn.Linear(dim, dim, **dd)
  137. self.proj_drop = nn.Dropout(proj_drop)
  138. def get_attention_map(self, x, return_map=False):
  139. B, N, C = x.shape
  140. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  141. q, k, v = qkv[0], qkv[1], qkv[2]
  142. attn_map = (q @ k.transpose(-2, -1)) * self.scale
  143. attn_map = attn_map.softmax(dim=-1).mean(0)
  144. img_size = int(N ** .5)
  145. ind = (
  146. torch.arange(img_size, dtype=torch.float32).view(1, -1)
  147. - torch.arange(img_size, dtype=torch.float32).view(-1, 1)
  148. )
  149. indx = ind.repeat(img_size, img_size)
  150. indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
  151. indd = indx ** 2 + indy ** 2
  152. distances = indd ** .5
  153. distances = distances.to(attn_map.device, attn_map.dtype)
  154. dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N
  155. if return_map:
  156. return dist, attn_map
  157. else:
  158. return dist
  159. def forward(self, x):
  160. B, N, C = x.shape
  161. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  162. q, k, v = qkv.unbind(0)
  163. attn = (q @ k.transpose(-2, -1)) * self.scale
  164. attn = attn.softmax(dim=-1)
  165. attn = self.attn_drop(attn)
  166. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  167. x = self.proj(x)
  168. x = self.proj_drop(x)
  169. return x
  170. class Block(nn.Module):
  171. def __init__(
  172. self,
  173. dim: int,
  174. num_heads: int,
  175. mlp_ratio: float = 4.,
  176. qkv_bias: bool = False,
  177. proj_drop: float = 0.,
  178. attn_drop: float = 0.,
  179. drop_path: float = 0.,
  180. act_layer: Type[nn.Module] = nn.GELU,
  181. norm_layer: Type[nn.Module] = LayerNorm,
  182. use_gpsa: bool = True,
  183. locality_strength: float = 1.,
  184. device=None,
  185. dtype=None,
  186. ):
  187. dd = {'device': device, 'dtype': dtype}
  188. super().__init__()
  189. self.norm1 = norm_layer(dim, **dd)
  190. self.use_gpsa = use_gpsa
  191. if self.use_gpsa:
  192. self.attn = GPSA(
  193. dim,
  194. num_heads=num_heads,
  195. qkv_bias=qkv_bias,
  196. attn_drop=attn_drop,
  197. proj_drop=proj_drop,
  198. locality_strength=locality_strength,
  199. **dd,
  200. )
  201. else:
  202. self.attn = MHSA(
  203. dim,
  204. num_heads=num_heads,
  205. qkv_bias=qkv_bias,
  206. attn_drop=attn_drop,
  207. proj_drop=proj_drop,
  208. **dd,
  209. )
  210. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  211. self.norm2 = norm_layer(dim, **dd)
  212. mlp_hidden_dim = int(dim * mlp_ratio)
  213. self.mlp = Mlp(
  214. in_features=dim,
  215. hidden_features=mlp_hidden_dim,
  216. act_layer=act_layer,
  217. drop=proj_drop,
  218. **dd,
  219. )
  220. def forward(self, x):
  221. x = x + self.drop_path(self.attn(self.norm1(x)))
  222. x = x + self.drop_path(self.mlp(self.norm2(x)))
  223. return x
  224. class ConVit(nn.Module):
  225. """ Vision Transformer with support for patch or hybrid CNN input stage
  226. """
  227. def __init__(
  228. self,
  229. img_size: int = 224,
  230. patch_size: int = 16,
  231. in_chans: int = 3,
  232. num_classes: int = 1000,
  233. global_pool: str = 'token',
  234. embed_dim: int = 768,
  235. depth: int = 12,
  236. num_heads: int = 12,
  237. mlp_ratio: float = 4.,
  238. qkv_bias: bool = False,
  239. drop_rate: float = 0.,
  240. pos_drop_rate: float = 0.,
  241. proj_drop_rate: float = 0.,
  242. attn_drop_rate: float = 0.,
  243. drop_path_rate: float = 0.,
  244. hybrid_backbone: Optional[Any] = None,
  245. norm_layer: Type[nn.Module] = LayerNorm,
  246. local_up_to_layer: int = 3,
  247. locality_strength: float = 1.,
  248. use_pos_embed: bool = True,
  249. device=None,
  250. dtype=None,
  251. ):
  252. super().__init__()
  253. dd = {'device': device, 'dtype': dtype}
  254. assert global_pool in ('', 'avg', 'token')
  255. embed_dim *= num_heads
  256. self.num_classes = num_classes
  257. self.in_chans = in_chans
  258. self.global_pool = global_pool
  259. self.local_up_to_layer = local_up_to_layer
  260. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
  261. self.locality_strength = locality_strength
  262. self.use_pos_embed = use_pos_embed
  263. if hybrid_backbone is not None:
  264. self.patch_embed = HybridEmbed(
  265. hybrid_backbone,
  266. img_size=img_size,
  267. in_chans=in_chans,
  268. embed_dim=embed_dim,
  269. **dd,
  270. )
  271. else:
  272. self.patch_embed = PatchEmbed(
  273. img_size=img_size,
  274. patch_size=patch_size,
  275. in_chans=in_chans,
  276. embed_dim=embed_dim,
  277. **dd,
  278. )
  279. num_patches = self.patch_embed.num_patches
  280. self.num_patches = num_patches
  281. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
  282. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  283. if self.use_pos_embed:
  284. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, **dd))
  285. trunc_normal_(self.pos_embed, std=.02)
  286. dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
  287. self.blocks = nn.ModuleList([
  288. Block(
  289. dim=embed_dim,
  290. num_heads=num_heads,
  291. mlp_ratio=mlp_ratio,
  292. qkv_bias=qkv_bias,
  293. proj_drop=proj_drop_rate,
  294. attn_drop=attn_drop_rate,
  295. drop_path=dpr[i],
  296. norm_layer=norm_layer,
  297. use_gpsa=i < local_up_to_layer,
  298. locality_strength=locality_strength,
  299. **dd,
  300. ) for i in range(depth)])
  301. self.norm = norm_layer(embed_dim, **dd)
  302. # Classifier head
  303. self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
  304. self.head_drop = nn.Dropout(drop_rate)
  305. self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  306. trunc_normal_(self.cls_token, std=.02)
  307. self.apply(self._init_weights)
  308. for n, m in self.named_modules():
  309. if hasattr(m, 'local_init'):
  310. m.local_init()
  311. def _init_weights(self, m):
  312. if isinstance(m, nn.Linear):
  313. trunc_normal_(m.weight, std=.02)
  314. if isinstance(m, nn.Linear) and m.bias is not None:
  315. nn.init.constant_(m.bias, 0)
  316. elif isinstance(m, nn.LayerNorm):
  317. nn.init.constant_(m.bias, 0)
  318. nn.init.constant_(m.weight, 1.0)
  319. @torch.jit.ignore
  320. def no_weight_decay(self):
  321. return {'pos_embed', 'cls_token'}
  322. @torch.jit.ignore
  323. def group_matcher(self, coarse=False):
  324. return dict(
  325. stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
  326. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
  327. )
  328. @torch.jit.ignore
  329. def set_grad_checkpointing(self, enable=True):
  330. assert not enable, 'gradient checkpointing not supported'
  331. @torch.jit.ignore
  332. def get_classifier(self) -> nn.Module:
  333. return self.head
  334. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  335. self.num_classes = num_classes
  336. if global_pool is not None:
  337. assert global_pool in ('', 'token', 'avg')
  338. self.global_pool = global_pool
  339. self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  340. def forward_features(self, x):
  341. x = self.patch_embed(x)
  342. if self.use_pos_embed:
  343. x = x + self.pos_embed
  344. x = self.pos_drop(x)
  345. cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
  346. for u, blk in enumerate(self.blocks):
  347. if u == self.local_up_to_layer:
  348. x = torch.cat((cls_tokens, x), dim=1)
  349. x = blk(x)
  350. x = self.norm(x)
  351. return x
  352. def forward_head(self, x, pre_logits: bool = False):
  353. if self.global_pool:
  354. x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
  355. x = self.head_drop(x)
  356. return x if pre_logits else self.head(x)
  357. def forward(self, x):
  358. x = self.forward_features(x)
  359. x = self.forward_head(x)
  360. return x
  361. def _create_convit(variant, pretrained=False, **kwargs):
  362. if kwargs.get('features_only', None):
  363. raise RuntimeError('features_only not implemented for Vision Transformer models.')
  364. return build_model_with_cfg(ConVit, variant, pretrained, **kwargs)
  365. def _cfg(url='', **kwargs):
  366. return {
  367. 'url': url,
  368. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  369. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
  370. 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'license': 'apache-2.0',
  371. **kwargs
  372. }
  373. default_cfgs = generate_default_cfgs({
  374. # ConViT
  375. 'convit_tiny.fb_in1k': _cfg(hf_hub_id='timm/'),
  376. 'convit_small.fb_in1k': _cfg(hf_hub_id='timm/'),
  377. 'convit_base.fb_in1k': _cfg(hf_hub_id='timm/')
  378. })
  379. @register_model
  380. def convit_tiny(pretrained=False, **kwargs) -> ConVit:
  381. model_args = dict(
  382. local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=4)
  383. model = _create_convit(variant='convit_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
  384. return model
  385. @register_model
  386. def convit_small(pretrained=False, **kwargs) -> ConVit:
  387. model_args = dict(
  388. local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=9)
  389. model = _create_convit(variant='convit_small', pretrained=pretrained, **dict(model_args, **kwargs))
  390. return model
  391. @register_model
  392. def convit_base(pretrained=False, **kwargs) -> ConVit:
  393. model_args = dict(
  394. local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=16)
  395. model = _create_convit(variant='convit_base', pretrained=pretrained, **dict(model_args, **kwargs))
  396. return model