eva.py 111 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026
  1. """ EVA
  2. EVA ViT from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636
  3. This file contains a number of ViT variants the utilise ROPE position embeddings, SwiGLU and other additions:
  4. * EVA & EVA02 model implementations that evolved from BEiT, additional models in vision_transformer.py.
  5. * `timm` original SBB ViT w/ ROPE position embeddings
  6. * Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)
  7. * ROPE-ViT from Naver AI (https://arxiv.org/abs/2403.13298)
  8. * DINOv3 from META AI Research (https://arxiv.org/abs/2508.10104)
  9. @article{EVA,
  10. title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale},
  11. author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang,
  12. Tiejun and Wang, Xinlong and Cao, Yue},
  13. journal={arXiv preprint arXiv:2211.07636},
  14. year={2022}
  15. }
  16. EVA-02: A Visual Representation for Neon Genesis - https://arxiv.org/abs/2303.11331
  17. @article{EVA02,
  18. title={EVA-02: A Visual Representation for Neon Genesis},
  19. author={Fang, Yuxin and Sun, Quan and Wang, Xinggang and Huang, Tiejun and Wang, Xinlong and Cao, Yue},
  20. journal={arXiv preprint arXiv:2303.11331},
  21. year={2023}
  22. }
  23. @article{bolya2025perception,
  24. title={Perception encoder: The best visual embeddings are not at the output of the network},
  25. author={Bolya, Daniel and Huang, Po-Yao and Sun, Peize and Cho, Jang Hyun and Madotto, Andrea and Wei, Chen and Ma,
  26. Tengyu and Zhi, Jiale and Rajasegaran, Jathushan and Rasheed, Hanoona and others},
  27. journal={arXiv preprint arXiv:2504.13181},
  28. year={2025}
  29. }
  30. @inproceedings{heo2024rotary,
  31. title={Rotary position embedding for vision transformer},
  32. author={Heo, Byeongho and Park, Song and Han, Dongyoon and Yun, Sangdoo},
  33. booktitle={European Conference on Computer Vision},
  34. pages={289--305},
  35. year={2024},
  36. organization={Springer}
  37. }
  38. @article{simeoni2025dinov3,
  39. title={{DINOv3}},
  40. author={Sim{\'e}oni, Oriane and Vo, Huy V. and Seitzer, Maximilian and Baldassarre, Federico and Oquab, Maxime
  41. and Jose, Cijo and Khalidov, Vasil and Szafraniec, Marc and Yi, Seungeun and Ramamonjisoa, Micha{\"e}l
  42. and Massa, Francisco and Haziza, Daniel and Wehrstedt, Luca and Wang, Jianyuan and Darcet, Timoth{\'e}e
  43. and Moutakanni, Th{\'e}o and Sentana, Leonel and Roberts, Claire and Vedaldi, Andrea and Tolan, Jamie
  44. and Brandt, John and Couprie, Camille and Mairal, Julien and J{\'e}gou, Herv{\'e} and Labatut, Patrick
  45. and Bojanowski, Piotr},
  46. year={2025},
  47. eprint={2508.10104},
  48. url={https://arxiv.org/abs/2508.10104},
  49. }
  50. DINOv3 code was a modification of existing EVA model and support modules, so licensed under Apache-2.0 like timm.
  51. Weights from META remain under DINOv3 License (https://ai.meta.com/resources/models-and-libraries/dinov3-license/).
  52. Modifications by / Copyright 2023 Ross Wightman, original copyrights below
  53. """
  54. # EVA models Copyright (c) 2022 BAAI-Vision
  55. # EVA02 models Copyright (c) 2023 BAAI-Vision
  56. import math
  57. import os
  58. from functools import partial
  59. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
  60. import torch
  61. import torch.nn as nn
  62. import torch.nn.functional as F
  63. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
  64. from timm.layers import (
  65. PatchEmbed,
  66. Mlp,
  67. GluMlp,
  68. SwiGLU,
  69. LayerNorm,
  70. DropPath, calculate_drop_path_rates,
  71. PatchDropoutWithIndices,
  72. create_rope_embed,
  73. apply_rot_embed_cat,
  74. apply_keep_indices_nlc,
  75. trunc_normal_,
  76. resample_patch_embed,
  77. resample_abs_pos_embed,
  78. global_pool_nlc,
  79. to_2tuple,
  80. use_fused_attn,
  81. maybe_add_mask,
  82. resolve_self_attn_mask,
  83. AttentionRope,
  84. AttentionPoolLatent,
  85. )
  86. from ._builder import build_model_with_cfg
  87. from ._features import feature_take_indices
  88. from ._manipulate import checkpoint
  89. from ._registry import generate_default_cfgs, register_model
  90. __all__ = ['Eva']
  91. class EvaAttention(nn.Module):
  92. """ EVA Attention with ROPE, no k-bias, and fused/unfused qkv options
  93. """
  94. fused_attn: torch.jit.Final[bool]
  95. def __init__(
  96. self,
  97. dim: int,
  98. num_heads: int = 8,
  99. qkv_bias: bool = True,
  100. qkv_fused: bool = True,
  101. qkv_bias_separate: bool = False,
  102. num_prefix_tokens: int = 1,
  103. attn_drop: float = 0.,
  104. proj_drop: float = 0.,
  105. attn_head_dim: Optional[int] = None,
  106. norm_layer: Optional[Callable] = None,
  107. qk_norm: bool = False,
  108. scale_norm: bool = True,
  109. rotate_half: bool = False,
  110. device=None,
  111. dtype=None,
  112. ):
  113. """
  114. Args:
  115. dim: Input dimension of the token embeddings
  116. num_heads: Number of attention heads
  117. qkv_bias: Whether to add a bias term to the query, key, and value projections
  118. qkv_fused: Whether qkv projections are fused into one projection or separate
  119. qkv_bias_separate: Whether to apply bias to qkv as a separate addition or part of F.linear() call
  120. num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
  121. should not have position embeddings applied
  122. attn_drop: Dropout rate for attention weights
  123. proj_drop: Dropout rate for the output projection
  124. attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
  125. norm_layer: Normalization layer constructor to use for QK and scale normalization
  126. qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
  127. scale_norm: Enable normalization (scaling) of attention output with norm_layer
  128. rotate_half: Use half rotation layout instead of interleaved
  129. """
  130. dd = {'device': device, 'dtype': dtype}
  131. super().__init__()
  132. if scale_norm or qk_norm:
  133. assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
  134. self.num_heads = num_heads
  135. self.head_dim = dim // num_heads
  136. if attn_head_dim is not None:
  137. self.head_dim = attn_head_dim
  138. attn_dim = self.head_dim * self.num_heads
  139. self.scale = self.head_dim ** -0.5
  140. self.num_prefix_tokens = num_prefix_tokens
  141. self.fused_attn = use_fused_attn()
  142. self.qkv_bias_separate = qkv_bias_separate
  143. self.rotate_half = rotate_half
  144. if qkv_fused:
  145. self.qkv = nn.Linear(dim, attn_dim * 3, bias=False, **dd)
  146. self.q_proj = self.k_proj = self.v_proj = None
  147. if qkv_bias:
  148. self.q_bias = nn.Parameter(torch.empty(attn_dim, **dd))
  149. self.register_buffer('k_bias', torch.empty(attn_dim, **dd), persistent=False)
  150. self.v_bias = nn.Parameter(torch.empty(attn_dim, **dd))
  151. else:
  152. self.q_bias = self.k_bias = self.v_bias = None
  153. else:
  154. self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
  155. self.k_proj = nn.Linear(dim, attn_dim, bias=False, **dd)
  156. self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
  157. self.qkv = None
  158. self.q_bias = self.k_bias = self.v_bias = None
  159. self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  160. self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  161. self.attn_drop = nn.Dropout(attn_drop)
  162. self.norm = norm_layer(attn_dim, **dd) if scale_norm else nn.Identity()
  163. self.proj = nn.Linear(attn_dim, dim, **dd)
  164. self.proj_drop = nn.Dropout(proj_drop)
  165. # TODO: skip init when on meta device when safe to do so
  166. self.reset_parameters()
  167. def reset_parameters(self) -> None:
  168. """Initialize parameters and buffers."""
  169. if self.q_bias is not None:
  170. nn.init.zeros_(self.q_bias)
  171. nn.init.zeros_(self.v_bias)
  172. self._init_buffers()
  173. def _init_buffers(self) -> None:
  174. """Compute and fill non-persistent buffer values."""
  175. if self.k_bias is not None:
  176. self.k_bias.zero_()
  177. def forward(
  178. self,
  179. x,
  180. rope: Optional[torch.Tensor] = None,
  181. attn_mask: Optional[torch.Tensor] = None,
  182. is_causal: bool = False,
  183. ):
  184. """Forward pass for the attention module.
  185. Args:
  186. x: Input tensor of shape (batch_size, sequence_length, embedding_dim)
  187. rope: Rotary position embeddings tensor for position-aware attention
  188. attn_mask: Optional attention mask to apply during attention computation
  189. is_causal: If True, use causal (autoregressive) masking
  190. Returns:
  191. Tensor of shape (batch_size, sequence_length, embedding_dim)
  192. """
  193. B, N, C = x.shape
  194. if self.qkv is not None:
  195. if self.q_bias is None:
  196. qkv = self.qkv(x)
  197. else:
  198. qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
  199. if self.qkv_bias_separate:
  200. qkv = self.qkv(x)
  201. qkv += qkv_bias
  202. else:
  203. qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
  204. qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  205. q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
  206. else:
  207. q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
  208. k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
  209. v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
  210. q, k = self.q_norm(q), self.k_norm(k)
  211. if rope is not None:
  212. npt = self.num_prefix_tokens
  213. half = getattr(self, 'rotate_half', False)
  214. q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope, half=half)], dim=2).type_as(v)
  215. k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope, half=half)], dim=2).type_as(v)
  216. if self.fused_attn:
  217. x = F.scaled_dot_product_attention(
  218. q, k, v,
  219. attn_mask=attn_mask,
  220. dropout_p=self.attn_drop.p if self.training else 0.,
  221. is_causal=is_causal,
  222. )
  223. else:
  224. q = q * self.scale
  225. attn = q @ k.transpose(-2, -1)
  226. attn_bias = resolve_self_attn_mask(N, attn, attn_mask, is_causal=is_causal)
  227. attn = maybe_add_mask(attn, attn_bias)
  228. attn = attn.softmax(dim=-1)
  229. attn = self.attn_drop(attn)
  230. x = attn @ v
  231. x = x.transpose(1, 2).reshape(B, N, C)
  232. x = self.norm(x)
  233. x = self.proj(x)
  234. x = self.proj_drop(x)
  235. return x
  236. def init_non_persistent_buffers(self) -> None:
  237. """Initialize non-persistent buffers."""
  238. self._init_buffers()
  239. class EvaBlock(nn.Module):
  240. def __init__(
  241. self,
  242. dim: int,
  243. num_heads: int,
  244. qkv_bias: bool = True,
  245. qkv_fused: bool = True,
  246. mlp_ratio: float = 4.,
  247. swiglu_mlp: bool = False,
  248. swiglu_align_to: int = 0,
  249. scale_mlp: bool = False,
  250. scale_attn_inner: bool = False,
  251. num_prefix_tokens: int = 1,
  252. attn_type: str = 'eva',
  253. rotate_half: bool = False,
  254. proj_drop: float = 0.,
  255. attn_drop: float = 0.,
  256. drop_path: float = 0.,
  257. init_values: Optional[float] = None,
  258. act_layer: Callable = nn.GELU,
  259. norm_layer: Callable = LayerNorm,
  260. attn_head_dim: Optional[int] = None,
  261. device=None,
  262. dtype=None,
  263. **kwargs,
  264. ):
  265. """ Initialize the EVA transformer block.
  266. Args:
  267. dim: Input dimension of the token embeddings
  268. num_heads: Number of attention heads
  269. qkv_bias: Whether to use bias terms in query, key, value projections
  270. qkv_fused: Whether to use a single projection for query, key, value
  271. mlp_ratio: Ratio of MLP hidden dimension to input dimension
  272. swiglu_mlp: Whether to use SwiGLU activation in the MLP
  273. scale_mlp: Whether to use normalization in the MLP
  274. scale_attn_inner: Whether to use normalization within the attention mechanism
  275. num_prefix_tokens: Number of tokens at the beginning of the sequence (class tokens, etc.)
  276. attn_type: Type of attention module to use ('eva' or 'rope')
  277. proj_drop: Dropout rate for projection layers
  278. attn_drop: Dropout rate for attention matrix
  279. drop_path: Stochastic depth rate
  280. init_values: Initial value for LayerScale, None = no LayerScale
  281. act_layer: Activation layer constructor
  282. norm_layer: Normalization layer constructor
  283. attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
  284. """
  285. dd = {'device': device, 'dtype': dtype}
  286. super().__init__()
  287. self.norm1 = norm_layer(dim, **dd)
  288. attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention
  289. self.attn = attn_cls(
  290. dim,
  291. num_heads=num_heads,
  292. qkv_bias=qkv_bias,
  293. qkv_fused=qkv_fused,
  294. num_prefix_tokens=num_prefix_tokens,
  295. attn_drop=attn_drop,
  296. proj_drop=proj_drop,
  297. attn_head_dim=attn_head_dim,
  298. norm_layer=norm_layer,
  299. scale_norm=scale_attn_inner,
  300. rotate_half=rotate_half,
  301. **dd,
  302. )
  303. self.init_values = init_values
  304. self.gamma_1 = nn.Parameter(torch.empty(dim, **dd)) if init_values is not None else None
  305. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  306. self.norm2 = norm_layer(dim, **dd)
  307. hidden_features = int(dim * mlp_ratio)
  308. if swiglu_mlp:
  309. if scale_mlp or swiglu_align_to:
  310. # when norm in SwiGLU used or alignment enabled, an impl with separate fc for gate & x is used
  311. self.mlp = SwiGLU(
  312. in_features=dim,
  313. hidden_features=hidden_features,
  314. norm_layer=norm_layer if scale_mlp else None,
  315. drop=proj_drop,
  316. align_to=swiglu_align_to,
  317. **dd,
  318. )
  319. else:
  320. # w/o any extra norm, an impl with packed weights is used
  321. self.mlp = GluMlp(
  322. in_features=dim,
  323. hidden_features=hidden_features * 2,
  324. norm_layer=norm_layer if scale_mlp else None,
  325. act_layer=nn.SiLU,
  326. gate_last=False,
  327. drop=proj_drop,
  328. **dd,
  329. )
  330. else:
  331. self.mlp = Mlp(
  332. in_features=dim,
  333. hidden_features=hidden_features,
  334. act_layer=act_layer,
  335. norm_layer=norm_layer if scale_mlp else None,
  336. drop=proj_drop,
  337. **dd,
  338. )
  339. self.gamma_2 = nn.Parameter(torch.empty(dim, **dd)) if init_values is not None else None
  340. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  341. # TODO: skip init when on meta device when safe to do so
  342. self.reset_parameters()
  343. def reset_parameters(self) -> None:
  344. """Initialize parameters."""
  345. if self.gamma_1 is not None:
  346. nn.init.constant_(self.gamma_1, self.init_values)
  347. nn.init.constant_(self.gamma_2, self.init_values)
  348. def forward(
  349. self,
  350. x: torch.Tensor,
  351. rope: Optional[torch.Tensor] = None,
  352. attn_mask: Optional[torch.Tensor] = None,
  353. is_causal: bool = False,
  354. ) -> torch.Tensor:
  355. if self.gamma_1 is None:
  356. x = x + self.drop_path1(self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask, is_causal=is_causal))
  357. x = x + self.drop_path2(self.mlp(self.norm2(x)))
  358. else:
  359. x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask, is_causal=is_causal))
  360. x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
  361. return x
  362. class EvaBlockPostNorm(nn.Module):
  363. """ EVA block w/ post-norm and support for swiglu, MLP norm scale, ROPE. """
  364. def __init__(
  365. self,
  366. dim: int,
  367. num_heads: int,
  368. qkv_bias: bool = True,
  369. qkv_fused: bool = True,
  370. mlp_ratio: float = 4.,
  371. attn_type: str = 'eva',
  372. rotate_half: bool = False,
  373. swiglu_mlp: bool = False,
  374. swiglu_align_to: int = 0,
  375. scale_mlp: bool = False,
  376. scale_attn_inner: bool = False,
  377. num_prefix_tokens: int = 1,
  378. proj_drop: float = 0.,
  379. attn_drop: float = 0.,
  380. drop_path: float = 0.,
  381. init_values: Optional[float] = None, # ignore for post-norm
  382. act_layer: Callable = nn.GELU,
  383. norm_layer: Callable = nn.LayerNorm,
  384. attn_head_dim: Optional[int] = None,
  385. device=None,
  386. dtype=None,
  387. **kwargs,
  388. ):
  389. """ Initialize the post-norm EVA transformer block.
  390. Args:
  391. dim: Input dimension of the token embeddings
  392. num_heads: Number of attention heads
  393. qkv_bias: Whether to use bias terms in query, key, value projections
  394. qkv_fused: Whether to use a single projection for query, key, value
  395. mlp_ratio: Ratio of MLP hidden dimension to input dimension
  396. swiglu_mlp: Whether to use SwiGLU activation in the MLP
  397. scale_mlp: Whether to use normalization in the MLP
  398. scale_attn_inner: Whether to use normalization within the attention mechanism
  399. num_prefix_tokens: Number of tokens at the beginning of the sequence (class tokens, etc.)
  400. attn_type: Type of attention module to use ('eva' or 'rope')
  401. proj_drop: Dropout rate for projection layers
  402. attn_drop: Dropout rate for attention matrix
  403. drop_path: Stochastic depth rate
  404. init_values: Initial value for LayerScale, None = no LayerScale (NOTE: ignored for post-norm block)
  405. act_layer: Activation layer constructor
  406. norm_layer: Normalization layer constructor
  407. attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
  408. """
  409. dd = {'device': device, 'dtype': dtype}
  410. super().__init__()
  411. attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention
  412. self.attn = attn_cls(
  413. dim,
  414. num_heads=num_heads,
  415. qkv_bias=qkv_bias,
  416. qkv_fused=qkv_fused,
  417. num_prefix_tokens=num_prefix_tokens,
  418. attn_drop=attn_drop,
  419. proj_drop=proj_drop,
  420. attn_head_dim=attn_head_dim,
  421. norm_layer=norm_layer,
  422. scale_norm=scale_attn_inner,
  423. rotate_half=rotate_half,
  424. **dd,
  425. )
  426. self.norm1 = norm_layer(dim, **dd)
  427. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  428. hidden_features = int(dim * mlp_ratio)
  429. if swiglu_mlp:
  430. if scale_mlp:
  431. # when norm in SwiGLU used, an impl with separate fc for gate & x is used
  432. self.mlp = SwiGLU(
  433. in_features=dim,
  434. hidden_features=hidden_features,
  435. norm_layer=norm_layer if scale_mlp else None,
  436. drop=proj_drop,
  437. align_to=swiglu_align_to,
  438. **dd,
  439. )
  440. else:
  441. # w/o any extra norm, an impl with packed fc1 weights is used, matches existing GluMLP
  442. self.mlp = GluMlp(
  443. in_features=dim,
  444. hidden_features=hidden_features * 2,
  445. norm_layer=norm_layer if scale_mlp else None,
  446. act_layer=nn.SiLU,
  447. gate_last=False,
  448. drop=proj_drop,
  449. **dd,
  450. )
  451. else:
  452. self.mlp = Mlp(
  453. in_features=dim,
  454. hidden_features=hidden_features,
  455. act_layer=act_layer,
  456. norm_layer=norm_layer if scale_mlp else None,
  457. drop=proj_drop,
  458. **dd,
  459. )
  460. self.norm2 = norm_layer(dim, **dd)
  461. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  462. def forward(
  463. self,
  464. x: torch.Tensor,
  465. rope: Optional[torch.Tensor] = None,
  466. attn_mask: Optional[torch.Tensor] = None,
  467. is_causal: bool = False,
  468. ) -> torch.Tensor:
  469. x = x + self.drop_path1(self.norm1(self.attn(x, rope=rope, attn_mask=attn_mask, is_causal=is_causal)))
  470. x = x + self.drop_path2(self.norm2(self.mlp(x)))
  471. return x
  472. class Eva(nn.Module):
  473. """ Eva Vision Transformer w/ Abs & Rotary Pos Embed
  474. This class implements the EVA and EVA02 models that were based on the BEiT ViT variant
  475. * EVA - abs pos embed, global avg pool
  476. * EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer)
  477. """
  478. def __init__(
  479. self,
  480. img_size: Union[int, Tuple[int, int]] = 224,
  481. patch_size: Union[int, Tuple[int, int]] = 16,
  482. in_chans: int = 3,
  483. num_classes: int = 1000,
  484. global_pool: str = 'avg',
  485. embed_dim: int = 768,
  486. depth: int = 12,
  487. num_heads: int = 12,
  488. qkv_bias: bool = True,
  489. qkv_fused: bool = True,
  490. mlp_ratio: float = 4.,
  491. swiglu_mlp: bool = False,
  492. swiglu_align_to: int = 0,
  493. scale_mlp: bool = False,
  494. scale_attn_inner: bool = False,
  495. attn_type: str = 'eva',
  496. drop_rate: float = 0.,
  497. pos_drop_rate: float = 0.,
  498. patch_drop_rate: float = 0.,
  499. proj_drop_rate: float = 0.,
  500. attn_drop_rate: float = 0.,
  501. drop_path_rate: float = 0.,
  502. norm_layer: Callable = LayerNorm,
  503. init_values: Optional[float] = None,
  504. class_token: bool = True,
  505. num_reg_tokens: int = 0,
  506. no_embed_class: bool = False,
  507. use_abs_pos_emb: bool = True,
  508. use_rot_pos_emb: bool = False,
  509. rope_type: Optional[str] = 'cat',
  510. rope_grid_offset: float = 0.,
  511. rope_grid_indexing: str = 'ij',
  512. rope_temperature: float = 10000.,
  513. rope_rotate_half: bool = False,
  514. use_post_norm: bool = False,
  515. use_pre_transformer_norm: bool = False,
  516. use_post_transformer_norm: Optional[bool] = None,
  517. use_fc_norm: Optional[bool] = None,
  518. attn_pool_num_heads: Optional[int] = None,
  519. attn_pool_mlp_ratio: Optional[float] = None,
  520. dynamic_img_size: bool = False,
  521. dynamic_img_pad: bool = False,
  522. ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
  523. head_init_scale: float = 0.001,
  524. device=None,
  525. dtype=None,
  526. ):
  527. """Initialize the EVA Vision Transformer model.
  528. Args:
  529. img_size: Input image size (single int for square, or tuple for rectangular)
  530. patch_size: Patch size to divide image into tokens (single int for square, or tuple)
  531. in_chans: Number of input image channels
  532. num_classes: Number of classes (output dim) for classification head (final projection), 0 for pass-through
  533. global_pool: Type of global pooling for final sequence ('avg', 'token', 'map', etc.)
  534. embed_dim: Embedding dimension for tokens
  535. depth: Number of transformer blocks
  536. num_heads: Number of attention heads
  537. qkv_bias: Enable bias for query, key, value projections
  538. qkv_fused: Use a single projection for query, key, value
  539. mlp_ratio: Ratio of mlp hidden dim to embedding dim
  540. swiglu_mlp: Use SwiGLU activation in MLP
  541. scale_mlp: Apply scaling normalization in MLP (normformer style)
  542. scale_attn_inner: Apply scaling normalization inside attention
  543. attn_type: Type of attention module to use
  544. drop_rate: Dropout rate after final projection and pooling
  545. pos_drop_rate: Dropout rate for positional embeddings
  546. patch_drop_rate: Rate of dropping patches during training
  547. proj_drop_rate: Dropout rate for projections
  548. attn_drop_rate: Dropout rate for attention
  549. drop_path_rate: Stochastic depth rate
  550. norm_layer: Normalization layer constructor
  551. init_values: Initial layer-scale values
  552. class_token: Use class token
  553. num_reg_tokens: Number of additional learnable 'register' tokens to add to the sequence
  554. no_embed_class: Don't include position embeddings for class (or reg) tokens
  555. use_abs_pos_emb: Use absolute (learned) positional embeddings
  556. use_rot_pos_emb: Use rotary position embeddings
  557. rope_type: Type of RoPE to use ('cat', 'mixed', 'dinov3', etc.).
  558. rope_grid_offset: Offset for rotary position embedding grid
  559. rope_grid_indexing: Indexing mode for rotary position embeddings ('ij' or 'xy')
  560. rope_temperature: Temperature parameter for ROPE frequency computation
  561. rope_rotate_half: Use half rotation layout (rotate D/2 dims), else use interleaved rotation layout
  562. use_post_norm: Use post-norm transformer block type
  563. use_pre_transformer_norm: Use normalization layer before transformer blocks
  564. use_post_transformer_norm: Use normalization layer after transformer blocks
  565. use_fc_norm: Use normalization layer after pooling, before final classifier
  566. attn_pool_num_heads: Number of heads in attention pooling
  567. attn_pool_mlp_ratio: MLP ratio in attention pooling
  568. dynamic_img_size: Support dynamic image sizes in forward pass
  569. dynamic_img_pad: Apply dynamic padding for irregular image sizes
  570. ref_feat_shape: Reference feature shape for rotary position embedding scale
  571. head_init_scale: Initialization scale for classification head weights
  572. """
  573. super().__init__()
  574. dd = {'device': device, 'dtype': dtype}
  575. assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
  576. self.num_classes = num_classes
  577. self.in_chans = in_chans
  578. self.global_pool = global_pool
  579. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
  580. self.num_prefix_tokens = (1 if class_token else 0) + num_reg_tokens
  581. self.no_embed_class = no_embed_class
  582. self.dynamic_img_size = dynamic_img_size
  583. self.grad_checkpointing = False
  584. # resolve norm / pool usage
  585. activate_pre_norm = use_pre_transformer_norm
  586. if use_fc_norm is not None:
  587. activate_fc_norm = use_fc_norm # pass through if explicit
  588. else:
  589. activate_fc_norm = global_pool == 'avg' # default on if avg pool used
  590. if use_post_transformer_norm is not None:
  591. activate_post_norm = use_post_transformer_norm # pass through if explicit
  592. else:
  593. activate_post_norm = not activate_fc_norm # default on if fc_norm isn't active
  594. embed_args = {}
  595. if dynamic_img_size:
  596. # flatten deferred until after pos embed
  597. embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
  598. self.patch_embed = PatchEmbed(
  599. img_size=img_size,
  600. patch_size=patch_size,
  601. in_chans=in_chans,
  602. embed_dim=embed_dim,
  603. dynamic_img_pad=dynamic_img_pad,
  604. bias=not use_pre_transformer_norm,
  605. **embed_args,
  606. **dd,
  607. )
  608. num_patches = self.patch_embed.num_patches
  609. r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
  610. self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, **dd)) if class_token else None
  611. self.reg_token = nn.Parameter(torch.empty(1, num_reg_tokens, embed_dim, **dd)) if num_reg_tokens else None
  612. self.cls_embed = class_token and self.reg_token is None
  613. num_pos_tokens = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
  614. self.pos_embed = nn.Parameter(torch.empty(1, num_pos_tokens, embed_dim, **dd)) if use_abs_pos_emb else None
  615. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  616. if patch_drop_rate > 0:
  617. self.patch_drop = PatchDropoutWithIndices(patch_drop_rate, num_prefix_tokens=self.num_prefix_tokens)
  618. else:
  619. self.patch_drop = None
  620. self.rope_mixed = False
  621. if use_rot_pos_emb:
  622. ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
  623. # Setup RoPE kwargs
  624. rope_kwargs = dict(
  625. dim=embed_dim,
  626. num_heads=num_heads,
  627. feat_shape=None if dynamic_img_size else self.patch_embed.grid_size,
  628. temperature=rope_temperature,
  629. grid_indexing=rope_grid_indexing,
  630. **dd,
  631. )
  632. if rope_type == 'mixed':
  633. rope_kwargs.update(dict(depth=depth))
  634. self.rope_mixed = True
  635. elif rope_type == 'cat':
  636. rope_kwargs.update(dict(
  637. in_pixels=False,
  638. grid_offset=rope_grid_offset,
  639. ref_feat_shape=ref_feat_shape,
  640. ))
  641. self.rope = create_rope_embed(rope_type=rope_type, **rope_kwargs)
  642. else:
  643. self.rope = None
  644. self.norm_pre = norm_layer(embed_dim, **dd) if activate_pre_norm else nn.Identity()
  645. dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
  646. block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock
  647. self.blocks = nn.ModuleList([
  648. block_fn(
  649. dim=embed_dim,
  650. num_heads=num_heads,
  651. qkv_bias=qkv_bias,
  652. qkv_fused=qkv_fused,
  653. mlp_ratio=mlp_ratio,
  654. swiglu_mlp=swiglu_mlp,
  655. swiglu_align_to=swiglu_align_to,
  656. scale_mlp=scale_mlp,
  657. scale_attn_inner=scale_attn_inner,
  658. attn_type=attn_type,
  659. rotate_half=rope_rotate_half,
  660. num_prefix_tokens=self.num_prefix_tokens,
  661. proj_drop=proj_drop_rate,
  662. attn_drop=attn_drop_rate,
  663. drop_path=dpr[i],
  664. norm_layer=norm_layer,
  665. init_values=init_values,
  666. **dd,
  667. )
  668. for i in range(depth)])
  669. self.feature_info = [
  670. dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
  671. self.norm = norm_layer(embed_dim, **dd) if activate_post_norm else nn.Identity()
  672. if global_pool == 'map':
  673. self.attn_pool = AttentionPoolLatent(
  674. self.embed_dim,
  675. num_heads=attn_pool_num_heads or num_heads,
  676. mlp_ratio=attn_pool_mlp_ratio or mlp_ratio,
  677. norm_layer=norm_layer,
  678. act_layer=nn.GELU,
  679. **dd,
  680. )
  681. else:
  682. self.attn_pool = None
  683. self.fc_norm = norm_layer(embed_dim, **dd) if activate_fc_norm else nn.Identity()
  684. self.head_drop = nn.Dropout(drop_rate)
  685. self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  686. self.head_init_scale = head_init_scale
  687. # TODO: skip init when on meta device when safe to do so
  688. self.init_weights(needs_reset=False)
  689. def init_weights(self, needs_reset: bool = True):
  690. self.apply(partial(self._init_weights, needs_reset=needs_reset))
  691. if self.pos_embed is not None:
  692. trunc_normal_(self.pos_embed, std=.02)
  693. if self.cls_token is not None:
  694. trunc_normal_(self.cls_token, std=.02)
  695. if self.reg_token is not None:
  696. trunc_normal_(self.reg_token, std=.02)
  697. self.fix_init_weight()
  698. if self.head_init_scale and isinstance(self.head, nn.Linear):
  699. trunc_normal_(self.head.weight, std=.02)
  700. with torch.no_grad():
  701. self.head.weight.mul_(self.head_init_scale)
  702. self.head.bias.mul_(self.head_init_scale)
  703. def fix_init_weight(self) -> None:
  704. """Fix initialization weights by rescaling based on layer depth."""
  705. with torch.no_grad():
  706. for layer_id, layer in enumerate(self.blocks):
  707. scale = math.sqrt(2.0 * (layer_id + 1))
  708. layer.attn.proj.weight.div_(scale)
  709. layer.mlp.fc2.weight.div_(scale)
  710. def _init_weights(self, m: nn.Module, needs_reset: bool = True) -> None:
  711. """Initialize weights for Linear layers and call reset_parameters on modules.
  712. Args:
  713. m: Module to initialize.
  714. needs_reset: Whether to call reset_parameters() on modules.
  715. """
  716. if isinstance(m, nn.Linear):
  717. trunc_normal_(m.weight, std=.02)
  718. if m.bias is not None:
  719. nn.init.zeros_(m.bias)
  720. elif needs_reset and hasattr(m, 'reset_parameters') and m is not self:
  721. m.reset_parameters()
  722. @torch.jit.ignore
  723. def no_weight_decay(self) -> Set[str]:
  724. """Parameters to exclude from weight decay."""
  725. nwd = {'pos_embed', 'cls_token'}
  726. if (rope := getattr(self, "rope", None)) and hasattr(rope, "no_weight_decay"):
  727. return nwd | {f"rope.{p}" for p in rope.no_weight_decay()}
  728. return nwd
  729. @torch.jit.ignore
  730. def set_grad_checkpointing(self, enable: bool = True) -> None:
  731. """Enable or disable gradient checkpointing."""
  732. self.grad_checkpointing = enable
  733. @torch.jit.ignore
  734. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  735. """Create layer groupings for optimization."""
  736. matcher = dict(
  737. stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
  738. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
  739. )
  740. return matcher
  741. @torch.jit.ignore
  742. def get_classifier(self) -> nn.Module:
  743. return self.head
  744. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  745. """Reset the classifier head.
  746. Args:
  747. num_classes: Number of output classes.
  748. global_pool: Global pooling type.
  749. """
  750. self.num_classes = num_classes
  751. if global_pool is not None:
  752. self.global_pool = global_pool
  753. self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  754. def set_input_size(
  755. self,
  756. img_size: Optional[Tuple[int, int]] = None,
  757. patch_size: Optional[Tuple[int, int]] = None,
  758. ) -> None:
  759. """Update the input image resolution and patch size.
  760. Args:
  761. img_size: New input resolution, if None current resolution is used.
  762. patch_size: New patch size, if None existing patch size is used.
  763. """
  764. prev_grid_size = self.patch_embed.grid_size
  765. self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
  766. if self.pos_embed is not None:
  767. num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
  768. num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens
  769. if num_new_tokens != self.pos_embed.shape[1]:
  770. self.pos_embed = nn.Parameter(resample_abs_pos_embed(
  771. self.pos_embed,
  772. new_size=self.patch_embed.grid_size,
  773. old_size=prev_grid_size,
  774. num_prefix_tokens=num_prefix_tokens,
  775. verbose=True,
  776. ))
  777. if self.rope is not None:
  778. self.rope.update_feat_shape(self.patch_embed.grid_size)
  779. def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  780. if self.dynamic_img_size:
  781. B, H, W, C = x.shape
  782. if self.pos_embed is not None:
  783. prev_grid_size = self.patch_embed.grid_size
  784. pos_embed = resample_abs_pos_embed(
  785. self.pos_embed,
  786. new_size=(H, W),
  787. old_size=prev_grid_size,
  788. num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
  789. )
  790. else:
  791. pos_embed = None
  792. x = x.view(B, -1, C)
  793. rot_pos_embed = self.rope.get_embed(shape=(H, W)) if self.rope is not None else None
  794. else:
  795. pos_embed = self.pos_embed
  796. rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
  797. to_cat = []
  798. if self.cls_token is not None:
  799. to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
  800. if self.reg_token is not None:
  801. to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
  802. if self.no_embed_class:
  803. # position embedding does not overlap with class / reg token
  804. if pos_embed is not None:
  805. x = x + pos_embed
  806. if to_cat:
  807. x = torch.cat(to_cat + [x], dim=1)
  808. else:
  809. # pos_embed has entry for class / reg token, concat then add
  810. if to_cat:
  811. x = torch.cat(to_cat + [x], dim=1)
  812. if pos_embed is not None:
  813. x = x + pos_embed
  814. x = self.pos_drop(x)
  815. # apply patch dropout to patches and rotary position embedding
  816. if self.patch_drop is not None:
  817. x, keep_indices = self.patch_drop(x)
  818. if rot_pos_embed is not None and keep_indices is not None:
  819. rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
  820. # After applying keep indices to rope embeds, batch dim is added
  821. if getattr(self, 'rope_mixed', False):
  822. # B, D, nH, N, dim -> D, B, nH, N, dim. For consistent iteration over depth at index 0.
  823. rot_pos_embed = rot_pos_embed.transpose(0, 1)
  824. else:
  825. # B, N, dim -> B, 1, N, dim. Need head dim singleton for correct dim alignment in axial mode.
  826. rot_pos_embed = rot_pos_embed.unsqueeze(1)
  827. return x, rot_pos_embed
  828. def forward_intermediates(
  829. self,
  830. x: torch.Tensor,
  831. indices: Optional[Union[int, List[int]]] = None,
  832. return_prefix_tokens: bool = False,
  833. norm: bool = False,
  834. stop_early: bool = False,
  835. output_fmt: str = 'NCHW',
  836. intermediates_only: bool = False,
  837. attn_mask: Optional[torch.Tensor] = None,
  838. is_causal: bool = False,
  839. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  840. """ Forward features that returns intermediates.
  841. Args:
  842. x: Input image tensor
  843. indices: Take last n blocks if an int, if is a sequence, select by matching indices
  844. return_prefix_tokens: Return both prefix and spatial intermediate tokens
  845. norm: Apply norm layer to all intermediates
  846. stop_early: Stop iterating over blocks when last desired intermediate hit
  847. output_fmt: Shape of intermediate feature outputs
  848. intermediates_only: Only return intermediate features
  849. attn_mask: Optional attention mask for masked attention
  850. is_causal: If True, use causal (autoregressive) masking in attention
  851. """
  852. assert output_fmt in ('NCHW', 'NLC'), 'Output format for EVA-ViT features must be one of NCHW or NLC.'
  853. reshape = output_fmt == 'NCHW'
  854. intermediates = []
  855. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  856. # forward pass
  857. B, _, height, width = x.shape
  858. x = self.patch_embed(x)
  859. x, rot_pos_embed = self._pos_embed(x)
  860. x = self.norm_pre(x)
  861. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  862. blocks = self.blocks
  863. else:
  864. blocks = self.blocks[:max_index + 1]
  865. # Handle depth-dependent embeddings for mixed mode
  866. if getattr(self, 'rope_mixed', False) and rot_pos_embed is not None:
  867. for i, blk in enumerate(blocks):
  868. if self.grad_checkpointing and not torch.jit.is_scripting():
  869. x = checkpoint(blk, x, rope=rot_pos_embed[i], attn_mask=attn_mask, is_causal=is_causal)
  870. else:
  871. x = blk(x, rope=rot_pos_embed[i], attn_mask=attn_mask, is_causal=is_causal)
  872. if i in take_indices:
  873. intermediates.append(self.norm(x) if norm else x)
  874. else:
  875. for i, blk in enumerate(blocks):
  876. if self.grad_checkpointing and not torch.jit.is_scripting():
  877. x = checkpoint(blk, x, rope=rot_pos_embed, attn_mask=attn_mask, is_causal=is_causal)
  878. else:
  879. x = blk(x, rope=rot_pos_embed, attn_mask=attn_mask, is_causal=is_causal)
  880. if i in take_indices:
  881. intermediates.append(self.norm(x) if norm else x)
  882. # process intermediates
  883. if self.num_prefix_tokens:
  884. # split prefix (e.g. class, distill) and spatial feature tokens
  885. prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
  886. intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
  887. if reshape:
  888. # reshape to BCHW output format
  889. H, W = self.patch_embed.dynamic_feat_size((height, width))
  890. intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
  891. if not torch.jit.is_scripting() and return_prefix_tokens:
  892. # return_prefix not support in torchscript due to poor type handling
  893. intermediates = list(zip(intermediates, prefix_tokens))
  894. if intermediates_only:
  895. return intermediates
  896. x = self.norm(x)
  897. return x, intermediates
  898. def prune_intermediate_layers(
  899. self,
  900. indices: Union[int, List[int]] = 1,
  901. prune_norm: bool = False,
  902. prune_head: bool = True,
  903. ):
  904. """ Prune layers not required for specified intermediates.
  905. """
  906. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  907. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  908. if prune_norm:
  909. self.norm = nn.Identity()
  910. if prune_head:
  911. self.attn_pool = None
  912. self.fc_norm = nn.Identity()
  913. self.reset_classifier(0, '')
  914. return take_indices
  915. def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
  916. if self.attn_pool is not None:
  917. x = self.attn_pool(x)
  918. return x
  919. pool_type = self.global_pool if pool_type is None else pool_type
  920. x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
  921. return x
  922. def forward_features(
  923. self,
  924. x: torch.Tensor,
  925. attn_mask: Optional[torch.Tensor] = None,
  926. is_causal: bool = False,
  927. ) -> torch.Tensor:
  928. """Forward pass through feature extraction layers.
  929. Args:
  930. x: Input tensor.
  931. attn_mask: Optional attention mask for masked attention
  932. is_causal: If True, use causal (autoregressive) masking in attention.
  933. Returns:
  934. Feature tensor.
  935. """
  936. x = self.patch_embed(x)
  937. x, rot_pos_embed = self._pos_embed(x)
  938. x = self.norm_pre(x)
  939. if getattr(self, 'rope_mixed', False) and rot_pos_embed is not None:
  940. # Handle depth-dependent embeddings for mixed mode
  941. # pos embed has shape (depth, num_heads, H*W, dim) or (depth, batch_size, num_heads, H*W, dim)
  942. for i, blk in enumerate(self.blocks):
  943. if self.grad_checkpointing and not torch.jit.is_scripting():
  944. x = checkpoint(blk, x, rope=rot_pos_embed[i], attn_mask=attn_mask, is_causal=is_causal)
  945. else:
  946. x = blk(x, rope=rot_pos_embed[i], attn_mask=attn_mask, is_causal=is_causal)
  947. else:
  948. # Standard path for non-mixed mode
  949. for blk in self.blocks:
  950. if self.grad_checkpointing and not torch.jit.is_scripting():
  951. x = checkpoint(blk, x, rope=rot_pos_embed, attn_mask=attn_mask, is_causal=is_causal)
  952. else:
  953. x = blk(x, rope=rot_pos_embed, attn_mask=attn_mask, is_causal=is_causal)
  954. x = self.norm(x)
  955. return x
  956. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  957. """Forward pass through classifier head.
  958. Args:
  959. x: Feature tensor.
  960. pre_logits: Return pre-logits if True.
  961. Returns:
  962. Output tensor.
  963. """
  964. x = self.pool(x)
  965. x = self.fc_norm(x)
  966. x = self.head_drop(x)
  967. return x if pre_logits else self.head(x)
  968. def forward(
  969. self,
  970. x: torch.Tensor,
  971. attn_mask: Optional[torch.Tensor] = None,
  972. is_causal: bool = False,
  973. ) -> torch.Tensor:
  974. """Forward pass.
  975. Args:
  976. x: Input tensor.
  977. attn_mask: Optional attention mask for masked attention
  978. is_causal: If True, use causal (autoregressive) masking in attention.
  979. Returns:
  980. Output tensor.
  981. """
  982. x = self.forward_features(x, attn_mask=attn_mask, is_causal=is_causal)
  983. x = self.forward_head(x)
  984. return x
  985. def _convert_pe(
  986. state_dict: Dict[str, torch.Tensor],
  987. model: nn.Module,
  988. prefix: str = 'visual.',
  989. ) -> Dict[str, torch.Tensor]:
  990. """Convert Perception Encoder weights.
  991. Args:
  992. state_dict: State dictionary to convert.
  993. model: Target model instance.
  994. prefix: Prefix to strip from keys.
  995. Returns:
  996. Converted state dictionary.
  997. """
  998. state_dict = state_dict.get('model', state_dict)
  999. state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
  1000. out_dict = {}
  1001. swaps = [
  1002. ('conv1', 'patch_embed.proj'),
  1003. ('positional_embedding', 'pos_embed'),
  1004. ('transformer.resblocks.', 'blocks.'),
  1005. ('ln_pre', 'norm_pre'),
  1006. ('ln_post', 'norm'),
  1007. ('ln_', 'norm'),
  1008. ('ls_1.gamma', 'gamma_1'),
  1009. ('ls_2.gamma', 'gamma_2'),
  1010. ('in_proj_', 'qkv.'),
  1011. ('out_proj', 'proj'),
  1012. ('mlp.c_fc', 'mlp.fc1'),
  1013. ('mlp.c_proj', 'mlp.fc2'),
  1014. ]
  1015. len_prefix = len(prefix)
  1016. for k, v in state_dict.items():
  1017. if prefix:
  1018. if not k.startswith(prefix):
  1019. continue
  1020. k = k[len_prefix:]
  1021. for sp in swaps:
  1022. k = k.replace(sp[0], sp[1])
  1023. if k.startswith('attn_pool'):
  1024. k = k.replace('attn_pool.attn', 'attn_pool')
  1025. k = k.replace('attn_pool.layernorm', 'attn_pool.norm')
  1026. k = k.replace('attn_pool.probe', 'attn_pool.latent')
  1027. if k.startswith('attn_pool.qkv'):
  1028. dim = v.shape[0] // 3
  1029. if k.endswith('weight'):
  1030. out_dict['attn_pool.q.weight'] = v[:dim]
  1031. out_dict['attn_pool.kv.weight'] = v[dim:]
  1032. elif k.endswith('bias'):
  1033. out_dict['attn_pool.q.bias'] = v[:dim]
  1034. out_dict['attn_pool.kv.bias'] = v[dim:]
  1035. continue
  1036. elif k == 'proj':
  1037. k = 'head.weight'
  1038. v = v.transpose(0, 1)
  1039. out_dict['head.bias'] = torch.zeros(v.shape[0])
  1040. elif k == 'class_embedding':
  1041. k = 'cls_token'
  1042. v = v.unsqueeze(0).unsqueeze(1)
  1043. elif k == 'pos_embed':
  1044. v = v.unsqueeze(0)
  1045. out_dict[k] = v
  1046. return out_dict
  1047. def checkpoint_filter_fn(
  1048. state_dict: Dict[str, torch.Tensor],
  1049. model: nn.Module,
  1050. interpolation: str = 'bicubic',
  1051. antialias: bool = True,
  1052. ) -> Dict[str, torch.Tensor]:
  1053. """Convert patch embedding weight from manual patchify + linear proj to conv.
  1054. Args:
  1055. state_dict: Checkpoint state dictionary.
  1056. model: Target model instance.
  1057. interpolation: Interpolation method for resizing.
  1058. antialias: Whether to use antialiasing when resizing.
  1059. Returns:
  1060. Filtered state dictionary.
  1061. """
  1062. out_dict = {}
  1063. # Standard EVA checkpoint processing
  1064. state_dict = state_dict.get('model_ema', state_dict)
  1065. state_dict = state_dict.get('model', state_dict)
  1066. state_dict = state_dict.get('module', state_dict)
  1067. state_dict = state_dict.get('state_dict', state_dict)
  1068. # Loading Meta PE (Perception Encoder) weights
  1069. if 'visual.conv1.weight' in state_dict:
  1070. return _convert_pe(state_dict, model)
  1071. elif 'conv1.weight' in state_dict:
  1072. return _convert_pe(state_dict, model, prefix='')
  1073. # prefix for loading OpenCLIP compatible weights
  1074. if 'visual.trunk.pos_embed' in state_dict:
  1075. prefix = 'visual.trunk.'
  1076. elif 'visual.pos_embed' in state_dict:
  1077. prefix = 'visual.'
  1078. else:
  1079. prefix = ''
  1080. dinov3_weights = 'storage_tokens' in state_dict
  1081. mim_weights = not dinov3_weights and prefix + 'mask_token' in state_dict
  1082. no_qkv = prefix + 'blocks.0.attn.q_proj.weight' in state_dict
  1083. len_prefix = len(prefix)
  1084. for k, v in state_dict.items():
  1085. if prefix:
  1086. if not k.startswith(prefix):
  1087. continue
  1088. k = k[len_prefix:]
  1089. if 'rope' in k and not k == 'rope.freqs':
  1090. # fixed embedding no need to load buffer from checkpoint
  1091. continue
  1092. if dinov3_weights:
  1093. if any([k.endswith(f) for f in ['.periods', '.bias_mask', 'mask_token']]):
  1094. # discard unused/non-persistent/pretrain only params
  1095. continue
  1096. if k.startswith('local_cls_norm'):
  1097. # discard, only used for 7b dinov3 pretrain w/ local crops
  1098. continue
  1099. if k.endswith('qkv.bias'):
  1100. q_bias_k = k.replace('qkv.bias', 'q_bias')
  1101. try:
  1102. # the distilled b,l,h models ended up with all zero biases, so timm
  1103. # has both qkv_bias=True and qkv_bias=False impl, test which
  1104. model.get_parameter(q_bias_k)
  1105. except Exception as e:
  1106. print(e)
  1107. # skip as target model has no bias parameter
  1108. continue
  1109. # split bias into components and skip the k as its supposed to be fixed at 0
  1110. qv, kv, vv = v.chunk(3, dim=-1)
  1111. out_dict[q_bias_k] = qv
  1112. out_dict[k.replace('qkv.bias', 'v_bias')] = vv
  1113. continue
  1114. k = k.replace('ls1.gamma', 'gamma_1') # match EVA ls naming
  1115. k = k.replace('ls2.gamma', 'gamma_2') # match EVA ls naming
  1116. k = k.replace('storage_tokens', 'reg_token') # rename storage to existing register naming
  1117. elif mim_weights and k in ('mask_token', 'lm_head.weight', 'lm_head.bias', 'norm.weight', 'norm.bias'):
  1118. if k == 'norm.weight' or k == 'norm.bias':
  1119. # try moving norm -> fc norm on fine-tune, probably a better starting point than new init
  1120. k = k.replace('norm', 'fc_norm')
  1121. else:
  1122. # skip pretrain mask token & head weights
  1123. continue
  1124. if 'patch_embed.proj.weight' in k:
  1125. _, _, H, W = model.patch_embed.proj.weight.shape
  1126. if v.shape[-1] != W or v.shape[-2] != H:
  1127. v = resample_patch_embed(
  1128. v,
  1129. (H, W),
  1130. interpolation=interpolation,
  1131. antialias=antialias,
  1132. verbose=True,
  1133. )
  1134. elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
  1135. # To resize pos embedding when using model at different size from pretrained weights
  1136. num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
  1137. v = resample_abs_pos_embed(
  1138. v,
  1139. new_size=model.patch_embed.grid_size,
  1140. num_prefix_tokens=num_prefix_tokens,
  1141. interpolation=interpolation,
  1142. antialias=antialias,
  1143. verbose=True,
  1144. )
  1145. k = k.replace('mlp.ffn_ln', 'mlp.norm')
  1146. k = k.replace('attn.inner_attn_ln', 'attn.norm')
  1147. k = k.replace('mlp.w12', 'mlp.fc1')
  1148. k = k.replace('mlp.w1', 'mlp.fc1_g')
  1149. k = k.replace('mlp.w2', 'mlp.fc1_x')
  1150. k = k.replace('mlp.w3', 'mlp.fc2')
  1151. if no_qkv:
  1152. k = k.replace('q_bias', 'q_proj.bias')
  1153. k = k.replace('v_bias', 'v_proj.bias')
  1154. out_dict[k] = v
  1155. return out_dict
  1156. def _create_eva(variant: str, pretrained: bool = False, **kwargs) -> Eva:
  1157. """Create an EVA model.
  1158. Args:
  1159. variant: Model variant name.
  1160. pretrained: Load pretrained weights.
  1161. **kwargs: Additional model arguments.
  1162. Returns:
  1163. Instantiated Eva model.
  1164. """
  1165. # Check if we should use NaFlexVit implementation
  1166. use_naflex = kwargs.pop('use_naflex', None)
  1167. _USE_NAFLEX_DEFAULT = os.environ.get('TIMM_USE_NAFLEX', '0') == '1'
  1168. if use_naflex is None:
  1169. use_naflex = _USE_NAFLEX_DEFAULT
  1170. if use_naflex:
  1171. # Import here to avoid circular import
  1172. from .naflexvit import _create_naflexvit_from_eva
  1173. return _create_naflexvit_from_eva(variant, pretrained, **kwargs)
  1174. out_indices = kwargs.pop('out_indices', 3)
  1175. model = build_model_with_cfg(
  1176. Eva, variant, pretrained,
  1177. pretrained_filter_fn=checkpoint_filter_fn,
  1178. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  1179. **kwargs,
  1180. )
  1181. return model
  1182. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  1183. """Generate default configuration for EVA models.
  1184. Args:
  1185. url: Model weights URL.
  1186. **kwargs: Additional configuration parameters.
  1187. Returns:
  1188. Model configuration dictionary.
  1189. """
  1190. return {
  1191. 'url': url,
  1192. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  1193. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  1194. 'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
  1195. 'first_conv': 'patch_embed.proj', 'classifier': 'head',
  1196. 'license': 'mit', **kwargs
  1197. }
  1198. def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  1199. """Generate default configuration for Perception Encoder models.
  1200. Args:
  1201. url: Model weights URL.
  1202. **kwargs: Additional configuration parameters.
  1203. Returns:
  1204. Model configuration dictionary.
  1205. """
  1206. return {
  1207. 'url': url,
  1208. 'num_classes': 0, 'input_size': (3, 224, 224), 'pool_size': None,
  1209. 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
  1210. 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
  1211. 'first_conv': 'patch_embed.proj', 'classifier': 'head',
  1212. 'license': 'apache-2.0', **kwargs
  1213. }
  1214. def _dinov3_cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  1215. """Generate default configuration for DINOv3 models.
  1216. Note: Original DINOv3 uses CLS-token pooling for representations. timm defaults to avg
  1217. pooling for the Eva architecture. Pass global_pool='token' at model creation to match
  1218. upstream behavior, which may be preferred for tasks like retrieval and few-shot classification.
  1219. Args:
  1220. url: Model weights URL.
  1221. **kwargs: Additional configuration parameters.
  1222. Returns:
  1223. Model configuration dictionary.
  1224. """
  1225. return {
  1226. 'url': url,
  1227. 'num_classes': 0, 'input_size': (3, 256, 256), 'pool_size': None,
  1228. 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
  1229. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  1230. 'first_conv': 'patch_embed.proj', 'classifier': 'head',
  1231. 'license': 'dinov3-license', **kwargs
  1232. }
  1233. default_cfgs = generate_default_cfgs({
  1234. # EVA 01 CLIP fine-tuned on imagenet-1k
  1235. 'eva_giant_patch14_224.clip_ft_in1k': _cfg(
  1236. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
  1237. hf_hub_id='timm/',
  1238. ),
  1239. 'eva_giant_patch14_336.clip_ft_in1k': _cfg(
  1240. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
  1241. hf_hub_id='timm/',
  1242. input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
  1243. # MIM EVA 01 pretrain, ft on in22k -> in1k
  1244. 'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
  1245. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
  1246. hf_hub_id='timm/',
  1247. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1248. input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
  1249. 'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
  1250. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
  1251. hf_hub_id='timm/',
  1252. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1253. input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
  1254. # in22k or m38m MIM pretrain w/ intermediate in22k fine-tune and final in1k fine-tune
  1255. 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
  1256. # hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_B_pt_in21k_medft_in21k_ft_in1k_p14.pt',
  1257. hf_hub_id='timm/',
  1258. input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
  1259. ),
  1260. 'eva02_large_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
  1261. # hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_in21k_medft_in21k_ft_in1k_p14.pt',
  1262. hf_hub_id='timm/',
  1263. input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
  1264. ),
  1265. 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k': _cfg(
  1266. hf_hub_id='timm/',
  1267. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_m38m_medft_in21k_ft_in1k_p14.pt',
  1268. input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
  1269. ),
  1270. # in22k or m3m MIM pretrain w/ in1k fine-tune
  1271. 'eva02_tiny_patch14_336.mim_in22k_ft_in1k': _cfg(
  1272. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_Ti_pt_in21k_ft_in1k_p14.pt',
  1273. hf_hub_id='timm/',
  1274. input_size=(3, 336, 336), crop_pct=1.0,
  1275. ),
  1276. 'eva02_small_patch14_336.mim_in22k_ft_in1k': _cfg(
  1277. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_S_pt_in21k_ft_in1k_p14.pt',
  1278. hf_hub_id='timm/',
  1279. input_size=(3, 336, 336), crop_pct=1.0,
  1280. ),
  1281. 'eva02_base_patch14_448.mim_in22k_ft_in1k': _cfg(
  1282. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_B_pt_in21k_ft_in1k_p14.pt',
  1283. hf_hub_id='timm/',
  1284. input_size=(3, 448, 448), crop_pct=1.0,
  1285. ),
  1286. 'eva02_large_patch14_448.mim_in22k_ft_in1k': _cfg(
  1287. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_L_pt_in21k_ft_in1k_p14.pt',
  1288. hf_hub_id='timm/',
  1289. input_size=(3, 448, 448), crop_pct=1.0,
  1290. ),
  1291. 'eva02_large_patch14_448.mim_m38m_ft_in1k': _cfg(
  1292. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_L_pt_m38m_ft_in1k_p14.pt',
  1293. hf_hub_id='timm/',
  1294. input_size=(3, 448, 448), crop_pct=1.0,
  1295. ),
  1296. # in22k or m3m MIM pretrain w/ in22k fine-tune
  1297. 'eva02_base_patch14_448.mim_in22k_ft_in22k': _cfg(
  1298. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_B_pt_in21k_medft_in21k_p14.pt',
  1299. hf_hub_id='timm/',
  1300. input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
  1301. ),
  1302. 'eva02_large_patch14_448.mim_in22k_ft_in22k': _cfg(
  1303. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_in21k_medft_in21k_p14.pt',
  1304. hf_hub_id='timm/',
  1305. input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
  1306. ),
  1307. 'eva02_large_patch14_448.mim_m38m_ft_in22k': _cfg(
  1308. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_m38m_medft_in21k_p14.pt',
  1309. hf_hub_id='timm/',
  1310. input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
  1311. ),
  1312. # in22k or m38m MIM pretrain
  1313. 'eva02_tiny_patch14_224.mim_in22k': _cfg(
  1314. # hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_Ti_pt_in21k_p14.pt',
  1315. hf_hub_id='timm/',
  1316. num_classes=0,
  1317. ),
  1318. 'eva02_small_patch14_224.mim_in22k': _cfg(
  1319. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_S_pt_in21k_p14.pt',
  1320. hf_hub_id='timm/',
  1321. num_classes=0,
  1322. ),
  1323. 'eva02_base_patch14_224.mim_in22k': _cfg(
  1324. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_B_pt_in21k_p14.pt',
  1325. hf_hub_id='timm/',
  1326. num_classes=0,
  1327. ),
  1328. 'eva02_large_patch14_224.mim_in22k': _cfg(
  1329. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_in21k_p14.pt',
  1330. hf_hub_id='timm/',
  1331. num_classes=0,
  1332. ),
  1333. 'eva02_large_patch14_224.mim_m38m': _cfg(
  1334. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_m38m_p14.pt',
  1335. hf_hub_id='timm/',
  1336. num_classes=0,
  1337. ),
  1338. # EVA01 and EVA02 CLIP image towers
  1339. 'eva_giant_patch14_clip_224.laion400m': _cfg(
  1340. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
  1341. # hf_hub_id='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k', # float16 weights
  1342. # hf_hub_filename='open_clip_pytorch_model.bin',
  1343. hf_hub_id='timm/',
  1344. num_classes=1024,
  1345. ),
  1346. 'eva_giant_patch14_clip_224.merged2b': _cfg(
  1347. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
  1348. # hf_hub_id='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k', # float16 weights
  1349. # hf_hub_filename='open_clip_pytorch_model.bin',
  1350. hf_hub_id='timm/',
  1351. num_classes=1024,
  1352. ),
  1353. 'eva02_base_patch16_clip_224.merged2b': _cfg(
  1354. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
  1355. # hf_hub_id='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k', # float16 weights
  1356. # hf_hub_filename='open_clip_pytorch_model.bin',
  1357. hf_hub_id='timm/',
  1358. num_classes=512,
  1359. ),
  1360. 'eva02_large_patch14_clip_224.merged2b': _cfg(
  1361. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
  1362. # hf_hub_id='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k', # float16 weights
  1363. # hf_hub_filename='open_clip_pytorch_model.bin',
  1364. hf_hub_id='timm/',
  1365. num_classes=768,
  1366. ),
  1367. 'eva02_large_patch14_clip_336.merged2b': _cfg(
  1368. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
  1369. # hf_hub_id='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k', # float16 weights
  1370. # hf_hub_filename='open_clip_pytorch_model.bin',
  1371. hf_hub_id='timm/',
  1372. input_size=(3, 336, 336), crop_pct=1.0,
  1373. num_classes=768,
  1374. ),
  1375. 'eva02_enormous_patch14_clip_224.laion2b': _cfg(
  1376. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
  1377. # hf_hub_id='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k', # float16 weights
  1378. # hf_hub_filename='open_clip_pytorch_model.bin',
  1379. hf_hub_id='timm/',
  1380. num_classes=1024,
  1381. ),
  1382. 'eva02_enormous_patch14_clip_224.laion2b_plus': _cfg(
  1383. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
  1384. # hf_hub_id='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k', # bfloat16 weights
  1385. # hf_hub_filename='open_clip_pytorch_model.bin',
  1386. hf_hub_id='timm/',
  1387. num_classes=1024,
  1388. ),
  1389. 'eva02_enormous_patch14_clip_224.pretrain': _cfg(
  1390. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_E_psz14.pt',
  1391. num_classes=0,
  1392. ),
  1393. 'vit_medium_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
  1394. hf_hub_id='timm/',
  1395. input_size=(3, 256, 256), crop_pct=0.95,
  1396. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
  1397. ),
  1398. 'vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
  1399. hf_hub_id='timm/',
  1400. input_size=(3, 256, 256), crop_pct=0.95,
  1401. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
  1402. ),
  1403. 'vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k': _cfg(
  1404. hf_hub_id='timm/',
  1405. input_size=(3, 256, 256), crop_pct=0.95,
  1406. ),
  1407. 'vit_base_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
  1408. hf_hub_id='timm/',
  1409. input_size=(3, 256, 256), crop_pct=0.95,
  1410. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
  1411. ),
  1412. # Perception Encoder weights
  1413. 'vit_pe_core_tiny_patch16_384.fb': _pe_cfg(
  1414. hf_hub_id='timm/',
  1415. #hf_hub_id='facebook/PE-Core-T16-384',
  1416. #hf_hub_filename='PE-Core-T16-384.pt',
  1417. input_size=(3, 384, 384),
  1418. num_classes=512, # output proj dim
  1419. ),
  1420. 'vit_pe_core_small_patch16_384.fb': _pe_cfg(
  1421. hf_hub_id='timm/',
  1422. #hf_hub_id='facebook/PE-Core-S16-384',
  1423. #hf_hub_filename='PE-Core-S16-384.pt',
  1424. input_size=(3, 384, 384),
  1425. num_classes=512, # output proj dim
  1426. ),
  1427. 'vit_pe_core_base_patch16_224.fb': _pe_cfg(
  1428. hf_hub_id='timm/',
  1429. #hf_hub_id='facebook/PE-Core-B16-224',
  1430. #hf_hub_filename='PE-Core-B16-224.pt',
  1431. input_size=(3, 224, 224),
  1432. num_classes=1024, # output proj dim
  1433. ),
  1434. 'vit_pe_core_large_patch14_336.fb': _pe_cfg(
  1435. hf_hub_id='timm/',
  1436. #hf_hub_id='facebook/PE-Core-L14-336',
  1437. #hf_hub_filename='PE-Core-L14-336.pt',
  1438. input_size=(3, 336, 336),
  1439. num_classes=1024, # output proj dim
  1440. ),
  1441. 'vit_pe_core_gigantic_patch14_448.fb': _pe_cfg(
  1442. hf_hub_id='timm/',
  1443. #hf_hub_id='facebook/PE-Core-G14-448',
  1444. #hf_hub_filename='PE-Core-G14-448.pt',
  1445. input_size=(3, 448, 448),
  1446. num_classes=1280, # output proj dim
  1447. ),
  1448. 'vit_pe_lang_large_patch14_448.fb': _pe_cfg(
  1449. hf_hub_id='timm/',
  1450. #hf_hub_id='facebook/PE-Lang-L14-448',
  1451. #hf_hub_filename='PE-Lang-L14-448.pt',
  1452. input_size=(3, 448, 448),
  1453. num_classes=0,
  1454. ),
  1455. 'vit_pe_lang_large_patch14_448.fb_tiling': _pe_cfg(
  1456. hf_hub_id='timm/',
  1457. #hf_hub_id='facebook/PE-Lang-L14-448-Tiling',
  1458. #hf_hub_filename='PE-Lang-L14-448-Tiling.pt',
  1459. input_size=(3, 448, 448),
  1460. num_classes=0,
  1461. ),
  1462. 'vit_pe_lang_gigantic_patch14_448.fb': _pe_cfg(
  1463. hf_hub_id='timm/',
  1464. #hf_hub_id='facebook/PE-Lang-G14-448',
  1465. #hf_hub_filename='PE-Lang-G14-448.pt',
  1466. input_size=(3, 448, 448),
  1467. num_classes=0,
  1468. ),
  1469. 'vit_pe_lang_gigantic_patch14_448.fb_tiling': _pe_cfg(
  1470. hf_hub_id='timm/',
  1471. #hf_hub_id='facebook/PE-Lang-G14-448-Tiling',
  1472. #hf_hub_filename='PE-Lang-G14-448-Tiling.pt',
  1473. input_size=(3, 448, 448),
  1474. num_classes=0,
  1475. ),
  1476. 'vit_pe_spatial_tiny_patch16_512.fb': _pe_cfg(
  1477. hf_hub_id='timm/',
  1478. #hf_hub_id='facebook/PE-Spatial-T16-512',
  1479. #hf_hub_filename='PE-Spatial-T16-512.pt',
  1480. input_size=(3, 512, 512),
  1481. num_classes=0,
  1482. ),
  1483. 'vit_pe_spatial_small_patch16_512.fb': _pe_cfg(
  1484. hf_hub_id='timm/',
  1485. #hf_hub_id='facebook/PE-Spatial-S16-512',
  1486. #hf_hub_filename='PE-Spatial-S16-512.pt',
  1487. input_size=(3, 512, 512),
  1488. num_classes=0,
  1489. ),
  1490. 'vit_pe_spatial_base_patch16_512.fb': _pe_cfg(
  1491. hf_hub_id='timm/',
  1492. #hf_hub_id='facebook/PE-Spatial-B16-512',
  1493. #hf_hub_filename='PE-Spatial-B16-512.pt',
  1494. input_size=(3, 512, 512),
  1495. num_classes=0,
  1496. ),
  1497. 'vit_pe_spatial_large_patch14_448.fb': _pe_cfg(
  1498. hf_hub_id='timm/',
  1499. #hf_hub_id='facebook/PE-Spatial-L14-448',
  1500. #hf_hub_filename='PE-Spatial-L14-448.pt',
  1501. input_size=(3, 448, 448),
  1502. num_classes=0,
  1503. ),
  1504. 'vit_pe_spatial_gigantic_patch14_448.fb': _pe_cfg(
  1505. hf_hub_id='timm/',
  1506. #hf_hub_id='facebook/PE-Spatial-G14-448',
  1507. #hf_hub_filename='PE-Spatial-G14-448.pt',
  1508. input_size=(3, 448, 448),
  1509. num_classes=0,
  1510. ),
  1511. # RoPE-ViT models from Naver
  1512. 'vit_small_patch16_rope_224.naver_in1k': _cfg(
  1513. hf_hub_id='timm/',
  1514. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1515. license='apache-2.0',
  1516. ),
  1517. 'vit_base_patch16_rope_224.naver_in1k': _cfg(
  1518. hf_hub_id='timm/',
  1519. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1520. license='apache-2.0',
  1521. ),
  1522. 'vit_large_patch16_rope_224.naver_in1k': _cfg(
  1523. hf_hub_id='timm/',
  1524. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1525. license='apache-2.0',
  1526. ),
  1527. 'vit_small_patch16_rope_mixed_224.naver_in1k': _cfg(
  1528. hf_hub_id='timm/',
  1529. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1530. license='apache-2.0',
  1531. ),
  1532. 'vit_base_patch16_rope_mixed_224.naver_in1k': _cfg(
  1533. hf_hub_id='timm/',
  1534. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1535. license='apache-2.0',
  1536. ),
  1537. 'vit_large_patch16_rope_mixed_224.naver_in1k': _cfg(
  1538. hf_hub_id='timm/',
  1539. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1540. license='apache-2.0',
  1541. ),
  1542. 'vit_small_patch16_rope_ape_224.naver_in1k': _cfg(
  1543. hf_hub_id='timm/',
  1544. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1545. license='apache-2.0',
  1546. ),
  1547. 'vit_base_patch16_rope_ape_224.naver_in1k': _cfg(
  1548. hf_hub_id='timm/',
  1549. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1550. license='apache-2.0',
  1551. ),
  1552. 'vit_large_patch16_rope_ape_224.naver_in1k': _cfg(
  1553. hf_hub_id='timm/',
  1554. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1555. license='apache-2.0',
  1556. ),
  1557. 'vit_small_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
  1558. hf_hub_id='timm/',
  1559. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1560. license='apache-2.0',
  1561. ),
  1562. 'vit_base_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
  1563. hf_hub_id='timm/',
  1564. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1565. license='apache-2.0',
  1566. ),
  1567. 'vit_large_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
  1568. hf_hub_id='timm/',
  1569. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1570. license='apache-2.0',
  1571. ),
  1572. # DINOv3 weights are under a specific license with redistribution terms, please see
  1573. # https://github.com/facebookresearch/dinov3/blob/main/LICENSE.md
  1574. # NOTE: Original DINOv3 uses CLS-token pooling (global_pool='token') which may be better
  1575. # for some tasks. Default here is avg pooling inherited from the Eva base class.
  1576. 'vit_small_patch16_dinov3.lvd1689m': _dinov3_cfg(
  1577. hf_hub_id='timm/',
  1578. ),
  1579. 'vit_small_patch16_dinov3_qkvb.lvd1689m': _dinov3_cfg(
  1580. hf_hub_id='timm/',
  1581. ),
  1582. 'vit_small_plus_patch16_dinov3.lvd1689m': _dinov3_cfg(
  1583. hf_hub_id='timm/',
  1584. ),
  1585. 'vit_small_plus_patch16_dinov3_qkvb.lvd1689m': _dinov3_cfg(
  1586. hf_hub_id='timm/',
  1587. ),
  1588. 'vit_base_patch16_dinov3.lvd1689m': _dinov3_cfg(
  1589. hf_hub_id='timm/',
  1590. ),
  1591. 'vit_base_patch16_dinov3_qkvb.lvd1689m': _dinov3_cfg(
  1592. hf_hub_id='timm/',
  1593. ),
  1594. 'vit_large_patch16_dinov3.lvd1689m': _dinov3_cfg(
  1595. hf_hub_id='timm/',
  1596. ),
  1597. 'vit_large_patch16_dinov3_qkvb.lvd1689m': _dinov3_cfg(
  1598. hf_hub_id='timm/',
  1599. ),
  1600. 'vit_large_patch16_dinov3.sat493m': _dinov3_cfg(
  1601. hf_hub_id='timm/',
  1602. mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143),
  1603. ),
  1604. 'vit_large_patch16_dinov3_qkvb.sat493m': _dinov3_cfg(
  1605. hf_hub_id='timm/',
  1606. mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143),
  1607. ),
  1608. 'vit_huge_plus_patch16_dinov3.lvd1689m': _dinov3_cfg(
  1609. hf_hub_id='timm/',
  1610. ),
  1611. 'vit_huge_plus_patch16_dinov3_qkvb.lvd1689m': _dinov3_cfg(
  1612. hf_hub_id='timm/',
  1613. ),
  1614. 'vit_7b_patch16_dinov3.lvd1689m': _dinov3_cfg(
  1615. hf_hub_id='timm/',
  1616. ),
  1617. 'vit_7b_patch16_dinov3.sat493m': _dinov3_cfg(
  1618. hf_hub_id='timm/',
  1619. mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143),
  1620. ),
  1621. })
  1622. @register_model
  1623. def eva_giant_patch14_224(pretrained: bool = False, **kwargs) -> Eva:
  1624. """EVA-g model https://arxiv.org/abs/2211.07636"""
  1625. model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
  1626. model = _create_eva('eva_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1627. return model
  1628. @register_model
  1629. def eva_giant_patch14_336(pretrained: bool = False, **kwargs) -> Eva:
  1630. """EVA-g model https://arxiv.org/abs/2211.07636"""
  1631. model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
  1632. model = _create_eva('eva_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  1633. return model
  1634. @register_model
  1635. def eva_giant_patch14_560(pretrained: bool = False, **kwargs) -> Eva:
  1636. """EVA-g model https://arxiv.org/abs/2211.07636"""
  1637. model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
  1638. model = _create_eva('eva_giant_patch14_560', pretrained=pretrained, **dict(model_args, **kwargs))
  1639. return model
  1640. @register_model
  1641. def eva02_tiny_patch14_224(pretrained: bool = False, **kwargs) -> Eva:
  1642. """EVA02 Tiny https://arxiv.org/abs/2303.11331"""
  1643. model_args = dict(
  1644. img_size=224,
  1645. patch_size=14,
  1646. embed_dim=192,
  1647. depth=12,
  1648. num_heads=3,
  1649. mlp_ratio=4 * 2 / 3,
  1650. swiglu_mlp=True,
  1651. use_rot_pos_emb=True,
  1652. ref_feat_shape=(16, 16), # 224/14
  1653. )
  1654. model = _create_eva('eva02_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1655. return model
  1656. @register_model
  1657. def eva02_small_patch14_224(pretrained: bool = False, **kwargs) -> Eva:
  1658. """EVA02 Small https://arxiv.org/abs/2303.11331"""
  1659. model_args = dict(
  1660. img_size=224,
  1661. patch_size=14,
  1662. embed_dim=384,
  1663. depth=12,
  1664. num_heads=6,
  1665. mlp_ratio=4 * 2 / 3,
  1666. swiglu_mlp=True,
  1667. use_rot_pos_emb=True,
  1668. ref_feat_shape=(16, 16), # 224/14
  1669. )
  1670. model = _create_eva('eva02_small_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1671. return model
  1672. @register_model
  1673. def eva02_base_patch14_224(pretrained: bool = False, **kwargs) -> Eva:
  1674. """EVA02 Base https://arxiv.org/abs/2303.11331"""
  1675. model_args = dict(
  1676. img_size=224,
  1677. patch_size=14,
  1678. embed_dim=768,
  1679. depth=12,
  1680. num_heads=12,
  1681. qkv_fused=False,
  1682. mlp_ratio=4 * 2 / 3,
  1683. swiglu_mlp=True,
  1684. scale_mlp=True,
  1685. use_rot_pos_emb=True,
  1686. ref_feat_shape=(16, 16), # 224/14
  1687. )
  1688. model = _create_eva('eva02_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1689. return model
  1690. @register_model
  1691. def eva02_large_patch14_224(pretrained: bool = False, **kwargs) -> Eva:
  1692. """EVA02 Large https://arxiv.org/abs/2303.11331"""
  1693. model_args = dict(
  1694. img_size=224,
  1695. patch_size=14,
  1696. embed_dim=1024,
  1697. depth=24,
  1698. num_heads=16,
  1699. mlp_ratio=4 * 2 / 3,
  1700. qkv_fused=False,
  1701. swiglu_mlp=True,
  1702. scale_mlp=True,
  1703. use_rot_pos_emb=True,
  1704. ref_feat_shape=(16, 16), # 224/14
  1705. )
  1706. model = _create_eva('eva02_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1707. return model
  1708. @register_model
  1709. def eva02_tiny_patch14_336(pretrained: bool = False, **kwargs) -> Eva:
  1710. """EVA02 Tiny https://arxiv.org/abs/2303.11331"""
  1711. model_args = dict(
  1712. img_size=336,
  1713. patch_size=14,
  1714. embed_dim=192,
  1715. depth=12,
  1716. num_heads=3,
  1717. mlp_ratio=4 * 2 / 3,
  1718. swiglu_mlp=True,
  1719. use_rot_pos_emb=True,
  1720. ref_feat_shape=(16, 16), # 224/14
  1721. )
  1722. model = _create_eva('eva02_tiny_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  1723. return model
  1724. @register_model
  1725. def eva02_small_patch14_336(pretrained: bool = False, **kwargs) -> Eva:
  1726. """EVA02 Small https://arxiv.org/abs/2303.11331"""
  1727. model_args = dict(
  1728. img_size=336,
  1729. patch_size=14,
  1730. embed_dim=384,
  1731. depth=12,
  1732. num_heads=6,
  1733. mlp_ratio=4 * 2 / 3,
  1734. swiglu_mlp=True,
  1735. use_rot_pos_emb=True,
  1736. ref_feat_shape=(16, 16), # 224/14
  1737. )
  1738. model = _create_eva('eva02_small_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  1739. return model
  1740. @register_model
  1741. def eva02_base_patch14_448(pretrained: bool = False, **kwargs) -> Eva:
  1742. """EVA02 Base https://arxiv.org/abs/2303.11331"""
  1743. model_args = dict(
  1744. img_size=448,
  1745. patch_size=14,
  1746. embed_dim=768,
  1747. depth=12,
  1748. num_heads=12,
  1749. qkv_fused=False,
  1750. mlp_ratio=4 * 2 / 3,
  1751. swiglu_mlp=True,
  1752. scale_mlp=True,
  1753. use_rot_pos_emb=True,
  1754. ref_feat_shape=(16, 16), # 224/14
  1755. )
  1756. model = _create_eva('eva02_base_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  1757. return model
  1758. @register_model
  1759. def eva02_large_patch14_448(pretrained: bool = False, **kwargs) -> Eva:
  1760. """EVA02 Large https://arxiv.org/abs/2303.11331"""
  1761. model_args = dict(
  1762. img_size=448,
  1763. patch_size=14,
  1764. embed_dim=1024,
  1765. depth=24,
  1766. num_heads=16,
  1767. mlp_ratio=4 * 2 / 3,
  1768. qkv_fused=False,
  1769. swiglu_mlp=True,
  1770. scale_mlp=True,
  1771. use_rot_pos_emb=True,
  1772. ref_feat_shape=(16, 16), # 224/14
  1773. )
  1774. model = _create_eva('eva02_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  1775. return model
  1776. @register_model
  1777. def eva_giant_patch14_clip_224(pretrained: bool = False, **kwargs) -> Eva:
  1778. """EVA-g CLIP model (only difference from non-CLIP is the pooling)"""
  1779. model_args = dict(
  1780. patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408,
  1781. global_pool=kwargs.pop('global_pool', 'token'))
  1782. model = _create_eva('eva_giant_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1783. return model
  1784. @register_model
  1785. def eva02_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> Eva:
  1786. """An EVA-CLIP specific variant that adds additional attn scale layer-norm to eva02_base"""
  1787. model_args = dict(
  1788. img_size=224,
  1789. patch_size=16,
  1790. embed_dim=768,
  1791. depth=12,
  1792. num_heads=12,
  1793. qkv_fused=False,
  1794. mlp_ratio=4 * 2 / 3,
  1795. swiglu_mlp=True,
  1796. scale_mlp=True,
  1797. scale_attn_inner=True,
  1798. use_rot_pos_emb=True,
  1799. ref_feat_shape=(16, 16), # 224/14
  1800. global_pool=kwargs.pop('global_pool', 'token'),
  1801. )
  1802. model = _create_eva('eva02_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1803. return model
  1804. @register_model
  1805. def eva02_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> Eva:
  1806. """An EVA-CLIP specific variant that adds additional attn scale layer-norm to eva02_large"""
  1807. model_args = dict(
  1808. img_size=224,
  1809. patch_size=14,
  1810. embed_dim=1024,
  1811. depth=24,
  1812. num_heads=16,
  1813. mlp_ratio=4 * 2 / 3,
  1814. qkv_fused=False,
  1815. swiglu_mlp=True,
  1816. scale_mlp=True,
  1817. scale_attn_inner=True,
  1818. use_rot_pos_emb=True,
  1819. ref_feat_shape=(16, 16), # 224/14
  1820. global_pool=kwargs.pop('global_pool', 'token'),
  1821. )
  1822. model = _create_eva('eva02_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1823. return model
  1824. @register_model
  1825. def eva02_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> Eva:
  1826. """An EVA-CLIP specific variant that adds additional attn scale layer-norm to eva02_large"""
  1827. model_args = dict(
  1828. img_size=336,
  1829. patch_size=14,
  1830. embed_dim=1024,
  1831. depth=24,
  1832. num_heads=16,
  1833. mlp_ratio=4 * 2 / 3,
  1834. qkv_fused=False,
  1835. swiglu_mlp=True,
  1836. scale_mlp=True,
  1837. scale_attn_inner=True,
  1838. use_rot_pos_emb=True,
  1839. ref_feat_shape=(16, 16), # 224/14
  1840. global_pool=kwargs.pop('global_pool', 'token'),
  1841. )
  1842. model = _create_eva('eva02_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
  1843. return model
  1844. @register_model
  1845. def eva02_enormous_patch14_clip_224(pretrained: bool = False, **kwargs) -> Eva:
  1846. """An EVA-CLIP specific variant that uses residual post-norm in blocks"""
  1847. model_args = dict(
  1848. img_size=224,
  1849. patch_size=14,
  1850. embed_dim=1792,
  1851. depth=64,
  1852. num_heads=16,
  1853. mlp_ratio=15360 / 1792,
  1854. use_post_norm=True,
  1855. global_pool=kwargs.pop('global_pool', 'token'),
  1856. )
  1857. model = _create_eva('eva02_enormous_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1858. return model
  1859. @register_model
  1860. def vit_medium_patch16_rope_reg1_gap_256(pretrained: bool = False, **kwargs) -> Eva:
  1861. """timm SBB ViT with ROPE"""
  1862. model_args = dict(
  1863. img_size=256,
  1864. patch_size=16,
  1865. embed_dim=512,
  1866. depth=12,
  1867. num_heads=8,
  1868. qkv_fused=True,
  1869. qkv_bias=True,
  1870. init_values=1e-5,
  1871. class_token=False,
  1872. num_reg_tokens=1,
  1873. use_rot_pos_emb=True,
  1874. use_abs_pos_emb=False,
  1875. ref_feat_shape=(16, 16), # 224/14
  1876. )
  1877. model = _create_eva('vit_medium_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  1878. return model
  1879. @register_model
  1880. def vit_mediumd_patch16_rope_reg1_gap_256(pretrained: bool = False, **kwargs) -> Eva:
  1881. """timm SBB ViT with ROPE"""
  1882. model_args = dict(
  1883. img_size=256,
  1884. patch_size=16,
  1885. embed_dim=512,
  1886. depth=20,
  1887. num_heads=8,
  1888. qkv_fused=True,
  1889. qkv_bias=False,
  1890. init_values=1e-5,
  1891. class_token=False,
  1892. num_reg_tokens=1,
  1893. use_rot_pos_emb=True,
  1894. use_abs_pos_emb=False,
  1895. ref_feat_shape=(16, 16), # 224/14
  1896. )
  1897. model = _create_eva('vit_mediumd_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  1898. return model
  1899. @register_model
  1900. def vit_betwixt_patch16_rope_reg4_gap_256(pretrained: bool = False, **kwargs) -> Eva:
  1901. """timm SBB ViT with ROPE"""
  1902. model_args = dict(
  1903. img_size=256,
  1904. patch_size=16,
  1905. embed_dim=640,
  1906. depth=12,
  1907. num_heads=10,
  1908. qkv_fused=True,
  1909. qkv_bias=True,
  1910. init_values=1e-5,
  1911. class_token=False,
  1912. num_reg_tokens=4,
  1913. use_rot_pos_emb=True,
  1914. use_abs_pos_emb=False,
  1915. ref_feat_shape=(16, 16), # 224/14
  1916. )
  1917. model = _create_eva('vit_betwixt_patch16_rope_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  1918. return model
  1919. @register_model
  1920. def vit_base_patch16_rope_reg1_gap_256(pretrained: bool = False, **kwargs) -> Eva:
  1921. """timm SBB ViT with ROPE"""
  1922. model_args = dict(
  1923. img_size=256,
  1924. patch_size=16,
  1925. embed_dim=768,
  1926. depth=12,
  1927. num_heads=12,
  1928. qkv_fused=True,
  1929. qkv_bias=True,
  1930. init_values=1e-5,
  1931. class_token=False,
  1932. num_reg_tokens=1,
  1933. use_rot_pos_emb=True,
  1934. use_abs_pos_emb=False,
  1935. ref_feat_shape=(16, 16), # 224/14
  1936. )
  1937. model = _create_eva('vit_base_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  1938. return model
  1939. @register_model
  1940. def vit_pe_core_tiny_patch16_384(pretrained: bool = False, **kwargs) -> Eva:
  1941. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  1942. model_args = dict(
  1943. patch_size=16,
  1944. embed_dim=192,
  1945. depth=12,
  1946. num_heads=3,
  1947. mlp_ratio=4.0,
  1948. global_pool='map',
  1949. attn_type='rope',
  1950. use_pre_transformer_norm=True,
  1951. use_rot_pos_emb=True,
  1952. ref_feat_shape=(24, 24),
  1953. rope_grid_offset=1.,
  1954. rope_grid_indexing='xy',
  1955. attn_pool_num_heads=8,
  1956. attn_pool_mlp_ratio=4.,
  1957. norm_layer=partial(LayerNorm, eps=1e-5),
  1958. #dynamic_img_size=True
  1959. )
  1960. return _create_eva('vit_pe_core_tiny_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  1961. @register_model
  1962. def vit_pe_core_small_patch16_384(pretrained: bool = False, **kwargs) -> Eva:
  1963. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  1964. model_args = dict(
  1965. patch_size=16,
  1966. embed_dim=384,
  1967. depth=12,
  1968. num_heads=6,
  1969. mlp_ratio=4.0,
  1970. global_pool='map',
  1971. attn_type='rope',
  1972. use_pre_transformer_norm=True,
  1973. use_rot_pos_emb=True,
  1974. ref_feat_shape=(24, 24),
  1975. rope_grid_offset=1.,
  1976. rope_grid_indexing='xy',
  1977. attn_pool_num_heads=8,
  1978. attn_pool_mlp_ratio=4.,
  1979. norm_layer=partial(LayerNorm, eps=1e-5),
  1980. #dynamic_img_size=True
  1981. )
  1982. return _create_eva('vit_pe_core_small_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  1983. @register_model
  1984. def vit_pe_core_base_patch16_224(pretrained: bool = False, **kwargs) -> Eva:
  1985. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  1986. model_args = dict(
  1987. patch_size=16,
  1988. embed_dim=768,
  1989. depth=12,
  1990. num_heads=12,
  1991. mlp_ratio=4.0,
  1992. global_pool='map',
  1993. attn_type='rope',
  1994. use_pre_transformer_norm=True,
  1995. use_rot_pos_emb=True,
  1996. ref_feat_shape=(14, 14),
  1997. rope_grid_offset=1.,
  1998. rope_grid_indexing='xy',
  1999. attn_pool_num_heads=8,
  2000. attn_pool_mlp_ratio=4.,
  2001. norm_layer=partial(LayerNorm, eps=1e-5),
  2002. #dynamic_img_size=True
  2003. )
  2004. return _create_eva('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2005. @register_model
  2006. def vit_pe_core_large_patch14_336(pretrained: bool = False, **kwargs) -> Eva:
  2007. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2008. model_args = dict(
  2009. patch_size=14,
  2010. embed_dim=1024,
  2011. depth=24,
  2012. num_heads=16,
  2013. mlp_ratio=4.0,
  2014. global_pool='map',
  2015. attn_type='rope',
  2016. use_pre_transformer_norm=True,
  2017. use_rot_pos_emb=True,
  2018. ref_feat_shape=(24, 24),
  2019. rope_grid_offset=1.,
  2020. rope_grid_indexing='xy',
  2021. attn_pool_num_heads=8,
  2022. attn_pool_mlp_ratio=4.,
  2023. norm_layer=partial(LayerNorm, eps=1e-5),
  2024. #dynamic_img_size=True,
  2025. )
  2026. return _create_eva('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  2027. @register_model
  2028. def vit_pe_core_gigantic_patch14_448(pretrained: bool = False, **kwargs) -> Eva:
  2029. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2030. model_args = dict(
  2031. patch_size=14,
  2032. embed_dim=1536,
  2033. depth=50,
  2034. num_heads=16,
  2035. mlp_ratio=8960 / 1536,
  2036. global_pool='map',
  2037. attn_type='rope',
  2038. class_token=False,
  2039. use_pre_transformer_norm=True,
  2040. use_rot_pos_emb=True,
  2041. ref_feat_shape=(32, 32),
  2042. rope_grid_indexing='xy',
  2043. attn_pool_num_heads=8,
  2044. attn_pool_mlp_ratio=4.,
  2045. norm_layer=partial(LayerNorm, eps=1e-5),
  2046. #dynamic_img_size=True,
  2047. )
  2048. return _create_eva('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  2049. @register_model
  2050. def vit_pe_lang_large_patch14_448(pretrained: bool = False, **kwargs) -> Eva:
  2051. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2052. model_args = dict(
  2053. patch_size=14,
  2054. embed_dim=1024,
  2055. depth=23,
  2056. num_heads=16,
  2057. mlp_ratio=4.0,
  2058. attn_type='rope',
  2059. class_token=True,
  2060. use_rot_pos_emb=True,
  2061. ref_feat_shape=(32, 32),
  2062. rope_grid_offset=1.,
  2063. rope_grid_indexing='xy',
  2064. use_pre_transformer_norm=True,
  2065. use_post_transformer_norm=False,
  2066. use_fc_norm=False, # explicitly disable
  2067. init_values=0.1,
  2068. norm_layer=partial(LayerNorm, eps=1e-5),
  2069. #dynamic_img_size=True,
  2070. )
  2071. return _create_eva('vit_pe_lang_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  2072. @register_model
  2073. def vit_pe_lang_gigantic_patch14_448(pretrained: bool = False, **kwargs) -> Eva:
  2074. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2075. model_args = dict(
  2076. patch_size=14,
  2077. embed_dim=1536,
  2078. depth=47,
  2079. num_heads=16,
  2080. mlp_ratio=8960 / 1536,
  2081. attn_type='rope',
  2082. class_token=False,
  2083. use_rot_pos_emb=True,
  2084. ref_feat_shape=(32, 32),
  2085. rope_grid_indexing='xy',
  2086. use_pre_transformer_norm=True,
  2087. use_post_transformer_norm=False,
  2088. use_fc_norm=False, # explicitly disable
  2089. init_values=0.1,
  2090. norm_layer=partial(LayerNorm, eps=1e-5),
  2091. #dynamic_img_size=True,
  2092. )
  2093. return _create_eva('vit_pe_lang_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  2094. @register_model
  2095. def vit_pe_spatial_tiny_patch16_512(pretrained: bool = False, **kwargs) -> Eva:
  2096. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2097. model_args = dict(
  2098. patch_size=16,
  2099. embed_dim=192,
  2100. depth=12,
  2101. num_heads=3,
  2102. mlp_ratio=4.0,
  2103. attn_type='rope',
  2104. use_pre_transformer_norm=True,
  2105. use_post_transformer_norm=False,
  2106. use_fc_norm=False, # explicitly disable
  2107. use_rot_pos_emb=True,
  2108. ref_feat_shape=(32, 32),
  2109. rope_grid_offset=1.,
  2110. rope_grid_indexing='xy',
  2111. norm_layer=partial(LayerNorm, eps=1e-5),
  2112. #dynamic_img_size=True
  2113. )
  2114. return _create_eva('vit_pe_spatial_tiny_patch16_512', pretrained=pretrained, **dict(model_args, **kwargs))
  2115. @register_model
  2116. def vit_pe_spatial_small_patch16_512(pretrained: bool = False, **kwargs) -> Eva:
  2117. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2118. model_args = dict(
  2119. patch_size=16,
  2120. embed_dim=384,
  2121. depth=12,
  2122. num_heads=6,
  2123. mlp_ratio=4.0,
  2124. attn_type='rope',
  2125. use_pre_transformer_norm=True,
  2126. use_post_transformer_norm=False,
  2127. use_fc_norm=False, # explicitly disable
  2128. use_rot_pos_emb=True,
  2129. ref_feat_shape=(32, 32),
  2130. rope_grid_offset=1.,
  2131. rope_grid_indexing='xy',
  2132. norm_layer=partial(LayerNorm, eps=1e-5),
  2133. #dynamic_img_size=True
  2134. )
  2135. return _create_eva('vit_pe_spatial_small_patch16_512', pretrained=pretrained, **dict(model_args, **kwargs))
  2136. @register_model
  2137. def vit_pe_spatial_base_patch16_512(pretrained: bool = False, **kwargs) -> Eva:
  2138. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2139. model_args = dict(
  2140. patch_size=16,
  2141. embed_dim=768,
  2142. depth=12,
  2143. num_heads=12,
  2144. mlp_ratio=4.0,
  2145. attn_type='rope',
  2146. use_pre_transformer_norm=True,
  2147. use_post_transformer_norm=False,
  2148. use_fc_norm=False, # explicitly disable
  2149. use_rot_pos_emb=True,
  2150. ref_feat_shape=(32, 32),
  2151. rope_grid_offset=1.,
  2152. rope_grid_indexing='xy',
  2153. norm_layer=partial(LayerNorm, eps=1e-5),
  2154. #dynamic_img_size=True
  2155. )
  2156. return _create_eva('vit_pe_spatial_base_patch16_512', pretrained=pretrained, **dict(model_args, **kwargs))
  2157. @register_model
  2158. def vit_pe_spatial_large_patch14_448(pretrained: bool = False, **kwargs) -> Eva:
  2159. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2160. model_args = dict(
  2161. patch_size=14,
  2162. embed_dim=1024,
  2163. depth=24,
  2164. num_heads=16,
  2165. mlp_ratio=4.0,
  2166. attn_type='rope',
  2167. use_pre_transformer_norm=True,
  2168. use_post_transformer_norm=False,
  2169. use_fc_norm=False, # explicitly disable
  2170. use_rot_pos_emb=True,
  2171. ref_feat_shape=(32, 32),
  2172. rope_grid_offset=1.,
  2173. rope_grid_indexing='xy',
  2174. norm_layer=partial(LayerNorm, eps=1e-5),
  2175. #dynamic_img_size=True,
  2176. )
  2177. return _create_eva('vit_pe_spatial_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  2178. @register_model
  2179. def vit_pe_spatial_gigantic_patch14_448(pretrained: bool = False, **kwargs) -> Eva:
  2180. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2181. model_args = dict(
  2182. patch_size=14,
  2183. embed_dim=1536,
  2184. depth=50,
  2185. num_heads=16,
  2186. mlp_ratio=8960 / 1536,
  2187. attn_type='rope',
  2188. class_token=False,
  2189. use_rot_pos_emb=True,
  2190. ref_feat_shape=(32, 32),
  2191. rope_grid_indexing='xy',
  2192. use_pre_transformer_norm=True,
  2193. use_post_transformer_norm=False,
  2194. use_fc_norm=False, # explicitly disable
  2195. init_values=0.1,
  2196. norm_layer=partial(LayerNorm, eps=1e-5),
  2197. #dynamic_img_size=True,
  2198. )
  2199. return _create_eva('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  2200. # RoPE-ViT models from https://github.com/naver-ai/rope-vit
  2201. @register_model
  2202. def vit_small_patch16_rope_224(pretrained: bool = False, **kwargs) -> Eva:
  2203. """RoPE-Axial ViT-S/16 from https://github.com/naver-ai/rope-vit"""
  2204. model_args = dict(
  2205. patch_size=16,
  2206. embed_dim=384,
  2207. depth=12,
  2208. num_heads=6,
  2209. mlp_ratio=4,
  2210. attn_type='rope',
  2211. qkv_bias=True,
  2212. init_values=1e-5,
  2213. class_token=True,
  2214. global_pool='token',
  2215. use_abs_pos_emb=False,
  2216. use_rot_pos_emb=True,
  2217. rope_grid_indexing='xy',
  2218. rope_temperature=100.0,
  2219. )
  2220. model = _create_eva('vit_small_patch16_rope_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2221. return model
  2222. @register_model
  2223. def vit_base_patch16_rope_224(pretrained: bool = False, **kwargs) -> Eva:
  2224. """RoPE-Axial ViT-B/16 from https://github.com/naver-ai/rope-vit"""
  2225. model_args = dict(
  2226. patch_size=16,
  2227. embed_dim=768,
  2228. depth=12,
  2229. num_heads=12,
  2230. mlp_ratio=4,
  2231. attn_type='rope',
  2232. use_fc_norm=False,
  2233. qkv_bias=True,
  2234. init_values=1e-5,
  2235. class_token=True,
  2236. global_pool='token',
  2237. use_abs_pos_emb=False,
  2238. use_rot_pos_emb=True,
  2239. rope_grid_indexing='xy',
  2240. rope_temperature=100.0,
  2241. )
  2242. model = _create_eva('vit_base_patch16_rope_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2243. return model
  2244. @register_model
  2245. def vit_large_patch16_rope_224(pretrained: bool = False, **kwargs) -> Eva:
  2246. """RoPE-Axial ViT-L/16 from https://github.com/naver-ai/rope-vit"""
  2247. model_args = dict(
  2248. patch_size=16,
  2249. embed_dim=1024,
  2250. depth=24,
  2251. num_heads=16,
  2252. mlp_ratio=4,
  2253. attn_type='rope',
  2254. qkv_bias=True,
  2255. init_values=1e-5,
  2256. class_token=True,
  2257. global_pool='token',
  2258. use_abs_pos_emb=False,
  2259. use_rot_pos_emb=True,
  2260. rope_grid_indexing='xy',
  2261. rope_temperature=100.0,
  2262. )
  2263. model = _create_eva('vit_large_patch16_rope_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2264. return model
  2265. @register_model
  2266. def vit_small_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
  2267. """RoPE-Mixed ViT-S/16 from https://github.com/naver-ai/rope-vit"""
  2268. model_args = dict(
  2269. patch_size=16,
  2270. embed_dim=384,
  2271. depth=12,
  2272. num_heads=6,
  2273. mlp_ratio=4,
  2274. attn_type='rope',
  2275. qkv_bias=True,
  2276. init_values=1e-5,
  2277. class_token=True,
  2278. global_pool='token',
  2279. use_abs_pos_emb=False,
  2280. use_rot_pos_emb=True,
  2281. rope_grid_indexing='xy',
  2282. rope_temperature=10.0,
  2283. rope_type='mixed'
  2284. )
  2285. model = _create_eva('vit_small_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2286. return model
  2287. @register_model
  2288. def vit_base_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
  2289. """RoPE-Mixed ViT-B/16 from https://github.com/naver-ai/rope-vit"""
  2290. model_args = dict(
  2291. patch_size=16,
  2292. embed_dim=768,
  2293. depth=12,
  2294. num_heads=12,
  2295. mlp_ratio=4,
  2296. qkv_bias=True,
  2297. attn_type='rope',
  2298. init_values=1e-5,
  2299. class_token=True,
  2300. global_pool='token',
  2301. use_abs_pos_emb=False,
  2302. use_rot_pos_emb=True,
  2303. rope_grid_indexing='xy',
  2304. rope_temperature=10.0,
  2305. rope_type='mixed'
  2306. )
  2307. model = _create_eva('vit_base_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2308. return model
  2309. @register_model
  2310. def vit_large_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
  2311. """RoPE-Mixed ViT-L/16 from https://github.com/naver-ai/rope-vit"""
  2312. model_args = dict(
  2313. patch_size=16,
  2314. embed_dim=1024,
  2315. depth=24,
  2316. num_heads=16,
  2317. mlp_ratio=4,
  2318. attn_type='rope',
  2319. qkv_bias=True,
  2320. init_values=1e-5,
  2321. class_token=True,
  2322. global_pool='token',
  2323. use_abs_pos_emb=False,
  2324. use_rot_pos_emb=True,
  2325. rope_grid_indexing='xy',
  2326. rope_temperature=10.0,
  2327. rope_type='mixed'
  2328. )
  2329. model = _create_eva('vit_large_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2330. return model
  2331. # APE variants (with absolute position embeddings)
  2332. @register_model
  2333. def vit_small_patch16_rope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
  2334. """RoPE-Axial + APE ViT-S/16 from https://github.com/naver-ai/rope-vit"""
  2335. model_args = dict(
  2336. patch_size=16,
  2337. embed_dim=384,
  2338. depth=12,
  2339. num_heads=6,
  2340. mlp_ratio=4,
  2341. attn_type='rope',
  2342. qkv_bias=True,
  2343. init_values=1e-5,
  2344. class_token=True,
  2345. global_pool='token',
  2346. no_embed_class=True,
  2347. use_abs_pos_emb=True,
  2348. use_rot_pos_emb=True,
  2349. rope_grid_indexing='xy',
  2350. rope_temperature=100.0,
  2351. )
  2352. model = _create_eva('vit_small_patch16_rope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2353. return model
  2354. @register_model
  2355. def vit_base_patch16_rope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
  2356. """RoPE-Axial + APE ViT-B/16 from https://github.com/naver-ai/rope-vit"""
  2357. model_args = dict(
  2358. patch_size=16,
  2359. embed_dim=768,
  2360. depth=12,
  2361. num_heads=12,
  2362. mlp_ratio=4,
  2363. attn_type='rope',
  2364. qkv_bias=True,
  2365. init_values=1e-5,
  2366. class_token=True,
  2367. global_pool='token',
  2368. no_embed_class=True,
  2369. use_abs_pos_emb=True,
  2370. use_rot_pos_emb=True,
  2371. rope_grid_indexing='xy',
  2372. rope_temperature=100.0,
  2373. )
  2374. model = _create_eva('vit_base_patch16_rope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2375. return model
  2376. @register_model
  2377. def vit_large_patch16_rope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
  2378. """RoPE-Axial + APE ViT-L/16 from https://github.com/naver-ai/rope-vit"""
  2379. model_args = dict(
  2380. patch_size=16,
  2381. embed_dim=1024,
  2382. depth=24,
  2383. num_heads=16,
  2384. mlp_ratio=4,
  2385. attn_type='rope',
  2386. qkv_bias=True,
  2387. init_values=1e-5,
  2388. class_token=True,
  2389. global_pool='token',
  2390. no_embed_class=True,
  2391. use_abs_pos_emb=True,
  2392. use_rot_pos_emb=True,
  2393. rope_grid_indexing='xy',
  2394. rope_temperature=100.0,
  2395. )
  2396. model = _create_eva('vit_large_patch16_rope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2397. return model
  2398. @register_model
  2399. def vit_small_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva:
  2400. """RoPE-Mixed + APE ViT-S/16 from https://github.com/naver-ai/rope-vit"""
  2401. model_args = dict(
  2402. patch_size=16,
  2403. embed_dim=384,
  2404. depth=12,
  2405. num_heads=6,
  2406. mlp_ratio=4,
  2407. attn_type='rope',
  2408. qkv_bias=True,
  2409. init_values=1e-5,
  2410. class_token=True,
  2411. global_pool='token',
  2412. no_embed_class=True,
  2413. use_abs_pos_emb=True,
  2414. use_rot_pos_emb=True,
  2415. rope_grid_indexing='xy',
  2416. rope_temperature=10.0,
  2417. rope_type='mixed'
  2418. )
  2419. model = _create_eva('vit_small_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2420. return model
  2421. @register_model
  2422. def vit_base_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva:
  2423. """RoPE-Mixed + APE ViT-B/16 from https://github.com/naver-ai/rope-vit"""
  2424. model_args = dict(
  2425. patch_size=16,
  2426. embed_dim=768,
  2427. depth=12,
  2428. num_heads=12,
  2429. mlp_ratio=4,
  2430. attn_type='rope',
  2431. qkv_bias=True,
  2432. init_values=1e-5,
  2433. class_token=True,
  2434. global_pool='token',
  2435. no_embed_class=True,
  2436. use_abs_pos_emb=True,
  2437. use_rot_pos_emb=True,
  2438. rope_grid_indexing='xy',
  2439. rope_temperature=10.0,
  2440. rope_type='mixed'
  2441. )
  2442. model = _create_eva('vit_base_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2443. return model
  2444. @register_model
  2445. def vit_large_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva:
  2446. """RoPE-Mixed + APE ViT-L/16 from https://github.com/naver-ai/rope-vit"""
  2447. model_args = dict(
  2448. patch_size=16,
  2449. embed_dim=1024,
  2450. depth=24,
  2451. num_heads=16,
  2452. mlp_ratio=4,
  2453. attn_type='rope',
  2454. qkv_bias=True,
  2455. init_values=1e-5,
  2456. class_token=True,
  2457. global_pool='token',
  2458. no_embed_class=True,
  2459. use_abs_pos_emb=True,
  2460. use_rot_pos_emb=True,
  2461. rope_grid_indexing='xy',
  2462. rope_temperature=10.0,
  2463. rope_type='mixed'
  2464. )
  2465. model = _create_eva('vit_large_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2466. return model
  2467. @register_model
  2468. def vit_small_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva:
  2469. """DINOv3 S/16 https://arxiv.org/abs/2508.10104
  2470. NOTE: Pass global_pool='token' to use CLS-token pooling (matches upstream DINOv3).
  2471. """
  2472. model_args = dict(
  2473. patch_size=16,
  2474. dynamic_img_size=True,
  2475. embed_dim=384,
  2476. depth=12,
  2477. num_heads=6,
  2478. qkv_bias=False,
  2479. init_values=1.0e-05, # layer-scale
  2480. rope_type='dinov3',
  2481. rope_temperature=100,
  2482. #rope_rescale_coords=2, # haven't added to interface
  2483. rope_rotate_half=True,
  2484. use_rot_pos_emb=True,
  2485. use_abs_pos_emb=False,
  2486. num_reg_tokens=4,
  2487. use_fc_norm=False,
  2488. norm_layer=partial(LayerNorm, eps=1e-5),
  2489. )
  2490. model = _create_eva('vit_small_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs))
  2491. return model
  2492. @register_model
  2493. def vit_small_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva:
  2494. """DINOv3 S/16 w/ QKV bias enabled (but zero) https://arxiv.org/abs/2508.10104
  2495. NOTE: Pass global_pool='token' to use CLS-token pooling (matches upstream DINOv3).
  2496. """
  2497. model_args = dict(
  2498. patch_size=16,
  2499. dynamic_img_size=True,
  2500. embed_dim=384,
  2501. depth=12,
  2502. num_heads=6,
  2503. qkv_bias=True,
  2504. init_values=1.0e-05, # layer-scale
  2505. rope_type='dinov3',
  2506. rope_temperature=100,
  2507. #rope_rescale_coords=2, # haven't added to interface
  2508. rope_rotate_half=True,
  2509. use_rot_pos_emb=True,
  2510. use_abs_pos_emb=False,
  2511. num_reg_tokens=4,
  2512. use_fc_norm=False,
  2513. norm_layer=partial(LayerNorm, eps=1e-5),
  2514. )
  2515. model = _create_eva('vit_small_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs))
  2516. return model
  2517. @register_model
  2518. def vit_small_plus_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva:
  2519. """DINOv3 S/16 Plus https://arxiv.org/abs/2508.10104
  2520. NOTE: Pass global_pool='token' to use CLS-token pooling (matches upstream DINOv3).
  2521. """
  2522. model_args = dict(
  2523. patch_size=16,
  2524. dynamic_img_size=True,
  2525. embed_dim=384,
  2526. depth=12,
  2527. num_heads=6,
  2528. qkv_bias=False,
  2529. init_values=1.0e-05, # layer-scale
  2530. rope_type='dinov3',
  2531. rope_temperature=100,
  2532. #rope_rescale_coords=2, # haven't added to interface
  2533. rope_rotate_half=True,
  2534. use_rot_pos_emb=True,
  2535. use_abs_pos_emb=False,
  2536. swiglu_mlp=True,
  2537. swiglu_align_to=8,
  2538. num_reg_tokens=4,
  2539. use_fc_norm=False,
  2540. norm_layer=partial(LayerNorm, eps=1e-5),
  2541. )
  2542. model = _create_eva('vit_small_plus_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs))
  2543. return model
  2544. @register_model
  2545. def vit_small_plus_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva:
  2546. """DINOv3 S/16 Plus w/ QKV bias enabled (but 0) https://arxiv.org/abs/2508.10104
  2547. NOTE: Pass global_pool='token' to use CLS-token pooling (matches upstream DINOv3).
  2548. """
  2549. model_args = dict(
  2550. patch_size=16,
  2551. dynamic_img_size=True,
  2552. embed_dim=384,
  2553. depth=12,
  2554. num_heads=6,
  2555. qkv_bias=True,
  2556. init_values=1.0e-05, # layer-scale
  2557. rope_type='dinov3',
  2558. rope_temperature=100,
  2559. #rope_rescale_coords=2, # haven't added to interface
  2560. rope_rotate_half=True,
  2561. use_rot_pos_emb=True,
  2562. use_abs_pos_emb=False,
  2563. swiglu_mlp=True,
  2564. swiglu_align_to=8,
  2565. num_reg_tokens=4,
  2566. use_fc_norm=False,
  2567. norm_layer=partial(LayerNorm, eps=1e-5),
  2568. )
  2569. model = _create_eva('vit_small_plus_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs))
  2570. return model
  2571. @register_model
  2572. def vit_base_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva:
  2573. """DINOv3 B/16 https://arxiv.org/abs/2508.10104
  2574. NOTE: Pass global_pool='token' to use CLS-token pooling (matches upstream DINOv3).
  2575. """
  2576. model_args = dict(
  2577. patch_size=16,
  2578. dynamic_img_size=True,
  2579. embed_dim=768,
  2580. depth=12,
  2581. num_heads=12,
  2582. qkv_bias=False,
  2583. init_values=1.0e-05, # layer-scale
  2584. rope_type='dinov3',
  2585. rope_temperature=100,
  2586. #rope_rescale_coords=2, # haven't added to interface
  2587. rope_rotate_half=True,
  2588. use_rot_pos_emb=True,
  2589. use_abs_pos_emb=False,
  2590. num_reg_tokens=4,
  2591. use_fc_norm=False,
  2592. norm_layer=partial(LayerNorm, eps=1e-5),
  2593. )
  2594. model = _create_eva('vit_base_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs))
  2595. return model
  2596. @register_model
  2597. def vit_base_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva:
  2598. """DINOv3 B/16 w/ QKV bias enabled (but zero) https://arxiv.org/abs/2508.10104
  2599. NOTE: Pass global_pool='token' to use CLS-token pooling (matches upstream DINOv3).
  2600. """
  2601. model_args = dict(
  2602. patch_size=16,
  2603. dynamic_img_size=True,
  2604. embed_dim=768,
  2605. depth=12,
  2606. num_heads=12,
  2607. qkv_bias=True,
  2608. init_values=1.0e-05, # layer-scale
  2609. rope_type='dinov3',
  2610. rope_temperature=100,
  2611. #rope_rescale_coords=2, # haven't added to interface
  2612. rope_rotate_half=True,
  2613. use_rot_pos_emb=True,
  2614. use_abs_pos_emb=False,
  2615. num_reg_tokens=4,
  2616. use_fc_norm=False,
  2617. norm_layer=partial(LayerNorm, eps=1e-5),
  2618. )
  2619. model = _create_eva('vit_base_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs))
  2620. return model
  2621. @register_model
  2622. def vit_large_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva:
  2623. """DINOv3 L/16 https://arxiv.org/abs/2508.10104
  2624. NOTE: Pass global_pool='token' to use CLS-token pooling (matches upstream DINOv3).
  2625. """
  2626. model_args = dict(
  2627. patch_size=16,
  2628. dynamic_img_size=True,
  2629. embed_dim=1024,
  2630. depth=24,
  2631. num_heads=16,
  2632. qkv_bias=False,
  2633. init_values=1.0e-5, # layer-scale
  2634. rope_type='dinov3',
  2635. rope_temperature=100,
  2636. use_rot_pos_emb=True,
  2637. use_abs_pos_emb=False,
  2638. rope_rotate_half=True,
  2639. #rope_rescale_coords=2, # haven't added to interface
  2640. num_reg_tokens=4,
  2641. use_fc_norm=False,
  2642. norm_layer=partial(LayerNorm, eps=1e-5),
  2643. )
  2644. model = _create_eva('vit_large_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs))
  2645. return model
  2646. @register_model
  2647. def vit_large_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva:
  2648. """DINOv3 w/ QKV bias enabled (but zero) https://arxiv.org/abs/2508.10104
  2649. NOTE: Pass global_pool='token' to use CLS-token pooling (matches upstream DINOv3).
  2650. """
  2651. model_args = dict(
  2652. patch_size=16,
  2653. dynamic_img_size=True,
  2654. embed_dim=1024,
  2655. depth=24,
  2656. num_heads=16,
  2657. qkv_bias=True,
  2658. init_values=1.0e-5, # layer-scale
  2659. rope_type='dinov3',
  2660. rope_temperature=100,
  2661. use_rot_pos_emb=True,
  2662. use_abs_pos_emb=False,
  2663. rope_rotate_half=True,
  2664. #rope_rescale_coords=2, # haven't added to interface
  2665. num_reg_tokens=4,
  2666. use_fc_norm=False,
  2667. norm_layer=partial(LayerNorm, eps=1e-5),
  2668. )
  2669. model = _create_eva('vit_large_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs))
  2670. return model
  2671. @register_model
  2672. def vit_huge_plus_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva:
  2673. """DINOv3 H/16 Plus https://arxiv.org/abs/2508.10104
  2674. NOTE: Pass global_pool='token' to use CLS-token pooling (matches upstream DINOv3).
  2675. """
  2676. model_args = dict(
  2677. patch_size=16,
  2678. dynamic_img_size=True,
  2679. embed_dim=1280,
  2680. depth=32,
  2681. num_heads=20,
  2682. qkv_bias=False,
  2683. init_values=1.0e-5, # layer-scale
  2684. rope_type='dinov3',
  2685. rope_temperature=100,
  2686. use_rot_pos_emb=True,
  2687. use_abs_pos_emb=False,
  2688. rope_rotate_half=True,
  2689. swiglu_mlp=True,
  2690. swiglu_align_to=8,
  2691. #rope_rescale_coords=2, # haven't added to interface
  2692. num_reg_tokens=4,
  2693. use_fc_norm=False,
  2694. norm_layer=partial(LayerNorm, eps=1e-5),
  2695. )
  2696. model = _create_eva('vit_huge_plus_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs))
  2697. return model
  2698. @register_model
  2699. def vit_huge_plus_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva:
  2700. """DINOv3 H/16 Plus w/ QKV bias enabled (but zero) https://arxiv.org/abs/2508.10104
  2701. NOTE: Pass global_pool='token' to use CLS-token pooling (matches upstream DINOv3).
  2702. """
  2703. model_args = dict(
  2704. patch_size=16,
  2705. dynamic_img_size=True,
  2706. embed_dim=1280,
  2707. depth=32,
  2708. num_heads=20,
  2709. qkv_bias=True,
  2710. init_values=1.0e-5, # layer-scale
  2711. rope_type='dinov3',
  2712. rope_temperature=100,
  2713. use_rot_pos_emb=True,
  2714. use_abs_pos_emb=False,
  2715. rope_rotate_half=True,
  2716. swiglu_mlp=True,
  2717. swiglu_align_to=8,
  2718. #rope_rescale_coords=2, # haven't added to interface
  2719. num_reg_tokens=4,
  2720. use_fc_norm=False,
  2721. norm_layer=partial(LayerNorm, eps=1e-5),
  2722. )
  2723. model = _create_eva('vit_huge_plus_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs))
  2724. return model
  2725. @register_model
  2726. def vit_7b_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva:
  2727. """DINOv3 7B/16 https://arxiv.org/abs/2508.10104
  2728. NOTE: Pass global_pool='token' to use CLS-token pooling (matches upstream DINOv3).
  2729. """
  2730. model_args = dict(
  2731. patch_size=16,
  2732. dynamic_img_size=True,
  2733. embed_dim=4096,
  2734. depth=40,
  2735. num_heads=32,
  2736. qkv_bias=False,
  2737. mlp_ratio=2,
  2738. init_values=1.0e-5, # layer-scale
  2739. rope_type='dinov3',
  2740. rope_temperature=100,
  2741. use_rot_pos_emb=True,
  2742. use_abs_pos_emb=False,
  2743. rope_rotate_half=True,
  2744. swiglu_mlp=True,
  2745. swiglu_align_to=64,
  2746. #rope_rescale_coords=2, # haven't added to interface
  2747. num_reg_tokens=4,
  2748. use_fc_norm=False,
  2749. norm_layer=partial(LayerNorm, eps=1e-5),
  2750. )
  2751. model = _create_eva('vit_7b_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs))
  2752. return model