vision_transformer.py 203 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839
  1. """ Vision Transformer (ViT) in PyTorch
  2. A PyTorch implement of Vision Transformers as described in:
  3. 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
  4. - https://arxiv.org/abs/2010.11929
  5. `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
  6. - https://arxiv.org/abs/2106.10270
  7. `FlexiViT: One Model for All Patch Sizes`
  8. - https://arxiv.org/abs/2212.08013
  9. The official jax code is released and available at
  10. * https://github.com/google-research/vision_transformer
  11. * https://github.com/google-research/big_vision
  12. Acknowledgments:
  13. * The paper authors for releasing code and weights, thanks!
  14. * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch
  15. * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
  16. * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
  17. Hacked together by / Copyright 2020, Ross Wightman
  18. """
  19. import copy
  20. import logging
  21. import math
  22. import os
  23. from collections import OrderedDict
  24. from functools import partial
  25. from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List
  26. try:
  27. from typing import Literal
  28. except ImportError:
  29. from typing_extensions import Literal
  30. import torch
  31. import torch.nn as nn
  32. import torch.nn.functional as F
  33. from torch.jit import Final
  34. from timm.data import (
  35. IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD,
  36. IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,
  37. OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
  38. )
  39. from timm.layers import (
  40. Attention,
  41. DiffAttention,
  42. AttentionPoolLatent,
  43. AttentionPoolPrr,
  44. PatchEmbed,
  45. Mlp,
  46. SwiGLUPacked,
  47. SwiGLU,
  48. LayerNorm,
  49. RmsNorm,
  50. DropPath,
  51. calculate_drop_path_rates,
  52. PatchDropout,
  53. trunc_normal_,
  54. lecun_normal_,
  55. resample_patch_embed,
  56. resample_abs_pos_embed,
  57. use_fused_attn,
  58. get_act_layer,
  59. get_norm_layer,
  60. maybe_add_mask,
  61. resolve_self_attn_mask,
  62. LayerType,
  63. LayerScale,
  64. )
  65. from ._builder import build_model_with_cfg
  66. from ._features import feature_take_indices
  67. from ._manipulate import named_apply, checkpoint, checkpoint_seq, adapt_input_conv
  68. from ._registry import generate_default_cfgs, register_model, register_model_deprecations
  69. __all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this
  70. _logger = logging.getLogger(__name__)
  71. ATTN_LAYERS = {
  72. '': Attention,
  73. 'attn': Attention,
  74. 'diff': DiffAttention,
  75. }
  76. def _create_attn(
  77. attn_layer: LayerType,
  78. dim: int,
  79. num_heads: int,
  80. qkv_bias: bool = False,
  81. qk_norm: bool = False,
  82. scale_norm: bool = False,
  83. proj_bias: bool = True,
  84. attn_drop: float = 0.,
  85. proj_drop: float = 0.,
  86. norm_layer: Optional[Type[nn.Module]] = None,
  87. depth: int = 0,
  88. **kwargs,
  89. ) -> nn.Module:
  90. if isinstance(attn_layer, str):
  91. attn_layer = ATTN_LAYERS.get(attn_layer, None)
  92. assert attn_layer is not None, f'Unknown attn_layer: {attn_layer}'
  93. # Only pass depth to attention layers that use it
  94. if issubclass(attn_layer, DiffAttention):
  95. kwargs['depth'] = depth
  96. return attn_layer(
  97. dim,
  98. num_heads=num_heads,
  99. qkv_bias=qkv_bias,
  100. qk_norm=qk_norm,
  101. scale_norm=scale_norm,
  102. proj_bias=proj_bias,
  103. attn_drop=attn_drop,
  104. proj_drop=proj_drop,
  105. norm_layer=norm_layer,
  106. **kwargs,
  107. )
  108. class Block(nn.Module):
  109. """Transformer block with pre-normalization."""
  110. def __init__(
  111. self,
  112. dim: int,
  113. num_heads: int,
  114. mlp_ratio: float = 4.,
  115. qkv_bias: bool = False,
  116. qk_norm: bool = False,
  117. scale_attn_norm: bool = False,
  118. scale_mlp_norm: bool = False,
  119. proj_bias: bool = True,
  120. proj_drop: float = 0.,
  121. attn_drop: float = 0.,
  122. init_values: Optional[float] = None,
  123. drop_path: float = 0.,
  124. act_layer: Type[nn.Module] = nn.GELU,
  125. norm_layer: Type[nn.Module] = LayerNorm,
  126. mlp_layer: Type[nn.Module] = Mlp,
  127. attn_layer: LayerType = Attention,
  128. depth: int = 0,
  129. device=None,
  130. dtype=None,
  131. ) -> None:
  132. """Initialize Block.
  133. Args:
  134. dim: Number of input channels.
  135. num_heads: Number of attention heads.
  136. mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  137. qkv_bias: If True, add a learnable bias to query, key, value.
  138. qk_norm: If True, apply normalization to query and key.
  139. proj_bias: If True, add bias to output projection.
  140. proj_drop: Projection dropout rate.
  141. attn_drop: Attention dropout rate.
  142. init_values: Initial values for layer scale.
  143. drop_path: Stochastic depth rate.
  144. act_layer: Activation layer.
  145. norm_layer: Normalization layer.
  146. mlp_layer: MLP layer.
  147. attn_layer: Attention layer type (class or string).
  148. depth: Block index, passed to attention layer for depth-dependent init.
  149. """
  150. super().__init__()
  151. dd = {'device': device, 'dtype': dtype}
  152. self.norm1 = norm_layer(dim, **dd)
  153. self.attn = _create_attn(
  154. attn_layer,
  155. dim,
  156. num_heads=num_heads,
  157. qkv_bias=qkv_bias,
  158. qk_norm=qk_norm,
  159. scale_norm=scale_attn_norm,
  160. proj_bias=proj_bias,
  161. attn_drop=attn_drop,
  162. proj_drop=proj_drop,
  163. norm_layer=norm_layer,
  164. depth=depth,
  165. **dd,
  166. )
  167. self.ls1 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()
  168. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  169. self.norm2 = norm_layer(dim, **dd)
  170. self.mlp = mlp_layer(
  171. in_features=dim,
  172. hidden_features=int(dim * mlp_ratio),
  173. act_layer=act_layer,
  174. norm_layer=norm_layer if scale_mlp_norm else None,
  175. bias=proj_bias,
  176. drop=proj_drop,
  177. **dd,
  178. )
  179. self.ls2 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()
  180. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  181. def forward(
  182. self,
  183. x: torch.Tensor,
  184. attn_mask: Optional[torch.Tensor] = None,
  185. is_causal: bool = False,
  186. ) -> torch.Tensor:
  187. x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask, is_causal=is_causal)))
  188. x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
  189. return x
  190. class ResPostBlock(nn.Module):
  191. def __init__(
  192. self,
  193. dim: int,
  194. num_heads: int,
  195. mlp_ratio: float = 4.,
  196. qkv_bias: bool = False,
  197. qk_norm: bool = False,
  198. scale_attn_norm: bool = False,
  199. scale_mlp_norm: bool = False,
  200. proj_bias: bool = True,
  201. proj_drop: float = 0.,
  202. attn_drop: float = 0.,
  203. init_values: Optional[float] = None,
  204. drop_path: float = 0.,
  205. act_layer: Type[nn.Module] = nn.GELU,
  206. norm_layer: Type[nn.Module] = LayerNorm,
  207. mlp_layer: Type[nn.Module] = Mlp,
  208. attn_layer: LayerType = Attention,
  209. depth: int = 0,
  210. device=None,
  211. dtype=None,
  212. ) -> None:
  213. super().__init__()
  214. dd = {'device': device, 'dtype': dtype}
  215. self.init_values = init_values
  216. self.attn = _create_attn(
  217. attn_layer,
  218. dim,
  219. num_heads=num_heads,
  220. qkv_bias=qkv_bias,
  221. qk_norm=qk_norm,
  222. scale_norm=scale_attn_norm,
  223. proj_bias=proj_bias,
  224. attn_drop=attn_drop,
  225. proj_drop=proj_drop,
  226. norm_layer=norm_layer,
  227. depth=depth,
  228. **dd,
  229. )
  230. self.norm1 = norm_layer(dim, **dd)
  231. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  232. self.mlp = mlp_layer(
  233. in_features=dim,
  234. hidden_features=int(dim * mlp_ratio),
  235. act_layer=act_layer,
  236. norm_layer=norm_layer if scale_mlp_norm else None,
  237. bias=proj_bias,
  238. drop=proj_drop,
  239. **dd,
  240. )
  241. self.norm2 = norm_layer(dim, **dd)
  242. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  243. self.init_weights()
  244. def init_weights(self) -> None:
  245. # NOTE this init overrides that base model init with specific changes for the block type
  246. if self.init_values is not None:
  247. nn.init.constant_(self.norm1.weight, self.init_values)
  248. nn.init.constant_(self.norm2.weight, self.init_values)
  249. def forward(
  250. self,
  251. x: torch.Tensor,
  252. attn_mask: Optional[torch.Tensor] = None,
  253. is_causal: bool = False,
  254. ) -> torch.Tensor:
  255. x = x + self.drop_path1(self.norm1(self.attn(x, attn_mask=attn_mask, is_causal=is_causal)))
  256. x = x + self.drop_path2(self.norm2(self.mlp(x)))
  257. return x
  258. class ParallelScalingBlock(nn.Module):
  259. """ Parallel ViT block (MLP & Attention in parallel)
  260. Based on:
  261. 'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
  262. """
  263. fused_attn: Final[bool]
  264. def __init__(
  265. self,
  266. dim: int,
  267. num_heads: int,
  268. mlp_ratio: float = 4.,
  269. qkv_bias: bool = False,
  270. qk_norm: bool = False,
  271. scale_attn_norm: bool = False,
  272. scale_mlp_norm: bool = False,
  273. proj_bias: bool = True,
  274. proj_drop: float = 0.,
  275. attn_drop: float = 0.,
  276. init_values: Optional[float] = None,
  277. drop_path: float = 0.,
  278. act_layer: Type[nn.Module] = nn.GELU,
  279. norm_layer: Type[nn.Module] = LayerNorm,
  280. mlp_layer: Optional[Type[nn.Module]] = None, # not used
  281. attn_layer: Optional[LayerType] = None, # not used
  282. depth: int = 0, # not used
  283. fuse_out_proj: bool = False,
  284. device=None,
  285. dtype=None,
  286. ) -> None:
  287. super().__init__()
  288. dd = {'device': device, 'dtype': dtype}
  289. assert dim % num_heads == 0, 'dim should be divisible by num_heads'
  290. assert not scale_attn_norm and not scale_mlp_norm, 'Scale norms not supported'
  291. self.num_heads = num_heads
  292. self.head_dim = dim // num_heads
  293. self.scale = self.head_dim ** -0.5
  294. self.fused_attn = use_fused_attn()
  295. mlp_hidden_dim = int(mlp_ratio * dim)
  296. in_proj_out_dim = mlp_hidden_dim + 3 * dim
  297. self.in_norm = norm_layer(dim, **dd)
  298. self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias, **dd)
  299. self.in_split = [mlp_hidden_dim] + [dim] * 3
  300. if qkv_bias:
  301. # mlp_bias is combined with qkv_bias in in_proj.bias
  302. self.register_parameter('mlp_bias', None)
  303. else:
  304. self.mlp_bias = nn.Parameter(torch.empty(mlp_hidden_dim, **dd))
  305. self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  306. self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  307. self.attn_drop = nn.Dropout(attn_drop)
  308. self.mlp_drop = nn.Dropout(proj_drop)
  309. self.mlp_act = act_layer()
  310. if fuse_out_proj:
  311. # Fused output projection for both attention and MLP
  312. self.out_proj = nn.Linear(dim + mlp_hidden_dim, dim, bias=proj_bias, **dd)
  313. self.attn_out_proj = None
  314. self.mlp_out_proj = None
  315. else:
  316. # Separate output projections
  317. self.out_proj = None
  318. self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias, **dd)
  319. self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias, **dd)
  320. self.ls = LayerScale(dim, init_values=init_values, **dd) if init_values is not None else nn.Identity()
  321. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  322. # TODO: skip init when on meta device when safe to do so
  323. self.reset_parameters()
  324. def reset_parameters(self) -> None:
  325. """Initialize parameters and buffers."""
  326. if self.mlp_bias is not None:
  327. nn.init.zeros_(self.mlp_bias)
  328. def forward(
  329. self,
  330. x: torch.Tensor,
  331. attn_mask: Optional[torch.Tensor] = None,
  332. is_causal: bool = False,
  333. ) -> torch.Tensor:
  334. B, N, C = x.shape
  335. # Combined MLP fc1 & qkv projections
  336. y = self.in_norm(x)
  337. y = self.in_proj(y)
  338. x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1)
  339. if self.mlp_bias is not None:
  340. x_mlp = x_mlp + self.mlp_bias
  341. # Dot product attention w/ qk norm
  342. q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
  343. k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
  344. v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
  345. if self.fused_attn:
  346. x_attn = F.scaled_dot_product_attention(
  347. q, k, v,
  348. attn_mask=attn_mask,
  349. dropout_p=self.attn_drop.p if self.training else 0.,
  350. is_causal=is_causal,
  351. )
  352. else:
  353. q = q * self.scale
  354. attn = q @ k.transpose(-2, -1)
  355. attn_bias = resolve_self_attn_mask(N, attn, attn_mask, is_causal=is_causal)
  356. attn = maybe_add_mask(attn, attn_bias)
  357. attn = attn.softmax(dim=-1)
  358. attn = self.attn_drop(attn)
  359. x_attn = attn @ v
  360. x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
  361. # MLP activation & dropout
  362. x_mlp = self.mlp_act(x_mlp)
  363. x_mlp = self.mlp_drop(x_mlp)
  364. # Output projection (fused or separate)
  365. if self.out_proj is not None:
  366. y = self.out_proj(torch.cat((x_attn, x_mlp), dim=-1))
  367. else:
  368. y = self.attn_out_proj(x_attn) + self.mlp_out_proj(x_mlp)
  369. # Add residual w/ drop path & layer scale applied
  370. x = x + self.drop_path(self.ls(y))
  371. return x
  372. class DiffParallelScalingBlock(nn.Module):
  373. """ Parallel ViT block with Differential Attention (MLP & Attention in parallel).
  374. Combines the parallel MLP+Attention structure from 'Scaling Vision Transformers to
  375. 22 Billion Parameters' (https://arxiv.org/abs/2302.05442) with differential attention
  376. from 'Differential Transformer' (https://arxiv.org/abs/2410.05258).
  377. """
  378. fused_attn: Final[bool]
  379. def __init__(
  380. self,
  381. dim: int,
  382. num_heads: int,
  383. mlp_ratio: float = 4.,
  384. qkv_bias: bool = False,
  385. qk_norm: bool = False,
  386. scale_attn_norm: bool = False,
  387. scale_mlp_norm: bool = False,
  388. proj_bias: bool = True,
  389. proj_drop: float = 0.,
  390. attn_drop: float = 0.,
  391. init_values: Optional[float] = None,
  392. drop_path: float = 0.,
  393. act_layer: Type[nn.Module] = nn.GELU,
  394. norm_layer: Type[nn.Module] = LayerNorm,
  395. mlp_layer: Optional[Type[nn.Module]] = None,
  396. attn_layer: Optional[LayerType] = None,
  397. depth: int = 0,
  398. dual_lambda: bool = False,
  399. device=None,
  400. dtype=None,
  401. ) -> None:
  402. super().__init__()
  403. dd = {'device': device, 'dtype': dtype}
  404. assert dim % num_heads == 0, 'dim should be divisible by num_heads'
  405. assert not scale_attn_norm and not scale_mlp_norm, 'Scale norms not supported'
  406. self.num_heads = num_heads
  407. self.head_dim = dim // num_heads // 2 # Half head_dim for diff attention
  408. self.scale = self.head_dim ** -0.5
  409. self.fused_attn = use_fused_attn()
  410. mlp_hidden_dim = int(mlp_ratio * dim)
  411. in_proj_out_dim = mlp_hidden_dim + 3 * dim
  412. self.in_norm = norm_layer(dim, **dd)
  413. self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias, **dd)
  414. self.in_split = [mlp_hidden_dim] + [dim] * 3
  415. if qkv_bias:
  416. # mlp_bias is combined with qkv_bias in in_proj.bias
  417. self.register_parameter('mlp_bias', None)
  418. else:
  419. self.mlp_bias = nn.Parameter(torch.empty(mlp_hidden_dim, **dd))
  420. self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  421. self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  422. self.attn_drop = nn.Dropout(attn_drop)
  423. self.attn_drop_p = attn_drop
  424. # Differential attention specific
  425. self.sub_norm = RmsNorm(2 * self.head_dim, eps=1e-5, **dd)
  426. self.dual_lambda = dual_lambda
  427. if dual_lambda:
  428. self.lambda_a = nn.Parameter(torch.empty((), dtype=torch.float32, device=device))
  429. self.lambda_b = nn.Parameter(torch.empty((), dtype=torch.float32, device=device))
  430. self.lambda_q1 = self.lambda_k1 = self.lambda_q2 = self.lambda_k2 = None
  431. else:
  432. self.lambda_a = self.lambda_b = None
  433. self.lambda_q1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
  434. self.lambda_k1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
  435. self.lambda_q2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
  436. self.lambda_k2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
  437. self.mlp_drop = nn.Dropout(proj_drop)
  438. self.mlp_act = act_layer()
  439. # Fused output projection for both attention and MLP
  440. self.out_proj = nn.Linear(dim + mlp_hidden_dim, dim, bias=proj_bias, **dd)
  441. self.ls = LayerScale(dim, init_values=init_values, **dd) if init_values is not None else nn.Identity()
  442. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  443. self.lambda_init = 0.8
  444. self.set_lambda_init(depth)
  445. # TODO: skip init when on meta device when safe to do so
  446. self.reset_parameters()
  447. def set_lambda_init(self, depth: int):
  448. self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)
  449. def reset_parameters(self) -> None:
  450. """Initialize parameters and buffers."""
  451. if self.mlp_bias is not None:
  452. nn.init.zeros_(self.mlp_bias)
  453. if self.dual_lambda:
  454. nn.init.zeros_(self.lambda_a)
  455. nn.init.zeros_(self.lambda_b)
  456. else:
  457. nn.init.normal_(self.lambda_q1, mean=0, std=0.1)
  458. nn.init.normal_(self.lambda_k1, mean=0, std=0.1)
  459. nn.init.normal_(self.lambda_q2, mean=0, std=0.1)
  460. nn.init.normal_(self.lambda_k2, mean=0, std=0.1)
  461. def _compute_lambda(self) -> torch.Tensor:
  462. if self.lambda_a is not None:
  463. lambda_1 = torch.exp(self.lambda_a)
  464. lambda_2 = torch.exp(self.lambda_b)
  465. else:
  466. lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float())
  467. lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float())
  468. return lambda_1 - lambda_2 + self.lambda_init
  469. def forward(
  470. self,
  471. x: torch.Tensor,
  472. attn_mask: Optional[torch.Tensor] = None,
  473. is_causal: bool = False,
  474. ) -> torch.Tensor:
  475. B, N, C = x.shape
  476. # Combined MLP fc1 & qkv projections
  477. y = self.in_norm(x)
  478. y = self.in_proj(y)
  479. x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1)
  480. if self.mlp_bias is not None:
  481. x_mlp = x_mlp + self.mlp_bias
  482. # Reshape for differential attention (2x heads with half head_dim for q/k)
  483. q = q.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
  484. k = k.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
  485. v = v.reshape(B, N, self.num_heads, 2 * self.head_dim).transpose(1, 2)
  486. q, k = self.q_norm(q), self.k_norm(k)
  487. lambda_full = self._compute_lambda().type_as(q)
  488. if self.fused_attn:
  489. q = q.reshape(B, self.num_heads, 2, N, self.head_dim)
  490. k = k.reshape(B, self.num_heads, 2, N, self.head_dim)
  491. q1, q2 = q.unbind(2)
  492. k1, k2 = k.unbind(2)
  493. dropout_p = self.attn_drop_p if self.training else 0.0
  494. attn1 = F.scaled_dot_product_attention(q1, k1, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
  495. attn2 = F.scaled_dot_product_attention(q2, k2, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
  496. x_attn = attn1 - lambda_full * attn2
  497. else:
  498. q = q * self.scale
  499. attn = q @ k.transpose(-2, -1)
  500. attn_bias = resolve_self_attn_mask(N, attn, attn_mask, is_causal=is_causal)
  501. attn = maybe_add_mask(attn, attn_bias)
  502. attn = attn.softmax(dim=-1)
  503. attn = self.attn_drop(attn)
  504. attn = attn.view(B, self.num_heads, 2, N, N)
  505. attn = attn[:, :, 0] - lambda_full * attn[:, :, 1]
  506. x_attn = attn @ v
  507. x_attn = self.sub_norm(x_attn)
  508. x_attn = x_attn * (1 - self.lambda_init)
  509. x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
  510. # MLP activation & dropout
  511. x_mlp = self.mlp_act(x_mlp)
  512. x_mlp = self.mlp_drop(x_mlp)
  513. # Fused output projection
  514. y = self.out_proj(torch.cat((x_attn, x_mlp), dim=-1))
  515. # Add residual w/ drop path & layer scale applied
  516. x = x + self.drop_path(self.ls(y))
  517. return x
  518. class ParallelThingsBlock(nn.Module):
  519. """ Parallel ViT block (N parallel attention followed by N parallel MLP)
  520. Based on:
  521. `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
  522. """
  523. def __init__(
  524. self,
  525. dim: int,
  526. num_heads: int,
  527. num_parallel: int = 2,
  528. mlp_ratio: float = 4.,
  529. qkv_bias: bool = False,
  530. qk_norm: bool = False,
  531. scale_attn_norm: bool = False,
  532. scale_mlp_norm: bool = False,
  533. proj_bias: bool = True,
  534. init_values: Optional[float] = None,
  535. proj_drop: float = 0.,
  536. attn_drop: float = 0.,
  537. drop_path: float = 0.,
  538. act_layer: Type[nn.Module] = nn.GELU,
  539. norm_layer: Type[nn.Module] = LayerNorm,
  540. mlp_layer: Type[nn.Module] = Mlp,
  541. attn_layer: LayerType = Attention,
  542. depth: int = 0,
  543. device=None,
  544. dtype=None,
  545. ) -> None:
  546. dd = {'device': device, 'dtype': dtype}
  547. super().__init__()
  548. self.num_parallel = num_parallel
  549. self.attns = nn.ModuleList()
  550. self.ffns = nn.ModuleList()
  551. for _ in range(num_parallel):
  552. self.attns.append(nn.Sequential(OrderedDict([
  553. ('norm', norm_layer(dim, **dd)),
  554. ('attn', _create_attn(
  555. attn_layer,
  556. dim,
  557. num_heads=num_heads,
  558. qkv_bias=qkv_bias,
  559. qk_norm=qk_norm,
  560. scale_norm=scale_attn_norm,
  561. proj_bias=proj_bias,
  562. attn_drop=attn_drop,
  563. proj_drop=proj_drop,
  564. norm_layer=norm_layer,
  565. depth=depth,
  566. **dd,
  567. )),
  568. ('ls', LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()),
  569. ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
  570. ])))
  571. self.ffns.append(nn.Sequential(OrderedDict([
  572. ('norm', norm_layer(dim, **dd)),
  573. ('mlp', mlp_layer(
  574. dim,
  575. hidden_features=int(dim * mlp_ratio),
  576. act_layer=act_layer,
  577. norm_layer=norm_layer if scale_mlp_norm else None,
  578. bias=proj_bias,
  579. drop=proj_drop,
  580. **dd,
  581. )),
  582. ('ls', LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()),
  583. ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
  584. ])))
  585. def forward(
  586. self,
  587. x: torch.Tensor,
  588. attn_mask: Optional[torch.Tensor] = None,
  589. is_causal: bool = False,
  590. ) -> torch.Tensor:
  591. if attn_mask is not None or is_causal:
  592. attn_out = []
  593. for attn in self.attns:
  594. x_attn = attn.norm(x)
  595. x_attn = attn.attn(x_attn, attn_mask=attn_mask, is_causal=is_causal)
  596. x_attn = attn.ls(x_attn)
  597. x_attn = attn.drop_path(x_attn)
  598. attn_out.append(x_attn)
  599. x = x + torch.stack(attn_out).sum(dim=0)
  600. else:
  601. x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
  602. x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
  603. return x
  604. def global_pool_nlc(
  605. x: torch.Tensor,
  606. pool_type: str = 'token',
  607. num_prefix_tokens: int = 1,
  608. reduce_include_prefix: bool = False,
  609. ):
  610. if not pool_type:
  611. return x
  612. if pool_type == 'token':
  613. x = x[:, 0] # class token
  614. else:
  615. x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
  616. if pool_type == 'avg':
  617. x = x.mean(dim=1)
  618. elif pool_type == 'avgmax':
  619. x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
  620. elif pool_type == 'max':
  621. x = x.amax(dim=1)
  622. else:
  623. assert not pool_type, f'Unknown pool type {pool_type}'
  624. return x
  625. class VisionTransformer(nn.Module):
  626. """ Vision Transformer
  627. A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
  628. - https://arxiv.org/abs/2010.11929
  629. """
  630. dynamic_img_size: Final[bool]
  631. def __init__(
  632. self,
  633. img_size: Union[int, Tuple[int, int]] = 224,
  634. patch_size: Union[int, Tuple[int, int]] = 16,
  635. in_chans: int = 3,
  636. num_classes: int = 1000,
  637. global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map', 'prr'] = 'token',
  638. embed_dim: int = 768,
  639. depth: int = 12,
  640. num_heads: int = 12,
  641. mlp_ratio: float = 4.,
  642. qkv_bias: bool = True,
  643. qk_norm: bool = False,
  644. scale_attn_norm: bool = False,
  645. scale_mlp_norm: bool = False,
  646. proj_bias: bool = True,
  647. init_values: Optional[float] = None,
  648. class_token: bool = True,
  649. pos_embed: str = 'learn',
  650. no_embed_class: bool = False,
  651. reg_tokens: int = 0,
  652. pre_norm: bool = False,
  653. final_norm: bool = True,
  654. fc_norm: Optional[bool] = None,
  655. pool_include_prefix: bool = False,
  656. dynamic_img_size: bool = False,
  657. dynamic_img_pad: bool = False,
  658. drop_rate: float = 0.,
  659. pos_drop_rate: float = 0.,
  660. patch_drop_rate: float = 0.,
  661. proj_drop_rate: float = 0.,
  662. attn_drop_rate: float = 0.,
  663. drop_path_rate: float = 0.,
  664. weight_init: Literal['skip', 'reset', 'jax', 'jax_nlhb', 'moco', ''] = '',
  665. fix_init: bool = False,
  666. embed_layer: Callable = PatchEmbed,
  667. embed_norm_layer: Optional[LayerType] = None,
  668. norm_layer: Optional[LayerType] = None,
  669. act_layer: Optional[LayerType] = None,
  670. block_fn: Type[nn.Module] = Block,
  671. mlp_layer: Type[nn.Module] = Mlp,
  672. attn_layer: LayerType = Attention,
  673. device=None,
  674. dtype=None,
  675. ) -> None:
  676. """
  677. Args:
  678. img_size: Input image size.
  679. patch_size: Patch size.
  680. in_chans: Number of image input channels.
  681. num_classes: Number of classes for classification head.
  682. global_pool: Type of global pooling for final sequence (default: 'token').
  683. embed_dim: Transformer embedding dimension.
  684. depth: Depth of transformer.
  685. num_heads: Number of attention heads.
  686. mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  687. qkv_bias: Enable bias for qkv projections if True.
  688. init_values: Layer-scale init values (layer-scale enabled if not None).
  689. class_token: Use class token.
  690. no_embed_class: Don't include position embeddings for class (or reg) tokens.
  691. reg_tokens: Number of register tokens.
  692. pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
  693. final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
  694. fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
  695. drop_rate: Head dropout rate.
  696. pos_drop_rate: Position embedding dropout rate.
  697. attn_drop_rate: Attention dropout rate.
  698. drop_path_rate: Stochastic depth rate.
  699. weight_init: Weight initialization scheme.
  700. fix_init: Apply weight initialization fix (scaling w/ layer index).
  701. embed_layer: Patch embedding layer.
  702. embed_norm_layer: Normalization layer to use / override in patch embed module.
  703. norm_layer: Normalization layer.
  704. act_layer: MLP activation layer.
  705. block_fn: Transformer block layer.
  706. """
  707. super().__init__()
  708. dd = {'device': device, 'dtype': dtype}
  709. assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map', 'prr')
  710. assert class_token or global_pool != 'token'
  711. assert pos_embed in ('', 'none', 'learn')
  712. use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
  713. norm_layer = get_norm_layer(norm_layer) or LayerNorm
  714. embed_norm_layer = get_norm_layer(embed_norm_layer)
  715. act_layer = get_act_layer(act_layer) or nn.GELU
  716. self.num_classes = num_classes
  717. self.in_chans = in_chans
  718. self.global_pool = global_pool
  719. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
  720. self.num_prefix_tokens = 1 if class_token else 0
  721. self.num_prefix_tokens += reg_tokens
  722. self.num_reg_tokens = reg_tokens
  723. self.has_class_token = class_token
  724. self.no_embed_class = no_embed_class
  725. self.pool_include_prefix = pool_include_prefix
  726. self.dynamic_img_size = dynamic_img_size
  727. self.grad_checkpointing = False
  728. embed_args = {}
  729. if dynamic_img_size:
  730. # flatten deferred until after pos embed
  731. embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
  732. if embed_norm_layer is not None:
  733. embed_args['norm_layer'] = embed_norm_layer
  734. self.patch_embed = embed_layer(
  735. img_size=img_size,
  736. patch_size=patch_size,
  737. in_chans=in_chans,
  738. embed_dim=embed_dim,
  739. bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
  740. dynamic_img_pad=dynamic_img_pad,
  741. **embed_args,
  742. **dd,
  743. )
  744. num_patches = self.patch_embed.num_patches
  745. reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
  746. self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, **dd)) if class_token else None
  747. self.reg_token = nn.Parameter(torch.empty(1, reg_tokens, embed_dim, **dd)) if reg_tokens else None
  748. embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
  749. if not pos_embed or pos_embed == 'none':
  750. self.pos_embed = None
  751. else:
  752. self.pos_embed = nn.Parameter(torch.empty(1, embed_len, embed_dim, **dd))
  753. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  754. if patch_drop_rate > 0:
  755. self.patch_drop = PatchDropout(
  756. patch_drop_rate,
  757. num_prefix_tokens=self.num_prefix_tokens,
  758. )
  759. else:
  760. self.patch_drop = nn.Identity()
  761. self.norm_pre = norm_layer(embed_dim, **dd) if pre_norm else nn.Identity()
  762. dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
  763. self.blocks = nn.Sequential(*[
  764. block_fn(
  765. dim=embed_dim,
  766. num_heads=num_heads,
  767. mlp_ratio=mlp_ratio,
  768. qkv_bias=qkv_bias,
  769. qk_norm=qk_norm,
  770. scale_attn_norm=scale_attn_norm,
  771. scale_mlp_norm=scale_mlp_norm,
  772. proj_bias=proj_bias,
  773. init_values=init_values,
  774. proj_drop=proj_drop_rate,
  775. attn_drop=attn_drop_rate,
  776. drop_path=dpr[i],
  777. norm_layer=norm_layer,
  778. act_layer=act_layer,
  779. mlp_layer=mlp_layer,
  780. attn_layer=attn_layer,
  781. depth=i,
  782. **dd,
  783. )
  784. for i in range(depth)])
  785. self.feature_info = [
  786. dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)]
  787. self.norm = norm_layer(embed_dim, **dd) if final_norm and not use_fc_norm else nn.Identity()
  788. # Classifier Head
  789. if global_pool == 'map':
  790. self.attn_pool = AttentionPoolLatent(
  791. self.embed_dim,
  792. num_heads=num_heads,
  793. mlp_ratio=mlp_ratio,
  794. norm_layer=norm_layer,
  795. act_layer=act_layer,
  796. **dd,
  797. )
  798. elif global_pool == 'prr':
  799. self.attn_pool = AttentionPoolPrr(
  800. self.embed_dim,
  801. num_heads=num_heads,
  802. pool_type='token' if class_token else 'avg',
  803. norm_layer=norm_layer,
  804. **dd,
  805. )
  806. self.pool_include_prefix = True
  807. else:
  808. self.attn_pool = None
  809. self.fc_norm = norm_layer(embed_dim, **dd) if final_norm and use_fc_norm else nn.Identity()
  810. self.head_drop = nn.Dropout(drop_rate)
  811. self.head = nn.Linear(self.embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  812. self.weight_init_mode = 'reset' if weight_init == 'skip' else weight_init
  813. self.fix_init = fix_init
  814. # TODO: skip init when on meta device when safe to do so
  815. if weight_init != 'skip':
  816. self.init_weights(needs_reset=False)
  817. def fix_init_weight(self) -> None:
  818. """Apply weight initialization fix (scaling w/ layer index)."""
  819. with torch.no_grad():
  820. for layer_id, layer in enumerate(self.blocks):
  821. scale = math.sqrt(2.0 * (layer_id + 1))
  822. layer.attn.proj.weight.div_(scale)
  823. layer.mlp.fc2.weight.div_(scale)
  824. def init_weights(self, mode: str = '', needs_reset: bool = True) -> None:
  825. """Initialize model weights.
  826. Args:
  827. mode: Weight initialization mode ('jax', 'jax_nlhb', 'moco', or '').
  828. needs_reset: If True, call reset_parameters() on modules that have it.
  829. Set to False when modules have already self-initialized in __init__.
  830. """
  831. mode = mode or self.weight_init_mode
  832. assert mode in ('jax', 'jax_nlhb', 'moco', 'reset', '')
  833. head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
  834. if self.pos_embed is not None:
  835. trunc_normal_(self.pos_embed, std=.02)
  836. if self.cls_token is not None:
  837. nn.init.normal_(self.cls_token, std=1e-6)
  838. if self.reg_token is not None:
  839. nn.init.normal_(self.reg_token, std=1e-6)
  840. named_apply(get_init_weights_vit(mode, head_bias, needs_reset=needs_reset), self)
  841. if self.fix_init:
  842. self.fix_init_weight()
  843. def _init_weights(self, m: nn.Module) -> None:
  844. """Initialize weights for a single module (compatibility method)."""
  845. # this fn left here for compat with downstream users
  846. init_weights_vit_timm(m)
  847. @torch.jit.ignore()
  848. def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None:
  849. """Load pretrained weights.
  850. Args:
  851. checkpoint_path: Path to checkpoint.
  852. prefix: Prefix for state dict keys.
  853. """
  854. _load_weights(self, checkpoint_path, prefix)
  855. @torch.jit.ignore
  856. def no_weight_decay(self) -> Set[str]:
  857. """Set of parameters that should not use weight decay."""
  858. return {'pos_embed', 'cls_token', 'dist_token'}
  859. @torch.jit.ignore
  860. def group_matcher(self, coarse: bool = False) -> Dict[str, Union[str, List]]:
  861. """Create regex patterns for parameter grouping.
  862. Args:
  863. coarse: Use coarse grouping.
  864. Returns:
  865. Dictionary mapping group names to regex patterns.
  866. """
  867. return dict(
  868. stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
  869. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
  870. )
  871. @torch.jit.ignore
  872. def set_grad_checkpointing(self, enable: bool = True) -> None:
  873. """Enable or disable gradient checkpointing.
  874. Args:
  875. enable: Whether to enable gradient checkpointing.
  876. """
  877. self.grad_checkpointing = enable
  878. if hasattr(self.patch_embed, 'set_grad_checkpointing'):
  879. self.patch_embed.set_grad_checkpointing(enable)
  880. @torch.jit.ignore
  881. def get_classifier(self) -> nn.Module:
  882. """Get the classifier head."""
  883. return self.head
  884. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  885. """Reset the classifier head.
  886. Args:
  887. num_classes: Number of classes for new classifier.
  888. global_pool: Global pooling type.
  889. """
  890. self.num_classes = num_classes
  891. if global_pool is not None:
  892. assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map', 'prr')
  893. if global_pool in ('map', 'prr') and self.attn_pool is None:
  894. assert False, "Cannot currently add attention pooling in reset_classifier()."
  895. elif global_pool not in ('map', 'prr') and self.attn_pool is not None:
  896. self.attn_pool = None # remove attention pooling
  897. elif global_pool in ('map', 'prr') and self.global_pool != global_pool:
  898. assert False, "Cannot currently change attention pooling type in reset_classifier()."
  899. self.global_pool = global_pool
  900. self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  901. def set_input_size(
  902. self,
  903. img_size: Optional[Tuple[int, int]] = None,
  904. patch_size: Optional[Tuple[int, int]] = None,
  905. ) -> None:
  906. """Update the input image resolution and patch size.
  907. Args:
  908. img_size: New input resolution, if None current resolution is used.
  909. patch_size: New patch size, if None existing patch size is used.
  910. """
  911. prev_grid_size = self.patch_embed.grid_size
  912. self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
  913. if self.pos_embed is not None:
  914. num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
  915. num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens
  916. if num_new_tokens != self.pos_embed.shape[1]:
  917. self.pos_embed = nn.Parameter(resample_abs_pos_embed(
  918. self.pos_embed,
  919. new_size=self.patch_embed.grid_size,
  920. old_size=prev_grid_size,
  921. num_prefix_tokens=num_prefix_tokens,
  922. verbose=True,
  923. ))
  924. def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
  925. """Apply positional embedding to input."""
  926. to_cat = []
  927. if self.cls_token is not None:
  928. to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
  929. if self.reg_token is not None:
  930. to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
  931. if self.pos_embed is None:
  932. return torch.cat(to_cat + [x.view(x.shape[0], -1, x.shape[-1])], dim=1)
  933. if self.dynamic_img_size:
  934. B, H, W, C = x.shape
  935. prev_grid_size = self.patch_embed.grid_size
  936. pos_embed = resample_abs_pos_embed(
  937. self.pos_embed,
  938. new_size=(H, W),
  939. old_size=prev_grid_size,
  940. num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
  941. )
  942. x = x.view(B, -1, C)
  943. else:
  944. pos_embed = self.pos_embed
  945. if self.no_embed_class:
  946. # deit-3, updated JAX (big vision)
  947. # position embedding does not overlap with class token, add then concat
  948. x = x + pos_embed
  949. if to_cat:
  950. x = torch.cat(to_cat + [x], dim=1)
  951. else:
  952. # original timm, JAX, and deit vit impl
  953. # pos_embed has entry for class token, concat then add
  954. if to_cat:
  955. x = torch.cat(to_cat + [x], dim=1)
  956. x = x + pos_embed
  957. return self.pos_drop(x)
  958. def forward_intermediates(
  959. self,
  960. x: torch.Tensor,
  961. indices: Optional[Union[int, List[int]]] = None,
  962. return_prefix_tokens: bool = False,
  963. norm: bool = False,
  964. stop_early: bool = False,
  965. output_fmt: str = 'NCHW',
  966. intermediates_only: bool = False,
  967. output_dict: bool = False,
  968. attn_mask: Optional[torch.Tensor] = None,
  969. is_causal: bool = False,
  970. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]:
  971. """ Forward features that returns intermediates.
  972. Args:
  973. x: Input image tensor
  974. indices: Take last n blocks if int, all if None, select matching indices if sequence
  975. return_prefix_tokens: Return both prefix and spatial intermediate tokens
  976. norm: Apply norm layer to all intermediates
  977. stop_early: Stop iterating over blocks when last desired intermediate hit
  978. output_fmt: Shape of intermediate feature outputs
  979. intermediates_only: Only return intermediate features
  980. output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys
  981. attn_mask: Optional attention mask for masked attention (e.g., for NaFlex)
  982. is_causal: If True, use causal (autoregressive) masking in attention
  983. Returns:
  984. A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing
  985. 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix')
  986. """
  987. assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
  988. reshape = output_fmt == 'NCHW'
  989. intermediates = []
  990. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  991. # forward pass
  992. B, _, height, width = x.shape
  993. x = self.patch_embed(x)
  994. x = self._pos_embed(x)
  995. x = self.patch_drop(x)
  996. x = self.norm_pre(x)
  997. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  998. blocks = self.blocks
  999. else:
  1000. blocks = self.blocks[:max_index + 1]
  1001. for i, blk in enumerate(blocks):
  1002. if attn_mask is not None or is_causal:
  1003. x = blk(x, attn_mask=attn_mask, is_causal=is_causal)
  1004. elif self.grad_checkpointing and not torch.jit.is_scripting():
  1005. x = checkpoint(blk, x)
  1006. else:
  1007. x = blk(x)
  1008. if i in take_indices:
  1009. # normalize intermediates with final norm layer if enabled
  1010. intermediates.append(self.norm(x) if norm else x)
  1011. # process intermediates
  1012. if self.num_prefix_tokens:
  1013. # split prefix (e.g. class, distill) and spatial feature tokens
  1014. prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
  1015. intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
  1016. else:
  1017. prefix_tokens = None
  1018. if reshape:
  1019. # reshape to BCHW output format
  1020. H, W = self.patch_embed.dynamic_feat_size((height, width))
  1021. intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
  1022. # For dictionary output, handle prefix tokens separately
  1023. if output_dict:
  1024. result_dict = {}
  1025. # Intermediates are always included
  1026. result_dict['image_intermediates'] = intermediates
  1027. if prefix_tokens is not None and return_prefix_tokens:
  1028. result_dict['image_intermediates_prefix'] = prefix_tokens
  1029. # Only include features if not intermediates_only
  1030. if not intermediates_only:
  1031. x_final = self.norm(x)
  1032. result_dict['image_features'] = x_final
  1033. return result_dict
  1034. # For non-dictionary output, maintain the original behavior
  1035. if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
  1036. # return_prefix not support in torchscript due to poor type handling
  1037. intermediates = list(zip(intermediates, prefix_tokens))
  1038. if intermediates_only:
  1039. return intermediates
  1040. x = self.norm(x)
  1041. return x, intermediates
  1042. def prune_intermediate_layers(
  1043. self,
  1044. indices: Union[int, List[int]] = 1,
  1045. prune_norm: bool = False,
  1046. prune_head: bool = True,
  1047. ) -> List[int]:
  1048. """Prune layers not required for specified intermediates.
  1049. Args:
  1050. indices: Indices of intermediate layers to keep.
  1051. prune_norm: Whether to prune normalization layer.
  1052. prune_head: Whether to prune the classifier head.
  1053. Returns:
  1054. List of indices that were kept.
  1055. """
  1056. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  1057. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  1058. if prune_norm:
  1059. self.norm = nn.Identity()
  1060. if prune_head:
  1061. self.fc_norm = nn.Identity()
  1062. self.reset_classifier(0, '')
  1063. return take_indices
  1064. def get_intermediate_layers(
  1065. self,
  1066. x: torch.Tensor,
  1067. n: Union[int, List[int], Tuple[int]] = 1,
  1068. reshape: bool = False,
  1069. return_prefix_tokens: bool = False,
  1070. norm: bool = False,
  1071. attn_mask: Optional[torch.Tensor] = None,
  1072. ) -> List[torch.Tensor]:
  1073. """Get intermediate layer outputs (DINO interface compatibility).
  1074. NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
  1075. Args:
  1076. x: Input tensor.
  1077. n: Number or indices of layers.
  1078. reshape: Reshape to NCHW format.
  1079. return_prefix_tokens: Return prefix tokens.
  1080. norm: Apply normalization.
  1081. Returns:
  1082. List of intermediate features.
  1083. """
  1084. return self.forward_intermediates(
  1085. x, n,
  1086. return_prefix_tokens=return_prefix_tokens,
  1087. norm=norm,
  1088. output_fmt='NCHW' if reshape else 'NLC',
  1089. intermediates_only=True,
  1090. attn_mask=attn_mask,
  1091. )
  1092. def forward_features(
  1093. self,
  1094. x: torch.Tensor,
  1095. attn_mask: Optional[torch.Tensor] = None,
  1096. is_causal: bool = False,
  1097. ) -> torch.Tensor:
  1098. """Forward pass through feature layers (embeddings, transformer blocks, post-transformer norm)."""
  1099. x = self.patch_embed(x)
  1100. x = self._pos_embed(x)
  1101. x = self.patch_drop(x)
  1102. x = self.norm_pre(x)
  1103. if attn_mask is not None or is_causal:
  1104. # If mask/causal provided, we need to apply blocks one by one
  1105. for blk in self.blocks:
  1106. x = blk(x, attn_mask=attn_mask, is_causal=is_causal)
  1107. elif self.grad_checkpointing and not torch.jit.is_scripting():
  1108. x = checkpoint_seq(self.blocks, x)
  1109. else:
  1110. x = self.blocks(x)
  1111. x = self.norm(x)
  1112. return x
  1113. def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
  1114. """Apply pooling to feature tokens.
  1115. Args:
  1116. x: Feature tensor.
  1117. pool_type: Pooling type override.
  1118. Returns:
  1119. Pooled features.
  1120. """
  1121. if self.attn_pool is not None:
  1122. if not self.pool_include_prefix:
  1123. x = x[:, self.num_prefix_tokens:]
  1124. x = self.attn_pool(x)
  1125. return x
  1126. pool_type = self.global_pool if pool_type is None else pool_type
  1127. x = global_pool_nlc(
  1128. x,
  1129. pool_type=pool_type,
  1130. num_prefix_tokens=self.num_prefix_tokens,
  1131. reduce_include_prefix=self.pool_include_prefix,
  1132. )
  1133. return x
  1134. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  1135. """Forward pass through classifier head.
  1136. Args:
  1137. x: Feature tensor.
  1138. pre_logits: Return features before final classifier.
  1139. Returns:
  1140. Output tensor.
  1141. """
  1142. x = self.pool(x)
  1143. x = self.fc_norm(x)
  1144. x = self.head_drop(x)
  1145. return x if pre_logits else self.head(x)
  1146. def forward(
  1147. self,
  1148. x: torch.Tensor,
  1149. attn_mask: Optional[torch.Tensor] = None,
  1150. is_causal: bool = False,
  1151. ) -> torch.Tensor:
  1152. x = self.forward_features(x, attn_mask=attn_mask, is_causal=is_causal)
  1153. x = self.forward_head(x)
  1154. return x
  1155. def init_weights_vit_timm(module: nn.Module, name: str = '', needs_reset: bool = True) -> None:
  1156. """ViT weight initialization, original timm impl (for reproducibility).
  1157. Args:
  1158. module: Module to initialize.
  1159. name: Module name for context.
  1160. needs_reset: If True, call reset_parameters() on modules that have it.
  1161. """
  1162. if isinstance(module, nn.Linear):
  1163. trunc_normal_(module.weight, std=.02)
  1164. if module.bias is not None:
  1165. nn.init.zeros_(module.bias)
  1166. elif hasattr(module, 'init_weights'):
  1167. module.init_weights()
  1168. elif needs_reset and hasattr(module, 'reset_parameters'):
  1169. module.reset_parameters()
  1170. def init_weights_vit_jax(
  1171. module: nn.Module,
  1172. name: str = '',
  1173. head_bias: float = 0.0,
  1174. needs_reset: bool = True,
  1175. ) -> None:
  1176. """ViT weight initialization, matching JAX (Flax) impl.
  1177. Args:
  1178. module: Module to initialize.
  1179. name: Module name for context.
  1180. head_bias: Bias value for head layer.
  1181. needs_reset: If True, call reset_parameters() on modules that have it.
  1182. """
  1183. if isinstance(module, nn.Linear):
  1184. if name.startswith('head'):
  1185. nn.init.zeros_(module.weight)
  1186. nn.init.constant_(module.bias, head_bias)
  1187. else:
  1188. nn.init.xavier_uniform_(module.weight)
  1189. if module.bias is not None:
  1190. nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
  1191. elif isinstance(module, nn.Conv2d):
  1192. lecun_normal_(module.weight)
  1193. if module.bias is not None:
  1194. nn.init.zeros_(module.bias)
  1195. elif hasattr(module, 'init_weights'):
  1196. module.init_weights()
  1197. elif needs_reset and hasattr(module, 'reset_parameters'):
  1198. module.reset_parameters()
  1199. def init_weights_vit_moco(module: nn.Module, name: str = '', needs_reset: bool = True) -> None:
  1200. """ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed.
  1201. Args:
  1202. module: Module to initialize.
  1203. name: Module name for context.
  1204. needs_reset: If True, call reset_parameters() on modules that have it.
  1205. """
  1206. if isinstance(module, nn.Linear):
  1207. if 'qkv' in name:
  1208. # treat the weights of Q, K, V separately
  1209. val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
  1210. nn.init.uniform_(module.weight, -val, val)
  1211. else:
  1212. nn.init.xavier_uniform_(module.weight)
  1213. if module.bias is not None:
  1214. nn.init.zeros_(module.bias)
  1215. elif hasattr(module, 'init_weights'):
  1216. module.init_weights()
  1217. elif needs_reset and hasattr(module, 'reset_parameters'):
  1218. module.reset_parameters()
  1219. def init_weights_reset_parameters(module: nn.Module, name: str = '', needs_reset: bool = True) -> None:
  1220. if needs_reset and hasattr(module, 'reset_parameters'):
  1221. module.reset_parameters()
  1222. def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0, needs_reset: bool = True) -> Callable:
  1223. if mode.startswith('jax'):
  1224. return partial(init_weights_vit_jax, head_bias=head_bias, needs_reset=needs_reset)
  1225. elif mode.startswith('moco'):
  1226. return partial(init_weights_vit_moco, needs_reset=needs_reset)
  1227. elif mode == 'reset':
  1228. # 'reset' means only call reset_parameters() on modules
  1229. return partial(init_weights_reset_parameters, needs_reset=needs_reset)
  1230. else:
  1231. # timm init is default
  1232. return partial(init_weights_vit_timm, needs_reset=needs_reset)
  1233. def resize_pos_embed(
  1234. posemb: torch.Tensor,
  1235. posemb_new: torch.Tensor,
  1236. num_prefix_tokens: int = 1,
  1237. gs_new: Tuple[int, int] = (),
  1238. interpolation: str = 'bicubic',
  1239. antialias: bool = False,
  1240. ) -> torch.Tensor:
  1241. """ Rescale the grid of position embeddings when loading from state_dict.
  1242. *DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed
  1243. """
  1244. ntok_new = posemb_new.shape[1] - num_prefix_tokens
  1245. ntok_old = posemb.shape[1] - num_prefix_tokens
  1246. gs_old = [int(math.sqrt(ntok_old))] * 2
  1247. if not len(gs_new): # backwards compatibility
  1248. gs_new = [int(math.sqrt(ntok_new))] * 2
  1249. return resample_abs_pos_embed(
  1250. posemb, gs_new, gs_old,
  1251. num_prefix_tokens=num_prefix_tokens,
  1252. interpolation=interpolation,
  1253. antialias=antialias,
  1254. verbose=True,
  1255. )
  1256. @torch.no_grad()
  1257. def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = '', load_bfloat16: bool = False) -> None:
  1258. """ Load weights from .npz checkpoints for official Google Brain Flax implementation
  1259. """
  1260. import numpy as np
  1261. if load_bfloat16:
  1262. import jax.numpy as jnp
  1263. import ml_dtypes
  1264. def _n2p(_w, t=True, idx=None):
  1265. if idx is not None:
  1266. _w = _w[idx]
  1267. if load_bfloat16:
  1268. _w = _w.view(ml_dtypes.bfloat16).astype(jnp.float32)
  1269. _w = np.array(_w)
  1270. if _w.ndim == 4 and _w.shape[0] == _w.shape[1] == _w.shape[2] == 1:
  1271. _w = _w.flatten()
  1272. if t:
  1273. if _w.ndim == 4:
  1274. _w = _w.transpose([3, 2, 0, 1])
  1275. elif _w.ndim == 3:
  1276. _w = _w.transpose([2, 0, 1])
  1277. elif _w.ndim == 2:
  1278. _w = _w.transpose([1, 0])
  1279. _w = torch.from_numpy(_w)
  1280. return _w
  1281. if load_bfloat16:
  1282. w = jnp.load(checkpoint_path)
  1283. else:
  1284. w = np.load(checkpoint_path)
  1285. interpolation = 'bilinear'
  1286. antialias = False
  1287. big_vision = False
  1288. if not prefix:
  1289. if 'opt/target/embedding/kernel' in w:
  1290. prefix = 'opt/target/'
  1291. elif 'params/embedding/kernel' in w:
  1292. prefix = 'params/'
  1293. big_vision = True
  1294. elif 'params/img/embedding/kernel' in w:
  1295. prefix = 'params/img/'
  1296. big_vision = True
  1297. if hasattr(model.patch_embed, 'backbone'):
  1298. # hybrid
  1299. backbone = model.patch_embed.backbone
  1300. stem_only = not hasattr(backbone, 'stem')
  1301. stem = backbone if stem_only else backbone.stem
  1302. stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
  1303. stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
  1304. stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
  1305. if not stem_only:
  1306. for i, stage in enumerate(backbone.stages):
  1307. for j, block in enumerate(stage.blocks):
  1308. bp = f'{prefix}block{i + 1}/unit{j + 1}/'
  1309. for r in range(3):
  1310. getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
  1311. getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
  1312. getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
  1313. if block.downsample is not None:
  1314. block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
  1315. block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
  1316. block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
  1317. embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
  1318. else:
  1319. embed_conv_w = adapt_input_conv(
  1320. model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
  1321. if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]:
  1322. embed_conv_w = resample_patch_embed(
  1323. embed_conv_w,
  1324. model.patch_embed.proj.weight.shape[-2:],
  1325. interpolation=interpolation,
  1326. antialias=antialias,
  1327. verbose=True,
  1328. )
  1329. model.patch_embed.proj.weight.copy_(embed_conv_w)
  1330. model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
  1331. if model.cls_token is not None:
  1332. model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
  1333. if big_vision:
  1334. pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
  1335. else:
  1336. pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
  1337. if pos_embed_w.shape != model.pos_embed.shape:
  1338. num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
  1339. pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
  1340. pos_embed_w,
  1341. new_size=model.patch_embed.grid_size,
  1342. num_prefix_tokens=num_prefix_tokens,
  1343. interpolation=interpolation,
  1344. antialias=antialias,
  1345. verbose=True,
  1346. )
  1347. model.pos_embed.copy_(pos_embed_w)
  1348. model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
  1349. model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
  1350. if (isinstance(model.head, nn.Linear) and
  1351. f'{prefix}head/bias' in w and
  1352. model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]):
  1353. model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
  1354. model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
  1355. # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
  1356. # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
  1357. # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
  1358. # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
  1359. if isinstance(model.attn_pool, AttentionPoolLatent):
  1360. block_prefix = f'{prefix}MAPHead_0/'
  1361. mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
  1362. model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
  1363. model.attn_pool.kv.weight.copy_(torch.cat([
  1364. _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
  1365. model.attn_pool.kv.bias.copy_(torch.cat([
  1366. _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
  1367. model.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
  1368. model.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
  1369. model.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
  1370. model.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
  1371. model.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
  1372. model.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
  1373. for r in range(2):
  1374. getattr(model.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
  1375. getattr(model.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
  1376. mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
  1377. for i, block in enumerate(model.blocks.children()):
  1378. if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
  1379. block_prefix = f'{prefix}Transformer/encoderblock/'
  1380. idx = i
  1381. else:
  1382. block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
  1383. idx = None
  1384. mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
  1385. block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
  1386. block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
  1387. block.attn.qkv.weight.copy_(torch.cat([
  1388. _n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
  1389. block.attn.qkv.bias.copy_(torch.cat([
  1390. _n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
  1391. block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
  1392. block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
  1393. block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
  1394. block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
  1395. for r in range(2):
  1396. getattr(block.mlp, f'fc{r + 1}').weight.copy_(
  1397. _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
  1398. getattr(block.mlp, f'fc{r + 1}').bias.copy_(
  1399. _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
  1400. def _convert_openai_clip(
  1401. state_dict: Dict[str, torch.Tensor],
  1402. model: VisionTransformer,
  1403. prefix: str = 'visual.',
  1404. ) -> Dict[str, torch.Tensor]:
  1405. out_dict = {}
  1406. swaps = [
  1407. ('conv1', 'patch_embed.proj'),
  1408. ('positional_embedding', 'pos_embed'),
  1409. ('transformer.resblocks.', 'blocks.'),
  1410. ('ln_pre', 'norm_pre'),
  1411. ('ln_post', 'norm'),
  1412. ('ln_', 'norm'),
  1413. ('in_proj_', 'qkv.'),
  1414. ('out_proj', 'proj'),
  1415. ('mlp.c_fc', 'mlp.fc1'),
  1416. ('mlp.c_proj', 'mlp.fc2'),
  1417. ]
  1418. for k, v in state_dict.items():
  1419. if not k.startswith(prefix):
  1420. continue
  1421. k = k.replace(prefix, '')
  1422. for sp in swaps:
  1423. k = k.replace(sp[0], sp[1])
  1424. if k == 'proj':
  1425. k = 'head.weight'
  1426. v = v.transpose(0, 1)
  1427. out_dict['head.bias'] = torch.zeros(v.shape[0])
  1428. elif k == 'class_embedding':
  1429. k = 'cls_token'
  1430. v = v.unsqueeze(0).unsqueeze(1)
  1431. elif k == 'pos_embed':
  1432. v = v.unsqueeze(0)
  1433. out_dict[k] = v
  1434. return out_dict
  1435. def _convert_dinov2(
  1436. state_dict: Dict[str, torch.Tensor],
  1437. model: VisionTransformer,
  1438. ) -> Dict[str, torch.Tensor]:
  1439. import re
  1440. out_dict = {}
  1441. state_dict.pop("mask_token", None)
  1442. if 'register_tokens' in state_dict:
  1443. # convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed)
  1444. out_dict['reg_token'] = state_dict.pop('register_tokens')
  1445. out_dict['cls_token'] = state_dict.pop('cls_token') + state_dict['pos_embed'][:, 0]
  1446. out_dict['pos_embed'] = state_dict.pop('pos_embed')[:, 1:]
  1447. for k, v in state_dict.items():
  1448. if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
  1449. out_dict[k.replace("w12", "fc1")] = v
  1450. continue
  1451. elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
  1452. out_dict[k.replace("w3", "fc2")] = v
  1453. continue
  1454. out_dict[k] = v
  1455. return out_dict
  1456. def _convert_aimv2(
  1457. state_dict: Dict[str, torch.Tensor],
  1458. model: VisionTransformer,
  1459. ) -> Dict[str, torch.Tensor]:
  1460. out_dict = {}
  1461. for k, v in state_dict.items():
  1462. k = k.replace('norm_1', 'norm1')
  1463. k = k.replace('norm_2', 'norm2')
  1464. k = k.replace('preprocessor.patchifier.', 'patch_embed.')
  1465. k = k.replace('preprocessor.pos_embed', 'pos_embed')
  1466. k = k.replace('trunk.', '')
  1467. k = k.replace('post_trunk_norm.', 'norm.')
  1468. k = k.replace('mlp.fc1', 'mlp.fc1_g')
  1469. k = k.replace('mlp.fc3', 'mlp.fc1_x')
  1470. out_dict[k] = v
  1471. return out_dict
  1472. def _convert_beit3(state_dict: dict, model):
  1473. """
  1474. Turn a BEiT-3 checkpoint into a standard VisionTransformer state-dict.
  1475. """
  1476. import re
  1477. state_dict = state_dict.get("model", state_dict) # unwrap if needed
  1478. # Prune unused
  1479. for k in ("beit3.text_embed.weight", "beit3.vision_embed.mask_token"):
  1480. state_dict.pop(k, None)
  1481. # Key renaming rules
  1482. rules = [
  1483. (r"beit3\.", ""),
  1484. (r"vision_embed\.cls_token", "cls_token"),
  1485. (r"vision_embed\.", "patch_embed."),
  1486. (r"embed_positions\.", "pos_embed."),
  1487. (r"encoder\.", ""),
  1488. (r"layers\.", "blocks."),
  1489. (r"ffn_layernorm\.", "norm."), (r"ffn\.", "mlp."),
  1490. (r"self_attn_layer_norm\.", "norm1."), (r"self_attn\.", "attn."),
  1491. (r"final_layer_norm\.", "norm2."),
  1492. (r"inner_attn_ln", "norm"),
  1493. (r"out_proj", "proj"),
  1494. (r"\.A\.", "."),
  1495. ]
  1496. # First pass, rename keys
  1497. tmp = {}
  1498. for k, v in state_dict.items():
  1499. if ".B." in k:
  1500. continue # use branch-A only
  1501. for old, new in rules:
  1502. k = re.sub(old, new, k)
  1503. if k == "pos_embed.weight":
  1504. # strip first two positions, [1, N+1, D]
  1505. tmp["pos_embed"] = v[2:].unsqueeze(0)
  1506. else:
  1507. tmp[k] = v
  1508. # Second pass, fuse q, k, v
  1509. out, buf = {}, {}
  1510. pat = re.compile(r"blocks\.(\d+)\.attn\.(q|k|v)_proj\.(weight|bias)$")
  1511. for k, v in tmp.items():
  1512. m = pat.fullmatch(k)
  1513. if not m: # anything not q/k/v -> copy through
  1514. out[k] = v
  1515. continue
  1516. blk, which, kind = m.groups() # block idx, 'q'/'k'/'v', 'weight'/'bias'
  1517. stash = buf.setdefault((blk, kind), {}) # Gather by block & param type
  1518. stash[which] = v
  1519. if len(stash) == 3: # Have q, k, v -> concatenate
  1520. out[f"blocks.{blk}.attn.qkv.{kind}"] = torch.cat(
  1521. [stash['q'], stash['k'], stash['v']], dim=0
  1522. )
  1523. return out
  1524. def checkpoint_filter_fn(
  1525. state_dict: Dict[str, torch.Tensor],
  1526. model: VisionTransformer,
  1527. adapt_layer_scale: bool = False,
  1528. interpolation: str = 'bicubic',
  1529. antialias: bool = True,
  1530. ) -> Dict[str, torch.Tensor]:
  1531. """ convert patch embedding weight from manual patchify + linear proj to conv"""
  1532. import re
  1533. out_dict = {}
  1534. state_dict = state_dict.get('model', state_dict)
  1535. state_dict = state_dict.get('state_dict', state_dict)
  1536. prefix = ''
  1537. if 'visual.class_embedding' in state_dict:
  1538. state_dict = _convert_openai_clip(state_dict, model)
  1539. elif 'module.visual.class_embedding' in state_dict:
  1540. state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.')
  1541. elif "mask_token" in state_dict:
  1542. state_dict = _convert_dinov2(state_dict, model)
  1543. elif any('beit3.' in k for k in state_dict.keys()):
  1544. # BEiT3 model - multimodal checkpoint with beit3.* prefix
  1545. state_dict = _convert_beit3(state_dict, model)
  1546. elif "encoder" in state_dict:
  1547. # IJEPA, vit in an 'encoder' submodule
  1548. state_dict = state_dict['encoder']
  1549. prefix = 'module.'
  1550. elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict:
  1551. # OpenCLIP model with timm vision encoder
  1552. prefix = 'visual.trunk.'
  1553. if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear):
  1554. # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
  1555. out_dict['head.weight'] = state_dict['visual.head.proj.weight']
  1556. out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
  1557. elif 'module.visual.trunk.pos_embed' in state_dict:
  1558. prefix = 'module.visual.trunk.'
  1559. elif 'preprocessor.patchifier.proj.weight' in state_dict:
  1560. state_dict = _convert_aimv2(state_dict, model)
  1561. if prefix:
  1562. # filter on & remove prefix string from keys
  1563. state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
  1564. for k, v in state_dict.items():
  1565. if 'patch_embed.proj.weight' in k:
  1566. O, I, H, W = model.patch_embed.proj.weight.shape
  1567. if len(v.shape) < 4:
  1568. # For old models that I trained prior to conv based patchification
  1569. O, I, H, W = model.patch_embed.proj.weight.shape
  1570. v = v.reshape(O, -1, H, W)
  1571. if v.shape[-1] != W or v.shape[-2] != H:
  1572. v = resample_patch_embed(
  1573. v,
  1574. (H, W),
  1575. interpolation=interpolation,
  1576. antialias=antialias,
  1577. verbose=True,
  1578. )
  1579. elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
  1580. # To resize pos embedding when using model at different size from pretrained weights
  1581. num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
  1582. v = resample_abs_pos_embed(
  1583. v,
  1584. new_size=model.patch_embed.grid_size,
  1585. num_prefix_tokens=num_prefix_tokens,
  1586. interpolation=interpolation,
  1587. antialias=antialias,
  1588. verbose=True,
  1589. )
  1590. elif adapt_layer_scale and 'gamma_' in k:
  1591. # remap layer-scale gamma into sub-module (deit3 models)
  1592. k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
  1593. elif 'pre_logits' in k:
  1594. # NOTE representation layer removed as not used in latest 21k/1k pretrained weights
  1595. continue
  1596. out_dict[k] = v
  1597. return out_dict
  1598. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  1599. return {
  1600. 'url': url,
  1601. 'num_classes': 1000,
  1602. 'input_size': (3, 224, 224),
  1603. 'pool_size': None,
  1604. 'crop_pct': 0.9,
  1605. 'interpolation': 'bicubic',
  1606. 'fixed_input_size': True,
  1607. 'mean': IMAGENET_INCEPTION_MEAN,
  1608. 'std': IMAGENET_INCEPTION_STD,
  1609. 'first_conv': 'patch_embed.proj',
  1610. 'classifier': 'head',
  1611. 'license': 'apache-2.0',
  1612. **kwargs,
  1613. }
  1614. default_cfgs = {
  1615. # re-finetuned augreg 21k FT on in1k weights
  1616. 'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
  1617. hf_hub_id='timm/'),
  1618. 'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg(
  1619. hf_hub_id='timm/'),
  1620. # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k
  1621. 'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg(
  1622. url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
  1623. hf_hub_id='timm/',
  1624. custom_load=True),
  1625. 'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg(
  1626. url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
  1627. hf_hub_id='timm/',
  1628. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1629. 'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg(
  1630. url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
  1631. hf_hub_id='timm/',
  1632. custom_load=True),
  1633. 'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg(
  1634. url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
  1635. hf_hub_id='timm/',
  1636. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1637. 'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg(
  1638. url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
  1639. hf_hub_id='timm/',
  1640. custom_load=True),
  1641. 'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg(
  1642. url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
  1643. hf_hub_id='timm/',
  1644. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1645. 'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg(
  1646. url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
  1647. hf_hub_id='timm/',
  1648. custom_load=True),
  1649. 'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg(
  1650. url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
  1651. hf_hub_id='timm/',
  1652. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1653. 'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg(
  1654. url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
  1655. hf_hub_id='timm/',
  1656. custom_load=True),
  1657. 'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg(
  1658. url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
  1659. hf_hub_id='timm/',
  1660. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1661. 'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg(
  1662. url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
  1663. hf_hub_id='timm/',
  1664. custom_load=True),
  1665. 'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg(
  1666. url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
  1667. hf_hub_id='timm/',
  1668. custom_load=True),
  1669. 'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg(
  1670. url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
  1671. hf_hub_id='timm/',
  1672. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1673. # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k
  1674. 'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg(
  1675. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
  1676. hf_hub_id='timm/'),
  1677. 'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg(
  1678. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
  1679. hf_hub_id='timm/',
  1680. input_size=(3, 384, 384), crop_pct=1.0),
  1681. 'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg(
  1682. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
  1683. hf_hub_id='timm/',
  1684. input_size=(3, 384, 384), crop_pct=1.0),
  1685. # How to train your ViT (augreg) weights trained on in1k only
  1686. 'vit_small_patch16_224.augreg_in1k': _cfg(
  1687. url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
  1688. hf_hub_id='timm/',
  1689. custom_load=True),
  1690. 'vit_small_patch16_384.augreg_in1k': _cfg(
  1691. url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
  1692. hf_hub_id='timm/',
  1693. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1694. 'vit_base_patch32_224.augreg_in1k': _cfg(
  1695. url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
  1696. hf_hub_id='timm/',
  1697. custom_load=True),
  1698. 'vit_base_patch32_384.augreg_in1k': _cfg(
  1699. url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
  1700. hf_hub_id='timm/',
  1701. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1702. 'vit_base_patch16_224.augreg_in1k': _cfg(
  1703. url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
  1704. hf_hub_id='timm/',
  1705. custom_load=True),
  1706. 'vit_base_patch16_384.augreg_in1k': _cfg(
  1707. url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
  1708. hf_hub_id='timm/',
  1709. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1710. 'vit_large_patch14_224.untrained': _cfg(url=''),
  1711. 'vit_huge_patch14_224.untrained': _cfg(url=''),
  1712. 'vit_giant_patch14_224.untrained': _cfg(url=''),
  1713. 'vit_gigantic_patch14_224.untrained': _cfg(url=''),
  1714. # patch models, imagenet21k (weights from official Google JAX impl), classifier not valid
  1715. 'vit_base_patch32_224.orig_in21k': _cfg(
  1716. #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
  1717. hf_hub_id='timm/',
  1718. num_classes=0),
  1719. 'vit_base_patch16_224.orig_in21k': _cfg(
  1720. #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
  1721. hf_hub_id='timm/',
  1722. num_classes=0),
  1723. 'vit_large_patch32_224.orig_in21k': _cfg(
  1724. #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
  1725. hf_hub_id='timm/',
  1726. num_classes=0),
  1727. 'vit_large_patch16_224.orig_in21k': _cfg(
  1728. #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
  1729. hf_hub_id='timm/',
  1730. num_classes=0),
  1731. 'vit_huge_patch14_224.orig_in21k': _cfg(
  1732. hf_hub_id='timm/',
  1733. num_classes=0),
  1734. # How to train your ViT (augreg) weights, pretrained on in21k
  1735. 'vit_tiny_patch16_224.augreg_in21k': _cfg(
  1736. url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
  1737. hf_hub_id='timm/',
  1738. custom_load=True, num_classes=21843),
  1739. 'vit_small_patch32_224.augreg_in21k': _cfg(
  1740. url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
  1741. hf_hub_id='timm/',
  1742. custom_load=True, num_classes=21843),
  1743. 'vit_small_patch16_224.augreg_in21k': _cfg(
  1744. url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
  1745. hf_hub_id='timm/',
  1746. custom_load=True, num_classes=21843),
  1747. 'vit_base_patch32_224.augreg_in21k': _cfg(
  1748. url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
  1749. hf_hub_id='timm/',
  1750. custom_load=True, num_classes=21843),
  1751. 'vit_base_patch16_224.augreg_in21k': _cfg(
  1752. url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
  1753. hf_hub_id='timm/',
  1754. custom_load=True, num_classes=21843),
  1755. 'vit_base_patch8_224.augreg_in21k': _cfg(
  1756. url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
  1757. hf_hub_id='timm/',
  1758. custom_load=True, num_classes=21843),
  1759. 'vit_large_patch16_224.augreg_in21k': _cfg(
  1760. url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
  1761. hf_hub_id='timm/',
  1762. custom_load=True, num_classes=21843),
  1763. # SAM trained models (https://arxiv.org/abs/2106.01548)
  1764. 'vit_base_patch32_224.sam_in1k': _cfg(
  1765. url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True,
  1766. hf_hub_id='timm/'),
  1767. 'vit_base_patch16_224.sam_in1k': _cfg(
  1768. url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True,
  1769. hf_hub_id='timm/'),
  1770. # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only)
  1771. 'vit_small_patch16_224.dino': _cfg(
  1772. url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth',
  1773. hf_hub_id='timm/',
  1774. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1775. 'vit_small_patch8_224.dino': _cfg(
  1776. url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth',
  1777. hf_hub_id='timm/',
  1778. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1779. 'vit_base_patch16_224.dino': _cfg(
  1780. url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth',
  1781. hf_hub_id='timm/',
  1782. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1783. 'vit_base_patch8_224.dino': _cfg(
  1784. url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
  1785. hf_hub_id='timm/',
  1786. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1787. # DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only)
  1788. 'vit_small_patch14_dinov2.lvd142m': _cfg(
  1789. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
  1790. hf_hub_id='timm/',
  1791. license='apache-2.0',
  1792. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1793. input_size=(3, 518, 518), crop_pct=1.0),
  1794. 'vit_base_patch14_dinov2.lvd142m': _cfg(
  1795. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth',
  1796. hf_hub_id='timm/',
  1797. license='apache-2.0',
  1798. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1799. input_size=(3, 518, 518), crop_pct=1.0),
  1800. 'vit_large_patch14_dinov2.lvd142m': _cfg(
  1801. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth',
  1802. hf_hub_id='timm/',
  1803. license='apache-2.0',
  1804. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1805. input_size=(3, 518, 518), crop_pct=1.0),
  1806. 'vit_giant_patch14_dinov2.lvd142m': _cfg(
  1807. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth',
  1808. hf_hub_id='timm/',
  1809. license='apache-2.0',
  1810. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1811. input_size=(3, 518, 518), crop_pct=1.0),
  1812. # DINOv2 pretrained w/ registers - https://arxiv.org/abs/2309.16588 (no classifier head, for fine-tune/features only)
  1813. 'vit_small_patch14_reg4_dinov2.lvd142m': _cfg(
  1814. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth',
  1815. hf_hub_id='timm/',
  1816. license='apache-2.0',
  1817. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1818. input_size=(3, 518, 518), crop_pct=1.0),
  1819. 'vit_base_patch14_reg4_dinov2.lvd142m': _cfg(
  1820. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth',
  1821. hf_hub_id='timm/',
  1822. license='apache-2.0',
  1823. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1824. input_size=(3, 518, 518), crop_pct=1.0),
  1825. 'vit_large_patch14_reg4_dinov2.lvd142m': _cfg(
  1826. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth',
  1827. hf_hub_id='timm/',
  1828. license='apache-2.0',
  1829. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1830. input_size=(3, 518, 518), crop_pct=1.0),
  1831. 'vit_giant_patch14_reg4_dinov2.lvd142m': _cfg(
  1832. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth',
  1833. hf_hub_id='timm/',
  1834. license='apache-2.0',
  1835. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1836. input_size=(3, 518, 518), crop_pct=1.0),
  1837. # ViT ImageNet-21K-P pretraining by MILL
  1838. 'vit_base_patch16_224_miil.in21k': _cfg(
  1839. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
  1840. hf_hub_id='timm/',
  1841. mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221),
  1842. 'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg(
  1843. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth',
  1844. hf_hub_id='timm/',
  1845. mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
  1846. # Custom timm variants
  1847. 'vit_base_patch16_rpn_224.sw_in1k': _cfg(
  1848. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth',
  1849. hf_hub_id='timm/'),
  1850. 'vit_medium_patch16_gap_240.sw_in12k': _cfg(
  1851. hf_hub_id='timm/',
  1852. input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821),
  1853. 'vit_medium_patch16_gap_256.sw_in12k_ft_in1k': _cfg(
  1854. hf_hub_id='timm/',
  1855. input_size=(3, 256, 256), crop_pct=0.95),
  1856. 'vit_medium_patch16_gap_384.sw_in12k_ft_in1k': _cfg(
  1857. hf_hub_id='timm/',
  1858. input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'),
  1859. 'vit_betwixt_patch16_gap_256.untrained': _cfg(
  1860. input_size=(3, 256, 256), crop_pct=0.95),
  1861. 'vit_base_patch16_gap_224.untrained': _cfg(),
  1862. # CLIP pretrained image tower and related fine-tuned weights
  1863. 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg(
  1864. hf_hub_id='timm/',
  1865. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
  1866. 'vit_base_patch32_clip_384.laion2b_ft_in12k_in1k': _cfg(
  1867. hf_hub_id='timm/',
  1868. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
  1869. 'vit_base_patch32_clip_448.laion2b_ft_in12k_in1k': _cfg(
  1870. hf_hub_id='timm/',
  1871. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
  1872. 'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg(
  1873. hf_hub_id='timm/',
  1874. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
  1875. 'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg(
  1876. hf_hub_id='timm/',
  1877. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1878. crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
  1879. 'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
  1880. hf_hub_id='timm/',
  1881. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
  1882. 'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg(
  1883. hf_hub_id='timm/',
  1884. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1885. crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
  1886. 'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
  1887. hf_hub_id='timm/',
  1888. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
  1889. 'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg(
  1890. hf_hub_id='timm/',
  1891. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1892. crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
  1893. 'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg(
  1894. # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k', # FIXME weight exists, need to push
  1895. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
  1896. 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg(
  1897. hf_hub_id='timm/',
  1898. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1899. crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
  1900. 'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg(
  1901. hf_hub_id='timm/',
  1902. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
  1903. 'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg(
  1904. hf_hub_id='timm/',
  1905. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1906. crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
  1907. 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg(
  1908. hf_hub_id='timm/',
  1909. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
  1910. 'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg(
  1911. hf_hub_id='timm/',
  1912. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1913. crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
  1914. 'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg(
  1915. hf_hub_id='timm/',
  1916. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
  1917. 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg(
  1918. hf_hub_id='timm/',
  1919. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
  1920. 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg(
  1921. hf_hub_id='timm/',
  1922. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1923. crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
  1924. 'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg(
  1925. hf_hub_id='timm/',
  1926. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
  1927. 'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg(
  1928. hf_hub_id='timm/',
  1929. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1930. crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
  1931. 'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg(
  1932. hf_hub_id='timm/',
  1933. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
  1934. 'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg(
  1935. hf_hub_id='',
  1936. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1937. crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
  1938. 'vit_base_patch32_clip_224.openai_ft_in1k': _cfg(
  1939. hf_hub_id='timm/',
  1940. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
  1941. 'vit_base_patch16_clip_224.openai_ft_in1k': _cfg(
  1942. hf_hub_id='timm/',
  1943. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
  1944. 'vit_base_patch16_clip_384.openai_ft_in1k': _cfg(
  1945. hf_hub_id='timm/',
  1946. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1947. crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
  1948. 'vit_large_patch14_clip_224.openai_ft_in1k': _cfg(
  1949. hf_hub_id='timm/',
  1950. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
  1951. 'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg(
  1952. hf_hub_id='timm/',
  1953. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
  1954. 'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg(
  1955. hf_hub_id='timm/',
  1956. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821),
  1957. 'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg(
  1958. hf_hub_id='timm/',
  1959. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
  1960. 'vit_base_patch16_clip_224.openai_ft_in12k': _cfg(
  1961. hf_hub_id='timm/',
  1962. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
  1963. 'vit_large_patch14_clip_224.openai_ft_in12k': _cfg(
  1964. hf_hub_id='timm/',
  1965. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
  1966. 'vit_base_patch32_clip_224.laion2b': _cfg(
  1967. hf_hub_id='timm/',
  1968. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  1969. 'vit_base_patch16_clip_224.laion2b': _cfg(
  1970. hf_hub_id='timm/',
  1971. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  1972. 'vit_large_patch14_clip_224.laion2b': _cfg(
  1973. hf_hub_id='timm/',
  1974. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768),
  1975. 'vit_huge_patch14_clip_224.laion2b': _cfg(
  1976. hf_hub_id='timm/',
  1977. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
  1978. 'vit_giant_patch14_clip_224.laion2b': _cfg(
  1979. hf_hub_id='timm/',
  1980. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
  1981. 'vit_gigantic_patch14_clip_224.laion2b': _cfg(
  1982. hf_hub_id='timm/',
  1983. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
  1984. 'vit_base_patch32_clip_224.laion400m_e32': _cfg(
  1985. hf_hub_id='timm/',
  1986. license='mit',
  1987. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1988. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  1989. 'vit_base_patch16_clip_224.laion400m_e32': _cfg(
  1990. hf_hub_id='timm/',
  1991. license='mit',
  1992. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  1993. 'vit_base_patch16_plus_clip_240.laion400m_e32': _cfg(
  1994. hf_hub_id='timm/',
  1995. license='mit',
  1996. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1997. input_size=(3, 240, 240), crop_pct=1.0, num_classes=640),
  1998. 'vit_large_patch14_clip_224.laion400m_e32': _cfg(
  1999. hf_hub_id='timm/',
  2000. license='mit',
  2001. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  2002. 'vit_base_patch32_clip_224.datacompxl': _cfg(
  2003. hf_hub_id='timm/',
  2004. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  2005. 'vit_base_patch32_clip_256.datacompxl': _cfg(
  2006. hf_hub_id='timm/',
  2007. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  2008. crop_pct=1.0, input_size=(3, 256, 256), num_classes=512),
  2009. 'vit_base_patch16_clip_224.datacompxl': _cfg(
  2010. hf_hub_id='timm/',
  2011. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  2012. 'vit_large_patch14_clip_224.datacompxl': _cfg(
  2013. hf_hub_id='timm/',
  2014. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  2015. 'vit_base_patch16_clip_224.dfn2b': _cfg(
  2016. hf_hub_id='timm/',
  2017. license='apple-ascl',
  2018. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  2019. 'vit_large_patch14_clip_224.dfn2b_s39b': _cfg(
  2020. hf_hub_id='timm/',
  2021. license='apple-ascl',
  2022. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  2023. 'vit_large_patch14_clip_224.dfn2b': _cfg(
  2024. hf_hub_id='timm/',
  2025. license='apple-ascl',
  2026. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2027. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  2028. 'vit_huge_patch14_clip_224.dfn5b': _cfg(
  2029. hf_hub_id='timm/',
  2030. license='apple-ascl',
  2031. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2032. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
  2033. 'vit_huge_patch14_clip_378.dfn5b': _cfg(
  2034. hf_hub_id='timm/',
  2035. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  2036. license='apple-ascl',
  2037. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2038. crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024),
  2039. # 'vit_large_patch14_clip_224.metaclip2_worldwide': _cfg(
  2040. # hf_hub_id='timm/',
  2041. # license='cc-by-nc-4.0',
  2042. # notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2043. # mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  2044. 'vit_huge_patch14_clip_224.metaclip2_worldwide': _cfg(
  2045. hf_hub_id='timm/',
  2046. license='cc-by-nc-4.0',
  2047. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2048. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
  2049. 'vit_huge_patch14_clip_378.metaclip2_worldwide': _cfg(
  2050. hf_hub_id='timm/',
  2051. license='cc-by-nc-4.0',
  2052. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  2053. input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash', num_classes=1024),
  2054. 'vit_gigantic_patch14_clip_224.metaclip2_worldwide': _cfg(
  2055. hf_hub_id='timm/',
  2056. license='cc-by-nc-4.0',
  2057. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
  2058. 'vit_gigantic_patch14_clip_378.metaclip2_worldwide': _cfg(
  2059. hf_hub_id='timm/',
  2060. license='cc-by-nc-4.0',
  2061. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  2062. input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash', num_classes=1280),
  2063. 'vit_base_patch32_clip_224.metaclip_2pt5b': _cfg(
  2064. hf_hub_id='timm/',
  2065. license='cc-by-nc-4.0',
  2066. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2067. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  2068. 'vit_base_patch16_clip_224.metaclip_2pt5b': _cfg(
  2069. hf_hub_id='timm/',
  2070. license='cc-by-nc-4.0',
  2071. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2072. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  2073. 'vit_large_patch14_clip_224.metaclip_2pt5b': _cfg(
  2074. hf_hub_id='timm/',
  2075. license='cc-by-nc-4.0',
  2076. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2077. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  2078. 'vit_huge_patch14_clip_224.metaclip_2pt5b': _cfg(
  2079. hf_hub_id='timm/',
  2080. license='cc-by-nc-4.0',
  2081. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2082. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
  2083. 'vit_huge_patch14_clip_224.metaclip_altogether': _cfg(
  2084. hf_hub_id='timm/',
  2085. license='cc-by-nc-4.0',
  2086. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
  2087. 'vit_gigantic_patch14_clip_224.metaclip_2pt5b': _cfg(
  2088. hf_hub_id='timm/',
  2089. license='cc-by-nc-4.0',
  2090. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2091. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
  2092. 'vit_base_patch32_clip_224.metaclip_400m': _cfg(
  2093. hf_hub_id='timm/',
  2094. license='cc-by-nc-4.0',
  2095. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2096. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  2097. 'vit_base_patch16_clip_224.metaclip_400m': _cfg(
  2098. hf_hub_id='timm/',
  2099. license='cc-by-nc-4.0',
  2100. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2101. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  2102. 'vit_large_patch14_clip_224.metaclip_400m': _cfg(
  2103. hf_hub_id='timm/',
  2104. license='cc-by-nc-4.0',
  2105. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2106. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  2107. 'vit_base_patch32_clip_224.openai': _cfg(
  2108. hf_hub_id='timm/',
  2109. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2110. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  2111. 'vit_base_patch16_clip_224.openai': _cfg(
  2112. hf_hub_id='timm/',
  2113. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2114. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  2115. 'vit_large_patch14_clip_224.openai': _cfg(
  2116. hf_hub_id='timm/',
  2117. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2118. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  2119. 'vit_large_patch14_clip_336.openai': _cfg(
  2120. hf_hub_id='timm/',
  2121. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  2122. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  2123. crop_pct=1.0, input_size=(3, 336, 336), num_classes=768),
  2124. 'vit_large_patch14_clip_224.apple_mclip2_dfndr2b': _cfg(
  2125. hf_hub_id='timm/',
  2126. num_classes=768,
  2127. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
  2128. license='apple-amlr'
  2129. ),
  2130. # experimental (may be removed)
  2131. 'vit_base_patch32_plus_256.untrained': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
  2132. 'vit_base_patch16_plus_240.untrained': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95),
  2133. 'vit_small_patch16_36x1_224.untrained': _cfg(url=''),
  2134. 'vit_small_patch16_18x2_224.untrained': _cfg(url=''),
  2135. 'vit_base_patch16_18x2_224.untrained': _cfg(url=''),
  2136. # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain
  2137. # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip
  2138. 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg(
  2139. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt',
  2140. hf_hub_id='timm/', license='mit',
  2141. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  2142. input_size=(3, 196, 196), crop_pct=1.0),
  2143. 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg(
  2144. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt',
  2145. hf_hub_id='timm/', license='mit',
  2146. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  2147. input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
  2148. 'eva_large_patch14_196.in22k_ft_in1k': _cfg(
  2149. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt',
  2150. hf_hub_id='timm/', license='mit',
  2151. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  2152. input_size=(3, 196, 196), crop_pct=1.0),
  2153. 'eva_large_patch14_336.in22k_ft_in1k': _cfg(
  2154. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt',
  2155. hf_hub_id='timm/', license='mit',
  2156. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  2157. input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
  2158. 'flexivit_small.1200ep_in1k': _cfg(
  2159. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz', custom_load=True,
  2160. hf_hub_id='timm/',
  2161. input_size=(3, 240, 240), crop_pct=0.95),
  2162. 'flexivit_small.600ep_in1k': _cfg(
  2163. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz', custom_load=True,
  2164. hf_hub_id='timm/',
  2165. input_size=(3, 240, 240), crop_pct=0.95),
  2166. 'flexivit_small.300ep_in1k': _cfg(
  2167. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz', custom_load=True,
  2168. hf_hub_id='timm/',
  2169. input_size=(3, 240, 240), crop_pct=0.95),
  2170. 'flexivit_base.1200ep_in1k': _cfg(
  2171. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz', custom_load=True,
  2172. hf_hub_id='timm/',
  2173. input_size=(3, 240, 240), crop_pct=0.95),
  2174. 'flexivit_base.600ep_in1k': _cfg(
  2175. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz', custom_load=True,
  2176. hf_hub_id='timm/',
  2177. input_size=(3, 240, 240), crop_pct=0.95),
  2178. 'flexivit_base.300ep_in1k': _cfg(
  2179. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz', custom_load=True,
  2180. hf_hub_id='timm/',
  2181. input_size=(3, 240, 240), crop_pct=0.95),
  2182. 'flexivit_base.1000ep_in21k': _cfg(
  2183. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz', custom_load=True,
  2184. hf_hub_id='timm/',
  2185. input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
  2186. 'flexivit_base.300ep_in21k': _cfg(
  2187. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz', custom_load=True,
  2188. hf_hub_id='timm/',
  2189. input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
  2190. 'flexivit_large.1200ep_in1k': _cfg(
  2191. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz', custom_load=True,
  2192. hf_hub_id='timm/',
  2193. input_size=(3, 240, 240), crop_pct=0.95),
  2194. 'flexivit_large.600ep_in1k': _cfg(
  2195. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz', custom_load=True,
  2196. hf_hub_id='timm/',
  2197. input_size=(3, 240, 240), crop_pct=0.95),
  2198. 'flexivit_large.300ep_in1k': _cfg(
  2199. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz', custom_load=True,
  2200. hf_hub_id='timm/',
  2201. input_size=(3, 240, 240), crop_pct=0.95),
  2202. 'flexivit_base.patch16_in21k': _cfg(
  2203. url='https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz', custom_load=True,
  2204. hf_hub_id='timm/',
  2205. input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
  2206. 'flexivit_base.patch30_in21k': _cfg(
  2207. url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True,
  2208. hf_hub_id='timm/',
  2209. input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
  2210. 'vit_base_patch16_xp_224.untrained': _cfg(url=''),
  2211. 'vit_large_patch14_xp_224.untrained': _cfg(url=''),
  2212. 'vit_huge_patch14_xp_224.untrained': _cfg(url=''),
  2213. 'vit_base_patch16_224.mae': _cfg(
  2214. url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth',
  2215. hf_hub_id='timm/',
  2216. license='cc-by-nc-4.0',
  2217. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  2218. 'vit_large_patch16_224.mae': _cfg(
  2219. url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth',
  2220. hf_hub_id='timm/',
  2221. license='cc-by-nc-4.0',
  2222. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  2223. 'vit_huge_patch14_224.mae': _cfg(
  2224. url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_huge.pth',
  2225. hf_hub_id='timm/',
  2226. license='cc-by-nc-4.0',
  2227. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  2228. 'vit_huge_patch14_gap_224.in1k_ijepa': _cfg(
  2229. url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
  2230. # hf_hub_id='timm/',
  2231. license='cc-by-nc-4.0',
  2232. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  2233. 'vit_huge_patch14_gap_224.in22k_ijepa': _cfg(
  2234. url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar',
  2235. # hf_hub_id='timm/',
  2236. license='cc-by-nc-4.0',
  2237. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  2238. 'vit_huge_patch16_gap_448.in1k_ijepa': _cfg(
  2239. url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar',
  2240. # hf_hub_id='timm/',
  2241. license='cc-by-nc-4.0',
  2242. input_size=(3, 448, 448), crop_pct=1.0,
  2243. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  2244. 'vit_giant_patch16_gap_224.in22k_ijepa': _cfg(
  2245. url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar',
  2246. # hf_hub_id='timm/',
  2247. license='cc-by-nc-4.0',
  2248. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  2249. 'vit_base_patch32_siglip_256.v2_webli': _cfg(
  2250. hf_hub_id='timm/',
  2251. input_size=(3, 256, 256),
  2252. num_classes=0),
  2253. 'vit_base_patch16_siglip_224.v2_webli': _cfg(
  2254. hf_hub_id='timm/',
  2255. num_classes=0),
  2256. 'vit_base_patch16_siglip_224.webli': _cfg(
  2257. hf_hub_id='timm/',
  2258. num_classes=0),
  2259. 'vit_base_patch16_siglip_256.v2_webli': _cfg(
  2260. hf_hub_id='timm/',
  2261. input_size=(3, 256, 256),
  2262. num_classes=0),
  2263. 'vit_base_patch16_siglip_256.webli': _cfg(
  2264. hf_hub_id='timm/',
  2265. input_size=(3, 256, 256),
  2266. num_classes=0),
  2267. 'vit_base_patch16_siglip_256.webli_i18n': _cfg(
  2268. hf_hub_id='timm/',
  2269. input_size=(3, 256, 256),
  2270. num_classes=0),
  2271. 'vit_base_patch16_siglip_384.v2_webli': _cfg(
  2272. hf_hub_id='timm/',
  2273. input_size=(3, 384, 384),
  2274. num_classes=0),
  2275. 'vit_base_patch16_siglip_384.webli': _cfg(
  2276. hf_hub_id='timm/',
  2277. input_size=(3, 384, 384),
  2278. num_classes=0),
  2279. 'vit_base_patch16_siglip_512.v2_webli': _cfg(
  2280. hf_hub_id='timm/',
  2281. input_size=(3, 512, 512),
  2282. num_classes=0),
  2283. 'vit_base_patch16_siglip_512.webli': _cfg(
  2284. hf_hub_id='timm/',
  2285. input_size=(3, 512, 512),
  2286. num_classes=0),
  2287. 'vit_large_patch16_siglip_256.v2_webli': _cfg(
  2288. hf_hub_id='timm/',
  2289. input_size=(3, 256, 256),
  2290. num_classes=0),
  2291. 'vit_large_patch16_siglip_256.webli': _cfg(
  2292. hf_hub_id='timm/',
  2293. input_size=(3, 256, 256),
  2294. num_classes=0),
  2295. 'vit_large_patch16_siglip_384.v2_webli': _cfg(
  2296. hf_hub_id='timm/',
  2297. input_size=(3, 384, 384),
  2298. num_classes=0),
  2299. 'vit_large_patch16_siglip_384.webli': _cfg(
  2300. hf_hub_id='timm/',
  2301. input_size=(3, 384, 384),
  2302. num_classes=0),
  2303. 'vit_large_patch16_siglip_512.v2_webli': _cfg(
  2304. hf_hub_id='timm/',
  2305. input_size=(3, 512, 512),
  2306. num_classes=0),
  2307. 'vit_so400m_patch14_siglip_224.v2_webli': _cfg(
  2308. hf_hub_id='timm/',
  2309. num_classes=0),
  2310. 'vit_so400m_patch14_siglip_224.webli': _cfg(
  2311. hf_hub_id='timm/',
  2312. num_classes=0),
  2313. 'vit_so400m_patch14_siglip_378.v2_webli': _cfg(
  2314. hf_hub_id='timm/',
  2315. input_size=(3, 378, 378),
  2316. num_classes=0),
  2317. 'vit_so400m_patch14_siglip_378.webli': _cfg(
  2318. hf_hub_id='timm/',
  2319. input_size=(3, 378, 378),
  2320. num_classes=0),
  2321. 'vit_so400m_patch14_siglip_384.webli': _cfg(
  2322. hf_hub_id='timm/',
  2323. input_size=(3, 384, 384),
  2324. num_classes=0),
  2325. 'vit_so400m_patch16_siglip_256.v2_webli': _cfg(
  2326. hf_hub_id='timm/',
  2327. input_size=(3, 256, 256),
  2328. num_classes=0),
  2329. 'vit_so400m_patch16_siglip_256.webli_i18n': _cfg(
  2330. hf_hub_id='timm/',
  2331. input_size=(3, 256, 256),
  2332. num_classes=0),
  2333. 'vit_so400m_patch16_siglip_384.v2_webli': _cfg(
  2334. hf_hub_id='timm/',
  2335. input_size=(3, 384, 384),
  2336. num_classes=0),
  2337. 'vit_so400m_patch16_siglip_512.v2_webli': _cfg(
  2338. hf_hub_id='timm/',
  2339. input_size=(3, 512, 512),
  2340. num_classes=0),
  2341. 'vit_giantopt_patch16_siglip_256.v2_webli': _cfg(
  2342. hf_hub_id='timm/',
  2343. input_size=(3, 256, 256),
  2344. num_classes=0),
  2345. 'vit_giantopt_patch16_siglip_384.v2_webli': _cfg(
  2346. hf_hub_id='timm/',
  2347. input_size=(3, 384, 384),
  2348. num_classes=0),
  2349. 'vit_base_patch32_siglip_gap_256.v2_webli': _cfg(
  2350. hf_hub_id='timm/',
  2351. input_size=(3, 256, 256),
  2352. num_classes=0),
  2353. 'vit_base_patch16_siglip_gap_224.v2_webli': _cfg(
  2354. hf_hub_id='timm/',
  2355. num_classes=0),
  2356. 'vit_base_patch16_siglip_gap_224.webli': _cfg(
  2357. hf_hub_id='timm/',
  2358. num_classes=0),
  2359. 'vit_base_patch16_siglip_gap_256.v2_webli': _cfg(
  2360. hf_hub_id='timm/',
  2361. input_size=(3, 256, 256),
  2362. num_classes=0),
  2363. 'vit_base_patch16_siglip_gap_256.webli': _cfg(
  2364. hf_hub_id='timm/',
  2365. input_size=(3, 256, 256),
  2366. num_classes=0),
  2367. 'vit_base_patch16_siglip_gap_256.webli_i18n': _cfg(
  2368. hf_hub_id='timm/',
  2369. input_size=(3, 256, 256),
  2370. num_classes=0),
  2371. 'vit_base_patch16_siglip_gap_384.v2_webli': _cfg(
  2372. hf_hub_id='timm/',
  2373. input_size=(3, 384, 384),
  2374. num_classes=0),
  2375. 'vit_base_patch16_siglip_gap_384.webli': _cfg(
  2376. hf_hub_id='timm/',
  2377. input_size=(3, 384, 384),
  2378. num_classes=0),
  2379. 'vit_base_patch16_siglip_gap_512.v2_webli': _cfg(
  2380. hf_hub_id='timm/',
  2381. input_size=(3, 512, 512),
  2382. num_classes=0),
  2383. 'vit_base_patch16_siglip_gap_512.webli': _cfg(
  2384. hf_hub_id='timm/',
  2385. input_size=(3, 512, 512),
  2386. num_classes=0),
  2387. 'vit_large_patch16_siglip_gap_256.v2_webli': _cfg(
  2388. hf_hub_id='timm/',
  2389. input_size=(3, 256, 256),
  2390. num_classes=0),
  2391. 'vit_large_patch16_siglip_gap_256.webli': _cfg(
  2392. hf_hub_id='timm/',
  2393. input_size=(3, 256, 256),
  2394. num_classes=0),
  2395. 'vit_large_patch16_siglip_gap_384.v2_webli': _cfg(
  2396. hf_hub_id='timm/',
  2397. input_size=(3, 384, 384),
  2398. num_classes=0),
  2399. 'vit_large_patch16_siglip_gap_384.webli': _cfg(
  2400. hf_hub_id='timm/',
  2401. input_size=(3, 384, 384),
  2402. num_classes=0),
  2403. 'vit_large_patch16_siglip_gap_512.v2_webli': _cfg(
  2404. hf_hub_id='timm/',
  2405. input_size=(3, 512, 512),
  2406. num_classes=0),
  2407. 'vit_so400m_patch14_siglip_gap_224.v2_webli': _cfg(
  2408. hf_hub_id='timm/',
  2409. num_classes=0),
  2410. 'vit_so400m_patch14_siglip_gap_224.webli': _cfg(
  2411. hf_hub_id='timm/',
  2412. num_classes=0),
  2413. 'vit_so400m_patch14_siglip_gap_224.pali_mix': _cfg(
  2414. hf_hub_id='timm/',
  2415. num_classes=0),
  2416. 'vit_so400m_patch14_siglip_gap_224.pali_pt': _cfg(
  2417. hf_hub_id='timm/',
  2418. num_classes=0),
  2419. 'vit_so400m_patch14_siglip_gap_224.pali2_3b_pt': _cfg(
  2420. hf_hub_id='timm/',
  2421. num_classes=0),
  2422. 'vit_so400m_patch14_siglip_gap_224.pali2_10b_pt': _cfg(
  2423. hf_hub_id='timm/',
  2424. num_classes=0),
  2425. # 'vit_so400m_patch14_siglip_gap_224.pali2_28b_pt': _cfg(
  2426. # hf_hub_id='google/paligemma2-28b-pt-224-jax',
  2427. # hf_hub_filename='pt_27b_224.npz',
  2428. # custom_load='hf',
  2429. # num_classes=0),
  2430. 'vit_so400m_patch14_siglip_gap_378.v2_webli': _cfg(
  2431. hf_hub_id='timm/',
  2432. input_size=(3, 378, 378),
  2433. num_classes=0),
  2434. 'vit_so400m_patch14_siglip_gap_378.webli': _cfg(
  2435. hf_hub_id='timm/',
  2436. input_size=(3, 378, 378), crop_pct=1.0,
  2437. num_classes=0),
  2438. 'vit_so400m_patch14_siglip_gap_384.webli': _cfg(
  2439. hf_hub_id='timm/',
  2440. input_size=(3, 384, 384), crop_pct=1.0,
  2441. num_classes=0),
  2442. 'vit_so400m_patch14_siglip_gap_448.pali_mix': _cfg(
  2443. hf_hub_id='timm/',
  2444. input_size=(3, 448, 448), crop_pct=1.0,
  2445. num_classes=0),
  2446. 'vit_so400m_patch14_siglip_gap_448.pali_pt': _cfg(
  2447. hf_hub_id='timm/',
  2448. input_size=(3, 448, 448), crop_pct=1.0,
  2449. num_classes=0),
  2450. 'vit_so400m_patch14_siglip_gap_448.pali_refcoco_seg': _cfg(
  2451. hf_hub_id='timm/',
  2452. input_size=(3, 448, 448), crop_pct=1.0,
  2453. num_classes=0),
  2454. 'vit_so400m_patch14_siglip_gap_448.pali_ocrvqa': _cfg(
  2455. hf_hub_id='timm/',
  2456. input_size=(3, 448, 448), crop_pct=1.0,
  2457. num_classes=0),
  2458. 'vit_so400m_patch14_siglip_gap_448.pali2_3b_pt': _cfg(
  2459. hf_hub_id='timm/',
  2460. input_size=(3, 448, 448), crop_pct=1.0,
  2461. num_classes=0),
  2462. 'vit_so400m_patch14_siglip_gap_448.pali2_10b_pt': _cfg(
  2463. hf_hub_id='timm/',
  2464. input_size=(3, 448, 448), crop_pct=1.0,
  2465. num_classes=0),
  2466. # 'vit_so400m_patch14_siglip_gap_448.pali2_28b_pt': _cfg(
  2467. # hf_hub_id='google/paligemma2-28b-pt-448-jax',
  2468. # hf_hub_filename='pt_27b_448.npz',
  2469. # custom_load='hf',
  2470. # input_size=(3, 448, 448), crop_pct=1.0,
  2471. # num_classes=0),
  2472. 'vit_so400m_patch14_siglip_gap_448.pali2_3b_docci': _cfg(
  2473. hf_hub_id='timm/',
  2474. input_size=(3, 448, 448), crop_pct=1.0,
  2475. num_classes=0),
  2476. 'vit_so400m_patch14_siglip_gap_448.pali2_10b_docci': _cfg(
  2477. hf_hub_id='timm/',
  2478. input_size=(3, 448, 448), crop_pct=1.0,
  2479. num_classes=0),
  2480. 'vit_so400m_patch14_siglip_gap_896.pali_pt': _cfg(
  2481. hf_hub_id='timm/',
  2482. input_size=(3, 896, 896), crop_pct=1.0,
  2483. num_classes=0),
  2484. 'vit_so400m_patch14_siglip_gap_896.pali_refcoco_seg': _cfg(
  2485. hf_hub_id='timm/',
  2486. input_size=(3, 896, 896), crop_pct=1.0,
  2487. num_classes=0),
  2488. 'vit_so400m_patch14_siglip_gap_896.pali_ocrvqa': _cfg(
  2489. hf_hub_id='timm/',
  2490. input_size=(3, 896, 896), crop_pct=1.0,
  2491. num_classes=0),
  2492. 'vit_so400m_patch14_siglip_gap_896.pali2_3b_pt': _cfg(
  2493. hf_hub_id='timm/',
  2494. input_size=(3, 896, 896), crop_pct=1.0,
  2495. num_classes=0),
  2496. 'vit_so400m_patch14_siglip_gap_896.pali2_10b_pt': _cfg(
  2497. hf_hub_id='timm/',
  2498. input_size=(3, 896, 896), crop_pct=1.0,
  2499. num_classes=0),
  2500. # 'vit_so400m_patch14_siglip_gap_896.pali2_28b_pt': _cfg(
  2501. # hf_hub_id='google/paligemma2-28b-pt-896-jax',
  2502. # hf_hub_filename='pt_27b_896.npz',
  2503. # custom_load='hf',
  2504. # input_size=(3, 896, 896), crop_pct=1.0,
  2505. # num_classes=0),
  2506. 'vit_so400m_patch16_siglip_gap_256.v2_webli': _cfg(
  2507. hf_hub_id='timm/',
  2508. input_size=(3, 256, 256),
  2509. num_classes=0),
  2510. 'vit_so400m_patch16_siglip_gap_256.webli_i18n': _cfg(
  2511. hf_hub_id='timm/',
  2512. input_size=(3, 256, 256),
  2513. num_classes=0),
  2514. 'vit_so400m_patch16_siglip_gap_384.v2_webli': _cfg(
  2515. hf_hub_id='timm/',
  2516. input_size=(3, 384, 384),
  2517. num_classes=0),
  2518. 'vit_so400m_patch16_siglip_gap_512.v2_webli': _cfg(
  2519. hf_hub_id='timm/',
  2520. input_size=(3, 512, 512),
  2521. num_classes=0),
  2522. 'vit_giantopt_patch16_siglip_gap_256.v2_webli': _cfg(
  2523. hf_hub_id='timm/',
  2524. input_size=(3, 256, 256),
  2525. num_classes=0),
  2526. 'vit_giantopt_patch16_siglip_gap_384.v2_webli': _cfg(
  2527. hf_hub_id='timm/',
  2528. input_size=(3, 384, 384),
  2529. num_classes=0),
  2530. 'vit_so400m_patch14_siglip_378.webli_ft_in1k': _cfg(
  2531. hf_hub_id='timm/',
  2532. input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash',
  2533. ),
  2534. 'vit_so400m_patch14_siglip_gap_378.webli_ft_in1k': _cfg(
  2535. hf_hub_id='timm/',
  2536. input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash',
  2537. ),
  2538. 'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg(
  2539. hf_hub_id='timm/',
  2540. license='mit',
  2541. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  2542. 'vit_medium_patch32_clip_224.tinyclip_laion400m': _cfg(
  2543. hf_hub_id='timm/',
  2544. license='mit',
  2545. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  2546. 'vit_medium_patch16_clip_224.tinyclip_yfcc15m': _cfg(
  2547. hf_hub_id='timm/',
  2548. license='mit',
  2549. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  2550. 'vit_betwixt_patch32_clip_224.tinyclip_laion400m': _cfg(
  2551. hf_hub_id='timm/',
  2552. license='mit',
  2553. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  2554. 'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg(
  2555. hf_hub_id='timm/',
  2556. input_size=(3, 256, 256), crop_pct=0.95),
  2557. 'vit_dwee_patch16_reg1_gap_256.sbb_nadamuon_in1k': _cfg(
  2558. hf_hub_id='timm/',
  2559. input_size=(3, 256, 256), crop_pct=0.95),
  2560. 'vit_dwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
  2561. hf_hub_id='timm/',
  2562. input_size=(3, 256, 256), crop_pct=0.95),
  2563. 'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
  2564. hf_hub_id='timm/',
  2565. input_size=(3, 256, 256), crop_pct=0.95),
  2566. 'vit_dpwee_patch16_reg1_gap_256.sbb_nadamuon_in1k': _cfg(
  2567. hf_hub_id='timm/',
  2568. input_size=(3, 256, 256), crop_pct=0.95),
  2569. 'vit_dpwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
  2570. hf_hub_id='timm/',
  2571. input_size=(3, 256, 256), crop_pct=0.95),
  2572. 'vit_little_patch16_reg1_gap_256.sbb_in12k_ft_in1k': _cfg(
  2573. hf_hub_id='timm/',
  2574. input_size=(3, 256, 256), crop_pct=0.95),
  2575. 'vit_little_patch16_reg1_gap_256.sbb_in12k': _cfg(
  2576. hf_hub_id='timm/',
  2577. num_classes=11821,
  2578. input_size=(3, 256, 256), crop_pct=0.95),
  2579. 'vit_little_patch16_reg4_gap_256.sbb_in1k': _cfg(
  2580. hf_hub_id='timm/',
  2581. input_size=(3, 256, 256), crop_pct=0.95),
  2582. 'vit_dlittle_patch16_reg1_gap_256.sbb_nadamuon_in1k': _cfg(
  2583. hf_hub_id='timm/',
  2584. input_size=(3, 256, 256), crop_pct=0.95),
  2585. 'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg(
  2586. hf_hub_id='timm/',
  2587. input_size=(3, 256, 256), crop_pct=0.95),
  2588. 'vit_medium_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
  2589. hf_hub_id='timm/',
  2590. input_size=(3, 256, 256), crop_pct=0.95),
  2591. 'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg(
  2592. hf_hub_id='timm/',
  2593. input_size=(3, 256, 256), crop_pct=0.95),
  2594. 'vit_medium_patch16_reg4_gap_256.sbb_in12k': _cfg(
  2595. hf_hub_id='timm/',
  2596. num_classes=11821,
  2597. input_size=(3, 256, 256), crop_pct=0.95),
  2598. 'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k': _cfg(
  2599. hf_hub_id='timm/',
  2600. input_size=(3, 256, 256), crop_pct=0.95),
  2601. 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
  2602. hf_hub_id='timm/',
  2603. input_size=(3, 256, 256), crop_pct=0.95),
  2604. 'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg(
  2605. hf_hub_id='timm/',
  2606. num_classes=11821,
  2607. input_size=(3, 256, 256), crop_pct=0.95),
  2608. 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg(
  2609. hf_hub_id='timm/',
  2610. num_classes=11821,
  2611. input_size=(3, 256, 256), crop_pct=0.95),
  2612. 'vit_mediumd_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k': _cfg(
  2613. hf_hub_id='timm/',
  2614. input_size=(3, 384, 384), crop_pct=1.0),
  2615. 'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg(
  2616. hf_hub_id='timm/',
  2617. input_size=(3, 256, 256), crop_pct=0.95),
  2618. 'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k': _cfg(
  2619. hf_hub_id='timm/',
  2620. input_size=(3, 256, 256), crop_pct=0.95),
  2621. 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
  2622. hf_hub_id='timm/',
  2623. input_size=(3, 256, 256), crop_pct=0.95),
  2624. 'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
  2625. hf_hub_id='timm/',
  2626. input_size=(3, 256, 256), crop_pct=0.95),
  2627. 'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg(
  2628. hf_hub_id='timm/',
  2629. num_classes=11821,
  2630. input_size=(3, 256, 256), crop_pct=0.95),
  2631. 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg(
  2632. hf_hub_id='timm/',
  2633. num_classes=11821,
  2634. input_size=(3, 256, 256), crop_pct=0.95),
  2635. 'vit_betwixt_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k': _cfg(
  2636. hf_hub_id='timm/',
  2637. input_size=(3, 384, 384), crop_pct=1.0),
  2638. 'vit_base_patch16_reg4_gap_256.untrained': _cfg(
  2639. input_size=(3, 256, 256)),
  2640. 'vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k_ft_in1k': _cfg(
  2641. hf_hub_id='timm/',
  2642. input_size=(3, 256, 256), crop_pct=0.95),
  2643. 'vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k': _cfg(
  2644. hf_hub_id='timm/',
  2645. num_classes=11821,
  2646. input_size=(3, 256, 256), crop_pct=0.95),
  2647. 'vit_so150m_patch16_reg4_gap_384.sbb_e250_in12k_ft_in1k': _cfg(
  2648. hf_hub_id='timm/',
  2649. input_size=(3, 384, 384), crop_pct=1.0),
  2650. 'vit_so150m_patch16_reg4_map_256.untrained': _cfg(
  2651. input_size=(3, 256, 256)),
  2652. 'vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k': _cfg(
  2653. hf_hub_id='timm/',
  2654. input_size=(3, 256, 256), crop_pct=1.0),
  2655. 'vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k': _cfg(
  2656. hf_hub_id='timm/',
  2657. num_classes=11821,
  2658. input_size=(3, 256, 256), crop_pct=1.0),
  2659. 'vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k': _cfg(
  2660. hf_hub_id='timm/',
  2661. input_size=(3, 384, 384), crop_pct=1.0),
  2662. 'vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k': _cfg(
  2663. hf_hub_id='timm/',
  2664. input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash'),
  2665. 'vit_intern300m_patch14_448.ogvl_dist': _cfg(
  2666. hf_hub_id='timm/',
  2667. license='mit',
  2668. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  2669. input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
  2670. ),
  2671. 'vit_intern300m_patch14_448.ogvl_2pt5': _cfg(
  2672. hf_hub_id='timm/',
  2673. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  2674. input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
  2675. ),
  2676. 'aimv2_large_patch14_224.apple_pt': _cfg(
  2677. hf_hub_id='timm/',
  2678. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2679. crop_pct=1.0, num_classes=0),
  2680. 'aimv2_large_patch14_224.apple_pt_dist': _cfg(
  2681. hf_hub_id='timm/',
  2682. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2683. crop_pct=1.0, num_classes=0),
  2684. 'aimv2_huge_patch14_224.apple_pt': _cfg(
  2685. hf_hub_id='timm/',
  2686. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2687. crop_pct=1.0, num_classes=0),
  2688. 'aimv2_1b_patch14_224.apple_pt': _cfg(
  2689. hf_hub_id='timm/',
  2690. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2691. crop_pct=1.0, num_classes=0),
  2692. 'aimv2_3b_patch14_224.apple_pt': _cfg(
  2693. hf_hub_id='timm/',
  2694. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2695. crop_pct=1.0, num_classes=0),
  2696. 'aimv2_large_patch14_336.apple_pt': _cfg(
  2697. hf_hub_id='timm/',
  2698. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2699. input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
  2700. 'aimv2_large_patch14_336.apple_pt_dist': _cfg(
  2701. hf_hub_id='timm/',
  2702. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2703. input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
  2704. 'aimv2_huge_patch14_336.apple_pt': _cfg(
  2705. hf_hub_id='timm/',
  2706. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2707. input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
  2708. 'aimv2_1b_patch14_336.apple_pt': _cfg(
  2709. hf_hub_id='timm/',
  2710. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2711. input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
  2712. 'aimv2_3b_patch14_336.apple_pt': _cfg(
  2713. hf_hub_id='timm/',
  2714. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2715. input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
  2716. 'aimv2_large_patch14_448.apple_pt': _cfg(
  2717. hf_hub_id='timm/',
  2718. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2719. input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
  2720. 'aimv2_huge_patch14_448.apple_pt': _cfg(
  2721. hf_hub_id='timm/',
  2722. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2723. input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
  2724. 'aimv2_1b_patch14_448.apple_pt': _cfg(
  2725. hf_hub_id='timm/',
  2726. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2727. input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
  2728. 'aimv2_3b_patch14_448.apple_pt': _cfg(
  2729. hf_hub_id='timm/',
  2730. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2731. input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
  2732. 'test_vit.r160_in1k': _cfg(
  2733. hf_hub_id='timm/',
  2734. input_size=(3, 160, 160), crop_pct=0.95),
  2735. 'test_vit2.r160_in1k': _cfg(
  2736. hf_hub_id='timm/',
  2737. input_size=(3, 160, 160), crop_pct=0.95),
  2738. 'test_vit3.r160_in1k': _cfg(
  2739. hf_hub_id='timm/',
  2740. input_size=(3, 160, 160), crop_pct=0.95),
  2741. 'test_vit4.r160_in1k': _cfg(
  2742. input_size=(3, 160, 160), crop_pct=0.95),
  2743. # BEiT3 models (remapped to VisionTransformer with scale_attn_norm=True, scale_mlp_norm=True)
  2744. 'beit3_base_patch16_224.in22k_ft_in1k': _cfg(
  2745. hf_hub_id='timm/',
  2746. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
  2747. 'beit3_base_patch16_224.indomain_in22k_ft_in1k': _cfg(
  2748. hf_hub_id='timm/',
  2749. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
  2750. 'beit3_large_patch16_224.in22k_ft_in1k': _cfg(
  2751. hf_hub_id='timm/',
  2752. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
  2753. 'beit3_large_patch16_224.indomain_in22k_ft_in1k': _cfg(
  2754. hf_hub_id='timm/',
  2755. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
  2756. 'beit3_giant_patch14_224.untrained': _cfg(
  2757. url='', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
  2758. 'beit3_giant_patch14_336.untrained': _cfg(
  2759. url='', input_size=(3, 336, 336), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
  2760. 'beit3_base_patch16_224.pt': _cfg(
  2761. hf_hub_id='timm/',
  2762. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0,
  2763. num_classes=0,
  2764. ),
  2765. 'beit3_base_patch16_224.indomain_pt': _cfg(
  2766. hf_hub_id='timm/',
  2767. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0,
  2768. num_classes=0,
  2769. ),
  2770. 'beit3_large_patch16_224.pt': _cfg(
  2771. hf_hub_id='timm/',
  2772. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0,
  2773. num_classes=0,
  2774. ),
  2775. 'beit3_large_patch16_224.indomain_pt': _cfg(
  2776. hf_hub_id='timm/',
  2777. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0,
  2778. num_classes=0,
  2779. ),
  2780. }
  2781. _quick_gelu_cfgs = [n for n, c in default_cfgs.items() if c.get('notes', ()) and 'quickgelu' in c['notes'][0]]
  2782. for n in _quick_gelu_cfgs:
  2783. # generate quickgelu default cfgs based on contents of notes field
  2784. c = copy.deepcopy(default_cfgs[n])
  2785. if c['hf_hub_id'] == 'timm/':
  2786. c['hf_hub_id'] = 'timm/' + n # need to use non-quickgelu model name for hub id
  2787. default_cfgs[n.replace('_clip_', '_clip_quickgelu_')] = c
  2788. default_cfgs = generate_default_cfgs(default_cfgs)
  2789. # Global flag to use NaFlexVit instead of VisionTransformer
  2790. _USE_NAFLEX_DEFAULT = os.environ.get('TIMM_USE_NAFLEXVIT', 'false').lower() == 'true'
  2791. def _create_vision_transformer(
  2792. variant: str,
  2793. pretrained: bool = False,
  2794. use_naflex: Optional[bool] = None,
  2795. **kwargs,
  2796. ) -> Union[VisionTransformer, 'NaFlexVit']:
  2797. # Check if we should use NaFlexVit instead
  2798. if use_naflex is None:
  2799. use_naflex = _USE_NAFLEX_DEFAULT
  2800. if use_naflex:
  2801. # Import here to avoid circular imports
  2802. from .naflexvit import _create_naflexvit_from_classic
  2803. return _create_naflexvit_from_classic(variant, pretrained, **kwargs)
  2804. out_indices = kwargs.pop('out_indices', 3)
  2805. if 'flexi' in variant:
  2806. # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
  2807. # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
  2808. _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
  2809. else:
  2810. _filter_fn = checkpoint_filter_fn
  2811. # FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln?
  2812. strict = kwargs.pop('pretrained_strict', True)
  2813. if 'siglip' in variant and kwargs.get('global_pool', None) != 'map':
  2814. strict = False
  2815. return build_model_with_cfg(
  2816. VisionTransformer,
  2817. variant,
  2818. pretrained,
  2819. pretrained_filter_fn=_filter_fn,
  2820. pretrained_strict=strict,
  2821. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  2822. **kwargs,
  2823. )
  2824. @register_model
  2825. def vit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2826. """ ViT-Tiny (Vit-Ti/16)
  2827. """
  2828. model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
  2829. model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2830. return model
  2831. @register_model
  2832. def vit_tiny_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2833. """ ViT-Tiny (Vit-Ti/16) @ 384x384.
  2834. """
  2835. model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
  2836. model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2837. return model
  2838. @register_model
  2839. def vit_small_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2840. """ ViT-Small (ViT-S/32)
  2841. """
  2842. model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
  2843. model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2844. return model
  2845. @register_model
  2846. def vit_small_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2847. """ ViT-Small (ViT-S/32) at 384x384.
  2848. """
  2849. model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
  2850. model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2851. return model
  2852. @register_model
  2853. def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2854. """ ViT-Small (ViT-S/16)
  2855. """
  2856. model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
  2857. model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2858. return model
  2859. @register_model
  2860. def vit_small_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2861. """ ViT-Small (ViT-S/16)
  2862. """
  2863. model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
  2864. model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2865. return model
  2866. @register_model
  2867. def vit_small_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2868. """ ViT-Small (ViT-S/8)
  2869. """
  2870. model_args = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6)
  2871. model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2872. return model
  2873. @register_model
  2874. def vit_base_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2875. """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
  2876. ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
  2877. """
  2878. model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
  2879. model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2880. return model
  2881. @register_model
  2882. def vit_base_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2883. """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
  2884. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
  2885. """
  2886. model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
  2887. model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2888. return model
  2889. @register_model
  2890. def vit_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2891. """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
  2892. ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
  2893. """
  2894. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
  2895. model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2896. return model
  2897. @register_model
  2898. def vit_base_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2899. """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
  2900. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
  2901. """
  2902. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
  2903. model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2904. return model
  2905. @register_model
  2906. def vit_base_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2907. """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
  2908. ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
  2909. """
  2910. model_args = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12)
  2911. model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2912. return model
  2913. @register_model
  2914. def vit_large_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2915. """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
  2916. """
  2917. model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
  2918. model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2919. return model
  2920. @register_model
  2921. def vit_large_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2922. """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
  2923. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
  2924. """
  2925. model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
  2926. model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2927. return model
  2928. @register_model
  2929. def vit_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2930. """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
  2931. ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
  2932. """
  2933. model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
  2934. model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2935. return model
  2936. @register_model
  2937. def vit_large_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2938. """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
  2939. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
  2940. """
  2941. model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
  2942. model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2943. return model
  2944. @register_model
  2945. def vit_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2946. """ ViT-Large model (ViT-L/14)
  2947. """
  2948. model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16)
  2949. model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2950. return model
  2951. @register_model
  2952. def vit_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2953. """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
  2954. """
  2955. model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16)
  2956. model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2957. return model
  2958. @register_model
  2959. def vit_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2960. """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
  2961. """
  2962. model_args = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16)
  2963. model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2964. return model
  2965. @register_model
  2966. def vit_gigantic_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2967. """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
  2968. """
  2969. model_args = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
  2970. model = _create_vision_transformer(
  2971. 'vit_gigantic_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2972. return model
  2973. @register_model
  2974. def vit_base_patch16_224_miil(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2975. """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
  2976. Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
  2977. """
  2978. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False)
  2979. model = _create_vision_transformer(
  2980. 'vit_base_patch16_224_miil', pretrained=pretrained, **dict(model_args, **kwargs))
  2981. return model
  2982. @register_model
  2983. def vit_medium_patch16_gap_240(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2984. """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240
  2985. """
  2986. model_args = dict(
  2987. patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
  2988. global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
  2989. model = _create_vision_transformer(
  2990. 'vit_medium_patch16_gap_240', pretrained=pretrained, **dict(model_args, **kwargs))
  2991. return model
  2992. @register_model
  2993. def vit_medium_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2994. """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256
  2995. """
  2996. model_args = dict(
  2997. patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
  2998. global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
  2999. model = _create_vision_transformer(
  3000. 'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3001. return model
  3002. @register_model
  3003. def vit_medium_patch16_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3004. """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384
  3005. """
  3006. model_args = dict(
  3007. patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
  3008. global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
  3009. model = _create_vision_transformer(
  3010. 'vit_medium_patch16_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3011. return model
  3012. @register_model
  3013. def vit_betwixt_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3014. """ ViT-Betwixt (ViT-b/16) w/o class token, w/ avg-pool @ 256x256
  3015. """
  3016. model_args = dict(
  3017. patch_size=16, embed_dim=640, depth=12, num_heads=10, class_token=False,
  3018. global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
  3019. model = _create_vision_transformer(
  3020. 'vit_betwixt_patch16_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3021. return model
  3022. @register_model
  3023. def vit_base_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3024. """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224
  3025. """
  3026. model_args = dict(
  3027. patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
  3028. model = _create_vision_transformer(
  3029. 'vit_base_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3030. return model
  3031. @register_model
  3032. def vit_huge_patch14_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3033. """ ViT-Huge model (ViT-H/14) w/ no class token, avg pool
  3034. """
  3035. model_args = dict(
  3036. patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
  3037. model = _create_vision_transformer(
  3038. 'vit_huge_patch14_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3039. return model
  3040. @register_model
  3041. def vit_huge_patch16_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3042. """ ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448
  3043. """
  3044. model_args = dict(
  3045. patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
  3046. model = _create_vision_transformer(
  3047. 'vit_huge_patch16_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
  3048. return model
  3049. @register_model
  3050. def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3051. """ ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool
  3052. """
  3053. model_args = dict(
  3054. patch_size=16, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
  3055. class_token=False, global_pool='avg', fc_norm=False)
  3056. model = _create_vision_transformer(
  3057. 'vit_giant_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3058. return model
  3059. @register_model
  3060. def vit_xsmall_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3061. # TinyCLIP 8M
  3062. model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3063. model = _create_vision_transformer(
  3064. 'vit_xsmall_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3065. return model
  3066. @register_model
  3067. def vit_medium_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3068. # TinyCLIP 40M
  3069. model_args = dict(
  3070. patch_size=32, embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3071. model = _create_vision_transformer(
  3072. 'vit_medium_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3073. return model
  3074. @register_model
  3075. def vit_medium_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3076. # TinyCLIP 39M
  3077. model_args = dict(embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3078. model = _create_vision_transformer(
  3079. 'vit_medium_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3080. return model
  3081. @register_model
  3082. def vit_betwixt_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3083. # TinyCLIP 61M
  3084. model_args = dict(
  3085. patch_size=32, embed_dim=640, depth=12, num_heads=10, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3086. model = _create_vision_transformer(
  3087. 'vit_betwixt_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3088. return model
  3089. @register_model
  3090. def vit_base_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3091. """ ViT-B/32 CLIP image tower @ 224x224
  3092. """
  3093. model_args = dict(
  3094. patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3095. model = _create_vision_transformer(
  3096. 'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3097. return model
  3098. @register_model
  3099. def vit_base_patch32_clip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3100. """ ViT-B/32 CLIP image tower @ 256x256
  3101. """
  3102. model_args = dict(
  3103. patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3104. model = _create_vision_transformer(
  3105. 'vit_base_patch32_clip_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3106. return model
  3107. @register_model
  3108. def vit_base_patch32_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3109. """ ViT-B/32 CLIP image tower @ 384x384
  3110. """
  3111. model_args = dict(
  3112. patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3113. model = _create_vision_transformer(
  3114. 'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3115. return model
  3116. @register_model
  3117. def vit_base_patch32_clip_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3118. """ ViT-B/32 CLIP image tower @ 448x448
  3119. """
  3120. model_args = dict(
  3121. patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3122. model = _create_vision_transformer(
  3123. 'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_args, **kwargs))
  3124. return model
  3125. @register_model
  3126. def vit_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3127. """ ViT-B/16 CLIP image tower
  3128. """
  3129. model_args = dict(
  3130. patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3131. model = _create_vision_transformer(
  3132. 'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3133. return model
  3134. @register_model
  3135. def vit_base_patch16_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3136. """ ViT-B/16 CLIP image tower @ 384x384
  3137. """
  3138. model_args = dict(
  3139. patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3140. model = _create_vision_transformer(
  3141. 'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3142. return model
  3143. @register_model
  3144. def vit_base_patch16_plus_clip_240(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3145. """ ViT-Base (ViT-B/16+) CLIP image tower @ 240x240
  3146. """
  3147. model_args = dict(
  3148. patch_size=16, embed_dim=896, depth=12, num_heads=14, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3149. model = _create_vision_transformer(
  3150. 'vit_base_patch16_plus_clip_240', pretrained=pretrained, **dict(model_args, **kwargs))
  3151. return model
  3152. @register_model
  3153. def vit_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3154. """ ViT-Large model (ViT-L/14) CLIP image tower
  3155. """
  3156. model_args = dict(
  3157. patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3158. model = _create_vision_transformer(
  3159. 'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3160. return model
  3161. @register_model
  3162. def vit_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3163. """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336
  3164. """
  3165. model_args = dict(
  3166. patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3167. model = _create_vision_transformer(
  3168. 'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
  3169. return model
  3170. @register_model
  3171. def vit_huge_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3172. """ ViT-Huge model (ViT-H/14) CLIP image tower.
  3173. """
  3174. model_args = dict(
  3175. patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3176. model = _create_vision_transformer(
  3177. 'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3178. return model
  3179. @register_model
  3180. def vit_huge_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3181. """ ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336
  3182. """
  3183. model_args = dict(
  3184. patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3185. model = _create_vision_transformer(
  3186. 'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
  3187. return model
  3188. @register_model
  3189. def vit_huge_patch14_clip_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3190. """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378
  3191. """
  3192. model_args = dict(
  3193. patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  3194. model = _create_vision_transformer(
  3195. 'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs))
  3196. return model
  3197. @register_model
  3198. def vit_giant_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3199. """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
  3200. Pretrained weights from CLIP image tower.
  3201. """
  3202. model_args = dict(
  3203. patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True,
  3204. norm_layer=partial(LayerNorm, eps=1e-5),
  3205. )
  3206. model = _create_vision_transformer(
  3207. 'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3208. return model
  3209. @register_model
  3210. def vit_gigantic_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3211. """ ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
  3212. Pretrained weights from CLIP image tower.
  3213. """
  3214. model_args = dict(
  3215. patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True,
  3216. norm_layer=partial(LayerNorm, eps=1e-5),
  3217. )
  3218. model = _create_vision_transformer(
  3219. 'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3220. return model
  3221. @register_model
  3222. def vit_gigantic_patch14_clip_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3223. """ ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
  3224. Pretrained weights from CLIP image tower.
  3225. """
  3226. model_args = dict(
  3227. patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True,
  3228. norm_layer=partial(LayerNorm, eps=1e-5),
  3229. )
  3230. model = _create_vision_transformer(
  3231. 'vit_gigantic_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs))
  3232. return model
  3233. @register_model
  3234. def vit_base_patch32_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3235. """ ViT-B/32 CLIP image tower @ 224x224
  3236. """
  3237. model_args = dict(
  3238. patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
  3239. norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
  3240. )
  3241. model = _create_vision_transformer(
  3242. 'vit_base_patch32_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3243. return model
  3244. @register_model
  3245. def vit_base_patch16_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3246. """ ViT-B/16 CLIP image tower w/ QuickGELU act
  3247. """
  3248. model_args = dict(
  3249. patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
  3250. norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
  3251. )
  3252. model = _create_vision_transformer(
  3253. 'vit_base_patch16_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3254. return model
  3255. @register_model
  3256. def vit_large_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3257. """ ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act
  3258. """
  3259. model_args = dict(
  3260. patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
  3261. norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
  3262. )
  3263. model = _create_vision_transformer(
  3264. 'vit_large_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3265. return model
  3266. @register_model
  3267. def vit_large_patch14_clip_quickgelu_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3268. """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 w/ QuickGELU act
  3269. """
  3270. model_args = dict(
  3271. patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
  3272. norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
  3273. )
  3274. model = _create_vision_transformer(
  3275. 'vit_large_patch14_clip_quickgelu_336', pretrained=pretrained, **dict(model_args, **kwargs))
  3276. return model
  3277. @register_model
  3278. def vit_huge_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3279. """ ViT-Huge model (ViT-H/14) CLIP image tower w/ QuickGELU act.
  3280. """
  3281. model_args = dict(
  3282. patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
  3283. norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
  3284. )
  3285. model = _create_vision_transformer(
  3286. 'vit_huge_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3287. return model
  3288. @register_model
  3289. def vit_huge_patch14_clip_quickgelu_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3290. """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 w/ QuickGELU act
  3291. """
  3292. model_args = dict(
  3293. patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
  3294. norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
  3295. )
  3296. model = _create_vision_transformer(
  3297. 'vit_huge_patch14_clip_quickgelu_378', pretrained=pretrained, **dict(model_args, **kwargs))
  3298. return model
  3299. @register_model
  3300. def vit_gigantic_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3301. """ ViT-bigG model (ViT-G/14) w/ QuickGELU act
  3302. """
  3303. model_args = dict(
  3304. patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True,
  3305. norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
  3306. )
  3307. model = _create_vision_transformer(
  3308. 'vit_gigantic_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3309. return model
  3310. # Experimental models below
  3311. @register_model
  3312. def vit_base_patch32_plus_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3313. """ ViT-Base (ViT-B/32+)
  3314. """
  3315. model_args = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
  3316. model = _create_vision_transformer(
  3317. 'vit_base_patch32_plus_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3318. return model
  3319. @register_model
  3320. def vit_base_patch16_plus_240(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3321. """ ViT-Base (ViT-B/16+)
  3322. """
  3323. model_args = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
  3324. model = _create_vision_transformer(
  3325. 'vit_base_patch16_plus_240', pretrained=pretrained, **dict(model_args, **kwargs))
  3326. return model
  3327. @register_model
  3328. def vit_base_patch16_rpn_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3329. """ ViT-Base (ViT-B/16) w/ residual post-norm
  3330. """
  3331. model_args = dict(
  3332. patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
  3333. class_token=False, block_fn=ResPostBlock, global_pool='avg')
  3334. model = _create_vision_transformer(
  3335. 'vit_base_patch16_rpn_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3336. return model
  3337. @register_model
  3338. def vit_small_patch16_36x1_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3339. """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove.
  3340. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
  3341. Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
  3342. """
  3343. model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5)
  3344. model = _create_vision_transformer(
  3345. 'vit_small_patch16_36x1_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3346. return model
  3347. @register_model
  3348. def vit_small_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3349. """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
  3350. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
  3351. Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
  3352. """
  3353. model_args = dict(
  3354. patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelThingsBlock)
  3355. model = _create_vision_transformer(
  3356. 'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3357. return model
  3358. @register_model
  3359. def vit_base_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3360. """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
  3361. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
  3362. """
  3363. model_args = dict(
  3364. patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelThingsBlock)
  3365. model = _create_vision_transformer(
  3366. 'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3367. return model
  3368. @register_model
  3369. def eva_large_patch14_196(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3370. """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain"""
  3371. model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
  3372. model = _create_vision_transformer(
  3373. 'eva_large_patch14_196', pretrained=pretrained, **dict(model_args, **kwargs))
  3374. return model
  3375. @register_model
  3376. def eva_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3377. """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain"""
  3378. model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
  3379. model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  3380. return model
  3381. @register_model
  3382. def flexivit_small(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3383. """ FlexiViT-Small
  3384. """
  3385. model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True)
  3386. model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **dict(model_args, **kwargs))
  3387. return model
  3388. @register_model
  3389. def flexivit_base(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3390. """ FlexiViT-Base
  3391. """
  3392. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
  3393. model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **dict(model_args, **kwargs))
  3394. return model
  3395. @register_model
  3396. def flexivit_large(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3397. """ FlexiViT-Large
  3398. """
  3399. model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
  3400. model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_args, **kwargs))
  3401. return model
  3402. @register_model
  3403. def vit_base_patch16_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3404. """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
  3405. """
  3406. model_args = dict(
  3407. patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, no_embed_class=True,
  3408. norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
  3409. )
  3410. model = _create_vision_transformer(
  3411. 'vit_base_patch16_xp_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3412. return model
  3413. @register_model
  3414. def vit_large_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3415. """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
  3416. """
  3417. model_args = dict(
  3418. patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, no_embed_class=True,
  3419. norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
  3420. )
  3421. model = _create_vision_transformer(
  3422. 'vit_large_patch14_xp_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3423. return model
  3424. @register_model
  3425. def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3426. """ ViT-Huge model (ViT-H/14) w/ parallel blocks and qk norm enabled.
  3427. """
  3428. model_args = dict(
  3429. patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, no_embed_class=True,
  3430. norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
  3431. )
  3432. model = _create_vision_transformer(
  3433. 'vit_huge_patch14_xp_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3434. return model
  3435. @register_model
  3436. def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3437. """ ViT-S/14 for DINOv2
  3438. """
  3439. model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5)
  3440. model = _create_vision_transformer(
  3441. 'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3442. return model
  3443. @register_model
  3444. def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3445. """ ViT-B/14 for DINOv2
  3446. """
  3447. model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5)
  3448. model = _create_vision_transformer(
  3449. 'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3450. return model
  3451. @register_model
  3452. def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3453. """ ViT-L/14 for DINOv2
  3454. """
  3455. model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5)
  3456. model = _create_vision_transformer(
  3457. 'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3458. return model
  3459. @register_model
  3460. def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3461. """ ViT-G/14 for DINOv2
  3462. """
  3463. # The hidden_features of SwiGLU is calculated by:
  3464. # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
  3465. # When embed_dim=1536, hidden_features=4096
  3466. # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
  3467. model_args = dict(
  3468. patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
  3469. mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, act_layer=nn.SiLU
  3470. )
  3471. model = _create_vision_transformer(
  3472. 'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3473. return model
  3474. @register_model
  3475. def vit_small_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3476. """ ViT-S/14 for DINOv2 w/ 4 registers
  3477. """
  3478. model_args = dict(
  3479. patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5,
  3480. reg_tokens=4, no_embed_class=True,
  3481. )
  3482. model = _create_vision_transformer(
  3483. 'vit_small_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3484. return model
  3485. @register_model
  3486. def vit_base_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3487. """ ViT-B/14 for DINOv2 w/ 4 registers
  3488. """
  3489. model_args = dict(
  3490. patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
  3491. reg_tokens=4, no_embed_class=True,
  3492. )
  3493. model = _create_vision_transformer(
  3494. 'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3495. return model
  3496. @register_model
  3497. def vit_large_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3498. """ ViT-L/14 for DINOv2 w/ 4 registers
  3499. """
  3500. model_args = dict(
  3501. patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
  3502. reg_tokens=4, no_embed_class=True,
  3503. )
  3504. model = _create_vision_transformer(
  3505. 'vit_large_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3506. return model
  3507. @register_model
  3508. def vit_giant_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3509. """ ViT-G/14 for DINOv2
  3510. """
  3511. # The hidden_features of SwiGLU is calculated by:
  3512. # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
  3513. # When embed_dim=1536, hidden_features=4096
  3514. # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
  3515. model_args = dict(
  3516. patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2,
  3517. mlp_layer=SwiGLUPacked, act_layer=nn.SiLU, reg_tokens=4, no_embed_class=True,
  3518. )
  3519. model = _create_vision_transformer(
  3520. 'vit_giant_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3521. return model
  3522. @register_model
  3523. def vit_base_patch32_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3524. model_args = dict(
  3525. patch_size=32, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
  3526. act_layer='gelu_tanh',
  3527. )
  3528. model = _create_vision_transformer(
  3529. 'vit_base_patch32_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3530. return model
  3531. @register_model
  3532. def vit_base_patch16_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3533. model_args = dict(
  3534. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
  3535. )
  3536. model = _create_vision_transformer(
  3537. 'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3538. return model
  3539. @register_model
  3540. def vit_base_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3541. model_args = dict(
  3542. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
  3543. )
  3544. model = _create_vision_transformer(
  3545. 'vit_base_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3546. return model
  3547. @register_model
  3548. def vit_base_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3549. model_args = dict(
  3550. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
  3551. )
  3552. model = _create_vision_transformer(
  3553. 'vit_base_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3554. return model
  3555. @register_model
  3556. def vit_base_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3557. model_args = dict(
  3558. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
  3559. )
  3560. model = _create_vision_transformer(
  3561. 'vit_base_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
  3562. return model
  3563. @register_model
  3564. def vit_large_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3565. model_args = dict(
  3566. patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
  3567. )
  3568. model = _create_vision_transformer(
  3569. 'vit_large_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3570. return model
  3571. @register_model
  3572. def vit_large_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3573. model_args = dict(
  3574. patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
  3575. )
  3576. model = _create_vision_transformer(
  3577. 'vit_large_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3578. return model
  3579. @register_model
  3580. def vit_large_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3581. model_args = dict(
  3582. patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
  3583. act_layer='gelu_tanh'
  3584. )
  3585. model = _create_vision_transformer(
  3586. 'vit_large_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
  3587. return model
  3588. @register_model
  3589. def vit_so400m_patch14_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3590. model_args = dict(
  3591. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
  3592. )
  3593. model = _create_vision_transformer(
  3594. 'vit_so400m_patch14_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3595. return model
  3596. @register_model
  3597. def vit_so400m_patch14_siglip_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3598. # this is a corrected variant of the 384 with a res properly divisible by patch size (no padding/truncation)
  3599. model_args = dict(
  3600. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
  3601. )
  3602. model = _create_vision_transformer(
  3603. 'vit_so400m_patch14_siglip_378', pretrained=pretrained, **dict(model_args, **kwargs))
  3604. return model
  3605. @register_model
  3606. def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3607. model_args = dict(
  3608. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
  3609. )
  3610. model = _create_vision_transformer(
  3611. 'vit_so400m_patch14_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3612. return model
  3613. @register_model
  3614. def vit_so400m_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3615. model_args = dict(
  3616. patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
  3617. act_layer='gelu_tanh',
  3618. )
  3619. model = _create_vision_transformer(
  3620. 'vit_so400m_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3621. return model
  3622. @register_model
  3623. def vit_so400m_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3624. model_args = dict(
  3625. patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
  3626. act_layer='gelu_tanh',
  3627. )
  3628. model = _create_vision_transformer(
  3629. 'vit_so400m_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3630. return model
  3631. @register_model
  3632. def vit_so400m_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3633. model_args = dict(
  3634. patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
  3635. act_layer='gelu_tanh',
  3636. )
  3637. model = _create_vision_transformer(
  3638. 'vit_so400m_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
  3639. return model
  3640. @register_model
  3641. def vit_giantopt_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3642. model_args = dict(
  3643. patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False, global_pool='map',
  3644. act_layer='gelu_tanh',
  3645. )
  3646. model = _create_vision_transformer(
  3647. 'vit_giantopt_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3648. return model
  3649. @register_model
  3650. def vit_giantopt_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3651. model_args = dict(
  3652. patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False, global_pool='map',
  3653. act_layer='gelu_tanh',
  3654. )
  3655. model = _create_vision_transformer(
  3656. 'vit_giantopt_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3657. return model
  3658. @register_model
  3659. def vit_base_patch32_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3660. model_args = dict(
  3661. patch_size=32, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
  3662. act_layer='gelu_tanh',
  3663. )
  3664. model = _create_vision_transformer(
  3665. 'vit_base_patch32_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3666. return model
  3667. @register_model
  3668. def vit_base_patch16_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3669. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3670. model_args = dict(
  3671. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
  3672. )
  3673. model = _create_vision_transformer(
  3674. 'vit_base_patch16_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3675. return model
  3676. @register_model
  3677. def vit_base_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3678. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3679. model_args = dict(
  3680. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
  3681. )
  3682. model = _create_vision_transformer(
  3683. 'vit_base_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3684. return model
  3685. @register_model
  3686. def vit_base_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3687. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3688. model_args = dict(
  3689. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
  3690. )
  3691. model = _create_vision_transformer(
  3692. 'vit_base_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3693. return model
  3694. @register_model
  3695. def vit_base_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3696. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3697. model_args = dict(
  3698. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
  3699. )
  3700. model = _create_vision_transformer(
  3701. 'vit_base_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
  3702. return model
  3703. @register_model
  3704. def vit_large_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3705. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3706. model_args = dict(
  3707. patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False,
  3708. )
  3709. model = _create_vision_transformer(
  3710. 'vit_large_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3711. return model
  3712. @register_model
  3713. def vit_large_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3714. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3715. model_args = dict(
  3716. patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False,
  3717. )
  3718. model = _create_vision_transformer(
  3719. 'vit_large_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3720. return model
  3721. @register_model
  3722. def vit_large_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3723. model_args = dict(
  3724. patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False,
  3725. global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
  3726. )
  3727. model = _create_vision_transformer(
  3728. 'vit_large_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
  3729. return model
  3730. @register_model
  3731. def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3732. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3733. model_args = dict(
  3734. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
  3735. class_token=False, global_pool='avg', fc_norm=False,
  3736. )
  3737. model = _create_vision_transformer(
  3738. 'vit_so400m_patch14_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3739. return model
  3740. @register_model
  3741. def vit_so400m_patch14_siglip_gap_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3742. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3743. model_args = dict(
  3744. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
  3745. class_token=False, global_pool='avg', fc_norm=False,
  3746. )
  3747. model = _create_vision_transformer(
  3748. 'vit_so400m_patch14_siglip_gap_378', pretrained=pretrained, **dict(model_args, **kwargs))
  3749. return model
  3750. @register_model
  3751. def vit_so400m_patch14_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3752. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3753. model_args = dict(
  3754. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
  3755. class_token=False, global_pool='avg', fc_norm=False,
  3756. )
  3757. model = _create_vision_transformer(
  3758. 'vit_so400m_patch14_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3759. return model
  3760. @register_model
  3761. def vit_so400m_patch14_siglip_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3762. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3763. model_args = dict(
  3764. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
  3765. class_token=False, global_pool='avg', fc_norm=False,
  3766. )
  3767. model = _create_vision_transformer(
  3768. 'vit_so400m_patch14_siglip_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
  3769. return model
  3770. @register_model
  3771. def vit_so400m_patch14_siglip_gap_896(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3772. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3773. model_args = dict(
  3774. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
  3775. class_token=False, global_pool='avg', fc_norm=False,
  3776. )
  3777. model = _create_vision_transformer(
  3778. 'vit_so400m_patch14_siglip_gap_896', pretrained=pretrained, **dict(model_args, **kwargs))
  3779. return model
  3780. @register_model
  3781. def vit_so400m_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3782. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3783. model_args = dict(
  3784. patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
  3785. class_token=False, global_pool='avg', fc_norm=False, act_layer='gelu_tanh',
  3786. )
  3787. model = _create_vision_transformer(
  3788. 'vit_so400m_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3789. return model
  3790. @register_model
  3791. def vit_so400m_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3792. model_args = dict(
  3793. patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False,
  3794. global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
  3795. )
  3796. model = _create_vision_transformer(
  3797. 'vit_so400m_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3798. return model
  3799. @register_model
  3800. def vit_so400m_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3801. model_args = dict(
  3802. patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False,
  3803. global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
  3804. )
  3805. model = _create_vision_transformer(
  3806. 'vit_so400m_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
  3807. return model
  3808. @register_model
  3809. def vit_giantopt_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3810. model_args = dict(
  3811. patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False,
  3812. global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
  3813. )
  3814. model = _create_vision_transformer(
  3815. 'vit_giantopt_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3816. return model
  3817. @register_model
  3818. def vit_giantopt_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3819. model_args = dict(
  3820. patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False,
  3821. global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
  3822. )
  3823. model = _create_vision_transformer(
  3824. 'vit_giantopt_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3825. return model
  3826. @register_model
  3827. def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3828. model_args = dict(
  3829. patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
  3830. class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
  3831. )
  3832. model = _create_vision_transformer(
  3833. 'vit_wee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3834. return model
  3835. @register_model
  3836. def vit_dwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3837. model_args = dict(
  3838. patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
  3839. class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', attn_layer='diff',
  3840. )
  3841. model = _create_vision_transformer(
  3842. 'vit_dwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3843. return model
  3844. @register_model
  3845. def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3846. model_args = dict(
  3847. patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5,
  3848. class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock,
  3849. )
  3850. model = _create_vision_transformer(
  3851. 'vit_pwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3852. return model
  3853. @register_model
  3854. def vit_dpwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3855. model_args = dict(
  3856. patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5,
  3857. class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=DiffParallelScalingBlock,
  3858. )
  3859. model = _create_vision_transformer(
  3860. 'vit_dpwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3861. return model
  3862. @register_model
  3863. def vit_little_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3864. model_args = dict(
  3865. patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6,
  3866. class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
  3867. )
  3868. model = _create_vision_transformer(
  3869. 'vit_little_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3870. return model
  3871. @register_model
  3872. def vit_dlittle_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3873. model_args = dict(
  3874. patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6,
  3875. class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', attn_layer='diff',
  3876. )
  3877. model = _create_vision_transformer(
  3878. 'vit_dlittle_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3879. return model
  3880. @register_model
  3881. def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3882. model_args = dict(
  3883. patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6,
  3884. class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
  3885. )
  3886. model = _create_vision_transformer(
  3887. 'vit_little_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3888. return model
  3889. @register_model
  3890. def vit_medium_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3891. model_args = dict(
  3892. patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5,
  3893. class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
  3894. )
  3895. model = _create_vision_transformer(
  3896. 'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3897. return model
  3898. @register_model
  3899. def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3900. model_args = dict(
  3901. patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5,
  3902. class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
  3903. )
  3904. model = _create_vision_transformer(
  3905. 'vit_medium_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3906. return model
  3907. @register_model
  3908. def vit_mediumd_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3909. model_args = dict(
  3910. patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
  3911. class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
  3912. )
  3913. model = _create_vision_transformer(
  3914. 'vit_mediumd_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3915. return model
  3916. @register_model
  3917. def vit_mediumd_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3918. model_args = dict(
  3919. patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
  3920. class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
  3921. )
  3922. model = _create_vision_transformer(
  3923. 'vit_mediumd_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3924. return model
  3925. @register_model
  3926. def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3927. model_args = dict(
  3928. patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5,
  3929. class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
  3930. )
  3931. model = _create_vision_transformer(
  3932. 'vit_betwixt_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3933. return model
  3934. @register_model
  3935. def vit_betwixt_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3936. model_args = dict(
  3937. patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5,
  3938. class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
  3939. )
  3940. model = _create_vision_transformer(
  3941. 'vit_betwixt_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3942. return model
  3943. @register_model
  3944. def vit_betwixt_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3945. model_args = dict(
  3946. patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5,
  3947. class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
  3948. )
  3949. model = _create_vision_transformer(
  3950. 'vit_betwixt_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3951. return model
  3952. @register_model
  3953. def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3954. model_args = dict(
  3955. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False,
  3956. no_embed_class=True, global_pool='avg', reg_tokens=4,
  3957. )
  3958. model = _create_vision_transformer(
  3959. 'vit_base_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3960. return model
  3961. @register_model
  3962. def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3963. """ SO150M (shape optimized, but diff than paper def, optimized for GPU) """
  3964. model_args = dict(
  3965. patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
  3966. class_token=False, reg_tokens=4, global_pool='map',
  3967. )
  3968. model = _create_vision_transformer(
  3969. 'vit_so150m_patch16_reg4_map_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3970. return model
  3971. @register_model
  3972. def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3973. """ SO150M (shape optimized, but diff than paper def, optimized for GPU) """
  3974. model_args = dict(
  3975. patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
  3976. class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
  3977. )
  3978. model = _create_vision_transformer(
  3979. 'vit_so150m_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3980. return model
  3981. @register_model
  3982. def vit_so150m_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3983. """ SO150M (shape optimized, but diff than paper def, optimized for GPU) """
  3984. model_args = dict(
  3985. patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
  3986. class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
  3987. )
  3988. model = _create_vision_transformer(
  3989. 'vit_so150m_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3990. return model
  3991. @register_model
  3992. def vit_so150m2_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3993. """ SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """
  3994. model_args = dict(
  3995. patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
  3996. qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg',
  3997. )
  3998. model = _create_vision_transformer(
  3999. 'vit_so150m2_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  4000. return model
  4001. @register_model
  4002. def vit_so150m2_patch16_reg1_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4003. """ SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """
  4004. model_args = dict(
  4005. patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
  4006. qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg',
  4007. )
  4008. model = _create_vision_transformer(
  4009. 'vit_so150m2_patch16_reg1_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  4010. return model
  4011. @register_model
  4012. def vit_so150m2_patch16_reg1_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4013. """ SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """
  4014. model_args = dict(
  4015. patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
  4016. qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg',
  4017. )
  4018. model = _create_vision_transformer(
  4019. 'vit_so150m2_patch16_reg1_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
  4020. return model
  4021. @register_model
  4022. def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4023. model_args = dict(
  4024. patch_size=14, embed_dim=1024, depth=24, num_heads=16,
  4025. init_values=0.1, final_norm=False, dynamic_img_size=True,
  4026. )
  4027. model = _create_vision_transformer(
  4028. 'vit_intern300m_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  4029. return model
  4030. @register_model
  4031. def aimv2_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4032. """ ViT Large AIM-v2 model
  4033. """
  4034. model_args = dict(
  4035. patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
  4036. mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  4037. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  4038. )
  4039. model = _create_vision_transformer(
  4040. 'aimv2_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  4041. return model
  4042. @register_model
  4043. def aimv2_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4044. """ ViT Huge AIM-v2 model
  4045. """
  4046. model_args = dict(
  4047. patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
  4048. mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  4049. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  4050. )
  4051. model = _create_vision_transformer(
  4052. 'aimv2_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  4053. return model
  4054. @register_model
  4055. def aimv2_1b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4056. """ ViT 1B AIM-v2 model
  4057. """
  4058. model_args = dict(
  4059. patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
  4060. mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  4061. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  4062. )
  4063. model = _create_vision_transformer(
  4064. 'aimv2_1b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  4065. return model
  4066. @register_model
  4067. def aimv2_3b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4068. """ ViT 3B AIM-v2 model
  4069. """
  4070. model_args = dict(
  4071. patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
  4072. mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  4073. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  4074. )
  4075. model = _create_vision_transformer(
  4076. 'aimv2_3b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  4077. return model
  4078. @register_model
  4079. def aimv2_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4080. """ ViT Large AIM-v2 model
  4081. """
  4082. model_args = dict(
  4083. patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
  4084. mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  4085. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  4086. )
  4087. model = _create_vision_transformer(
  4088. 'aimv2_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  4089. return model
  4090. @register_model
  4091. def aimv2_huge_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4092. """ ViT Huge AIM-v2 model
  4093. """
  4094. model_args = dict(
  4095. patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
  4096. mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  4097. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  4098. )
  4099. model = _create_vision_transformer(
  4100. 'aimv2_huge_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  4101. return model
  4102. @register_model
  4103. def aimv2_1b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4104. """ ViT 1B AIM-v2 model
  4105. """
  4106. model_args = dict(
  4107. patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
  4108. mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  4109. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  4110. )
  4111. model = _create_vision_transformer(
  4112. 'aimv2_1b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  4113. return model
  4114. @register_model
  4115. def aimv2_3b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4116. """ ViT 3B AIM-v2 model
  4117. """
  4118. model_args = dict(
  4119. patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
  4120. mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  4121. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  4122. )
  4123. model = _create_vision_transformer(
  4124. 'aimv2_3b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  4125. return model
  4126. @register_model
  4127. def aimv2_large_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4128. """ ViT Large AIM-v2 model
  4129. """
  4130. model_args = dict(
  4131. patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
  4132. mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  4133. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  4134. )
  4135. model = _create_vision_transformer(
  4136. 'aimv2_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  4137. return model
  4138. @register_model
  4139. def aimv2_huge_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4140. """ ViT Huge AIM-v2 model
  4141. """
  4142. model_args = dict(
  4143. patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
  4144. mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  4145. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  4146. )
  4147. model = _create_vision_transformer(
  4148. 'aimv2_huge_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  4149. return model
  4150. @register_model
  4151. def aimv2_1b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4152. """ ViT 1B AIM-v2 model
  4153. """
  4154. model_args = dict(
  4155. patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
  4156. mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  4157. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  4158. )
  4159. model = _create_vision_transformer(
  4160. 'aimv2_1b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  4161. return model
  4162. @register_model
  4163. def aimv2_3b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4164. """ ViT 3B AIM-v2 model
  4165. """
  4166. model_args = dict(
  4167. patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
  4168. mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  4169. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  4170. )
  4171. model = _create_vision_transformer(
  4172. 'aimv2_3b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  4173. return model
  4174. @register_model
  4175. def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4176. """ ViT Test
  4177. """
  4178. model_args = dict(patch_size=16, embed_dim=64, depth=6, num_heads=2, mlp_ratio=3, dynamic_img_size=True)
  4179. model = _create_vision_transformer('test_vit', pretrained=pretrained, **dict(model_args, **kwargs))
  4180. return model
  4181. @register_model
  4182. def test_vit2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4183. """ ViT Test
  4184. """
  4185. model_args = dict(
  4186. patch_size=16, embed_dim=64, depth=8, num_heads=2, mlp_ratio=3,
  4187. class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5, dynamic_img_size=True)
  4188. model = _create_vision_transformer('test_vit2', pretrained=pretrained, **dict(model_args, **kwargs))
  4189. return model
  4190. @register_model
  4191. def test_vit3(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4192. """ ViT Test
  4193. """
  4194. model_args = dict(
  4195. patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=2,
  4196. class_token=False, reg_tokens=1, global_pool='map', pool_include_prefix=True, init_values=1e-5)
  4197. model = _create_vision_transformer('test_vit3', pretrained=pretrained, **dict(model_args, **kwargs))
  4198. return model
  4199. @register_model
  4200. def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4201. """ ViT Test
  4202. """
  4203. model_args = dict(
  4204. patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=3,
  4205. class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5, dynamic_img_size=True,
  4206. norm_layer='rmsnorm',
  4207. )
  4208. model = _create_vision_transformer('test_vit4', pretrained=pretrained, **dict(model_args, **kwargs))
  4209. return model
  4210. @register_model
  4211. def beit3_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4212. """ BEiT3 Base model (ViT-Base size) with patch size 16x16.
  4213. Remapped to VisionTransformer with scale_norm=True.
  4214. """
  4215. model_args = dict(
  4216. patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
  4217. scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg',
  4218. norm_layer=partial(LayerNorm, eps=1e-5)
  4219. )
  4220. model = _create_vision_transformer('beit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  4221. return model
  4222. @register_model
  4223. def beit3_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4224. """ BEiT3 Large model (ViT-Large size) with patch size 16x16.
  4225. Remapped to VisionTransformer with scale_norm=True.
  4226. """
  4227. model_args = dict(
  4228. patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
  4229. scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg',
  4230. norm_layer=partial(LayerNorm, eps=1e-5),
  4231. )
  4232. model = _create_vision_transformer('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  4233. return model
  4234. @register_model
  4235. def beit3_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4236. """ BEiT3 Giant model with patch size 14x14.
  4237. Remapped to VisionTransformer with scale_norm=True.
  4238. """
  4239. model_args = dict(
  4240. patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637,
  4241. scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg',
  4242. norm_layer=partial(LayerNorm, eps=1e-5),
  4243. )
  4244. model = _create_vision_transformer('beit3_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  4245. return model
  4246. @register_model
  4247. def beit3_giant_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  4248. """ BEiT3 Giant model with patch size 14x14 and image size 336x336.
  4249. Remapped to VisionTransformer with scale_norm=True.
  4250. """
  4251. model_args = dict(
  4252. img_size=336, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637,
  4253. scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg',
  4254. norm_layer=partial(LayerNorm, eps=1e-5),
  4255. )
  4256. model = _create_vision_transformer('beit3_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  4257. return model
  4258. register_model_deprecations(__name__, {
  4259. 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
  4260. 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',
  4261. 'vit_small_patch16_224_in21k': 'vit_small_patch16_224.augreg_in21k',
  4262. 'vit_base_patch32_224_in21k': 'vit_base_patch32_224.augreg_in21k',
  4263. 'vit_base_patch16_224_in21k': 'vit_base_patch16_224.augreg_in21k',
  4264. 'vit_base_patch8_224_in21k': 'vit_base_patch8_224.augreg_in21k',
  4265. 'vit_large_patch32_224_in21k': 'vit_large_patch32_224.orig_in21k',
  4266. 'vit_large_patch16_224_in21k': 'vit_large_patch16_224.augreg_in21k',
  4267. 'vit_huge_patch14_224_in21k': 'vit_huge_patch14_224.orig_in21k',
  4268. 'vit_base_patch32_224_sam': 'vit_base_patch32_224.sam',
  4269. 'vit_base_patch16_224_sam': 'vit_base_patch16_224.sam',
  4270. 'vit_small_patch16_224_dino': 'vit_small_patch16_224.dino',
  4271. 'vit_small_patch8_224_dino': 'vit_small_patch8_224.dino',
  4272. 'vit_base_patch16_224_dino': 'vit_base_patch16_224.dino',
  4273. 'vit_base_patch8_224_dino': 'vit_base_patch8_224.dino',
  4274. 'vit_base_patch16_224_miil_in21k': 'vit_base_patch16_224_miil.in21k',
  4275. 'vit_base_patch32_224_clip_laion2b': 'vit_base_patch32_clip_224.laion2b',
  4276. 'vit_large_patch14_224_clip_laion2b': 'vit_large_patch14_clip_224.laion2b',
  4277. 'vit_huge_patch14_224_clip_laion2b': 'vit_huge_patch14_clip_224.laion2b',
  4278. 'vit_giant_patch14_224_clip_laion2b': 'vit_giant_patch14_clip_224.laion2b',
  4279. })