gcvit.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686
  1. """ Global Context ViT
  2. From scratch implementation of GCViT in the style of timm swin_transformer_v2_cr.py
  3. Global Context Vision Transformers -https://arxiv.org/abs/2206.09959
  4. @article{hatamizadeh2022global,
  5. title={Global Context Vision Transformers},
  6. author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
  7. journal={arXiv preprint arXiv:2206.09959},
  8. year={2022}
  9. }
  10. Free of any code related to NVIDIA GCVit impl at https://github.com/NVlabs/GCVit.
  11. The license for this code release is Apache 2.0 with no commercial restrictions.
  12. However, weight files adapted from NVIDIA GCVit impl ARE under a non-commercial share-alike license
  13. (https://creativecommons.org/licenses/by-nc-sa/4.0/) until I have a chance to train new ones...
  14. Hacked together by / Copyright 2022, Ross Wightman
  15. """
  16. import math
  17. from functools import partial
  18. from typing import Callable, List, Optional, Tuple, Type, Union
  19. import torch
  20. import torch.nn as nn
  21. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  22. from timm.layers import (
  23. DropPath,
  24. calculate_drop_path_rates,
  25. to_2tuple,
  26. to_ntuple,
  27. Mlp,
  28. ClassifierHead,
  29. LayerNorm2d,
  30. LayerScale,
  31. get_attn,
  32. get_act_layer,
  33. get_norm_layer,
  34. RelPosBias,
  35. _assert,
  36. )
  37. from ._builder import build_model_with_cfg
  38. from ._features import feature_take_indices
  39. from ._features_fx import register_notrace_function
  40. from ._manipulate import named_apply, checkpoint
  41. from ._registry import register_model, generate_default_cfgs
  42. __all__ = ['GlobalContextVit']
  43. class MbConvBlock(nn.Module):
  44. """ A depthwise separable / fused mbconv style residual block with SE, `no norm.
  45. """
  46. def __init__(
  47. self,
  48. in_chs: int,
  49. out_chs: Optional[int] = None,
  50. expand_ratio: float = 1.0,
  51. attn_layer: str = 'se',
  52. bias: bool = False,
  53. act_layer: Type[nn.Module] = nn.GELU,
  54. device=None,
  55. dtype=None,
  56. ):
  57. dd = {'device': device, 'dtype': dtype}
  58. super().__init__()
  59. attn_kwargs = dict(act_layer=act_layer, **dd)
  60. if isinstance(attn_layer, str) and attn_layer == 'se' or attn_layer == 'eca':
  61. attn_kwargs['rd_ratio'] = 0.25
  62. attn_kwargs['bias'] = False
  63. attn_layer = get_attn(attn_layer)
  64. out_chs = out_chs or in_chs
  65. mid_chs = int(expand_ratio * in_chs)
  66. self.conv_dw = nn.Conv2d(in_chs, mid_chs, 3, 1, 1, groups=in_chs, bias=bias, **dd)
  67. self.act = act_layer()
  68. self.se = attn_layer(mid_chs, **attn_kwargs)
  69. self.conv_pw = nn.Conv2d(mid_chs, out_chs, 1, 1, 0, bias=bias, **dd)
  70. def forward(self, x):
  71. shortcut = x
  72. x = self.conv_dw(x)
  73. x = self.act(x)
  74. x = self.se(x)
  75. x = self.conv_pw(x)
  76. x = x + shortcut
  77. return x
  78. class Downsample2d(nn.Module):
  79. def __init__(
  80. self,
  81. dim: int,
  82. dim_out: Optional[int] = None,
  83. reduction: str = 'conv',
  84. act_layer: Type[nn.Module] = nn.GELU,
  85. norm_layer: Type[nn.Module] = LayerNorm2d, # NOTE in NCHW
  86. device=None,
  87. dtype=None,
  88. ):
  89. dd = {'device': device, 'dtype': dtype}
  90. super().__init__()
  91. dim_out = dim_out or dim
  92. self.norm1 = norm_layer(dim, **dd) if norm_layer is not None else nn.Identity()
  93. self.conv_block = MbConvBlock(dim, act_layer=act_layer, **dd)
  94. assert reduction in ('conv', 'max', 'avg')
  95. if reduction == 'conv':
  96. self.reduction = nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False, **dd)
  97. elif reduction == 'max':
  98. assert dim == dim_out
  99. self.reduction = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  100. else:
  101. assert dim == dim_out
  102. self.reduction = nn.AvgPool2d(kernel_size=2)
  103. self.norm2 = norm_layer(dim_out, **dd) if norm_layer is not None else nn.Identity()
  104. def forward(self, x):
  105. x = self.norm1(x)
  106. x = self.conv_block(x)
  107. x = self.reduction(x)
  108. x = self.norm2(x)
  109. return x
  110. class FeatureBlock(nn.Module):
  111. def __init__(
  112. self,
  113. dim: int,
  114. levels: int = 0,
  115. reduction: str = 'max',
  116. act_layer: Type[nn.Module] = nn.GELU,
  117. device=None,
  118. dtype=None,
  119. ):
  120. dd = {'device': device, 'dtype': dtype}
  121. super().__init__()
  122. reductions = levels
  123. levels = max(1, levels)
  124. if reduction == 'avg':
  125. pool_fn = partial(nn.AvgPool2d, kernel_size=2)
  126. else:
  127. pool_fn = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
  128. self.blocks = nn.Sequential()
  129. for i in range(levels):
  130. self.blocks.add_module(f'conv{i+1}', MbConvBlock(dim, act_layer=act_layer, **dd))
  131. if reductions:
  132. self.blocks.add_module(f'pool{i+1}', pool_fn())
  133. reductions -= 1
  134. def forward(self, x):
  135. return self.blocks(x)
  136. class Stem(nn.Module):
  137. def __init__(
  138. self,
  139. in_chs: int = 3,
  140. out_chs: int = 96,
  141. act_layer: Type[nn.Module] = nn.GELU,
  142. norm_layer: Type[nn.Module] = LayerNorm2d, # NOTE stem in NCHW
  143. device=None,
  144. dtype=None,
  145. ):
  146. super().__init__()
  147. dd = {'device': device, 'dtype': dtype}
  148. self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1, **dd)
  149. self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer, **dd)
  150. def forward(self, x):
  151. x = self.conv1(x)
  152. x = self.down(x)
  153. return x
  154. class WindowAttentionGlobal(nn.Module):
  155. def __init__(
  156. self,
  157. dim: int,
  158. num_heads: int,
  159. window_size: Tuple[int, int],
  160. use_global: bool = True,
  161. qkv_bias: bool = True,
  162. attn_drop: float = 0.,
  163. proj_drop: float = 0.,
  164. device=None,
  165. dtype=None,
  166. ):
  167. dd = {'device': device, 'dtype': dtype}
  168. super().__init__()
  169. window_size = to_2tuple(window_size)
  170. self.window_size = window_size
  171. self.num_heads = num_heads
  172. self.head_dim = dim // num_heads
  173. self.scale = self.head_dim ** -0.5
  174. self.use_global = use_global
  175. self.rel_pos = RelPosBias(window_size=window_size, num_heads=num_heads, **dd)
  176. if self.use_global:
  177. self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias, **dd)
  178. else:
  179. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  180. self.attn_drop = nn.Dropout(attn_drop)
  181. self.proj = nn.Linear(dim, dim, **dd)
  182. self.proj_drop = nn.Dropout(proj_drop)
  183. def forward(self, x, q_global: Optional[torch.Tensor] = None):
  184. B, N, C = x.shape
  185. if self.use_global and q_global is not None:
  186. _assert(x.shape[-1] == q_global.shape[-1], 'x and q_global seq lengths should be equal')
  187. kv = self.qkv(x)
  188. kv = kv.reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  189. k, v = kv.unbind(0)
  190. q = q_global.repeat(B // q_global.shape[0], 1, 1, 1)
  191. q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
  192. else:
  193. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  194. q, k, v = qkv.unbind(0)
  195. q = q * self.scale
  196. attn = q @ k.transpose(-2, -1).contiguous() # NOTE contiguous() fixes an odd jit bug in PyTorch 2.0
  197. attn = self.rel_pos(attn)
  198. attn = attn.softmax(dim=-1)
  199. attn = self.attn_drop(attn)
  200. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  201. x = self.proj(x)
  202. x = self.proj_drop(x)
  203. return x
  204. def window_partition(x, window_size: Tuple[int, int]):
  205. B, H, W, C = x.shape
  206. x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
  207. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
  208. return windows
  209. @register_notrace_function # reason: int argument is a Proxy
  210. def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
  211. H, W = img_size
  212. C = windows.shape[-1]
  213. x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
  214. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
  215. return x
  216. class GlobalContextVitBlock(nn.Module):
  217. def __init__(
  218. self,
  219. dim: int,
  220. feat_size: Tuple[int, int],
  221. num_heads: int,
  222. window_size: int = 7,
  223. mlp_ratio: float = 4.,
  224. use_global: bool = True,
  225. qkv_bias: bool = True,
  226. layer_scale: Optional[float] = None,
  227. proj_drop: float = 0.,
  228. attn_drop: float = 0.,
  229. drop_path: float = 0.,
  230. attn_layer: Callable = WindowAttentionGlobal,
  231. act_layer: Type[nn.Module] = nn.GELU,
  232. norm_layer: Type[nn.Module] = nn.LayerNorm,
  233. device=None,
  234. dtype=None,
  235. ):
  236. dd = {'device': device, 'dtype': dtype}
  237. super().__init__()
  238. feat_size = to_2tuple(feat_size)
  239. window_size = to_2tuple(window_size)
  240. self.window_size = window_size
  241. self.num_windows = int((feat_size[0] // window_size[0]) * (feat_size[1] // window_size[1]))
  242. self.norm1 = norm_layer(dim, **dd)
  243. self.attn = attn_layer(
  244. dim,
  245. num_heads=num_heads,
  246. window_size=window_size,
  247. use_global=use_global,
  248. qkv_bias=qkv_bias,
  249. attn_drop=attn_drop,
  250. proj_drop=proj_drop,
  251. **dd,
  252. )
  253. self.ls1 = LayerScale(dim, layer_scale, **dd) if layer_scale is not None else nn.Identity()
  254. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  255. self.norm2 = norm_layer(dim, **dd)
  256. self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, **dd)
  257. self.ls2 = LayerScale(dim, layer_scale, **dd) if layer_scale is not None else nn.Identity()
  258. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  259. def _window_attn(self, x, q_global: Optional[torch.Tensor] = None):
  260. B, H, W, C = x.shape
  261. x_win = window_partition(x, self.window_size)
  262. x_win = x_win.view(-1, self.window_size[0] * self.window_size[1], C)
  263. attn_win = self.attn(x_win, q_global)
  264. x = window_reverse(attn_win, self.window_size, (H, W))
  265. return x
  266. def forward(self, x, q_global: Optional[torch.Tensor] = None):
  267. x = x + self.drop_path1(self.ls1(self._window_attn(self.norm1(x), q_global)))
  268. x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
  269. return x
  270. class GlobalContextVitStage(nn.Module):
  271. def __init__(
  272. self,
  273. dim: int,
  274. depth: int,
  275. num_heads: int,
  276. feat_size: Tuple[int, int],
  277. window_size: Tuple[int, int],
  278. downsample: bool = True,
  279. global_norm: bool = False,
  280. stage_norm: bool = False,
  281. mlp_ratio: float = 4.,
  282. qkv_bias: bool = True,
  283. layer_scale: Optional[float] = None,
  284. proj_drop: float = 0.,
  285. attn_drop: float = 0.,
  286. drop_path: Union[List[float], float] = 0.0,
  287. act_layer: Type[nn.Module] = nn.GELU,
  288. norm_layer: Type[nn.Module] = nn.LayerNorm,
  289. norm_layer_cl: Type[nn.Module] = LayerNorm2d,
  290. device=None,
  291. dtype=None,
  292. ):
  293. dd = {'device': device, 'dtype': dtype}
  294. super().__init__()
  295. if downsample:
  296. self.downsample = Downsample2d(
  297. dim=dim,
  298. dim_out=dim * 2,
  299. norm_layer=norm_layer,
  300. **dd,
  301. )
  302. dim = dim * 2
  303. feat_size = (feat_size[0] // 2, feat_size[1] // 2)
  304. else:
  305. self.downsample = nn.Identity()
  306. self.feat_size = feat_size
  307. window_size = to_2tuple(window_size)
  308. feat_levels = int(math.log2(min(feat_size) / min(window_size)))
  309. self.global_block = FeatureBlock(dim, feat_levels, **dd)
  310. self.global_norm = norm_layer_cl(dim, **dd) if global_norm else nn.Identity()
  311. self.blocks = nn.ModuleList([
  312. GlobalContextVitBlock(
  313. dim=dim,
  314. num_heads=num_heads,
  315. feat_size=feat_size,
  316. window_size=window_size,
  317. mlp_ratio=mlp_ratio,
  318. qkv_bias=qkv_bias,
  319. use_global=(i % 2 != 0),
  320. layer_scale=layer_scale,
  321. proj_drop=proj_drop,
  322. attn_drop=attn_drop,
  323. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  324. act_layer=act_layer,
  325. norm_layer=norm_layer_cl,
  326. **dd,
  327. )
  328. for i in range(depth)
  329. ])
  330. self.norm = norm_layer_cl(dim, **dd) if stage_norm else nn.Identity()
  331. self.dim = dim
  332. self.feat_size = feat_size
  333. self.grad_checkpointing = False
  334. def forward(self, x):
  335. # input NCHW, downsample & global block are 2d conv + pooling
  336. x = self.downsample(x)
  337. global_query = self.global_block(x)
  338. # reshape NCHW --> NHWC for transformer blocks
  339. x = x.permute(0, 2, 3, 1)
  340. global_query = self.global_norm(global_query.permute(0, 2, 3, 1))
  341. for blk in self.blocks:
  342. if self.grad_checkpointing and not torch.jit.is_scripting():
  343. x = checkpoint(blk, x, global_query)
  344. else:
  345. x = blk(x, global_query)
  346. x = self.norm(x)
  347. x = x.permute(0, 3, 1, 2).contiguous() # back to NCHW
  348. return x
  349. class GlobalContextVit(nn.Module):
  350. def __init__(
  351. self,
  352. in_chans: int = 3,
  353. num_classes: int = 1000,
  354. global_pool: str = 'avg',
  355. img_size: Union[int, Tuple[int, int]] = 224,
  356. window_ratio: Tuple[int, ...] = (32, 32, 16, 32),
  357. window_size: Optional[Union[int, Tuple[int, ...]]] = None,
  358. embed_dim: int = 64,
  359. depths: Tuple[int, ...] = (3, 4, 19, 5),
  360. num_heads: Tuple[int, ...] = (2, 4, 8, 16),
  361. mlp_ratio: float = 3.0,
  362. qkv_bias: bool = True,
  363. layer_scale: Optional[float] = None,
  364. drop_rate: float = 0.,
  365. proj_drop_rate: float = 0.,
  366. attn_drop_rate: float = 0.,
  367. drop_path_rate: float = 0.,
  368. weight_init: str = '',
  369. act_layer: str = 'gelu',
  370. norm_layer: str = 'layernorm2d',
  371. norm_layer_cl: str = 'layernorm',
  372. norm_eps: float = 1e-5,
  373. device=None,
  374. dtype=None,
  375. ):
  376. super().__init__()
  377. dd = {'device': device, 'dtype': dtype}
  378. act_layer = get_act_layer(act_layer)
  379. norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
  380. norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
  381. self.feature_info = []
  382. img_size = to_2tuple(img_size)
  383. feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
  384. self.global_pool = global_pool
  385. self.num_classes = num_classes
  386. self.in_chans = in_chans
  387. self.drop_rate = drop_rate
  388. num_stages = len(depths)
  389. self.num_features = self.head_hidden_size = int(embed_dim * 2 ** (num_stages - 1))
  390. if window_size is not None:
  391. window_size = to_ntuple(num_stages)(window_size)
  392. else:
  393. assert window_ratio is not None
  394. window_size = tuple([(img_size[0] // r, img_size[1] // r) for r in to_ntuple(num_stages)(window_ratio)])
  395. self.stem = Stem(
  396. in_chs=in_chans,
  397. out_chs=embed_dim,
  398. act_layer=act_layer,
  399. norm_layer=norm_layer,
  400. **dd,
  401. )
  402. dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  403. stages = []
  404. for i in range(num_stages):
  405. last_stage = i == num_stages - 1
  406. stage_scale = 2 ** max(i - 1, 0)
  407. stages.append(GlobalContextVitStage(
  408. dim=embed_dim * stage_scale,
  409. depth=depths[i],
  410. num_heads=num_heads[i],
  411. feat_size=(feat_size[0] // stage_scale, feat_size[1] // stage_scale),
  412. window_size=window_size[i],
  413. downsample=i != 0,
  414. stage_norm=last_stage,
  415. mlp_ratio=mlp_ratio,
  416. qkv_bias=qkv_bias,
  417. layer_scale=layer_scale,
  418. proj_drop=proj_drop_rate,
  419. attn_drop=attn_drop_rate,
  420. drop_path=dpr[i],
  421. act_layer=act_layer,
  422. norm_layer=norm_layer,
  423. norm_layer_cl=norm_layer_cl,
  424. **dd,
  425. ))
  426. self.feature_info += [dict(num_chs=stages[-1].dim, reduction=2**(i+2), module=f'stages.{i}')]
  427. self.stages = nn.Sequential(*stages)
  428. # Classifier head
  429. self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate, **dd)
  430. if weight_init:
  431. named_apply(partial(self._init_weights, scheme=weight_init), self)
  432. def _init_weights(self, module, name, scheme='vit'):
  433. # note Conv2d left as default init
  434. if scheme == 'vit':
  435. if isinstance(module, nn.Linear):
  436. nn.init.xavier_uniform_(module.weight)
  437. if module.bias is not None:
  438. if 'mlp' in name:
  439. nn.init.normal_(module.bias, std=1e-6)
  440. else:
  441. nn.init.zeros_(module.bias)
  442. else:
  443. if isinstance(module, nn.Linear):
  444. nn.init.normal_(module.weight, std=.02)
  445. if module.bias is not None:
  446. nn.init.zeros_(module.bias)
  447. @torch.jit.ignore
  448. def no_weight_decay(self):
  449. return {
  450. k for k, _ in self.named_parameters()
  451. if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
  452. @torch.jit.ignore
  453. def group_matcher(self, coarse=False):
  454. matcher = dict(
  455. stem=r'^stem', # stem and embed
  456. blocks=r'^stages\.(\d+)'
  457. )
  458. return matcher
  459. @torch.jit.ignore
  460. def set_grad_checkpointing(self, enable=True):
  461. for s in self.stages:
  462. s.grad_checkpointing = enable
  463. @torch.jit.ignore
  464. def get_classifier(self) -> nn.Module:
  465. return self.head.fc
  466. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, device=None, dtype=None):
  467. dd = {'device': device, 'dtype': dtype}
  468. self.num_classes = num_classes
  469. if global_pool is None:
  470. global_pool = self.head.global_pool.pool_type
  471. self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, **dd)
  472. def forward_intermediates(
  473. self,
  474. x: torch.Tensor,
  475. indices: Optional[Union[int, List[int]]] = None,
  476. norm: bool = False,
  477. stop_early: bool = False,
  478. output_fmt: str = 'NCHW',
  479. intermediates_only: bool = False,
  480. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  481. """ Forward features that returns intermediates.
  482. Args:
  483. x: Input image tensor
  484. indices: Take last n blocks if int, all if None, select matching indices if sequence
  485. norm: Apply norm layer to compatible intermediates
  486. stop_early: Stop iterating over blocks when last desired intermediate hit
  487. output_fmt: Shape of intermediate feature outputs
  488. intermediates_only: Only return intermediate features
  489. Returns:
  490. """
  491. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  492. intermediates = []
  493. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  494. # forward pass
  495. x = self.stem(x)
  496. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  497. stages = self.stages
  498. else:
  499. stages = self.stages[:max_index + 1]
  500. for feat_idx, stage in enumerate(stages):
  501. x = stage(x)
  502. if feat_idx in take_indices:
  503. intermediates.append(x)
  504. if intermediates_only:
  505. return intermediates
  506. return x, intermediates
  507. def prune_intermediate_layers(
  508. self,
  509. indices: Union[int, List[int]] = 1,
  510. prune_norm: bool = False,
  511. prune_head: bool = True,
  512. ):
  513. """ Prune layers not required for specified intermediates.
  514. """
  515. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  516. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  517. if prune_head:
  518. self.reset_classifier(0, '')
  519. return take_indices
  520. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  521. x = self.stem(x)
  522. x = self.stages(x)
  523. return x
  524. def forward_head(self, x, pre_logits: bool = False):
  525. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  526. def forward(self, x: torch.Tensor) -> torch.Tensor:
  527. x = self.forward_features(x)
  528. x = self.forward_head(x)
  529. return x
  530. def _create_gcvit(variant, pretrained=False, **kwargs):
  531. model = build_model_with_cfg(
  532. GlobalContextVit, variant, pretrained,
  533. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  534. **kwargs
  535. )
  536. return model
  537. def _cfg(url='', **kwargs):
  538. return {
  539. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  540. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  541. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  542. 'first_conv': 'stem.conv1', 'classifier': 'head.fc',
  543. 'fixed_input_size': True,
  544. 'license': 'apache-2.0',
  545. **kwargs
  546. }
  547. default_cfgs = generate_default_cfgs({
  548. 'gcvit_xxtiny.in1k': _cfg(
  549. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xxtiny_224_nvidia-d1d86009.pth'),
  550. 'gcvit_xtiny.in1k': _cfg(
  551. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xtiny_224_nvidia-274b92b7.pth'),
  552. 'gcvit_tiny.in1k': _cfg(
  553. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_tiny_224_nvidia-ac783954.pth'),
  554. 'gcvit_small.in1k': _cfg(
  555. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_small_224_nvidia-4e98afa2.pth'),
  556. 'gcvit_base.in1k': _cfg(
  557. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_base_224_nvidia-f009139b.pth'),
  558. })
  559. @register_model
  560. def gcvit_xxtiny(pretrained=False, **kwargs) -> GlobalContextVit:
  561. model_kwargs = dict(
  562. depths=(2, 2, 6, 2),
  563. num_heads=(2, 4, 8, 16),
  564. **kwargs)
  565. return _create_gcvit('gcvit_xxtiny', pretrained=pretrained, **model_kwargs)
  566. @register_model
  567. def gcvit_xtiny(pretrained=False, **kwargs) -> GlobalContextVit:
  568. model_kwargs = dict(
  569. depths=(3, 4, 6, 5),
  570. num_heads=(2, 4, 8, 16),
  571. **kwargs)
  572. return _create_gcvit('gcvit_xtiny', pretrained=pretrained, **model_kwargs)
  573. @register_model
  574. def gcvit_tiny(pretrained=False, **kwargs) -> GlobalContextVit:
  575. model_kwargs = dict(
  576. depths=(3, 4, 19, 5),
  577. num_heads=(2, 4, 8, 16),
  578. **kwargs)
  579. return _create_gcvit('gcvit_tiny', pretrained=pretrained, **model_kwargs)
  580. @register_model
  581. def gcvit_small(pretrained=False, **kwargs) -> GlobalContextVit:
  582. model_kwargs = dict(
  583. depths=(3, 4, 19, 5),
  584. num_heads=(3, 6, 12, 24),
  585. embed_dim=96,
  586. mlp_ratio=2,
  587. layer_scale=1e-5,
  588. **kwargs)
  589. return _create_gcvit('gcvit_small', pretrained=pretrained, **model_kwargs)
  590. @register_model
  591. def gcvit_base(pretrained=False, **kwargs) -> GlobalContextVit:
  592. model_kwargs = dict(
  593. depths=(3, 4, 19, 5),
  594. num_heads=(4, 8, 16, 32),
  595. embed_dim=128,
  596. mlp_ratio=2,
  597. layer_scale=1e-5,
  598. **kwargs)
  599. return _create_gcvit('gcvit_base', pretrained=pretrained, **model_kwargs)