| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626 |
- """ ViTamin
- Paper: Designing Scalable Vison Models in the Vision-Language Era
- A family of model weights on Huggingface: https://huggingface.co/collections/jienengchen/vitamin-family-661048126b72debdaca060bf
- @inproceedings{chen2024vitamin,
- title={ViTamin: Designing Scalable Vision Models in the Vision-language Era},
- author={Chen, Jieneng and Yu, Qihang and Shen, Xiaohui and Yuille, Alan and Chen, Liang-Chieh},
- booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
- year={2024}
- }
- Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin
- Modifications and timm support by Jieneng Chen 2024
- Reference:
- https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
- https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
- """
- import math
- from dataclasses import dataclass, field
- from functools import partial
- from typing import Optional, Union, Tuple
- import torch
- import torch.nn as nn
- from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
- from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \
- make_divisible, DropPath, HybridEmbed
- from ._builder import build_model_with_cfg
- from ._manipulate import named_apply, checkpoint_seq
- from ._registry import register_model, generate_default_cfgs
- from .vision_transformer import VisionTransformer, checkpoint_filter_fn
- @dataclass
- class VitConvCfg:
- expand_ratio: float = 4.0
- expand_output: bool = True # calculate expansion channels from output (vs input chs)
- kernel_size: int = 3
- group_size: int = 1 # 1 == depthwise
- pre_norm_act: bool = False # activation after pre-norm
- stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
- pool_type: str = 'avg2'
- downsample_pool_type: str = 'avg2'
- act_layer: str = 'gelu' # stem & stage 1234
- norm_layer: str = ''
- norm_eps: float = 1e-5
- down_shortcut: Optional[bool] = True
- mlp: str = 'mlp'
- @dataclass
- class VitCfg:
- embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768)
- depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2)
- stem_width: int = 64
- conv_cfg: VitConvCfg = field(default_factory=VitConvCfg)
- head_type: str = ""
- def _init_conv(module, name, scheme=''):
- if isinstance(module, nn.Conv2d):
- fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
- fan_out //= module.groups
- nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- class Stem(nn.Module):
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- act_layer: str = 'gelu',
- norm_layer: str = 'layernorm2d',
- norm_eps: float = 1e-6,
- bias: bool = True,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
- self.out_chs = out_chs
- self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias, **dd)
- self.norm1 = norm_act_layer(out_chs, **dd)
- self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias, **dd)
- named_apply(_init_conv, self)
- def forward(self, x):
- x = self.conv1(x)
- x = self.norm1(x)
- x = self.conv2(x)
- return x
- class Downsample2d(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: int,
- pool_type: str = 'avg2',
- bias: bool = True,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
- if dim != dim_out:
- self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias, **dd) # 1x1 conv
- else:
- self.expand = nn.Identity()
- def forward(self, x):
- x = self.pool(x) # spatial downsample
- x = self.expand(x) # expand chs
- return x
- class StridedConv(nn.Module):
- """ downsample 2d as well
- """
- def __init__(
- self,
- kernel_size: int = 3,
- stride: int = 2,
- padding: int = 1,
- in_chans: int = 3,
- embed_dim: int = 768,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding, **dd)
- self.norm = norm_layer(in_chans, **dd) # affine over C
- def forward(self, x):
- x = self.norm(x)
- x = self.proj(x)
- return x
- class MbConvLNBlock(nn.Module):
- """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)
- """
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- stride: int = 1,
- drop_path: float = 0.,
- kernel_size: int = 3,
- norm_layer: str = 'layernorm2d',
- norm_eps: float = 1e-6,
- act_layer: str = 'gelu',
- expand_ratio: float = 4.0,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs
- mid_chs = make_divisible(out_chs * expand_ratio)
- prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
- if stride == 2:
- self.shortcut = Downsample2d(in_chs, out_chs, pool_type='avg', bias=True, **dd)
- elif in_chs != out_chs:
- self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True, **dd)
- else:
- self.shortcut = nn.Identity()
- self.pre_norm = prenorm_act_layer(in_chs, apply_act=False, **dd)
- self.down = nn.Identity()
- self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True, **dd)
- self.act1 = create_act_layer(act_layer, inplace=True)
- self.conv2_kxk = create_conv2d(
- mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True, **dd)
- self.act2 = create_act_layer(act_layer, inplace=True)
- self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True, **dd)
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- def init_weights(self, scheme=''):
- named_apply(partial(_init_conv, scheme=scheme), self)
- def forward(self, x):
- shortcut = self.shortcut(x)
- x = self.pre_norm(x)
- x = self.down(x) # nn.Identity()
- # 1x1 expansion conv & act
- x = self.conv1_1x1(x)
- x = self.act1(x)
- # (strided) depthwise 3x3 conv & act
- x = self.conv2_kxk(x)
- x = self.act2(x)
- # 1x1 linear projection to output width
- x = self.conv3_1x1(x)
- x = self.drop_path(x) + shortcut
- return x
- class MbConvStages(nn.Module):
- """ MobileConv for stage 1 and stage 2 of ViTamin
- """
- def __init__(
- self,
- cfg: VitCfg,
- img_size: Union[int, Tuple[int, int]] = 224, # place holder
- in_chans: int = 3,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.grad_checkpointing = False
- self.stem = Stem(
- in_chs=in_chans,
- out_chs=cfg.stem_width,
- **dd,
- )
- stages = []
- self.num_stages = len(cfg.embed_dim)
- for s, dim in enumerate(cfg.embed_dim[:2]): # stage
- stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width
- blocks = [
- MbConvLNBlock(
- in_chs = stage_in_chs if d==0 else dim,
- out_chs = dim,
- stride = 2 if d == 0 else 1,
- **dd,
- )
- for d in range(cfg.depths[s])
- ]
- stages += [nn.Sequential(*blocks)]
- self.stages = nn.Sequential(*stages)
- self.pool = StridedConv(
- stride=2,
- in_chans=cfg.embed_dim[1],
- embed_dim=cfg.embed_dim[2],
- **dd,
- )
- def forward(self, x):
- x = self.stem(x)
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.stages, x)
- else:
- x = self.stages(x)
- x = self.pool(x)
- return x
- class GeGluMlp(nn.Module):
- def __init__(
- self,
- in_features: int,
- hidden_features: int,
- act_layer: str = 'gelu',
- norm_layer: Optional[str] = None,
- bias: bool = True,
- drop: float = 0.0,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- norm_layer = partial(get_norm_layer(norm_layer or 'layernorm'), eps=1e-6)
- self.norm = norm_layer(in_features, **dd)
- self.w0 = nn.Linear(in_features, hidden_features, bias=bias, **dd)
- self.act = create_act_layer(act_layer)
- self.w1 = nn.Linear(in_features, hidden_features, bias=bias, **dd)
- self.w2 = nn.Linear(hidden_features, in_features, bias=bias, **dd)
- def forward(self, x):
- x = self.norm(x)
- x = self.act(self.w0(x)) * self.w1(x)
- x = self.w2(x)
- return x
- def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs):
- out_indices = kwargs.pop('out_indices', 3)
- assert embed_cfg is not None
- dd = {'device': kwargs.get('device', None), 'dtype': kwargs.get('dtype', None)}
- backbone = MbConvStages(cfg=embed_cfg, in_chans=kwargs.get('in_chans', 3), **dd)
- kwargs['embed_layer'] = partial(HybridEmbed, backbone=backbone, proj=False)
- kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
- return build_model_with_cfg(
- VisionTransformer,
- variant,
- pretrained,
- pretrained_filter_fn=checkpoint_filter_fn,
- feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
- **kwargs,
- )
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
- 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
- 'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
- 'first_conv': 'patch_embed.backbone.stem.conv1',
- 'classifier': 'head', 'license': 'mit',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- 'vitamin_small_224.datacomp1b_clip_ltt': _cfg(
- hf_hub_id='jienengchen/ViTamin-S-LTT', num_classes=768),
- 'vitamin_small_224.datacomp1b_clip': _cfg(
- hf_hub_id='jienengchen/ViTamin-S', num_classes=384),
- 'vitamin_base_224.datacomp1b_clip_ltt': _cfg(
- hf_hub_id='jienengchen/ViTamin-B-LTT', num_classes=768),
- 'vitamin_base_224.datacomp1b_clip': _cfg(
- hf_hub_id='jienengchen/ViTamin-B', num_classes=768),
- 'vitamin_large_224.datacomp1b_clip': _cfg(
- hf_hub_id='jienengchen/ViTamin-L-224px', num_classes=768),
- 'vitamin_large_256.datacomp1b_clip': _cfg(
- hf_hub_id='jienengchen/ViTamin-L-256px', num_classes=768,
- input_size=(3, 256, 256), crop_pct=1.0),
- 'vitamin_large_336.datacomp1b_clip': _cfg(
- hf_hub_id='jienengchen/ViTamin-L-336px', num_classes=768,
- input_size=(3, 336, 336), crop_pct=1.0),
- 'vitamin_large_384.datacomp1b_clip': _cfg(
- hf_hub_id='jienengchen/ViTamin-L-384px', num_classes=768,
- input_size=(3, 384, 384), crop_pct=1.0),
- 'vitamin_large2_224.datacomp1b_clip': _cfg(
- hf_hub_id='jienengchen/ViTamin-L2-224px', num_classes=1024),
- 'vitamin_large2_256.datacomp1b_clip': _cfg(
- hf_hub_id='jienengchen/ViTamin-L2-256px', num_classes=1024,
- input_size=(3, 256, 256), crop_pct=1.0),
- 'vitamin_large2_336.datacomp1b_clip': _cfg(
- hf_hub_id='jienengchen/ViTamin-L2-336px', num_classes=1024,
- input_size=(3, 336, 336), crop_pct=1.0),
- 'vitamin_large2_384.datacomp1b_clip': _cfg(
- hf_hub_id='jienengchen/ViTamin-L2-384px', num_classes=1024,
- input_size=(3, 384, 384), crop_pct=1.0),
- 'vitamin_xlarge_256.datacomp1b_clip': _cfg(
- hf_hub_id='jienengchen/ViTamin-XL-256px', num_classes=1152,
- input_size=(3, 256, 256), crop_pct=1.0),
- 'vitamin_xlarge_336.datacomp1b_clip': _cfg(
- hf_hub_id='jienengchen/ViTamin-XL-336px', num_classes=1152,
- input_size=(3, 336, 336), crop_pct=1.0),
- 'vitamin_xlarge_384.datacomp1b_clip': _cfg(
- hf_hub_id='jienengchen/ViTamin-XL-384px', num_classes=1152,
- input_size=(3, 384, 384), crop_pct=1.0),
- })
- @register_model
- def vitamin_small_224(pretrained=False, **kwargs) -> VisionTransformer:
- embed_cfg = VitCfg(
- embed_dim=(64, 128, 384),
- depths=(2, 4, 1),
- stem_width=64,
- conv_cfg=VitConvCfg(
- norm_layer='layernorm2d',
- norm_eps=1e-6,
- ),
- head_type='1d',
- )
- model_args = dict(
- embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2.,
- class_token=False, global_pool='avg', embed_cfg=embed_cfg
- )
- model = _create_vitamin('vitamin_small_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vitamin_base_224(pretrained=False, **kwargs) -> VisionTransformer:
- embed_cfg = VitCfg(
- embed_dim=(128, 256, 768),
- depths=(2, 4, 1),
- stem_width=128,
- conv_cfg=VitConvCfg(
- norm_layer='layernorm2d',
- norm_eps=1e-6,
- ),
- head_type='1d',
- )
- model_args = dict(
- embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2.,
- class_token=False, global_pool='avg', embed_cfg=embed_cfg)
- model = _create_vitamin('vitamin_base_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vitamin_large_224(pretrained=False, **kwargs) -> VisionTransformer:
- embed_cfg = VitCfg(
- embed_dim=(160, 320, 1024),
- depths=(2, 4, 1),
- stem_width=160,
- conv_cfg=VitConvCfg(
- norm_layer='layernorm2d',
- norm_eps=1e-6,
- ),
- head_type='1d',
- )
- model_args = dict(
- embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
- class_token=False, global_pool='avg', embed_cfg=embed_cfg,
- )
- model = _create_vitamin('vitamin_large_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
- embed_cfg = VitCfg(
- embed_dim=(160, 320, 1024),
- depths=(2, 4, 1),
- stem_width=160,
- conv_cfg=VitConvCfg(
- norm_layer='layernorm2d',
- norm_eps=1e-6,
- ),
- head_type='1d',
- )
- model_args = dict(
- img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
- class_token=False, global_pool='avg', embed_cfg=embed_cfg)
- model = _create_vitamin('vitamin_large_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
- embed_cfg = VitCfg(
- embed_dim=(160, 320, 1024),
- depths=(2, 4, 1),
- stem_width=160,
- conv_cfg=VitConvCfg(
- norm_layer='layernorm2d',
- norm_eps=1e-6,
- ),
- head_type='1d',
- )
- model_args = dict(
- img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
- class_token=False, global_pool='avg', embed_cfg=embed_cfg
- )
- model = _create_vitamin('vitamin_large_336', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
- embed_cfg = VitCfg(
- embed_dim=(160, 320, 1024),
- depths=(2, 4, 1),
- stem_width=160,
- conv_cfg=VitConvCfg(
- norm_layer='layernorm2d',
- norm_eps=1e-6,
- ),
- head_type='1d',
- )
- model_args = dict(
- img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
- class_token=False, global_pool='avg', embed_cfg=embed_cfg)
- model = _create_vitamin('vitamin_large_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vitamin_large2_224(pretrained=False, **kwargs) -> VisionTransformer:
- embed_cfg = VitCfg(
- embed_dim=(160, 320, 1024),
- depths=(2, 4, 1),
- stem_width=160,
- conv_cfg=VitConvCfg(
- norm_layer='layernorm2d',
- norm_eps=1e-6,
- ),
- head_type='1d',
- )
- model_args = dict(
- embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
- class_token=False, global_pool='avg', embed_cfg=embed_cfg,
- )
- model = _create_vitamin('vitamin_large2_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vitamin_large2_256(pretrained=False, **kwargs) -> VisionTransformer:
- embed_cfg = VitCfg(
- embed_dim=(160, 320, 1024),
- depths=(2, 4, 1),
- stem_width=160,
- conv_cfg=VitConvCfg(
- norm_layer='layernorm2d',
- norm_eps=1e-6,
- ),
- head_type='1d',
- )
- model_args = dict(
- img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
- class_token=False, global_pool='avg', embed_cfg=embed_cfg)
- model = _create_vitamin('vitamin_large2_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vitamin_large2_336(pretrained=False, **kwargs) -> VisionTransformer:
- embed_cfg = VitCfg(
- embed_dim=(160, 320, 1024),
- depths=(2, 4, 1),
- stem_width=160,
- conv_cfg=VitConvCfg(
- norm_layer='layernorm2d',
- norm_eps=1e-6,
- ),
- head_type='1d',
- )
- model_args = dict(
- img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
- class_token=False, global_pool='avg', embed_cfg=embed_cfg
- )
- model = _create_vitamin('vitamin_large2_336', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vitamin_large2_384(pretrained=False, **kwargs) -> VisionTransformer:
- embed_cfg = VitCfg(
- embed_dim=(160, 320, 1024),
- depths=(2, 4, 1),
- stem_width=160,
- conv_cfg=VitConvCfg(
- norm_layer='layernorm2d',
- norm_eps=1e-6,
- ),
- head_type='1d',
- )
- model_args = dict(
- img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
- class_token=False, global_pool='avg', embed_cfg=embed_cfg)
- model = _create_vitamin('vitamin_large2_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
- embed_cfg=VitCfg(
- embed_dim=(192, 384, 1152),
- depths=(2, 4, 1),
- stem_width=192,
- conv_cfg=VitConvCfg(
- norm_layer='layernorm2d',
- norm_eps=1e-6,
- ),
- head_type='1d',
- )
- model_args = dict(
- img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
- class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
- model = _create_vitamin(
- 'vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
- embed_cfg = VitCfg(
- embed_dim=(192, 384, 1152),
- depths=(2, 4, 1),
- stem_width=192,
- conv_cfg=VitConvCfg(
- norm_layer='layernorm2d',
- norm_eps=1e-6,
- ),
- head_type='1d',
- )
- model_args = dict(
- img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
- class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
- model = _create_vitamin('vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
- embed_cfg = VitCfg(
- embed_dim=(192, 384, 1152),
- depths=(2, 4, 1),
- stem_width=192,
- conv_cfg=VitConvCfg(
- norm_layer='layernorm2d',
- norm_eps=1e-6,
- ),
- head_type='1d',
- )
- model_args = dict(
- img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
- class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
- model = _create_vitamin('vitamin_xlarge_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
|