| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839 |
- """ Vision Transformer (ViT) in PyTorch
- A PyTorch implement of Vision Transformers as described in:
- 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
- - https://arxiv.org/abs/2010.11929
- `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
- - https://arxiv.org/abs/2106.10270
- `FlexiViT: One Model for All Patch Sizes`
- - https://arxiv.org/abs/2212.08013
- The official jax code is released and available at
- * https://github.com/google-research/vision_transformer
- * https://github.com/google-research/big_vision
- Acknowledgments:
- * The paper authors for releasing code and weights, thanks!
- * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch
- * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
- * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
- Hacked together by / Copyright 2020, Ross Wightman
- """
- import copy
- import logging
- import math
- import os
- from collections import OrderedDict
- from functools import partial
- from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List
- try:
- from typing import Literal
- except ImportError:
- from typing_extensions import Literal
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.jit import Final
- from timm.data import (
- IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD,
- IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,
- OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
- )
- from timm.layers import (
- Attention,
- DiffAttention,
- AttentionPoolLatent,
- AttentionPoolPrr,
- PatchEmbed,
- Mlp,
- SwiGLUPacked,
- SwiGLU,
- LayerNorm,
- RmsNorm,
- DropPath,
- calculate_drop_path_rates,
- PatchDropout,
- trunc_normal_,
- lecun_normal_,
- resample_patch_embed,
- resample_abs_pos_embed,
- use_fused_attn,
- get_act_layer,
- get_norm_layer,
- maybe_add_mask,
- resolve_self_attn_mask,
- LayerType,
- LayerScale,
- )
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._manipulate import named_apply, checkpoint, checkpoint_seq, adapt_input_conv
- from ._registry import generate_default_cfgs, register_model, register_model_deprecations
- __all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this
- _logger = logging.getLogger(__name__)
- ATTN_LAYERS = {
- '': Attention,
- 'attn': Attention,
- 'diff': DiffAttention,
- }
- def _create_attn(
- attn_layer: LayerType,
- dim: int,
- num_heads: int,
- qkv_bias: bool = False,
- qk_norm: bool = False,
- scale_norm: bool = False,
- proj_bias: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: Optional[Type[nn.Module]] = None,
- depth: int = 0,
- **kwargs,
- ) -> nn.Module:
- if isinstance(attn_layer, str):
- attn_layer = ATTN_LAYERS.get(attn_layer, None)
- assert attn_layer is not None, f'Unknown attn_layer: {attn_layer}'
- # Only pass depth to attention layers that use it
- if issubclass(attn_layer, DiffAttention):
- kwargs['depth'] = depth
- return attn_layer(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_norm=qk_norm,
- scale_norm=scale_norm,
- proj_bias=proj_bias,
- attn_drop=attn_drop,
- proj_drop=proj_drop,
- norm_layer=norm_layer,
- **kwargs,
- )
- class Block(nn.Module):
- """Transformer block with pre-normalization."""
- def __init__(
- self,
- dim: int,
- num_heads: int,
- mlp_ratio: float = 4.,
- qkv_bias: bool = False,
- qk_norm: bool = False,
- scale_attn_norm: bool = False,
- scale_mlp_norm: bool = False,
- proj_bias: bool = True,
- proj_drop: float = 0.,
- attn_drop: float = 0.,
- init_values: Optional[float] = None,
- drop_path: float = 0.,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Type[nn.Module] = LayerNorm,
- mlp_layer: Type[nn.Module] = Mlp,
- attn_layer: LayerType = Attention,
- depth: int = 0,
- device=None,
- dtype=None,
- ) -> None:
- """Initialize Block.
- Args:
- dim: Number of input channels.
- num_heads: Number of attention heads.
- mlp_ratio: Ratio of mlp hidden dim to embedding dim.
- qkv_bias: If True, add a learnable bias to query, key, value.
- qk_norm: If True, apply normalization to query and key.
- proj_bias: If True, add bias to output projection.
- proj_drop: Projection dropout rate.
- attn_drop: Attention dropout rate.
- init_values: Initial values for layer scale.
- drop_path: Stochastic depth rate.
- act_layer: Activation layer.
- norm_layer: Normalization layer.
- mlp_layer: MLP layer.
- attn_layer: Attention layer type (class or string).
- depth: Block index, passed to attention layer for depth-dependent init.
- """
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.norm1 = norm_layer(dim, **dd)
- self.attn = _create_attn(
- attn_layer,
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_norm=qk_norm,
- scale_norm=scale_attn_norm,
- proj_bias=proj_bias,
- attn_drop=attn_drop,
- proj_drop=proj_drop,
- norm_layer=norm_layer,
- depth=depth,
- **dd,
- )
- self.ls1 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()
- self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim, **dd)
- self.mlp = mlp_layer(
- in_features=dim,
- hidden_features=int(dim * mlp_ratio),
- act_layer=act_layer,
- norm_layer=norm_layer if scale_mlp_norm else None,
- bias=proj_bias,
- drop=proj_drop,
- **dd,
- )
- self.ls2 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()
- self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- def forward(
- self,
- x: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- is_causal: bool = False,
- ) -> torch.Tensor:
- x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask, is_causal=is_causal)))
- x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
- return x
- class ResPostBlock(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int,
- mlp_ratio: float = 4.,
- qkv_bias: bool = False,
- qk_norm: bool = False,
- scale_attn_norm: bool = False,
- scale_mlp_norm: bool = False,
- proj_bias: bool = True,
- proj_drop: float = 0.,
- attn_drop: float = 0.,
- init_values: Optional[float] = None,
- drop_path: float = 0.,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Type[nn.Module] = LayerNorm,
- mlp_layer: Type[nn.Module] = Mlp,
- attn_layer: LayerType = Attention,
- depth: int = 0,
- device=None,
- dtype=None,
- ) -> None:
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.init_values = init_values
- self.attn = _create_attn(
- attn_layer,
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_norm=qk_norm,
- scale_norm=scale_attn_norm,
- proj_bias=proj_bias,
- attn_drop=attn_drop,
- proj_drop=proj_drop,
- norm_layer=norm_layer,
- depth=depth,
- **dd,
- )
- self.norm1 = norm_layer(dim, **dd)
- self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.mlp = mlp_layer(
- in_features=dim,
- hidden_features=int(dim * mlp_ratio),
- act_layer=act_layer,
- norm_layer=norm_layer if scale_mlp_norm else None,
- bias=proj_bias,
- drop=proj_drop,
- **dd,
- )
- self.norm2 = norm_layer(dim, **dd)
- self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.init_weights()
- def init_weights(self) -> None:
- # NOTE this init overrides that base model init with specific changes for the block type
- if self.init_values is not None:
- nn.init.constant_(self.norm1.weight, self.init_values)
- nn.init.constant_(self.norm2.weight, self.init_values)
- def forward(
- self,
- x: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- is_causal: bool = False,
- ) -> torch.Tensor:
- x = x + self.drop_path1(self.norm1(self.attn(x, attn_mask=attn_mask, is_causal=is_causal)))
- x = x + self.drop_path2(self.norm2(self.mlp(x)))
- return x
- class ParallelScalingBlock(nn.Module):
- """ Parallel ViT block (MLP & Attention in parallel)
- Based on:
- 'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
- """
- fused_attn: Final[bool]
- def __init__(
- self,
- dim: int,
- num_heads: int,
- mlp_ratio: float = 4.,
- qkv_bias: bool = False,
- qk_norm: bool = False,
- scale_attn_norm: bool = False,
- scale_mlp_norm: bool = False,
- proj_bias: bool = True,
- proj_drop: float = 0.,
- attn_drop: float = 0.,
- init_values: Optional[float] = None,
- drop_path: float = 0.,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Type[nn.Module] = LayerNorm,
- mlp_layer: Optional[Type[nn.Module]] = None, # not used
- attn_layer: Optional[LayerType] = None, # not used
- depth: int = 0, # not used
- fuse_out_proj: bool = False,
- device=None,
- dtype=None,
- ) -> None:
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
- assert not scale_attn_norm and not scale_mlp_norm, 'Scale norms not supported'
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
- self.fused_attn = use_fused_attn()
- mlp_hidden_dim = int(mlp_ratio * dim)
- in_proj_out_dim = mlp_hidden_dim + 3 * dim
- self.in_norm = norm_layer(dim, **dd)
- self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias, **dd)
- self.in_split = [mlp_hidden_dim] + [dim] * 3
- if qkv_bias:
- # mlp_bias is combined with qkv_bias in in_proj.bias
- self.register_parameter('mlp_bias', None)
- else:
- self.mlp_bias = nn.Parameter(torch.empty(mlp_hidden_dim, **dd))
- self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.mlp_drop = nn.Dropout(proj_drop)
- self.mlp_act = act_layer()
- if fuse_out_proj:
- # Fused output projection for both attention and MLP
- self.out_proj = nn.Linear(dim + mlp_hidden_dim, dim, bias=proj_bias, **dd)
- self.attn_out_proj = None
- self.mlp_out_proj = None
- else:
- # Separate output projections
- self.out_proj = None
- self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias, **dd)
- self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias, **dd)
- self.ls = LayerScale(dim, init_values=init_values, **dd) if init_values is not None else nn.Identity()
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- # TODO: skip init when on meta device when safe to do so
- self.reset_parameters()
- def reset_parameters(self) -> None:
- """Initialize parameters and buffers."""
- if self.mlp_bias is not None:
- nn.init.zeros_(self.mlp_bias)
- def forward(
- self,
- x: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- is_causal: bool = False,
- ) -> torch.Tensor:
- B, N, C = x.shape
- # Combined MLP fc1 & qkv projections
- y = self.in_norm(x)
- y = self.in_proj(y)
- x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1)
- if self.mlp_bias is not None:
- x_mlp = x_mlp + self.mlp_bias
- # Dot product attention w/ qk norm
- q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
- k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
- v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
- if self.fused_attn:
- x_attn = F.scaled_dot_product_attention(
- q, k, v,
- attn_mask=attn_mask,
- dropout_p=self.attn_drop.p if self.training else 0.,
- is_causal=is_causal,
- )
- else:
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
- attn_bias = resolve_self_attn_mask(N, attn, attn_mask, is_causal=is_causal)
- attn = maybe_add_mask(attn, attn_bias)
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x_attn = attn @ v
- x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
- # MLP activation & dropout
- x_mlp = self.mlp_act(x_mlp)
- x_mlp = self.mlp_drop(x_mlp)
- # Output projection (fused or separate)
- if self.out_proj is not None:
- y = self.out_proj(torch.cat((x_attn, x_mlp), dim=-1))
- else:
- y = self.attn_out_proj(x_attn) + self.mlp_out_proj(x_mlp)
- # Add residual w/ drop path & layer scale applied
- x = x + self.drop_path(self.ls(y))
- return x
- class DiffParallelScalingBlock(nn.Module):
- """ Parallel ViT block with Differential Attention (MLP & Attention in parallel).
- Combines the parallel MLP+Attention structure from 'Scaling Vision Transformers to
- 22 Billion Parameters' (https://arxiv.org/abs/2302.05442) with differential attention
- from 'Differential Transformer' (https://arxiv.org/abs/2410.05258).
- """
- fused_attn: Final[bool]
- def __init__(
- self,
- dim: int,
- num_heads: int,
- mlp_ratio: float = 4.,
- qkv_bias: bool = False,
- qk_norm: bool = False,
- scale_attn_norm: bool = False,
- scale_mlp_norm: bool = False,
- proj_bias: bool = True,
- proj_drop: float = 0.,
- attn_drop: float = 0.,
- init_values: Optional[float] = None,
- drop_path: float = 0.,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Type[nn.Module] = LayerNorm,
- mlp_layer: Optional[Type[nn.Module]] = None,
- attn_layer: Optional[LayerType] = None,
- depth: int = 0,
- dual_lambda: bool = False,
- device=None,
- dtype=None,
- ) -> None:
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
- assert not scale_attn_norm and not scale_mlp_norm, 'Scale norms not supported'
- self.num_heads = num_heads
- self.head_dim = dim // num_heads // 2 # Half head_dim for diff attention
- self.scale = self.head_dim ** -0.5
- self.fused_attn = use_fused_attn()
- mlp_hidden_dim = int(mlp_ratio * dim)
- in_proj_out_dim = mlp_hidden_dim + 3 * dim
- self.in_norm = norm_layer(dim, **dd)
- self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias, **dd)
- self.in_split = [mlp_hidden_dim] + [dim] * 3
- if qkv_bias:
- # mlp_bias is combined with qkv_bias in in_proj.bias
- self.register_parameter('mlp_bias', None)
- else:
- self.mlp_bias = nn.Parameter(torch.empty(mlp_hidden_dim, **dd))
- self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.attn_drop_p = attn_drop
- # Differential attention specific
- self.sub_norm = RmsNorm(2 * self.head_dim, eps=1e-5, **dd)
- self.dual_lambda = dual_lambda
- if dual_lambda:
- self.lambda_a = nn.Parameter(torch.empty((), dtype=torch.float32, device=device))
- self.lambda_b = nn.Parameter(torch.empty((), dtype=torch.float32, device=device))
- self.lambda_q1 = self.lambda_k1 = self.lambda_q2 = self.lambda_k2 = None
- else:
- self.lambda_a = self.lambda_b = None
- self.lambda_q1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
- self.lambda_k1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
- self.lambda_q2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
- self.lambda_k2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
- self.mlp_drop = nn.Dropout(proj_drop)
- self.mlp_act = act_layer()
- # Fused output projection for both attention and MLP
- self.out_proj = nn.Linear(dim + mlp_hidden_dim, dim, bias=proj_bias, **dd)
- self.ls = LayerScale(dim, init_values=init_values, **dd) if init_values is not None else nn.Identity()
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.lambda_init = 0.8
- self.set_lambda_init(depth)
- # TODO: skip init when on meta device when safe to do so
- self.reset_parameters()
- def set_lambda_init(self, depth: int):
- self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)
- def reset_parameters(self) -> None:
- """Initialize parameters and buffers."""
- if self.mlp_bias is not None:
- nn.init.zeros_(self.mlp_bias)
- if self.dual_lambda:
- nn.init.zeros_(self.lambda_a)
- nn.init.zeros_(self.lambda_b)
- else:
- nn.init.normal_(self.lambda_q1, mean=0, std=0.1)
- nn.init.normal_(self.lambda_k1, mean=0, std=0.1)
- nn.init.normal_(self.lambda_q2, mean=0, std=0.1)
- nn.init.normal_(self.lambda_k2, mean=0, std=0.1)
- def _compute_lambda(self) -> torch.Tensor:
- if self.lambda_a is not None:
- lambda_1 = torch.exp(self.lambda_a)
- lambda_2 = torch.exp(self.lambda_b)
- else:
- lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float())
- lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float())
- return lambda_1 - lambda_2 + self.lambda_init
- def forward(
- self,
- x: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- is_causal: bool = False,
- ) -> torch.Tensor:
- B, N, C = x.shape
- # Combined MLP fc1 & qkv projections
- y = self.in_norm(x)
- y = self.in_proj(y)
- x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1)
- if self.mlp_bias is not None:
- x_mlp = x_mlp + self.mlp_bias
- # Reshape for differential attention (2x heads with half head_dim for q/k)
- q = q.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
- k = k.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
- v = v.reshape(B, N, self.num_heads, 2 * self.head_dim).transpose(1, 2)
- q, k = self.q_norm(q), self.k_norm(k)
- lambda_full = self._compute_lambda().type_as(q)
- if self.fused_attn:
- q = q.reshape(B, self.num_heads, 2, N, self.head_dim)
- k = k.reshape(B, self.num_heads, 2, N, self.head_dim)
- q1, q2 = q.unbind(2)
- k1, k2 = k.unbind(2)
- dropout_p = self.attn_drop_p if self.training else 0.0
- attn1 = F.scaled_dot_product_attention(q1, k1, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
- attn2 = F.scaled_dot_product_attention(q2, k2, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
- x_attn = attn1 - lambda_full * attn2
- else:
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
- attn_bias = resolve_self_attn_mask(N, attn, attn_mask, is_causal=is_causal)
- attn = maybe_add_mask(attn, attn_bias)
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- attn = attn.view(B, self.num_heads, 2, N, N)
- attn = attn[:, :, 0] - lambda_full * attn[:, :, 1]
- x_attn = attn @ v
- x_attn = self.sub_norm(x_attn)
- x_attn = x_attn * (1 - self.lambda_init)
- x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
- # MLP activation & dropout
- x_mlp = self.mlp_act(x_mlp)
- x_mlp = self.mlp_drop(x_mlp)
- # Fused output projection
- y = self.out_proj(torch.cat((x_attn, x_mlp), dim=-1))
- # Add residual w/ drop path & layer scale applied
- x = x + self.drop_path(self.ls(y))
- return x
- class ParallelThingsBlock(nn.Module):
- """ Parallel ViT block (N parallel attention followed by N parallel MLP)
- Based on:
- `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
- """
- def __init__(
- self,
- dim: int,
- num_heads: int,
- num_parallel: int = 2,
- mlp_ratio: float = 4.,
- qkv_bias: bool = False,
- qk_norm: bool = False,
- scale_attn_norm: bool = False,
- scale_mlp_norm: bool = False,
- proj_bias: bool = True,
- init_values: Optional[float] = None,
- proj_drop: float = 0.,
- attn_drop: float = 0.,
- drop_path: float = 0.,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Type[nn.Module] = LayerNorm,
- mlp_layer: Type[nn.Module] = Mlp,
- attn_layer: LayerType = Attention,
- depth: int = 0,
- device=None,
- dtype=None,
- ) -> None:
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.num_parallel = num_parallel
- self.attns = nn.ModuleList()
- self.ffns = nn.ModuleList()
- for _ in range(num_parallel):
- self.attns.append(nn.Sequential(OrderedDict([
- ('norm', norm_layer(dim, **dd)),
- ('attn', _create_attn(
- attn_layer,
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_norm=qk_norm,
- scale_norm=scale_attn_norm,
- proj_bias=proj_bias,
- attn_drop=attn_drop,
- proj_drop=proj_drop,
- norm_layer=norm_layer,
- depth=depth,
- **dd,
- )),
- ('ls', LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()),
- ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
- ])))
- self.ffns.append(nn.Sequential(OrderedDict([
- ('norm', norm_layer(dim, **dd)),
- ('mlp', mlp_layer(
- dim,
- hidden_features=int(dim * mlp_ratio),
- act_layer=act_layer,
- norm_layer=norm_layer if scale_mlp_norm else None,
- bias=proj_bias,
- drop=proj_drop,
- **dd,
- )),
- ('ls', LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()),
- ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
- ])))
- def forward(
- self,
- x: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- is_causal: bool = False,
- ) -> torch.Tensor:
- if attn_mask is not None or is_causal:
- attn_out = []
- for attn in self.attns:
- x_attn = attn.norm(x)
- x_attn = attn.attn(x_attn, attn_mask=attn_mask, is_causal=is_causal)
- x_attn = attn.ls(x_attn)
- x_attn = attn.drop_path(x_attn)
- attn_out.append(x_attn)
- x = x + torch.stack(attn_out).sum(dim=0)
- else:
- x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
- x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
- return x
- def global_pool_nlc(
- x: torch.Tensor,
- pool_type: str = 'token',
- num_prefix_tokens: int = 1,
- reduce_include_prefix: bool = False,
- ):
- if not pool_type:
- return x
- if pool_type == 'token':
- x = x[:, 0] # class token
- else:
- x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
- if pool_type == 'avg':
- x = x.mean(dim=1)
- elif pool_type == 'avgmax':
- x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
- elif pool_type == 'max':
- x = x.amax(dim=1)
- else:
- assert not pool_type, f'Unknown pool type {pool_type}'
- return x
- class VisionTransformer(nn.Module):
- """ Vision Transformer
- A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- - https://arxiv.org/abs/2010.11929
- """
- dynamic_img_size: Final[bool]
- def __init__(
- self,
- img_size: Union[int, Tuple[int, int]] = 224,
- patch_size: Union[int, Tuple[int, int]] = 16,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map', 'prr'] = 'token',
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- mlp_ratio: float = 4.,
- qkv_bias: bool = True,
- qk_norm: bool = False,
- scale_attn_norm: bool = False,
- scale_mlp_norm: bool = False,
- proj_bias: bool = True,
- init_values: Optional[float] = None,
- class_token: bool = True,
- pos_embed: str = 'learn',
- no_embed_class: bool = False,
- reg_tokens: int = 0,
- pre_norm: bool = False,
- final_norm: bool = True,
- fc_norm: Optional[bool] = None,
- pool_include_prefix: bool = False,
- dynamic_img_size: bool = False,
- dynamic_img_pad: bool = False,
- drop_rate: float = 0.,
- pos_drop_rate: float = 0.,
- patch_drop_rate: float = 0.,
- proj_drop_rate: float = 0.,
- attn_drop_rate: float = 0.,
- drop_path_rate: float = 0.,
- weight_init: Literal['skip', 'reset', 'jax', 'jax_nlhb', 'moco', ''] = '',
- fix_init: bool = False,
- embed_layer: Callable = PatchEmbed,
- embed_norm_layer: Optional[LayerType] = None,
- norm_layer: Optional[LayerType] = None,
- act_layer: Optional[LayerType] = None,
- block_fn: Type[nn.Module] = Block,
- mlp_layer: Type[nn.Module] = Mlp,
- attn_layer: LayerType = Attention,
- device=None,
- dtype=None,
- ) -> None:
- """
- Args:
- img_size: Input image size.
- patch_size: Patch size.
- in_chans: Number of image input channels.
- num_classes: Number of classes for classification head.
- global_pool: Type of global pooling for final sequence (default: 'token').
- embed_dim: Transformer embedding dimension.
- depth: Depth of transformer.
- num_heads: Number of attention heads.
- mlp_ratio: Ratio of mlp hidden dim to embedding dim.
- qkv_bias: Enable bias for qkv projections if True.
- init_values: Layer-scale init values (layer-scale enabled if not None).
- class_token: Use class token.
- no_embed_class: Don't include position embeddings for class (or reg) tokens.
- reg_tokens: Number of register tokens.
- pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
- final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
- fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
- drop_rate: Head dropout rate.
- pos_drop_rate: Position embedding dropout rate.
- attn_drop_rate: Attention dropout rate.
- drop_path_rate: Stochastic depth rate.
- weight_init: Weight initialization scheme.
- fix_init: Apply weight initialization fix (scaling w/ layer index).
- embed_layer: Patch embedding layer.
- embed_norm_layer: Normalization layer to use / override in patch embed module.
- norm_layer: Normalization layer.
- act_layer: MLP activation layer.
- block_fn: Transformer block layer.
- """
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map', 'prr')
- assert class_token or global_pool != 'token'
- assert pos_embed in ('', 'none', 'learn')
- use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
- norm_layer = get_norm_layer(norm_layer) or LayerNorm
- embed_norm_layer = get_norm_layer(embed_norm_layer)
- act_layer = get_act_layer(act_layer) or nn.GELU
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.global_pool = global_pool
- self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
- self.num_prefix_tokens = 1 if class_token else 0
- self.num_prefix_tokens += reg_tokens
- self.num_reg_tokens = reg_tokens
- self.has_class_token = class_token
- self.no_embed_class = no_embed_class
- self.pool_include_prefix = pool_include_prefix
- self.dynamic_img_size = dynamic_img_size
- self.grad_checkpointing = False
- embed_args = {}
- if dynamic_img_size:
- # flatten deferred until after pos embed
- embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
- if embed_norm_layer is not None:
- embed_args['norm_layer'] = embed_norm_layer
- self.patch_embed = embed_layer(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=in_chans,
- embed_dim=embed_dim,
- bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
- dynamic_img_pad=dynamic_img_pad,
- **embed_args,
- **dd,
- )
- num_patches = self.patch_embed.num_patches
- reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
- self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, **dd)) if class_token else None
- self.reg_token = nn.Parameter(torch.empty(1, reg_tokens, embed_dim, **dd)) if reg_tokens else None
- embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
- if not pos_embed or pos_embed == 'none':
- self.pos_embed = None
- else:
- self.pos_embed = nn.Parameter(torch.empty(1, embed_len, embed_dim, **dd))
- self.pos_drop = nn.Dropout(p=pos_drop_rate)
- if patch_drop_rate > 0:
- self.patch_drop = PatchDropout(
- patch_drop_rate,
- num_prefix_tokens=self.num_prefix_tokens,
- )
- else:
- self.patch_drop = nn.Identity()
- self.norm_pre = norm_layer(embed_dim, **dd) if pre_norm else nn.Identity()
- dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
- self.blocks = nn.Sequential(*[
- block_fn(
- dim=embed_dim,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_norm=qk_norm,
- scale_attn_norm=scale_attn_norm,
- scale_mlp_norm=scale_mlp_norm,
- proj_bias=proj_bias,
- init_values=init_values,
- proj_drop=proj_drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[i],
- norm_layer=norm_layer,
- act_layer=act_layer,
- mlp_layer=mlp_layer,
- attn_layer=attn_layer,
- depth=i,
- **dd,
- )
- for i in range(depth)])
- self.feature_info = [
- dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)]
- self.norm = norm_layer(embed_dim, **dd) if final_norm and not use_fc_norm else nn.Identity()
- # Classifier Head
- if global_pool == 'map':
- self.attn_pool = AttentionPoolLatent(
- self.embed_dim,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- norm_layer=norm_layer,
- act_layer=act_layer,
- **dd,
- )
- elif global_pool == 'prr':
- self.attn_pool = AttentionPoolPrr(
- self.embed_dim,
- num_heads=num_heads,
- pool_type='token' if class_token else 'avg',
- norm_layer=norm_layer,
- **dd,
- )
- self.pool_include_prefix = True
- else:
- self.attn_pool = None
- self.fc_norm = norm_layer(embed_dim, **dd) if final_norm and use_fc_norm else nn.Identity()
- self.head_drop = nn.Dropout(drop_rate)
- self.head = nn.Linear(self.embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
- self.weight_init_mode = 'reset' if weight_init == 'skip' else weight_init
- self.fix_init = fix_init
- # TODO: skip init when on meta device when safe to do so
- if weight_init != 'skip':
- self.init_weights(needs_reset=False)
- def fix_init_weight(self) -> None:
- """Apply weight initialization fix (scaling w/ layer index)."""
- with torch.no_grad():
- for layer_id, layer in enumerate(self.blocks):
- scale = math.sqrt(2.0 * (layer_id + 1))
- layer.attn.proj.weight.div_(scale)
- layer.mlp.fc2.weight.div_(scale)
- def init_weights(self, mode: str = '', needs_reset: bool = True) -> None:
- """Initialize model weights.
- Args:
- mode: Weight initialization mode ('jax', 'jax_nlhb', 'moco', or '').
- needs_reset: If True, call reset_parameters() on modules that have it.
- Set to False when modules have already self-initialized in __init__.
- """
- mode = mode or self.weight_init_mode
- assert mode in ('jax', 'jax_nlhb', 'moco', 'reset', '')
- head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
- if self.pos_embed is not None:
- trunc_normal_(self.pos_embed, std=.02)
- if self.cls_token is not None:
- nn.init.normal_(self.cls_token, std=1e-6)
- if self.reg_token is not None:
- nn.init.normal_(self.reg_token, std=1e-6)
- named_apply(get_init_weights_vit(mode, head_bias, needs_reset=needs_reset), self)
- if self.fix_init:
- self.fix_init_weight()
- def _init_weights(self, m: nn.Module) -> None:
- """Initialize weights for a single module (compatibility method)."""
- # this fn left here for compat with downstream users
- init_weights_vit_timm(m)
- @torch.jit.ignore()
- def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None:
- """Load pretrained weights.
- Args:
- checkpoint_path: Path to checkpoint.
- prefix: Prefix for state dict keys.
- """
- _load_weights(self, checkpoint_path, prefix)
- @torch.jit.ignore
- def no_weight_decay(self) -> Set[str]:
- """Set of parameters that should not use weight decay."""
- return {'pos_embed', 'cls_token', 'dist_token'}
- @torch.jit.ignore
- def group_matcher(self, coarse: bool = False) -> Dict[str, Union[str, List]]:
- """Create regex patterns for parameter grouping.
- Args:
- coarse: Use coarse grouping.
- Returns:
- Dictionary mapping group names to regex patterns.
- """
- return dict(
- stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
- blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
- )
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable: bool = True) -> None:
- """Enable or disable gradient checkpointing.
- Args:
- enable: Whether to enable gradient checkpointing.
- """
- self.grad_checkpointing = enable
- if hasattr(self.patch_embed, 'set_grad_checkpointing'):
- self.patch_embed.set_grad_checkpointing(enable)
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- """Get the classifier head."""
- return self.head
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
- """Reset the classifier head.
- Args:
- num_classes: Number of classes for new classifier.
- global_pool: Global pooling type.
- """
- self.num_classes = num_classes
- if global_pool is not None:
- assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map', 'prr')
- if global_pool in ('map', 'prr') and self.attn_pool is None:
- assert False, "Cannot currently add attention pooling in reset_classifier()."
- elif global_pool not in ('map', 'prr') and self.attn_pool is not None:
- self.attn_pool = None # remove attention pooling
- elif global_pool in ('map', 'prr') and self.global_pool != global_pool:
- assert False, "Cannot currently change attention pooling type in reset_classifier()."
- self.global_pool = global_pool
- self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
- def set_input_size(
- self,
- img_size: Optional[Tuple[int, int]] = None,
- patch_size: Optional[Tuple[int, int]] = None,
- ) -> None:
- """Update the input image resolution and patch size.
- Args:
- img_size: New input resolution, if None current resolution is used.
- patch_size: New patch size, if None existing patch size is used.
- """
- prev_grid_size = self.patch_embed.grid_size
- self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
- if self.pos_embed is not None:
- num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
- num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens
- if num_new_tokens != self.pos_embed.shape[1]:
- self.pos_embed = nn.Parameter(resample_abs_pos_embed(
- self.pos_embed,
- new_size=self.patch_embed.grid_size,
- old_size=prev_grid_size,
- num_prefix_tokens=num_prefix_tokens,
- verbose=True,
- ))
- def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
- """Apply positional embedding to input."""
- to_cat = []
- if self.cls_token is not None:
- to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
- if self.reg_token is not None:
- to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
- if self.pos_embed is None:
- return torch.cat(to_cat + [x.view(x.shape[0], -1, x.shape[-1])], dim=1)
- if self.dynamic_img_size:
- B, H, W, C = x.shape
- prev_grid_size = self.patch_embed.grid_size
- pos_embed = resample_abs_pos_embed(
- self.pos_embed,
- new_size=(H, W),
- old_size=prev_grid_size,
- num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
- )
- x = x.view(B, -1, C)
- else:
- pos_embed = self.pos_embed
- if self.no_embed_class:
- # deit-3, updated JAX (big vision)
- # position embedding does not overlap with class token, add then concat
- x = x + pos_embed
- if to_cat:
- x = torch.cat(to_cat + [x], dim=1)
- else:
- # original timm, JAX, and deit vit impl
- # pos_embed has entry for class token, concat then add
- if to_cat:
- x = torch.cat(to_cat + [x], dim=1)
- x = x + pos_embed
- return self.pos_drop(x)
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- return_prefix_tokens: bool = False,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- output_dict: bool = False,
- attn_mask: Optional[torch.Tensor] = None,
- is_causal: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- return_prefix_tokens: Return both prefix and spatial intermediate tokens
- norm: Apply norm layer to all intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys
- attn_mask: Optional attention mask for masked attention (e.g., for NaFlex)
- is_causal: If True, use causal (autoregressive) masking in attention
- Returns:
- A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing
- 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix')
- """
- assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
- reshape = output_fmt == 'NCHW'
- intermediates = []
- take_indices, max_index = feature_take_indices(len(self.blocks), indices)
- # forward pass
- B, _, height, width = x.shape
- x = self.patch_embed(x)
- x = self._pos_embed(x)
- x = self.patch_drop(x)
- x = self.norm_pre(x)
- if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
- blocks = self.blocks
- else:
- blocks = self.blocks[:max_index + 1]
- for i, blk in enumerate(blocks):
- if attn_mask is not None or is_causal:
- x = blk(x, attn_mask=attn_mask, is_causal=is_causal)
- elif self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint(blk, x)
- else:
- x = blk(x)
- if i in take_indices:
- # normalize intermediates with final norm layer if enabled
- intermediates.append(self.norm(x) if norm else x)
- # process intermediates
- if self.num_prefix_tokens:
- # split prefix (e.g. class, distill) and spatial feature tokens
- prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
- intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
- else:
- prefix_tokens = None
- if reshape:
- # reshape to BCHW output format
- H, W = self.patch_embed.dynamic_feat_size((height, width))
- intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
- # For dictionary output, handle prefix tokens separately
- if output_dict:
- result_dict = {}
- # Intermediates are always included
- result_dict['image_intermediates'] = intermediates
- if prefix_tokens is not None and return_prefix_tokens:
- result_dict['image_intermediates_prefix'] = prefix_tokens
- # Only include features if not intermediates_only
- if not intermediates_only:
- x_final = self.norm(x)
- result_dict['image_features'] = x_final
- return result_dict
- # For non-dictionary output, maintain the original behavior
- if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
- # return_prefix not support in torchscript due to poor type handling
- intermediates = list(zip(intermediates, prefix_tokens))
- if intermediates_only:
- return intermediates
- x = self.norm(x)
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ) -> List[int]:
- """Prune layers not required for specified intermediates.
- Args:
- indices: Indices of intermediate layers to keep.
- prune_norm: Whether to prune normalization layer.
- prune_head: Whether to prune the classifier head.
- Returns:
- List of indices that were kept.
- """
- take_indices, max_index = feature_take_indices(len(self.blocks), indices)
- self.blocks = self.blocks[:max_index + 1] # truncate blocks
- if prune_norm:
- self.norm = nn.Identity()
- if prune_head:
- self.fc_norm = nn.Identity()
- self.reset_classifier(0, '')
- return take_indices
- def get_intermediate_layers(
- self,
- x: torch.Tensor,
- n: Union[int, List[int], Tuple[int]] = 1,
- reshape: bool = False,
- return_prefix_tokens: bool = False,
- norm: bool = False,
- attn_mask: Optional[torch.Tensor] = None,
- ) -> List[torch.Tensor]:
- """Get intermediate layer outputs (DINO interface compatibility).
- NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
- Args:
- x: Input tensor.
- n: Number or indices of layers.
- reshape: Reshape to NCHW format.
- return_prefix_tokens: Return prefix tokens.
- norm: Apply normalization.
- Returns:
- List of intermediate features.
- """
- return self.forward_intermediates(
- x, n,
- return_prefix_tokens=return_prefix_tokens,
- norm=norm,
- output_fmt='NCHW' if reshape else 'NLC',
- intermediates_only=True,
- attn_mask=attn_mask,
- )
- def forward_features(
- self,
- x: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- is_causal: bool = False,
- ) -> torch.Tensor:
- """Forward pass through feature layers (embeddings, transformer blocks, post-transformer norm)."""
- x = self.patch_embed(x)
- x = self._pos_embed(x)
- x = self.patch_drop(x)
- x = self.norm_pre(x)
- if attn_mask is not None or is_causal:
- # If mask/causal provided, we need to apply blocks one by one
- for blk in self.blocks:
- x = blk(x, attn_mask=attn_mask, is_causal=is_causal)
- elif self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.blocks, x)
- else:
- x = self.blocks(x)
- x = self.norm(x)
- return x
- def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
- """Apply pooling to feature tokens.
- Args:
- x: Feature tensor.
- pool_type: Pooling type override.
- Returns:
- Pooled features.
- """
- if self.attn_pool is not None:
- if not self.pool_include_prefix:
- x = x[:, self.num_prefix_tokens:]
- x = self.attn_pool(x)
- return x
- pool_type = self.global_pool if pool_type is None else pool_type
- x = global_pool_nlc(
- x,
- pool_type=pool_type,
- num_prefix_tokens=self.num_prefix_tokens,
- reduce_include_prefix=self.pool_include_prefix,
- )
- return x
- def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
- """Forward pass through classifier head.
- Args:
- x: Feature tensor.
- pre_logits: Return features before final classifier.
- Returns:
- Output tensor.
- """
- x = self.pool(x)
- x = self.fc_norm(x)
- x = self.head_drop(x)
- return x if pre_logits else self.head(x)
- def forward(
- self,
- x: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- is_causal: bool = False,
- ) -> torch.Tensor:
- x = self.forward_features(x, attn_mask=attn_mask, is_causal=is_causal)
- x = self.forward_head(x)
- return x
- def init_weights_vit_timm(module: nn.Module, name: str = '', needs_reset: bool = True) -> None:
- """ViT weight initialization, original timm impl (for reproducibility).
- Args:
- module: Module to initialize.
- name: Module name for context.
- needs_reset: If True, call reset_parameters() on modules that have it.
- """
- if isinstance(module, nn.Linear):
- trunc_normal_(module.weight, std=.02)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif hasattr(module, 'init_weights'):
- module.init_weights()
- elif needs_reset and hasattr(module, 'reset_parameters'):
- module.reset_parameters()
- def init_weights_vit_jax(
- module: nn.Module,
- name: str = '',
- head_bias: float = 0.0,
- needs_reset: bool = True,
- ) -> None:
- """ViT weight initialization, matching JAX (Flax) impl.
- Args:
- module: Module to initialize.
- name: Module name for context.
- head_bias: Bias value for head layer.
- needs_reset: If True, call reset_parameters() on modules that have it.
- """
- if isinstance(module, nn.Linear):
- if name.startswith('head'):
- nn.init.zeros_(module.weight)
- nn.init.constant_(module.bias, head_bias)
- else:
- nn.init.xavier_uniform_(module.weight)
- if module.bias is not None:
- nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
- elif isinstance(module, nn.Conv2d):
- lecun_normal_(module.weight)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif hasattr(module, 'init_weights'):
- module.init_weights()
- elif needs_reset and hasattr(module, 'reset_parameters'):
- module.reset_parameters()
- def init_weights_vit_moco(module: nn.Module, name: str = '', needs_reset: bool = True) -> None:
- """ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed.
- Args:
- module: Module to initialize.
- name: Module name for context.
- needs_reset: If True, call reset_parameters() on modules that have it.
- """
- if isinstance(module, nn.Linear):
- if 'qkv' in name:
- # treat the weights of Q, K, V separately
- val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
- nn.init.uniform_(module.weight, -val, val)
- else:
- nn.init.xavier_uniform_(module.weight)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif hasattr(module, 'init_weights'):
- module.init_weights()
- elif needs_reset and hasattr(module, 'reset_parameters'):
- module.reset_parameters()
- def init_weights_reset_parameters(module: nn.Module, name: str = '', needs_reset: bool = True) -> None:
- if needs_reset and hasattr(module, 'reset_parameters'):
- module.reset_parameters()
- def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0, needs_reset: bool = True) -> Callable:
- if mode.startswith('jax'):
- return partial(init_weights_vit_jax, head_bias=head_bias, needs_reset=needs_reset)
- elif mode.startswith('moco'):
- return partial(init_weights_vit_moco, needs_reset=needs_reset)
- elif mode == 'reset':
- # 'reset' means only call reset_parameters() on modules
- return partial(init_weights_reset_parameters, needs_reset=needs_reset)
- else:
- # timm init is default
- return partial(init_weights_vit_timm, needs_reset=needs_reset)
- def resize_pos_embed(
- posemb: torch.Tensor,
- posemb_new: torch.Tensor,
- num_prefix_tokens: int = 1,
- gs_new: Tuple[int, int] = (),
- interpolation: str = 'bicubic',
- antialias: bool = False,
- ) -> torch.Tensor:
- """ Rescale the grid of position embeddings when loading from state_dict.
- *DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed
- """
- ntok_new = posemb_new.shape[1] - num_prefix_tokens
- ntok_old = posemb.shape[1] - num_prefix_tokens
- gs_old = [int(math.sqrt(ntok_old))] * 2
- if not len(gs_new): # backwards compatibility
- gs_new = [int(math.sqrt(ntok_new))] * 2
- return resample_abs_pos_embed(
- posemb, gs_new, gs_old,
- num_prefix_tokens=num_prefix_tokens,
- interpolation=interpolation,
- antialias=antialias,
- verbose=True,
- )
- @torch.no_grad()
- def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = '', load_bfloat16: bool = False) -> None:
- """ Load weights from .npz checkpoints for official Google Brain Flax implementation
- """
- import numpy as np
- if load_bfloat16:
- import jax.numpy as jnp
- import ml_dtypes
- def _n2p(_w, t=True, idx=None):
- if idx is not None:
- _w = _w[idx]
- if load_bfloat16:
- _w = _w.view(ml_dtypes.bfloat16).astype(jnp.float32)
- _w = np.array(_w)
- if _w.ndim == 4 and _w.shape[0] == _w.shape[1] == _w.shape[2] == 1:
- _w = _w.flatten()
- if t:
- if _w.ndim == 4:
- _w = _w.transpose([3, 2, 0, 1])
- elif _w.ndim == 3:
- _w = _w.transpose([2, 0, 1])
- elif _w.ndim == 2:
- _w = _w.transpose([1, 0])
- _w = torch.from_numpy(_w)
- return _w
- if load_bfloat16:
- w = jnp.load(checkpoint_path)
- else:
- w = np.load(checkpoint_path)
- interpolation = 'bilinear'
- antialias = False
- big_vision = False
- if not prefix:
- if 'opt/target/embedding/kernel' in w:
- prefix = 'opt/target/'
- elif 'params/embedding/kernel' in w:
- prefix = 'params/'
- big_vision = True
- elif 'params/img/embedding/kernel' in w:
- prefix = 'params/img/'
- big_vision = True
- if hasattr(model.patch_embed, 'backbone'):
- # hybrid
- backbone = model.patch_embed.backbone
- stem_only = not hasattr(backbone, 'stem')
- stem = backbone if stem_only else backbone.stem
- stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
- stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
- stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
- if not stem_only:
- for i, stage in enumerate(backbone.stages):
- for j, block in enumerate(stage.blocks):
- bp = f'{prefix}block{i + 1}/unit{j + 1}/'
- for r in range(3):
- getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
- getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
- getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
- if block.downsample is not None:
- block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
- block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
- block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
- embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
- else:
- embed_conv_w = adapt_input_conv(
- model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
- if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]:
- embed_conv_w = resample_patch_embed(
- embed_conv_w,
- model.patch_embed.proj.weight.shape[-2:],
- interpolation=interpolation,
- antialias=antialias,
- verbose=True,
- )
- model.patch_embed.proj.weight.copy_(embed_conv_w)
- model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
- if model.cls_token is not None:
- model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
- if big_vision:
- pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
- else:
- pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
- if pos_embed_w.shape != model.pos_embed.shape:
- num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
- pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
- pos_embed_w,
- new_size=model.patch_embed.grid_size,
- num_prefix_tokens=num_prefix_tokens,
- interpolation=interpolation,
- antialias=antialias,
- verbose=True,
- )
- model.pos_embed.copy_(pos_embed_w)
- model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
- model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
- if (isinstance(model.head, nn.Linear) and
- f'{prefix}head/bias' in w and
- model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]):
- model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
- model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
- # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
- # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
- # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
- # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
- if isinstance(model.attn_pool, AttentionPoolLatent):
- block_prefix = f'{prefix}MAPHead_0/'
- mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
- model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
- model.attn_pool.kv.weight.copy_(torch.cat([
- _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
- model.attn_pool.kv.bias.copy_(torch.cat([
- _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
- model.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
- model.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
- model.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
- model.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
- model.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
- model.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
- for r in range(2):
- getattr(model.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
- getattr(model.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
- mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
- for i, block in enumerate(model.blocks.children()):
- if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
- block_prefix = f'{prefix}Transformer/encoderblock/'
- idx = i
- else:
- block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
- idx = None
- mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
- block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
- block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
- block.attn.qkv.weight.copy_(torch.cat([
- _n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
- block.attn.qkv.bias.copy_(torch.cat([
- _n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
- block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
- block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
- block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
- block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
- for r in range(2):
- getattr(block.mlp, f'fc{r + 1}').weight.copy_(
- _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
- getattr(block.mlp, f'fc{r + 1}').bias.copy_(
- _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
- def _convert_openai_clip(
- state_dict: Dict[str, torch.Tensor],
- model: VisionTransformer,
- prefix: str = 'visual.',
- ) -> Dict[str, torch.Tensor]:
- out_dict = {}
- swaps = [
- ('conv1', 'patch_embed.proj'),
- ('positional_embedding', 'pos_embed'),
- ('transformer.resblocks.', 'blocks.'),
- ('ln_pre', 'norm_pre'),
- ('ln_post', 'norm'),
- ('ln_', 'norm'),
- ('in_proj_', 'qkv.'),
- ('out_proj', 'proj'),
- ('mlp.c_fc', 'mlp.fc1'),
- ('mlp.c_proj', 'mlp.fc2'),
- ]
- for k, v in state_dict.items():
- if not k.startswith(prefix):
- continue
- k = k.replace(prefix, '')
- for sp in swaps:
- k = k.replace(sp[0], sp[1])
- if k == 'proj':
- k = 'head.weight'
- v = v.transpose(0, 1)
- out_dict['head.bias'] = torch.zeros(v.shape[0])
- elif k == 'class_embedding':
- k = 'cls_token'
- v = v.unsqueeze(0).unsqueeze(1)
- elif k == 'pos_embed':
- v = v.unsqueeze(0)
- out_dict[k] = v
- return out_dict
- def _convert_dinov2(
- state_dict: Dict[str, torch.Tensor],
- model: VisionTransformer,
- ) -> Dict[str, torch.Tensor]:
- import re
- out_dict = {}
- state_dict.pop("mask_token", None)
- if 'register_tokens' in state_dict:
- # convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed)
- out_dict['reg_token'] = state_dict.pop('register_tokens')
- out_dict['cls_token'] = state_dict.pop('cls_token') + state_dict['pos_embed'][:, 0]
- out_dict['pos_embed'] = state_dict.pop('pos_embed')[:, 1:]
- for k, v in state_dict.items():
- if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
- out_dict[k.replace("w12", "fc1")] = v
- continue
- elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
- out_dict[k.replace("w3", "fc2")] = v
- continue
- out_dict[k] = v
- return out_dict
- def _convert_aimv2(
- state_dict: Dict[str, torch.Tensor],
- model: VisionTransformer,
- ) -> Dict[str, torch.Tensor]:
- out_dict = {}
- for k, v in state_dict.items():
- k = k.replace('norm_1', 'norm1')
- k = k.replace('norm_2', 'norm2')
- k = k.replace('preprocessor.patchifier.', 'patch_embed.')
- k = k.replace('preprocessor.pos_embed', 'pos_embed')
- k = k.replace('trunk.', '')
- k = k.replace('post_trunk_norm.', 'norm.')
- k = k.replace('mlp.fc1', 'mlp.fc1_g')
- k = k.replace('mlp.fc3', 'mlp.fc1_x')
- out_dict[k] = v
- return out_dict
- def _convert_beit3(state_dict: dict, model):
- """
- Turn a BEiT-3 checkpoint into a standard VisionTransformer state-dict.
- """
- import re
- state_dict = state_dict.get("model", state_dict) # unwrap if needed
- # Prune unused
- for k in ("beit3.text_embed.weight", "beit3.vision_embed.mask_token"):
- state_dict.pop(k, None)
- # Key renaming rules
- rules = [
- (r"beit3\.", ""),
- (r"vision_embed\.cls_token", "cls_token"),
- (r"vision_embed\.", "patch_embed."),
- (r"embed_positions\.", "pos_embed."),
- (r"encoder\.", ""),
- (r"layers\.", "blocks."),
- (r"ffn_layernorm\.", "norm."), (r"ffn\.", "mlp."),
- (r"self_attn_layer_norm\.", "norm1."), (r"self_attn\.", "attn."),
- (r"final_layer_norm\.", "norm2."),
- (r"inner_attn_ln", "norm"),
- (r"out_proj", "proj"),
- (r"\.A\.", "."),
- ]
- # First pass, rename keys
- tmp = {}
- for k, v in state_dict.items():
- if ".B." in k:
- continue # use branch-A only
- for old, new in rules:
- k = re.sub(old, new, k)
- if k == "pos_embed.weight":
- # strip first two positions, [1, N+1, D]
- tmp["pos_embed"] = v[2:].unsqueeze(0)
- else:
- tmp[k] = v
- # Second pass, fuse q, k, v
- out, buf = {}, {}
- pat = re.compile(r"blocks\.(\d+)\.attn\.(q|k|v)_proj\.(weight|bias)$")
- for k, v in tmp.items():
- m = pat.fullmatch(k)
- if not m: # anything not q/k/v -> copy through
- out[k] = v
- continue
- blk, which, kind = m.groups() # block idx, 'q'/'k'/'v', 'weight'/'bias'
- stash = buf.setdefault((blk, kind), {}) # Gather by block & param type
- stash[which] = v
- if len(stash) == 3: # Have q, k, v -> concatenate
- out[f"blocks.{blk}.attn.qkv.{kind}"] = torch.cat(
- [stash['q'], stash['k'], stash['v']], dim=0
- )
- return out
- def checkpoint_filter_fn(
- state_dict: Dict[str, torch.Tensor],
- model: VisionTransformer,
- adapt_layer_scale: bool = False,
- interpolation: str = 'bicubic',
- antialias: bool = True,
- ) -> Dict[str, torch.Tensor]:
- """ convert patch embedding weight from manual patchify + linear proj to conv"""
- import re
- out_dict = {}
- state_dict = state_dict.get('model', state_dict)
- state_dict = state_dict.get('state_dict', state_dict)
- prefix = ''
- if 'visual.class_embedding' in state_dict:
- state_dict = _convert_openai_clip(state_dict, model)
- elif 'module.visual.class_embedding' in state_dict:
- state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.')
- elif "mask_token" in state_dict:
- state_dict = _convert_dinov2(state_dict, model)
- elif any('beit3.' in k for k in state_dict.keys()):
- # BEiT3 model - multimodal checkpoint with beit3.* prefix
- state_dict = _convert_beit3(state_dict, model)
- elif "encoder" in state_dict:
- # IJEPA, vit in an 'encoder' submodule
- state_dict = state_dict['encoder']
- prefix = 'module.'
- elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict:
- # OpenCLIP model with timm vision encoder
- prefix = 'visual.trunk.'
- if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear):
- # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
- out_dict['head.weight'] = state_dict['visual.head.proj.weight']
- out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
- elif 'module.visual.trunk.pos_embed' in state_dict:
- prefix = 'module.visual.trunk.'
- elif 'preprocessor.patchifier.proj.weight' in state_dict:
- state_dict = _convert_aimv2(state_dict, model)
- if prefix:
- # filter on & remove prefix string from keys
- state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
- for k, v in state_dict.items():
- if 'patch_embed.proj.weight' in k:
- O, I, H, W = model.patch_embed.proj.weight.shape
- if len(v.shape) < 4:
- # For old models that I trained prior to conv based patchification
- O, I, H, W = model.patch_embed.proj.weight.shape
- v = v.reshape(O, -1, H, W)
- if v.shape[-1] != W or v.shape[-2] != H:
- v = resample_patch_embed(
- v,
- (H, W),
- interpolation=interpolation,
- antialias=antialias,
- verbose=True,
- )
- elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
- # To resize pos embedding when using model at different size from pretrained weights
- num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
- v = resample_abs_pos_embed(
- v,
- new_size=model.patch_embed.grid_size,
- num_prefix_tokens=num_prefix_tokens,
- interpolation=interpolation,
- antialias=antialias,
- verbose=True,
- )
- elif adapt_layer_scale and 'gamma_' in k:
- # remap layer-scale gamma into sub-module (deit3 models)
- k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
- elif 'pre_logits' in k:
- # NOTE representation layer removed as not used in latest 21k/1k pretrained weights
- continue
- out_dict[k] = v
- return out_dict
- def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
- return {
- 'url': url,
- 'num_classes': 1000,
- 'input_size': (3, 224, 224),
- 'pool_size': None,
- 'crop_pct': 0.9,
- 'interpolation': 'bicubic',
- 'fixed_input_size': True,
- 'mean': IMAGENET_INCEPTION_MEAN,
- 'std': IMAGENET_INCEPTION_STD,
- 'first_conv': 'patch_embed.proj',
- 'classifier': 'head',
- 'license': 'apache-2.0',
- **kwargs,
- }
- default_cfgs = {
- # re-finetuned augreg 21k FT on in1k weights
- 'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
- hf_hub_id='timm/'),
- 'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg(
- hf_hub_id='timm/'),
- # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k
- 'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
- hf_hub_id='timm/',
- custom_load=True),
- 'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
- hf_hub_id='timm/',
- custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
- 'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
- hf_hub_id='timm/',
- custom_load=True),
- 'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
- hf_hub_id='timm/',
- custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
- 'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
- hf_hub_id='timm/',
- custom_load=True),
- 'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
- hf_hub_id='timm/',
- custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
- 'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
- hf_hub_id='timm/',
- custom_load=True),
- 'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
- hf_hub_id='timm/',
- custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
- 'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
- hf_hub_id='timm/',
- custom_load=True),
- 'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
- hf_hub_id='timm/',
- custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
- 'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
- hf_hub_id='timm/',
- custom_load=True),
- 'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
- hf_hub_id='timm/',
- custom_load=True),
- 'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
- hf_hub_id='timm/',
- custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
- # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k
- 'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
- hf_hub_id='timm/'),
- 'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
- hf_hub_id='timm/',
- input_size=(3, 384, 384), crop_pct=1.0),
- 'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
- hf_hub_id='timm/',
- input_size=(3, 384, 384), crop_pct=1.0),
- # How to train your ViT (augreg) weights trained on in1k only
- 'vit_small_patch16_224.augreg_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
- hf_hub_id='timm/',
- custom_load=True),
- 'vit_small_patch16_384.augreg_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
- hf_hub_id='timm/',
- custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
- 'vit_base_patch32_224.augreg_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
- hf_hub_id='timm/',
- custom_load=True),
- 'vit_base_patch32_384.augreg_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
- hf_hub_id='timm/',
- custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
- 'vit_base_patch16_224.augreg_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
- hf_hub_id='timm/',
- custom_load=True),
- 'vit_base_patch16_384.augreg_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
- hf_hub_id='timm/',
- custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
- 'vit_large_patch14_224.untrained': _cfg(url=''),
- 'vit_huge_patch14_224.untrained': _cfg(url=''),
- 'vit_giant_patch14_224.untrained': _cfg(url=''),
- 'vit_gigantic_patch14_224.untrained': _cfg(url=''),
- # patch models, imagenet21k (weights from official Google JAX impl), classifier not valid
- 'vit_base_patch32_224.orig_in21k': _cfg(
- #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_base_patch16_224.orig_in21k': _cfg(
- #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_large_patch32_224.orig_in21k': _cfg(
- #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_large_patch16_224.orig_in21k': _cfg(
- #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_huge_patch14_224.orig_in21k': _cfg(
- hf_hub_id='timm/',
- num_classes=0),
- # How to train your ViT (augreg) weights, pretrained on in21k
- 'vit_tiny_patch16_224.augreg_in21k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
- hf_hub_id='timm/',
- custom_load=True, num_classes=21843),
- 'vit_small_patch32_224.augreg_in21k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
- hf_hub_id='timm/',
- custom_load=True, num_classes=21843),
- 'vit_small_patch16_224.augreg_in21k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
- hf_hub_id='timm/',
- custom_load=True, num_classes=21843),
- 'vit_base_patch32_224.augreg_in21k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
- hf_hub_id='timm/',
- custom_load=True, num_classes=21843),
- 'vit_base_patch16_224.augreg_in21k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
- hf_hub_id='timm/',
- custom_load=True, num_classes=21843),
- 'vit_base_patch8_224.augreg_in21k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
- hf_hub_id='timm/',
- custom_load=True, num_classes=21843),
- 'vit_large_patch16_224.augreg_in21k': _cfg(
- url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
- hf_hub_id='timm/',
- custom_load=True, num_classes=21843),
- # SAM trained models (https://arxiv.org/abs/2106.01548)
- 'vit_base_patch32_224.sam_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True,
- hf_hub_id='timm/'),
- 'vit_base_patch16_224.sam_in1k': _cfg(
- url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True,
- hf_hub_id='timm/'),
- # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only)
- 'vit_small_patch16_224.dino': _cfg(
- url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth',
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
- 'vit_small_patch8_224.dino': _cfg(
- url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth',
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
- 'vit_base_patch16_224.dino': _cfg(
- url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth',
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
- 'vit_base_patch8_224.dino': _cfg(
- url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
- # DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only)
- 'vit_small_patch14_dinov2.lvd142m': _cfg(
- url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
- hf_hub_id='timm/',
- license='apache-2.0',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
- input_size=(3, 518, 518), crop_pct=1.0),
- 'vit_base_patch14_dinov2.lvd142m': _cfg(
- url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth',
- hf_hub_id='timm/',
- license='apache-2.0',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
- input_size=(3, 518, 518), crop_pct=1.0),
- 'vit_large_patch14_dinov2.lvd142m': _cfg(
- url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth',
- hf_hub_id='timm/',
- license='apache-2.0',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
- input_size=(3, 518, 518), crop_pct=1.0),
- 'vit_giant_patch14_dinov2.lvd142m': _cfg(
- url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth',
- hf_hub_id='timm/',
- license='apache-2.0',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
- input_size=(3, 518, 518), crop_pct=1.0),
- # DINOv2 pretrained w/ registers - https://arxiv.org/abs/2309.16588 (no classifier head, for fine-tune/features only)
- 'vit_small_patch14_reg4_dinov2.lvd142m': _cfg(
- url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth',
- hf_hub_id='timm/',
- license='apache-2.0',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
- input_size=(3, 518, 518), crop_pct=1.0),
- 'vit_base_patch14_reg4_dinov2.lvd142m': _cfg(
- url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth',
- hf_hub_id='timm/',
- license='apache-2.0',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
- input_size=(3, 518, 518), crop_pct=1.0),
- 'vit_large_patch14_reg4_dinov2.lvd142m': _cfg(
- url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth',
- hf_hub_id='timm/',
- license='apache-2.0',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
- input_size=(3, 518, 518), crop_pct=1.0),
- 'vit_giant_patch14_reg4_dinov2.lvd142m': _cfg(
- url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth',
- hf_hub_id='timm/',
- license='apache-2.0',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
- input_size=(3, 518, 518), crop_pct=1.0),
- # ViT ImageNet-21K-P pretraining by MILL
- 'vit_base_patch16_224_miil.in21k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
- hf_hub_id='timm/',
- mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221),
- 'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth',
- hf_hub_id='timm/',
- mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
- # Custom timm variants
- 'vit_base_patch16_rpn_224.sw_in1k': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth',
- hf_hub_id='timm/'),
- 'vit_medium_patch16_gap_240.sw_in12k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821),
- 'vit_medium_patch16_gap_256.sw_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_medium_patch16_gap_384.sw_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'),
- 'vit_betwixt_patch16_gap_256.untrained': _cfg(
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_base_patch16_gap_224.untrained': _cfg(),
- # CLIP pretrained image tower and related fine-tuned weights
- 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
- 'vit_base_patch32_clip_384.laion2b_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
- 'vit_base_patch32_clip_448.laion2b_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
- 'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
- 'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
- 'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
- 'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
- crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
- 'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
- 'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
- 'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg(
- # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k', # FIXME weight exists, need to push
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
- 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
- 'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
- 'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
- 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
- 'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
- 'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
- 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
- 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
- 'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
- 'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
- crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
- 'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
- 'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg(
- hf_hub_id='',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
- 'vit_base_patch32_clip_224.openai_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
- 'vit_base_patch16_clip_224.openai_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
- 'vit_base_patch16_clip_384.openai_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
- 'vit_large_patch14_clip_224.openai_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
- 'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
- 'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821),
- 'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
- 'vit_base_patch16_clip_224.openai_ft_in12k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
- 'vit_large_patch14_clip_224.openai_ft_in12k': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
- 'vit_base_patch32_clip_224.laion2b': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
- 'vit_base_patch16_clip_224.laion2b': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
- 'vit_large_patch14_clip_224.laion2b': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768),
- 'vit_huge_patch14_clip_224.laion2b': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
- 'vit_giant_patch14_clip_224.laion2b': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
- 'vit_gigantic_patch14_clip_224.laion2b': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
- 'vit_base_patch32_clip_224.laion400m_e32': _cfg(
- hf_hub_id='timm/',
- license='mit',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
- 'vit_base_patch16_clip_224.laion400m_e32': _cfg(
- hf_hub_id='timm/',
- license='mit',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
- 'vit_base_patch16_plus_clip_240.laion400m_e32': _cfg(
- hf_hub_id='timm/',
- license='mit',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 240, 240), crop_pct=1.0, num_classes=640),
- 'vit_large_patch14_clip_224.laion400m_e32': _cfg(
- hf_hub_id='timm/',
- license='mit',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
- 'vit_base_patch32_clip_224.datacompxl': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
- 'vit_base_patch32_clip_256.datacompxl': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- crop_pct=1.0, input_size=(3, 256, 256), num_classes=512),
- 'vit_base_patch16_clip_224.datacompxl': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
- 'vit_large_patch14_clip_224.datacompxl': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
- 'vit_base_patch16_clip_224.dfn2b': _cfg(
- hf_hub_id='timm/',
- license='apple-ascl',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
- 'vit_large_patch14_clip_224.dfn2b_s39b': _cfg(
- hf_hub_id='timm/',
- license='apple-ascl',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
- 'vit_large_patch14_clip_224.dfn2b': _cfg(
- hf_hub_id='timm/',
- license='apple-ascl',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
- 'vit_huge_patch14_clip_224.dfn5b': _cfg(
- hf_hub_id='timm/',
- license='apple-ascl',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
- 'vit_huge_patch14_clip_378.dfn5b': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- license='apple-ascl',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024),
- # 'vit_large_patch14_clip_224.metaclip2_worldwide': _cfg(
- # hf_hub_id='timm/',
- # license='cc-by-nc-4.0',
- # notes=('natively QuickGELU, use quickgelu model variant for original results',),
- # mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
- 'vit_huge_patch14_clip_224.metaclip2_worldwide': _cfg(
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
- 'vit_huge_patch14_clip_378.metaclip2_worldwide': _cfg(
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash', num_classes=1024),
- 'vit_gigantic_patch14_clip_224.metaclip2_worldwide': _cfg(
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
- 'vit_gigantic_patch14_clip_378.metaclip2_worldwide': _cfg(
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash', num_classes=1280),
- 'vit_base_patch32_clip_224.metaclip_2pt5b': _cfg(
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
- 'vit_base_patch16_clip_224.metaclip_2pt5b': _cfg(
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
- 'vit_large_patch14_clip_224.metaclip_2pt5b': _cfg(
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
- 'vit_huge_patch14_clip_224.metaclip_2pt5b': _cfg(
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
- 'vit_huge_patch14_clip_224.metaclip_altogether': _cfg(
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
- 'vit_gigantic_patch14_clip_224.metaclip_2pt5b': _cfg(
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
- 'vit_base_patch32_clip_224.metaclip_400m': _cfg(
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
- 'vit_base_patch16_clip_224.metaclip_400m': _cfg(
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
- 'vit_large_patch14_clip_224.metaclip_400m': _cfg(
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
- 'vit_base_patch32_clip_224.openai': _cfg(
- hf_hub_id='timm/',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
- 'vit_base_patch16_clip_224.openai': _cfg(
- hf_hub_id='timm/',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
- 'vit_large_patch14_clip_224.openai': _cfg(
- hf_hub_id='timm/',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
- 'vit_large_patch14_clip_336.openai': _cfg(
- hf_hub_id='timm/',
- notes=('natively QuickGELU, use quickgelu model variant for original results',),
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- crop_pct=1.0, input_size=(3, 336, 336), num_classes=768),
- 'vit_large_patch14_clip_224.apple_mclip2_dfndr2b': _cfg(
- hf_hub_id='timm/',
- num_classes=768,
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
- license='apple-amlr'
- ),
- # experimental (may be removed)
- 'vit_base_patch32_plus_256.untrained': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_base_patch16_plus_240.untrained': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95),
- 'vit_small_patch16_36x1_224.untrained': _cfg(url=''),
- 'vit_small_patch16_18x2_224.untrained': _cfg(url=''),
- 'vit_base_patch16_18x2_224.untrained': _cfg(url=''),
- # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain
- # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip
- 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg(
- # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt',
- hf_hub_id='timm/', license='mit',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 196, 196), crop_pct=1.0),
- 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg(
- # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt',
- hf_hub_id='timm/', license='mit',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
- 'eva_large_patch14_196.in22k_ft_in1k': _cfg(
- # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt',
- hf_hub_id='timm/', license='mit',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 196, 196), crop_pct=1.0),
- 'eva_large_patch14_336.in22k_ft_in1k': _cfg(
- # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt',
- hf_hub_id='timm/', license='mit',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
- input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
- 'flexivit_small.1200ep_in1k': _cfg(
- url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz', custom_load=True,
- hf_hub_id='timm/',
- input_size=(3, 240, 240), crop_pct=0.95),
- 'flexivit_small.600ep_in1k': _cfg(
- url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz', custom_load=True,
- hf_hub_id='timm/',
- input_size=(3, 240, 240), crop_pct=0.95),
- 'flexivit_small.300ep_in1k': _cfg(
- url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz', custom_load=True,
- hf_hub_id='timm/',
- input_size=(3, 240, 240), crop_pct=0.95),
- 'flexivit_base.1200ep_in1k': _cfg(
- url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz', custom_load=True,
- hf_hub_id='timm/',
- input_size=(3, 240, 240), crop_pct=0.95),
- 'flexivit_base.600ep_in1k': _cfg(
- url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz', custom_load=True,
- hf_hub_id='timm/',
- input_size=(3, 240, 240), crop_pct=0.95),
- 'flexivit_base.300ep_in1k': _cfg(
- url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz', custom_load=True,
- hf_hub_id='timm/',
- input_size=(3, 240, 240), crop_pct=0.95),
- 'flexivit_base.1000ep_in21k': _cfg(
- url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz', custom_load=True,
- hf_hub_id='timm/',
- input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
- 'flexivit_base.300ep_in21k': _cfg(
- url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz', custom_load=True,
- hf_hub_id='timm/',
- input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
- 'flexivit_large.1200ep_in1k': _cfg(
- url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz', custom_load=True,
- hf_hub_id='timm/',
- input_size=(3, 240, 240), crop_pct=0.95),
- 'flexivit_large.600ep_in1k': _cfg(
- url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz', custom_load=True,
- hf_hub_id='timm/',
- input_size=(3, 240, 240), crop_pct=0.95),
- 'flexivit_large.300ep_in1k': _cfg(
- url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz', custom_load=True,
- hf_hub_id='timm/',
- input_size=(3, 240, 240), crop_pct=0.95),
- 'flexivit_base.patch16_in21k': _cfg(
- url='https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz', custom_load=True,
- hf_hub_id='timm/',
- input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
- 'flexivit_base.patch30_in21k': _cfg(
- url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True,
- hf_hub_id='timm/',
- input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
- 'vit_base_patch16_xp_224.untrained': _cfg(url=''),
- 'vit_large_patch14_xp_224.untrained': _cfg(url=''),
- 'vit_huge_patch14_xp_224.untrained': _cfg(url=''),
- 'vit_base_patch16_224.mae': _cfg(
- url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth',
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
- 'vit_large_patch16_224.mae': _cfg(
- url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth',
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
- 'vit_huge_patch14_224.mae': _cfg(
- url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_huge.pth',
- hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
- 'vit_huge_patch14_gap_224.in1k_ijepa': _cfg(
- url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
- # hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
- 'vit_huge_patch14_gap_224.in22k_ijepa': _cfg(
- url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar',
- # hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
- 'vit_huge_patch16_gap_448.in1k_ijepa': _cfg(
- url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar',
- # hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- input_size=(3, 448, 448), crop_pct=1.0,
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
- 'vit_giant_patch16_gap_224.in22k_ijepa': _cfg(
- url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar',
- # hf_hub_id='timm/',
- license='cc-by-nc-4.0',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
- 'vit_base_patch32_siglip_256.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_base_patch16_siglip_224.v2_webli': _cfg(
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_base_patch16_siglip_224.webli': _cfg(
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_base_patch16_siglip_256.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_base_patch16_siglip_256.webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_base_patch16_siglip_256.webli_i18n': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_base_patch16_siglip_384.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384),
- num_classes=0),
- 'vit_base_patch16_siglip_384.webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384),
- num_classes=0),
- 'vit_base_patch16_siglip_512.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512),
- num_classes=0),
- 'vit_base_patch16_siglip_512.webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512),
- num_classes=0),
- 'vit_large_patch16_siglip_256.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_large_patch16_siglip_256.webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_large_patch16_siglip_384.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384),
- num_classes=0),
- 'vit_large_patch16_siglip_384.webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384),
- num_classes=0),
- 'vit_large_patch16_siglip_512.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512),
- num_classes=0),
- 'vit_so400m_patch14_siglip_224.v2_webli': _cfg(
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_so400m_patch14_siglip_224.webli': _cfg(
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_so400m_patch14_siglip_378.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 378, 378),
- num_classes=0),
- 'vit_so400m_patch14_siglip_378.webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 378, 378),
- num_classes=0),
- 'vit_so400m_patch14_siglip_384.webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384),
- num_classes=0),
- 'vit_so400m_patch16_siglip_256.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_so400m_patch16_siglip_256.webli_i18n': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_so400m_patch16_siglip_384.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384),
- num_classes=0),
- 'vit_so400m_patch16_siglip_512.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512),
- num_classes=0),
- 'vit_giantopt_patch16_siglip_256.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_giantopt_patch16_siglip_384.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384),
- num_classes=0),
- 'vit_base_patch32_siglip_gap_256.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_base_patch16_siglip_gap_224.v2_webli': _cfg(
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_base_patch16_siglip_gap_224.webli': _cfg(
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_base_patch16_siglip_gap_256.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_base_patch16_siglip_gap_256.webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_base_patch16_siglip_gap_256.webli_i18n': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_base_patch16_siglip_gap_384.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384),
- num_classes=0),
- 'vit_base_patch16_siglip_gap_384.webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384),
- num_classes=0),
- 'vit_base_patch16_siglip_gap_512.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512),
- num_classes=0),
- 'vit_base_patch16_siglip_gap_512.webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512),
- num_classes=0),
- 'vit_large_patch16_siglip_gap_256.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_large_patch16_siglip_gap_256.webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_large_patch16_siglip_gap_384.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384),
- num_classes=0),
- 'vit_large_patch16_siglip_gap_384.webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384),
- num_classes=0),
- 'vit_large_patch16_siglip_gap_512.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512),
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_224.v2_webli': _cfg(
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_224.webli': _cfg(
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_224.pali_mix': _cfg(
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_224.pali_pt': _cfg(
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_224.pali2_3b_pt': _cfg(
- hf_hub_id='timm/',
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_224.pali2_10b_pt': _cfg(
- hf_hub_id='timm/',
- num_classes=0),
- # 'vit_so400m_patch14_siglip_gap_224.pali2_28b_pt': _cfg(
- # hf_hub_id='google/paligemma2-28b-pt-224-jax',
- # hf_hub_filename='pt_27b_224.npz',
- # custom_load='hf',
- # num_classes=0),
- 'vit_so400m_patch14_siglip_gap_378.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 378, 378),
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_378.webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 378, 378), crop_pct=1.0,
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_384.webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), crop_pct=1.0,
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_448.pali_mix': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 448, 448), crop_pct=1.0,
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_448.pali_pt': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 448, 448), crop_pct=1.0,
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_448.pali_refcoco_seg': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 448, 448), crop_pct=1.0,
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_448.pali_ocrvqa': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 448, 448), crop_pct=1.0,
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_448.pali2_3b_pt': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 448, 448), crop_pct=1.0,
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_448.pali2_10b_pt': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 448, 448), crop_pct=1.0,
- num_classes=0),
- # 'vit_so400m_patch14_siglip_gap_448.pali2_28b_pt': _cfg(
- # hf_hub_id='google/paligemma2-28b-pt-448-jax',
- # hf_hub_filename='pt_27b_448.npz',
- # custom_load='hf',
- # input_size=(3, 448, 448), crop_pct=1.0,
- # num_classes=0),
- 'vit_so400m_patch14_siglip_gap_448.pali2_3b_docci': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 448, 448), crop_pct=1.0,
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_448.pali2_10b_docci': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 448, 448), crop_pct=1.0,
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_896.pali_pt': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 896, 896), crop_pct=1.0,
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_896.pali_refcoco_seg': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 896, 896), crop_pct=1.0,
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_896.pali_ocrvqa': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 896, 896), crop_pct=1.0,
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_896.pali2_3b_pt': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 896, 896), crop_pct=1.0,
- num_classes=0),
- 'vit_so400m_patch14_siglip_gap_896.pali2_10b_pt': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 896, 896), crop_pct=1.0,
- num_classes=0),
- # 'vit_so400m_patch14_siglip_gap_896.pali2_28b_pt': _cfg(
- # hf_hub_id='google/paligemma2-28b-pt-896-jax',
- # hf_hub_filename='pt_27b_896.npz',
- # custom_load='hf',
- # input_size=(3, 896, 896), crop_pct=1.0,
- # num_classes=0),
- 'vit_so400m_patch16_siglip_gap_256.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_so400m_patch16_siglip_gap_256.webli_i18n': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_so400m_patch16_siglip_gap_384.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384),
- num_classes=0),
- 'vit_so400m_patch16_siglip_gap_512.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 512, 512),
- num_classes=0),
- 'vit_giantopt_patch16_siglip_gap_256.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256),
- num_classes=0),
- 'vit_giantopt_patch16_siglip_gap_384.v2_webli': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384),
- num_classes=0),
- 'vit_so400m_patch14_siglip_378.webli_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash',
- ),
- 'vit_so400m_patch14_siglip_gap_378.webli_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash',
- ),
- 'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg(
- hf_hub_id='timm/',
- license='mit',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
- 'vit_medium_patch32_clip_224.tinyclip_laion400m': _cfg(
- hf_hub_id='timm/',
- license='mit',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
- 'vit_medium_patch16_clip_224.tinyclip_yfcc15m': _cfg(
- hf_hub_id='timm/',
- license='mit',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
- 'vit_betwixt_patch32_clip_224.tinyclip_laion400m': _cfg(
- hf_hub_id='timm/',
- license='mit',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
- 'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_dwee_patch16_reg1_gap_256.sbb_nadamuon_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_dwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_dpwee_patch16_reg1_gap_256.sbb_nadamuon_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_dpwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_little_patch16_reg1_gap_256.sbb_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_little_patch16_reg1_gap_256.sbb_in12k': _cfg(
- hf_hub_id='timm/',
- num_classes=11821,
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_little_patch16_reg4_gap_256.sbb_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_dlittle_patch16_reg1_gap_256.sbb_nadamuon_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_medium_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_medium_patch16_reg4_gap_256.sbb_in12k': _cfg(
- hf_hub_id='timm/',
- num_classes=11821,
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg(
- hf_hub_id='timm/',
- num_classes=11821,
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg(
- hf_hub_id='timm/',
- num_classes=11821,
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_mediumd_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), crop_pct=1.0),
- 'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg(
- hf_hub_id='timm/',
- num_classes=11821,
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg(
- hf_hub_id='timm/',
- num_classes=11821,
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_betwixt_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), crop_pct=1.0),
- 'vit_base_patch16_reg4_gap_256.untrained': _cfg(
- input_size=(3, 256, 256)),
- 'vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k': _cfg(
- hf_hub_id='timm/',
- num_classes=11821,
- input_size=(3, 256, 256), crop_pct=0.95),
- 'vit_so150m_patch16_reg4_gap_384.sbb_e250_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), crop_pct=1.0),
- 'vit_so150m_patch16_reg4_map_256.untrained': _cfg(
- input_size=(3, 256, 256)),
- 'vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 256, 256), crop_pct=1.0),
- 'vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k': _cfg(
- hf_hub_id='timm/',
- num_classes=11821,
- input_size=(3, 256, 256), crop_pct=1.0),
- 'vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), crop_pct=1.0),
- 'vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash'),
- 'vit_intern300m_patch14_448.ogvl_dist': _cfg(
- hf_hub_id='timm/',
- license='mit',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
- input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
- ),
- 'vit_intern300m_patch14_448.ogvl_2pt5': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
- input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
- ),
- 'aimv2_large_patch14_224.apple_pt': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
- crop_pct=1.0, num_classes=0),
- 'aimv2_large_patch14_224.apple_pt_dist': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
- crop_pct=1.0, num_classes=0),
- 'aimv2_huge_patch14_224.apple_pt': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
- crop_pct=1.0, num_classes=0),
- 'aimv2_1b_patch14_224.apple_pt': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
- crop_pct=1.0, num_classes=0),
- 'aimv2_3b_patch14_224.apple_pt': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
- crop_pct=1.0, num_classes=0),
- 'aimv2_large_patch14_336.apple_pt': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
- input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
- 'aimv2_large_patch14_336.apple_pt_dist': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
- input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
- 'aimv2_huge_patch14_336.apple_pt': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
- input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
- 'aimv2_1b_patch14_336.apple_pt': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
- input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
- 'aimv2_3b_patch14_336.apple_pt': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
- input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
- 'aimv2_large_patch14_448.apple_pt': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
- input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
- 'aimv2_huge_patch14_448.apple_pt': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
- input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
- 'aimv2_1b_patch14_448.apple_pt': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
- input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
- 'aimv2_3b_patch14_448.apple_pt': _cfg(
- hf_hub_id='timm/',
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
- input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
- 'test_vit.r160_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 160, 160), crop_pct=0.95),
- 'test_vit2.r160_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 160, 160), crop_pct=0.95),
- 'test_vit3.r160_in1k': _cfg(
- hf_hub_id='timm/',
- input_size=(3, 160, 160), crop_pct=0.95),
- 'test_vit4.r160_in1k': _cfg(
- input_size=(3, 160, 160), crop_pct=0.95),
- # BEiT3 models (remapped to VisionTransformer with scale_attn_norm=True, scale_mlp_norm=True)
- 'beit3_base_patch16_224.in22k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
- 'beit3_base_patch16_224.indomain_in22k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
- 'beit3_large_patch16_224.in22k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
- 'beit3_large_patch16_224.indomain_in22k_ft_in1k': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
- 'beit3_giant_patch14_224.untrained': _cfg(
- url='', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
- 'beit3_giant_patch14_336.untrained': _cfg(
- url='', input_size=(3, 336, 336), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
- 'beit3_base_patch16_224.pt': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0,
- num_classes=0,
- ),
- 'beit3_base_patch16_224.indomain_pt': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0,
- num_classes=0,
- ),
- 'beit3_large_patch16_224.pt': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0,
- num_classes=0,
- ),
- 'beit3_large_patch16_224.indomain_pt': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0,
- num_classes=0,
- ),
- }
- _quick_gelu_cfgs = [n for n, c in default_cfgs.items() if c.get('notes', ()) and 'quickgelu' in c['notes'][0]]
- for n in _quick_gelu_cfgs:
- # generate quickgelu default cfgs based on contents of notes field
- c = copy.deepcopy(default_cfgs[n])
- if c['hf_hub_id'] == 'timm/':
- c['hf_hub_id'] = 'timm/' + n # need to use non-quickgelu model name for hub id
- default_cfgs[n.replace('_clip_', '_clip_quickgelu_')] = c
- default_cfgs = generate_default_cfgs(default_cfgs)
- # Global flag to use NaFlexVit instead of VisionTransformer
- _USE_NAFLEX_DEFAULT = os.environ.get('TIMM_USE_NAFLEXVIT', 'false').lower() == 'true'
- def _create_vision_transformer(
- variant: str,
- pretrained: bool = False,
- use_naflex: Optional[bool] = None,
- **kwargs,
- ) -> Union[VisionTransformer, 'NaFlexVit']:
- # Check if we should use NaFlexVit instead
- if use_naflex is None:
- use_naflex = _USE_NAFLEX_DEFAULT
- if use_naflex:
- # Import here to avoid circular imports
- from .naflexvit import _create_naflexvit_from_classic
- return _create_naflexvit_from_classic(variant, pretrained, **kwargs)
- out_indices = kwargs.pop('out_indices', 3)
- if 'flexi' in variant:
- # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
- # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
- _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
- else:
- _filter_fn = checkpoint_filter_fn
- # FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln?
- strict = kwargs.pop('pretrained_strict', True)
- if 'siglip' in variant and kwargs.get('global_pool', None) != 'map':
- strict = False
- return build_model_with_cfg(
- VisionTransformer,
- variant,
- pretrained,
- pretrained_filter_fn=_filter_fn,
- pretrained_strict=strict,
- feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
- **kwargs,
- )
- @register_model
- def vit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Tiny (Vit-Ti/16)
- """
- model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
- model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_tiny_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Tiny (Vit-Ti/16) @ 384x384.
- """
- model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
- model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_small_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Small (ViT-S/32)
- """
- model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
- model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_small_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Small (ViT-S/32) at 384x384.
- """
- model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
- model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Small (ViT-S/16)
- """
- model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
- model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_small_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Small (ViT-S/16)
- """
- model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
- model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_small_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Small (ViT-S/8)
- """
- model_args = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6)
- model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
- """
- model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
- model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
- """
- model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
- model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
- """
- model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
- model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
- """
- model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
- model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
- """
- model_args = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12)
- model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
- """
- model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
- model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
- """
- model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
- model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
- """
- model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
- model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
- """
- model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
- model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Large model (ViT-L/14)
- """
- model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16)
- model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
- """
- model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16)
- model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
- """
- model_args = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16)
- model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_gigantic_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
- """
- model_args = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
- model = _create_vision_transformer(
- 'vit_gigantic_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_224_miil(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
- Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
- """
- model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False)
- model = _create_vision_transformer(
- 'vit_base_patch16_224_miil', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_medium_patch16_gap_240(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240
- """
- model_args = dict(
- patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
- global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
- model = _create_vision_transformer(
- 'vit_medium_patch16_gap_240', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_medium_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256
- """
- model_args = dict(
- patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
- global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
- model = _create_vision_transformer(
- 'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_medium_patch16_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384
- """
- model_args = dict(
- patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
- global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
- model = _create_vision_transformer(
- 'vit_medium_patch16_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_betwixt_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Betwixt (ViT-b/16) w/o class token, w/ avg-pool @ 256x256
- """
- model_args = dict(
- patch_size=16, embed_dim=640, depth=12, num_heads=10, class_token=False,
- global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
- model = _create_vision_transformer(
- 'vit_betwixt_patch16_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224
- """
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
- model = _create_vision_transformer(
- 'vit_base_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_huge_patch14_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Huge model (ViT-H/14) w/ no class token, avg pool
- """
- model_args = dict(
- patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
- model = _create_vision_transformer(
- 'vit_huge_patch14_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_huge_patch16_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448
- """
- model_args = dict(
- patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
- model = _create_vision_transformer(
- 'vit_huge_patch16_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool
- """
- model_args = dict(
- patch_size=16, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
- class_token=False, global_pool='avg', fc_norm=False)
- model = _create_vision_transformer(
- 'vit_giant_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_xsmall_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- # TinyCLIP 8M
- model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_xsmall_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_medium_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- # TinyCLIP 40M
- model_args = dict(
- patch_size=32, embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_medium_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_medium_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- # TinyCLIP 39M
- model_args = dict(embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_medium_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_betwixt_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- # TinyCLIP 61M
- model_args = dict(
- patch_size=32, embed_dim=640, depth=12, num_heads=10, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_betwixt_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-B/32 CLIP image tower @ 224x224
- """
- model_args = dict(
- patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch32_clip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-B/32 CLIP image tower @ 256x256
- """
- model_args = dict(
- patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_base_patch32_clip_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch32_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-B/32 CLIP image tower @ 384x384
- """
- model_args = dict(
- patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch32_clip_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-B/32 CLIP image tower @ 448x448
- """
- model_args = dict(
- patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-B/16 CLIP image tower
- """
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-B/16 CLIP image tower @ 384x384
- """
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_plus_clip_240(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Base (ViT-B/16+) CLIP image tower @ 240x240
- """
- model_args = dict(
- patch_size=16, embed_dim=896, depth=12, num_heads=14, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_base_patch16_plus_clip_240', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Large model (ViT-L/14) CLIP image tower
- """
- model_args = dict(
- patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336
- """
- model_args = dict(
- patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_huge_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Huge model (ViT-H/14) CLIP image tower.
- """
- model_args = dict(
- patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_huge_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336
- """
- model_args = dict(
- patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_huge_patch14_clip_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378
- """
- model_args = dict(
- patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
- model = _create_vision_transformer(
- 'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_giant_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
- Pretrained weights from CLIP image tower.
- """
- model_args = dict(
- patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True,
- norm_layer=partial(LayerNorm, eps=1e-5),
- )
- model = _create_vision_transformer(
- 'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_gigantic_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
- Pretrained weights from CLIP image tower.
- """
- model_args = dict(
- patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True,
- norm_layer=partial(LayerNorm, eps=1e-5),
- )
- model = _create_vision_transformer(
- 'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_gigantic_patch14_clip_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
- Pretrained weights from CLIP image tower.
- """
- model_args = dict(
- patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True,
- norm_layer=partial(LayerNorm, eps=1e-5),
- )
- model = _create_vision_transformer(
- 'vit_gigantic_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch32_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-B/32 CLIP image tower @ 224x224
- """
- model_args = dict(
- patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
- norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
- )
- model = _create_vision_transformer(
- 'vit_base_patch32_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-B/16 CLIP image tower w/ QuickGELU act
- """
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
- norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
- )
- model = _create_vision_transformer(
- 'vit_base_patch16_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act
- """
- model_args = dict(
- patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
- norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
- )
- model = _create_vision_transformer(
- 'vit_large_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch14_clip_quickgelu_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 w/ QuickGELU act
- """
- model_args = dict(
- patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
- norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
- )
- model = _create_vision_transformer(
- 'vit_large_patch14_clip_quickgelu_336', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_huge_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Huge model (ViT-H/14) CLIP image tower w/ QuickGELU act.
- """
- model_args = dict(
- patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
- norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
- )
- model = _create_vision_transformer(
- 'vit_huge_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_huge_patch14_clip_quickgelu_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 w/ QuickGELU act
- """
- model_args = dict(
- patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
- norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
- )
- model = _create_vision_transformer(
- 'vit_huge_patch14_clip_quickgelu_378', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_gigantic_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-bigG model (ViT-G/14) w/ QuickGELU act
- """
- model_args = dict(
- patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True,
- norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
- )
- model = _create_vision_transformer(
- 'vit_gigantic_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- # Experimental models below
- @register_model
- def vit_base_patch32_plus_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Base (ViT-B/32+)
- """
- model_args = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
- model = _create_vision_transformer(
- 'vit_base_patch32_plus_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_plus_240(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Base (ViT-B/16+)
- """
- model_args = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
- model = _create_vision_transformer(
- 'vit_base_patch16_plus_240', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_rpn_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Base (ViT-B/16) w/ residual post-norm
- """
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
- class_token=False, block_fn=ResPostBlock, global_pool='avg')
- model = _create_vision_transformer(
- 'vit_base_patch16_rpn_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_small_patch16_36x1_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove.
- Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
- Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
- """
- model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5)
- model = _create_vision_transformer(
- 'vit_small_patch16_36x1_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_small_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
- Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
- Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
- """
- model_args = dict(
- patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelThingsBlock)
- model = _create_vision_transformer(
- 'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
- Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
- """
- model_args = dict(
- patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelThingsBlock)
- model = _create_vision_transformer(
- 'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def eva_large_patch14_196(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain"""
- model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
- model = _create_vision_transformer(
- 'eva_large_patch14_196', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def eva_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain"""
- model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
- model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def flexivit_small(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ FlexiViT-Small
- """
- model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True)
- model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def flexivit_base(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ FlexiViT-Base
- """
- model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
- model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def flexivit_large(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ FlexiViT-Large
- """
- model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
- model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
- """
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, no_embed_class=True,
- norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
- )
- model = _create_vision_transformer(
- 'vit_base_patch16_xp_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
- """
- model_args = dict(
- patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, no_embed_class=True,
- norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
- )
- model = _create_vision_transformer(
- 'vit_large_patch14_xp_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-Huge model (ViT-H/14) w/ parallel blocks and qk norm enabled.
- """
- model_args = dict(
- patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, no_embed_class=True,
- norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
- )
- model = _create_vision_transformer(
- 'vit_huge_patch14_xp_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-S/14 for DINOv2
- """
- model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5)
- model = _create_vision_transformer(
- 'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-B/14 for DINOv2
- """
- model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5)
- model = _create_vision_transformer(
- 'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-L/14 for DINOv2
- """
- model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5)
- model = _create_vision_transformer(
- 'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-G/14 for DINOv2
- """
- # The hidden_features of SwiGLU is calculated by:
- # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
- # When embed_dim=1536, hidden_features=4096
- # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
- model_args = dict(
- patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
- mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, act_layer=nn.SiLU
- )
- model = _create_vision_transformer(
- 'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_small_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-S/14 for DINOv2 w/ 4 registers
- """
- model_args = dict(
- patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5,
- reg_tokens=4, no_embed_class=True,
- )
- model = _create_vision_transformer(
- 'vit_small_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-B/14 for DINOv2 w/ 4 registers
- """
- model_args = dict(
- patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
- reg_tokens=4, no_embed_class=True,
- )
- model = _create_vision_transformer(
- 'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-L/14 for DINOv2 w/ 4 registers
- """
- model_args = dict(
- patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
- reg_tokens=4, no_embed_class=True,
- )
- model = _create_vision_transformer(
- 'vit_large_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_giant_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT-G/14 for DINOv2
- """
- # The hidden_features of SwiGLU is calculated by:
- # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
- # When embed_dim=1536, hidden_features=4096
- # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
- model_args = dict(
- patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2,
- mlp_layer=SwiGLUPacked, act_layer=nn.SiLU, reg_tokens=4, no_embed_class=True,
- )
- model = _create_vision_transformer(
- 'vit_giant_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch32_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=32, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
- act_layer='gelu_tanh',
- )
- model = _create_vision_transformer(
- 'vit_base_patch32_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
- )
- model = _create_vision_transformer(
- 'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
- )
- model = _create_vision_transformer(
- 'vit_base_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
- )
- model = _create_vision_transformer(
- 'vit_base_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
- )
- model = _create_vision_transformer(
- 'vit_base_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
- )
- model = _create_vision_transformer(
- 'vit_large_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
- )
- model = _create_vision_transformer(
- 'vit_large_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
- act_layer='gelu_tanh'
- )
- model = _create_vision_transformer(
- 'vit_large_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so400m_patch14_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
- )
- model = _create_vision_transformer(
- 'vit_so400m_patch14_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so400m_patch14_siglip_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
- # this is a corrected variant of the 384 with a res properly divisible by patch size (no padding/truncation)
- model_args = dict(
- patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
- )
- model = _create_vision_transformer(
- 'vit_so400m_patch14_siglip_378', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
- )
- model = _create_vision_transformer(
- 'vit_so400m_patch14_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so400m_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
- act_layer='gelu_tanh',
- )
- model = _create_vision_transformer(
- 'vit_so400m_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so400m_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
- act_layer='gelu_tanh',
- )
- model = _create_vision_transformer(
- 'vit_so400m_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so400m_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
- act_layer='gelu_tanh',
- )
- model = _create_vision_transformer(
- 'vit_so400m_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_giantopt_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False, global_pool='map',
- act_layer='gelu_tanh',
- )
- model = _create_vision_transformer(
- 'vit_giantopt_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_giantopt_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False, global_pool='map',
- act_layer='gelu_tanh',
- )
- model = _create_vision_transformer(
- 'vit_giantopt_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch32_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=32, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
- act_layer='gelu_tanh',
- )
- model = _create_vision_transformer(
- 'vit_base_patch32_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
- )
- model = _create_vision_transformer(
- 'vit_base_patch16_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
- )
- model = _create_vision_transformer(
- 'vit_base_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
- )
- model = _create_vision_transformer(
- 'vit_base_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
- )
- model = _create_vision_transformer(
- 'vit_base_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
- model_args = dict(
- patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False,
- )
- model = _create_vision_transformer(
- 'vit_large_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
- model_args = dict(
- patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False,
- )
- model = _create_vision_transformer(
- 'vit_large_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_large_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False,
- global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
- )
- model = _create_vision_transformer(
- 'vit_large_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
- model_args = dict(
- patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
- class_token=False, global_pool='avg', fc_norm=False,
- )
- model = _create_vision_transformer(
- 'vit_so400m_patch14_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so400m_patch14_siglip_gap_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
- model_args = dict(
- patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
- class_token=False, global_pool='avg', fc_norm=False,
- )
- model = _create_vision_transformer(
- 'vit_so400m_patch14_siglip_gap_378', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so400m_patch14_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
- model_args = dict(
- patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
- class_token=False, global_pool='avg', fc_norm=False,
- )
- model = _create_vision_transformer(
- 'vit_so400m_patch14_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so400m_patch14_siglip_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
- model_args = dict(
- patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
- class_token=False, global_pool='avg', fc_norm=False,
- )
- model = _create_vision_transformer(
- 'vit_so400m_patch14_siglip_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so400m_patch14_siglip_gap_896(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
- model_args = dict(
- patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
- class_token=False, global_pool='avg', fc_norm=False,
- )
- model = _create_vision_transformer(
- 'vit_so400m_patch14_siglip_gap_896', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so400m_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
- model_args = dict(
- patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
- class_token=False, global_pool='avg', fc_norm=False, act_layer='gelu_tanh',
- )
- model = _create_vision_transformer(
- 'vit_so400m_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so400m_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False,
- global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
- )
- model = _create_vision_transformer(
- 'vit_so400m_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so400m_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False,
- global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
- )
- model = _create_vision_transformer(
- 'vit_so400m_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_giantopt_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False,
- global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
- )
- model = _create_vision_transformer(
- 'vit_giantopt_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_giantopt_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False,
- global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
- )
- model = _create_vision_transformer(
- 'vit_giantopt_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
- class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
- )
- model = _create_vision_transformer(
- 'vit_wee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_dwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
- class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', attn_layer='diff',
- )
- model = _create_vision_transformer(
- 'vit_dwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5,
- class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock,
- )
- model = _create_vision_transformer(
- 'vit_pwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_dpwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5,
- class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=DiffParallelScalingBlock,
- )
- model = _create_vision_transformer(
- 'vit_dpwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_little_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6,
- class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
- )
- model = _create_vision_transformer(
- 'vit_little_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_dlittle_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6,
- class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', attn_layer='diff',
- )
- model = _create_vision_transformer(
- 'vit_dlittle_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6,
- class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
- )
- model = _create_vision_transformer(
- 'vit_little_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_medium_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5,
- class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
- )
- model = _create_vision_transformer(
- 'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5,
- class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
- )
- model = _create_vision_transformer(
- 'vit_medium_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_mediumd_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
- class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
- )
- model = _create_vision_transformer(
- 'vit_mediumd_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_mediumd_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
- class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
- )
- model = _create_vision_transformer(
- 'vit_mediumd_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5,
- class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
- )
- model = _create_vision_transformer(
- 'vit_betwixt_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_betwixt_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5,
- class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
- )
- model = _create_vision_transformer(
- 'vit_betwixt_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_betwixt_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5,
- class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
- )
- model = _create_vision_transformer(
- 'vit_betwixt_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False,
- no_embed_class=True, global_pool='avg', reg_tokens=4,
- )
- model = _create_vision_transformer(
- 'vit_base_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ SO150M (shape optimized, but diff than paper def, optimized for GPU) """
- model_args = dict(
- patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
- class_token=False, reg_tokens=4, global_pool='map',
- )
- model = _create_vision_transformer(
- 'vit_so150m_patch16_reg4_map_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ SO150M (shape optimized, but diff than paper def, optimized for GPU) """
- model_args = dict(
- patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
- class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
- )
- model = _create_vision_transformer(
- 'vit_so150m_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so150m_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ SO150M (shape optimized, but diff than paper def, optimized for GPU) """
- model_args = dict(
- patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
- class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
- )
- model = _create_vision_transformer(
- 'vit_so150m_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so150m2_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """
- model_args = dict(
- patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
- qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg',
- )
- model = _create_vision_transformer(
- 'vit_so150m2_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so150m2_patch16_reg1_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """
- model_args = dict(
- patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
- qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg',
- )
- model = _create_vision_transformer(
- 'vit_so150m2_patch16_reg1_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_so150m2_patch16_reg1_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """
- model_args = dict(
- patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
- qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg',
- )
- model = _create_vision_transformer(
- 'vit_so150m2_patch16_reg1_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
- model_args = dict(
- patch_size=14, embed_dim=1024, depth=24, num_heads=16,
- init_values=0.1, final_norm=False, dynamic_img_size=True,
- )
- model = _create_vision_transformer(
- 'vit_intern300m_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def aimv2_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT Large AIM-v2 model
- """
- model_args = dict(
- patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
- mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
- norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
- )
- model = _create_vision_transformer(
- 'aimv2_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def aimv2_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT Huge AIM-v2 model
- """
- model_args = dict(
- patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
- mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
- norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
- )
- model = _create_vision_transformer(
- 'aimv2_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def aimv2_1b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT 1B AIM-v2 model
- """
- model_args = dict(
- patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
- mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
- norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
- )
- model = _create_vision_transformer(
- 'aimv2_1b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def aimv2_3b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT 3B AIM-v2 model
- """
- model_args = dict(
- patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
- mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
- norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
- )
- model = _create_vision_transformer(
- 'aimv2_3b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def aimv2_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT Large AIM-v2 model
- """
- model_args = dict(
- patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
- mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
- norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
- )
- model = _create_vision_transformer(
- 'aimv2_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def aimv2_huge_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT Huge AIM-v2 model
- """
- model_args = dict(
- patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
- mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
- norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
- )
- model = _create_vision_transformer(
- 'aimv2_huge_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def aimv2_1b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT 1B AIM-v2 model
- """
- model_args = dict(
- patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
- mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
- norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
- )
- model = _create_vision_transformer(
- 'aimv2_1b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def aimv2_3b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT 3B AIM-v2 model
- """
- model_args = dict(
- patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
- mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
- norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
- )
- model = _create_vision_transformer(
- 'aimv2_3b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def aimv2_large_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT Large AIM-v2 model
- """
- model_args = dict(
- patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
- mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
- norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
- )
- model = _create_vision_transformer(
- 'aimv2_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def aimv2_huge_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT Huge AIM-v2 model
- """
- model_args = dict(
- patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
- mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
- norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
- )
- model = _create_vision_transformer(
- 'aimv2_huge_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def aimv2_1b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT 1B AIM-v2 model
- """
- model_args = dict(
- patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
- mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
- norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
- )
- model = _create_vision_transformer(
- 'aimv2_1b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def aimv2_3b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT 3B AIM-v2 model
- """
- model_args = dict(
- patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
- mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
- norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
- )
- model = _create_vision_transformer(
- 'aimv2_3b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT Test
- """
- model_args = dict(patch_size=16, embed_dim=64, depth=6, num_heads=2, mlp_ratio=3, dynamic_img_size=True)
- model = _create_vision_transformer('test_vit', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def test_vit2(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT Test
- """
- model_args = dict(
- patch_size=16, embed_dim=64, depth=8, num_heads=2, mlp_ratio=3,
- class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5, dynamic_img_size=True)
- model = _create_vision_transformer('test_vit2', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def test_vit3(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT Test
- """
- model_args = dict(
- patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=2,
- class_token=False, reg_tokens=1, global_pool='map', pool_include_prefix=True, init_values=1e-5)
- model = _create_vision_transformer('test_vit3', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ ViT Test
- """
- model_args = dict(
- patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=3,
- class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5, dynamic_img_size=True,
- norm_layer='rmsnorm',
- )
- model = _create_vision_transformer('test_vit4', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def beit3_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ BEiT3 Base model (ViT-Base size) with patch size 16x16.
- Remapped to VisionTransformer with scale_norm=True.
- """
- model_args = dict(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
- scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg',
- norm_layer=partial(LayerNorm, eps=1e-5)
- )
- model = _create_vision_transformer('beit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def beit3_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ BEiT3 Large model (ViT-Large size) with patch size 16x16.
- Remapped to VisionTransformer with scale_norm=True.
- """
- model_args = dict(
- patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
- scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg',
- norm_layer=partial(LayerNorm, eps=1e-5),
- )
- model = _create_vision_transformer('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def beit3_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ BEiT3 Giant model with patch size 14x14.
- Remapped to VisionTransformer with scale_norm=True.
- """
- model_args = dict(
- patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637,
- scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg',
- norm_layer=partial(LayerNorm, eps=1e-5),
- )
- model = _create_vision_transformer('beit3_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def beit3_giant_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
- """ BEiT3 Giant model with patch size 14x14 and image size 336x336.
- Remapped to VisionTransformer with scale_norm=True.
- """
- model_args = dict(
- img_size=336, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637,
- scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg',
- norm_layer=partial(LayerNorm, eps=1e-5),
- )
- model = _create_vision_transformer('beit3_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- register_model_deprecations(__name__, {
- 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
- 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',
- 'vit_small_patch16_224_in21k': 'vit_small_patch16_224.augreg_in21k',
- 'vit_base_patch32_224_in21k': 'vit_base_patch32_224.augreg_in21k',
- 'vit_base_patch16_224_in21k': 'vit_base_patch16_224.augreg_in21k',
- 'vit_base_patch8_224_in21k': 'vit_base_patch8_224.augreg_in21k',
- 'vit_large_patch32_224_in21k': 'vit_large_patch32_224.orig_in21k',
- 'vit_large_patch16_224_in21k': 'vit_large_patch16_224.augreg_in21k',
- 'vit_huge_patch14_224_in21k': 'vit_huge_patch14_224.orig_in21k',
- 'vit_base_patch32_224_sam': 'vit_base_patch32_224.sam',
- 'vit_base_patch16_224_sam': 'vit_base_patch16_224.sam',
- 'vit_small_patch16_224_dino': 'vit_small_patch16_224.dino',
- 'vit_small_patch8_224_dino': 'vit_small_patch8_224.dino',
- 'vit_base_patch16_224_dino': 'vit_base_patch16_224.dino',
- 'vit_base_patch8_224_dino': 'vit_base_patch8_224.dino',
- 'vit_base_patch16_224_miil_in21k': 'vit_base_patch16_224_miil.in21k',
- 'vit_base_patch32_224_clip_laion2b': 'vit_base_patch32_clip_224.laion2b',
- 'vit_large_patch14_224_clip_laion2b': 'vit_large_patch14_clip_224.laion2b',
- 'vit_huge_patch14_224_clip_laion2b': 'vit_huge_patch14_clip_224.laion2b',
- 'vit_giant_patch14_224_clip_laion2b': 'vit_giant_patch14_clip_224.laion2b',
- })
|