vitamin.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. """ ViTamin
  2. Paper: Designing Scalable Vison Models in the Vision-Language Era
  3. A family of model weights on Huggingface: https://huggingface.co/collections/jienengchen/vitamin-family-661048126b72debdaca060bf
  4. @inproceedings{chen2024vitamin,
  5. title={ViTamin: Designing Scalable Vision Models in the Vision-language Era},
  6. author={Chen, Jieneng and Yu, Qihang and Shen, Xiaohui and Yuille, Alan and Chen, Liang-Chieh},
  7. booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  8. year={2024}
  9. }
  10. Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin
  11. Modifications and timm support by Jieneng Chen 2024
  12. Reference:
  13. https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
  14. https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
  15. """
  16. import math
  17. from dataclasses import dataclass, field
  18. from functools import partial
  19. from typing import Optional, Union, Tuple
  20. import torch
  21. import torch.nn as nn
  22. from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
  23. from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \
  24. make_divisible, DropPath, HybridEmbed
  25. from ._builder import build_model_with_cfg
  26. from ._manipulate import named_apply, checkpoint_seq
  27. from ._registry import register_model, generate_default_cfgs
  28. from .vision_transformer import VisionTransformer, checkpoint_filter_fn
  29. @dataclass
  30. class VitConvCfg:
  31. expand_ratio: float = 4.0
  32. expand_output: bool = True # calculate expansion channels from output (vs input chs)
  33. kernel_size: int = 3
  34. group_size: int = 1 # 1 == depthwise
  35. pre_norm_act: bool = False # activation after pre-norm
  36. stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
  37. pool_type: str = 'avg2'
  38. downsample_pool_type: str = 'avg2'
  39. act_layer: str = 'gelu' # stem & stage 1234
  40. norm_layer: str = ''
  41. norm_eps: float = 1e-5
  42. down_shortcut: Optional[bool] = True
  43. mlp: str = 'mlp'
  44. @dataclass
  45. class VitCfg:
  46. embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768)
  47. depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2)
  48. stem_width: int = 64
  49. conv_cfg: VitConvCfg = field(default_factory=VitConvCfg)
  50. head_type: str = ""
  51. def _init_conv(module, name, scheme=''):
  52. if isinstance(module, nn.Conv2d):
  53. fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
  54. fan_out //= module.groups
  55. nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
  56. if module.bias is not None:
  57. nn.init.zeros_(module.bias)
  58. class Stem(nn.Module):
  59. def __init__(
  60. self,
  61. in_chs: int,
  62. out_chs: int,
  63. act_layer: str = 'gelu',
  64. norm_layer: str = 'layernorm2d',
  65. norm_eps: float = 1e-6,
  66. bias: bool = True,
  67. device=None,
  68. dtype=None,
  69. ):
  70. dd = {'device': device, 'dtype': dtype}
  71. super().__init__()
  72. norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
  73. self.out_chs = out_chs
  74. self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias, **dd)
  75. self.norm1 = norm_act_layer(out_chs, **dd)
  76. self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias, **dd)
  77. named_apply(_init_conv, self)
  78. def forward(self, x):
  79. x = self.conv1(x)
  80. x = self.norm1(x)
  81. x = self.conv2(x)
  82. return x
  83. class Downsample2d(nn.Module):
  84. def __init__(
  85. self,
  86. dim: int,
  87. dim_out: int,
  88. pool_type: str = 'avg2',
  89. bias: bool = True,
  90. device=None,
  91. dtype=None,
  92. ):
  93. dd = {'device': device, 'dtype': dtype}
  94. super().__init__()
  95. self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
  96. if dim != dim_out:
  97. self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias, **dd) # 1x1 conv
  98. else:
  99. self.expand = nn.Identity()
  100. def forward(self, x):
  101. x = self.pool(x) # spatial downsample
  102. x = self.expand(x) # expand chs
  103. return x
  104. class StridedConv(nn.Module):
  105. """ downsample 2d as well
  106. """
  107. def __init__(
  108. self,
  109. kernel_size: int = 3,
  110. stride: int = 2,
  111. padding: int = 1,
  112. in_chans: int = 3,
  113. embed_dim: int = 768,
  114. device=None,
  115. dtype=None,
  116. ):
  117. dd = {'device': device, 'dtype': dtype}
  118. super().__init__()
  119. norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
  120. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding, **dd)
  121. self.norm = norm_layer(in_chans, **dd) # affine over C
  122. def forward(self, x):
  123. x = self.norm(x)
  124. x = self.proj(x)
  125. return x
  126. class MbConvLNBlock(nn.Module):
  127. """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)
  128. """
  129. def __init__(
  130. self,
  131. in_chs: int,
  132. out_chs: int,
  133. stride: int = 1,
  134. drop_path: float = 0.,
  135. kernel_size: int = 3,
  136. norm_layer: str = 'layernorm2d',
  137. norm_eps: float = 1e-6,
  138. act_layer: str = 'gelu',
  139. expand_ratio: float = 4.0,
  140. device=None,
  141. dtype=None,
  142. ):
  143. dd = {'device': device, 'dtype': dtype}
  144. super().__init__()
  145. self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs
  146. mid_chs = make_divisible(out_chs * expand_ratio)
  147. prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
  148. if stride == 2:
  149. self.shortcut = Downsample2d(in_chs, out_chs, pool_type='avg', bias=True, **dd)
  150. elif in_chs != out_chs:
  151. self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True, **dd)
  152. else:
  153. self.shortcut = nn.Identity()
  154. self.pre_norm = prenorm_act_layer(in_chs, apply_act=False, **dd)
  155. self.down = nn.Identity()
  156. self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True, **dd)
  157. self.act1 = create_act_layer(act_layer, inplace=True)
  158. self.conv2_kxk = create_conv2d(
  159. mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True, **dd)
  160. self.act2 = create_act_layer(act_layer, inplace=True)
  161. self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True, **dd)
  162. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  163. def init_weights(self, scheme=''):
  164. named_apply(partial(_init_conv, scheme=scheme), self)
  165. def forward(self, x):
  166. shortcut = self.shortcut(x)
  167. x = self.pre_norm(x)
  168. x = self.down(x) # nn.Identity()
  169. # 1x1 expansion conv & act
  170. x = self.conv1_1x1(x)
  171. x = self.act1(x)
  172. # (strided) depthwise 3x3 conv & act
  173. x = self.conv2_kxk(x)
  174. x = self.act2(x)
  175. # 1x1 linear projection to output width
  176. x = self.conv3_1x1(x)
  177. x = self.drop_path(x) + shortcut
  178. return x
  179. class MbConvStages(nn.Module):
  180. """ MobileConv for stage 1 and stage 2 of ViTamin
  181. """
  182. def __init__(
  183. self,
  184. cfg: VitCfg,
  185. img_size: Union[int, Tuple[int, int]] = 224, # place holder
  186. in_chans: int = 3,
  187. device=None,
  188. dtype=None,
  189. ):
  190. dd = {'device': device, 'dtype': dtype}
  191. super().__init__()
  192. self.grad_checkpointing = False
  193. self.stem = Stem(
  194. in_chs=in_chans,
  195. out_chs=cfg.stem_width,
  196. **dd,
  197. )
  198. stages = []
  199. self.num_stages = len(cfg.embed_dim)
  200. for s, dim in enumerate(cfg.embed_dim[:2]): # stage
  201. stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width
  202. blocks = [
  203. MbConvLNBlock(
  204. in_chs = stage_in_chs if d==0 else dim,
  205. out_chs = dim,
  206. stride = 2 if d == 0 else 1,
  207. **dd,
  208. )
  209. for d in range(cfg.depths[s])
  210. ]
  211. stages += [nn.Sequential(*blocks)]
  212. self.stages = nn.Sequential(*stages)
  213. self.pool = StridedConv(
  214. stride=2,
  215. in_chans=cfg.embed_dim[1],
  216. embed_dim=cfg.embed_dim[2],
  217. **dd,
  218. )
  219. def forward(self, x):
  220. x = self.stem(x)
  221. if self.grad_checkpointing and not torch.jit.is_scripting():
  222. x = checkpoint_seq(self.stages, x)
  223. else:
  224. x = self.stages(x)
  225. x = self.pool(x)
  226. return x
  227. class GeGluMlp(nn.Module):
  228. def __init__(
  229. self,
  230. in_features: int,
  231. hidden_features: int,
  232. act_layer: str = 'gelu',
  233. norm_layer: Optional[str] = None,
  234. bias: bool = True,
  235. drop: float = 0.0,
  236. device=None,
  237. dtype=None,
  238. ):
  239. dd = {'device': device, 'dtype': dtype}
  240. super().__init__()
  241. norm_layer = partial(get_norm_layer(norm_layer or 'layernorm'), eps=1e-6)
  242. self.norm = norm_layer(in_features, **dd)
  243. self.w0 = nn.Linear(in_features, hidden_features, bias=bias, **dd)
  244. self.act = create_act_layer(act_layer)
  245. self.w1 = nn.Linear(in_features, hidden_features, bias=bias, **dd)
  246. self.w2 = nn.Linear(hidden_features, in_features, bias=bias, **dd)
  247. def forward(self, x):
  248. x = self.norm(x)
  249. x = self.act(self.w0(x)) * self.w1(x)
  250. x = self.w2(x)
  251. return x
  252. def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs):
  253. out_indices = kwargs.pop('out_indices', 3)
  254. assert embed_cfg is not None
  255. dd = {'device': kwargs.get('device', None), 'dtype': kwargs.get('dtype', None)}
  256. backbone = MbConvStages(cfg=embed_cfg, in_chans=kwargs.get('in_chans', 3), **dd)
  257. kwargs['embed_layer'] = partial(HybridEmbed, backbone=backbone, proj=False)
  258. kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
  259. return build_model_with_cfg(
  260. VisionTransformer,
  261. variant,
  262. pretrained,
  263. pretrained_filter_fn=checkpoint_filter_fn,
  264. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  265. **kwargs,
  266. )
  267. def _cfg(url='', **kwargs):
  268. return {
  269. 'url': url,
  270. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  271. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  272. 'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
  273. 'first_conv': 'patch_embed.backbone.stem.conv1',
  274. 'classifier': 'head', 'license': 'mit',
  275. **kwargs
  276. }
  277. default_cfgs = generate_default_cfgs({
  278. 'vitamin_small_224.datacomp1b_clip_ltt': _cfg(
  279. hf_hub_id='jienengchen/ViTamin-S-LTT', num_classes=768),
  280. 'vitamin_small_224.datacomp1b_clip': _cfg(
  281. hf_hub_id='jienengchen/ViTamin-S', num_classes=384),
  282. 'vitamin_base_224.datacomp1b_clip_ltt': _cfg(
  283. hf_hub_id='jienengchen/ViTamin-B-LTT', num_classes=768),
  284. 'vitamin_base_224.datacomp1b_clip': _cfg(
  285. hf_hub_id='jienengchen/ViTamin-B', num_classes=768),
  286. 'vitamin_large_224.datacomp1b_clip': _cfg(
  287. hf_hub_id='jienengchen/ViTamin-L-224px', num_classes=768),
  288. 'vitamin_large_256.datacomp1b_clip': _cfg(
  289. hf_hub_id='jienengchen/ViTamin-L-256px', num_classes=768,
  290. input_size=(3, 256, 256), crop_pct=1.0),
  291. 'vitamin_large_336.datacomp1b_clip': _cfg(
  292. hf_hub_id='jienengchen/ViTamin-L-336px', num_classes=768,
  293. input_size=(3, 336, 336), crop_pct=1.0),
  294. 'vitamin_large_384.datacomp1b_clip': _cfg(
  295. hf_hub_id='jienengchen/ViTamin-L-384px', num_classes=768,
  296. input_size=(3, 384, 384), crop_pct=1.0),
  297. 'vitamin_large2_224.datacomp1b_clip': _cfg(
  298. hf_hub_id='jienengchen/ViTamin-L2-224px', num_classes=1024),
  299. 'vitamin_large2_256.datacomp1b_clip': _cfg(
  300. hf_hub_id='jienengchen/ViTamin-L2-256px', num_classes=1024,
  301. input_size=(3, 256, 256), crop_pct=1.0),
  302. 'vitamin_large2_336.datacomp1b_clip': _cfg(
  303. hf_hub_id='jienengchen/ViTamin-L2-336px', num_classes=1024,
  304. input_size=(3, 336, 336), crop_pct=1.0),
  305. 'vitamin_large2_384.datacomp1b_clip': _cfg(
  306. hf_hub_id='jienengchen/ViTamin-L2-384px', num_classes=1024,
  307. input_size=(3, 384, 384), crop_pct=1.0),
  308. 'vitamin_xlarge_256.datacomp1b_clip': _cfg(
  309. hf_hub_id='jienengchen/ViTamin-XL-256px', num_classes=1152,
  310. input_size=(3, 256, 256), crop_pct=1.0),
  311. 'vitamin_xlarge_336.datacomp1b_clip': _cfg(
  312. hf_hub_id='jienengchen/ViTamin-XL-336px', num_classes=1152,
  313. input_size=(3, 336, 336), crop_pct=1.0),
  314. 'vitamin_xlarge_384.datacomp1b_clip': _cfg(
  315. hf_hub_id='jienengchen/ViTamin-XL-384px', num_classes=1152,
  316. input_size=(3, 384, 384), crop_pct=1.0),
  317. })
  318. @register_model
  319. def vitamin_small_224(pretrained=False, **kwargs) -> VisionTransformer:
  320. embed_cfg = VitCfg(
  321. embed_dim=(64, 128, 384),
  322. depths=(2, 4, 1),
  323. stem_width=64,
  324. conv_cfg=VitConvCfg(
  325. norm_layer='layernorm2d',
  326. norm_eps=1e-6,
  327. ),
  328. head_type='1d',
  329. )
  330. model_args = dict(
  331. embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2.,
  332. class_token=False, global_pool='avg', embed_cfg=embed_cfg
  333. )
  334. model = _create_vitamin('vitamin_small_224', pretrained=pretrained, **dict(model_args, **kwargs))
  335. return model
  336. @register_model
  337. def vitamin_base_224(pretrained=False, **kwargs) -> VisionTransformer:
  338. embed_cfg = VitCfg(
  339. embed_dim=(128, 256, 768),
  340. depths=(2, 4, 1),
  341. stem_width=128,
  342. conv_cfg=VitConvCfg(
  343. norm_layer='layernorm2d',
  344. norm_eps=1e-6,
  345. ),
  346. head_type='1d',
  347. )
  348. model_args = dict(
  349. embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2.,
  350. class_token=False, global_pool='avg', embed_cfg=embed_cfg)
  351. model = _create_vitamin('vitamin_base_224', pretrained=pretrained, **dict(model_args, **kwargs))
  352. return model
  353. @register_model
  354. def vitamin_large_224(pretrained=False, **kwargs) -> VisionTransformer:
  355. embed_cfg = VitCfg(
  356. embed_dim=(160, 320, 1024),
  357. depths=(2, 4, 1),
  358. stem_width=160,
  359. conv_cfg=VitConvCfg(
  360. norm_layer='layernorm2d',
  361. norm_eps=1e-6,
  362. ),
  363. head_type='1d',
  364. )
  365. model_args = dict(
  366. embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
  367. class_token=False, global_pool='avg', embed_cfg=embed_cfg,
  368. )
  369. model = _create_vitamin('vitamin_large_224', pretrained=pretrained, **dict(model_args, **kwargs))
  370. return model
  371. @register_model
  372. def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
  373. embed_cfg = VitCfg(
  374. embed_dim=(160, 320, 1024),
  375. depths=(2, 4, 1),
  376. stem_width=160,
  377. conv_cfg=VitConvCfg(
  378. norm_layer='layernorm2d',
  379. norm_eps=1e-6,
  380. ),
  381. head_type='1d',
  382. )
  383. model_args = dict(
  384. img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
  385. class_token=False, global_pool='avg', embed_cfg=embed_cfg)
  386. model = _create_vitamin('vitamin_large_256', pretrained=pretrained, **dict(model_args, **kwargs))
  387. return model
  388. @register_model
  389. def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
  390. embed_cfg = VitCfg(
  391. embed_dim=(160, 320, 1024),
  392. depths=(2, 4, 1),
  393. stem_width=160,
  394. conv_cfg=VitConvCfg(
  395. norm_layer='layernorm2d',
  396. norm_eps=1e-6,
  397. ),
  398. head_type='1d',
  399. )
  400. model_args = dict(
  401. img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
  402. class_token=False, global_pool='avg', embed_cfg=embed_cfg
  403. )
  404. model = _create_vitamin('vitamin_large_336', pretrained=pretrained, **dict(model_args, **kwargs))
  405. return model
  406. @register_model
  407. def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
  408. embed_cfg = VitCfg(
  409. embed_dim=(160, 320, 1024),
  410. depths=(2, 4, 1),
  411. stem_width=160,
  412. conv_cfg=VitConvCfg(
  413. norm_layer='layernorm2d',
  414. norm_eps=1e-6,
  415. ),
  416. head_type='1d',
  417. )
  418. model_args = dict(
  419. img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
  420. class_token=False, global_pool='avg', embed_cfg=embed_cfg)
  421. model = _create_vitamin('vitamin_large_384', pretrained=pretrained, **dict(model_args, **kwargs))
  422. return model
  423. @register_model
  424. def vitamin_large2_224(pretrained=False, **kwargs) -> VisionTransformer:
  425. embed_cfg = VitCfg(
  426. embed_dim=(160, 320, 1024),
  427. depths=(2, 4, 1),
  428. stem_width=160,
  429. conv_cfg=VitConvCfg(
  430. norm_layer='layernorm2d',
  431. norm_eps=1e-6,
  432. ),
  433. head_type='1d',
  434. )
  435. model_args = dict(
  436. embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
  437. class_token=False, global_pool='avg', embed_cfg=embed_cfg,
  438. )
  439. model = _create_vitamin('vitamin_large2_224', pretrained=pretrained, **dict(model_args, **kwargs))
  440. return model
  441. @register_model
  442. def vitamin_large2_256(pretrained=False, **kwargs) -> VisionTransformer:
  443. embed_cfg = VitCfg(
  444. embed_dim=(160, 320, 1024),
  445. depths=(2, 4, 1),
  446. stem_width=160,
  447. conv_cfg=VitConvCfg(
  448. norm_layer='layernorm2d',
  449. norm_eps=1e-6,
  450. ),
  451. head_type='1d',
  452. )
  453. model_args = dict(
  454. img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
  455. class_token=False, global_pool='avg', embed_cfg=embed_cfg)
  456. model = _create_vitamin('vitamin_large2_256', pretrained=pretrained, **dict(model_args, **kwargs))
  457. return model
  458. @register_model
  459. def vitamin_large2_336(pretrained=False, **kwargs) -> VisionTransformer:
  460. embed_cfg = VitCfg(
  461. embed_dim=(160, 320, 1024),
  462. depths=(2, 4, 1),
  463. stem_width=160,
  464. conv_cfg=VitConvCfg(
  465. norm_layer='layernorm2d',
  466. norm_eps=1e-6,
  467. ),
  468. head_type='1d',
  469. )
  470. model_args = dict(
  471. img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
  472. class_token=False, global_pool='avg', embed_cfg=embed_cfg
  473. )
  474. model = _create_vitamin('vitamin_large2_336', pretrained=pretrained, **dict(model_args, **kwargs))
  475. return model
  476. @register_model
  477. def vitamin_large2_384(pretrained=False, **kwargs) -> VisionTransformer:
  478. embed_cfg = VitCfg(
  479. embed_dim=(160, 320, 1024),
  480. depths=(2, 4, 1),
  481. stem_width=160,
  482. conv_cfg=VitConvCfg(
  483. norm_layer='layernorm2d',
  484. norm_eps=1e-6,
  485. ),
  486. head_type='1d',
  487. )
  488. model_args = dict(
  489. img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
  490. class_token=False, global_pool='avg', embed_cfg=embed_cfg)
  491. model = _create_vitamin('vitamin_large2_384', pretrained=pretrained, **dict(model_args, **kwargs))
  492. return model
  493. @register_model
  494. def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
  495. embed_cfg=VitCfg(
  496. embed_dim=(192, 384, 1152),
  497. depths=(2, 4, 1),
  498. stem_width=192,
  499. conv_cfg=VitConvCfg(
  500. norm_layer='layernorm2d',
  501. norm_eps=1e-6,
  502. ),
  503. head_type='1d',
  504. )
  505. model_args = dict(
  506. img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
  507. class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
  508. model = _create_vitamin(
  509. 'vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs))
  510. return model
  511. @register_model
  512. def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
  513. embed_cfg = VitCfg(
  514. embed_dim=(192, 384, 1152),
  515. depths=(2, 4, 1),
  516. stem_width=192,
  517. conv_cfg=VitConvCfg(
  518. norm_layer='layernorm2d',
  519. norm_eps=1e-6,
  520. ),
  521. head_type='1d',
  522. )
  523. model_args = dict(
  524. img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
  525. class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
  526. model = _create_vitamin('vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs))
  527. return model
  528. @register_model
  529. def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
  530. embed_cfg = VitCfg(
  531. embed_dim=(192, 384, 1152),
  532. depths=(2, 4, 1),
  533. stem_width=192,
  534. conv_cfg=VitConvCfg(
  535. norm_layer='layernorm2d',
  536. norm_eps=1e-6,
  537. ),
  538. head_type='1d',
  539. )
  540. model_args = dict(
  541. img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
  542. class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
  543. model = _create_vitamin('vitamin_xlarge_384', pretrained=pretrained, **dict(model_args, **kwargs))
  544. return model