naflexvit.py 92 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267
  1. """ NaFlex Vision Transformer
  2. An improved version of the Vision Transformer with:
  3. 1. Encapsulated embedding and position encoding in a single module
  4. 2. Support for linear patch embedding on pre-patchified inputs
  5. 3. Support for NaFlex variable aspect, variable resolution
  6. 4. Support for FlexiViT variable patch size
  7. 5. Support for NaViT fractional/factorized position embedding
  8. Based on ideas from:
  9. - Original Vision Transformer: https://arxiv.org/abs/2010.11929
  10. - FlexiViT: https://arxiv.org/abs/2212.08013
  11. - NaViT: https://arxiv.org/abs/2307.06304
  12. - NaFlex (SigLip-2): https://arxiv.org/abs/2502.14786
  13. Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
  14. """
  15. import logging
  16. import math
  17. from dataclasses import dataclass, fields, replace
  18. from functools import partial
  19. from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Any
  20. import torch
  21. import torch.nn as nn
  22. import torch.nn.functional as F
  23. from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
  24. from timm.layers import (
  25. AttentionPoolLatent,
  26. Mlp,
  27. LayerNorm,
  28. PatchDropoutWithIndices,
  29. PatchEmbedInterpolator,
  30. _assert,
  31. to_2tuple,
  32. get_act_layer,
  33. get_norm_layer,
  34. apply_keep_indices_nlc,
  35. disable_compiler,
  36. calculate_drop_path_rates,
  37. )
  38. from ._builder import build_model_with_cfg
  39. from ._features import feature_take_indices
  40. from ._features_fx import register_notrace_function, register_notrace_module
  41. from ._manipulate import checkpoint, named_apply
  42. from ._registry import register_model, generate_default_cfgs
  43. from .eva import EvaBlock
  44. from .vision_transformer import Block, global_pool_nlc
  45. __all__ = ['NaFlexVitCfg', 'NaFlexVit']
  46. _logger = logging.getLogger(__name__)
  47. @dataclass
  48. class NaFlexVitCfg:
  49. """Configuration for FlexVit model.
  50. This dataclass contains the bulk of model configuration parameters,
  51. with core parameters (img_size, in_chans, num_classes, etc.) remaining
  52. as direct constructor arguments for API compatibility.
  53. """
  54. # Architecture parameters
  55. patch_size: Union[int, Tuple[int, int]] = 16
  56. embed_dim: int = 768
  57. depth: int = 12
  58. num_heads: int = 12
  59. mlp_ratio: float = 4.0
  60. scale_mlp_norm: bool = False # Apply scaling norm to MLP
  61. # Attention parameters
  62. qkv_bias: bool = True
  63. qk_norm: bool = False
  64. proj_bias: bool = True
  65. attn_drop_rate: float = 0.0
  66. scale_attn_inner_norm: bool = False # Apply scaling norm to attn context
  67. # Regularization
  68. init_values: Optional[float] = None # Layer-scale init values (layer-scale enabled if not None)
  69. drop_rate: float = 0.0 # Dropout rate for classifier
  70. pos_drop_rate: float = 0.0 # Dropout rate for position embeddings
  71. patch_drop_rate: float = 0.0 # Dropout rate for patch tokens
  72. proj_drop_rate: float = 0.0 # Dropout rate for linear projections
  73. drop_path_rate: float = 0.0 # Stochastic depth drop rate
  74. # Prefix token configuration
  75. class_token: bool = False # Use class token
  76. reg_tokens: int = 0 # Number of register tokens
  77. # Position embedding configuration
  78. pos_embed: str = 'learned' # Type of position embedding ('learned', 'factorized', 'rope', 'none')
  79. pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16) # Grid size for position embedding initialization
  80. pos_embed_interp_mode: str = 'bicubic' # Interpolation mode for position embedding resizing
  81. pos_embed_ar_preserving: bool = False # Whether to preserve aspect ratio during position embedding interpolation
  82. pos_embed_use_grid_sample: bool = False # Whether to use grid_sample for naflex position embedding interpolation
  83. # ROPE specific configuration
  84. rope_type: str = '' # ROPE type: '' or 'none' for no ROPE, 'axial' for standard, 'mixed' for learnable frequencies
  85. rope_temperature: float = 10000.0 # Temperature for ROPE frequency computation
  86. rope_ref_feat_shape: Optional[Tuple[int, int]] = None
  87. rope_grid_offset: float = 0. # Grid offset for non-pixel ROPE mode
  88. rope_grid_indexing: str = 'ij' # Grid indexing mode for ROPE ('ij' or 'xy')
  89. # Image processing
  90. dynamic_img_pad: bool = False # Whether to enable dynamic padding for variable resolution
  91. # Other architecture choices
  92. pre_norm: bool = False # Whether to apply normalization before attention/MLP layers (start of blocks)
  93. final_norm: bool = True # Whether to apply final normalization before pooling and classifier (end of blocks)
  94. fc_norm: Optional[bool] = None # Whether to normalize features before final classifier (after pooling)
  95. # Global pooling setup
  96. global_pool: str = 'map' # Type of global pooling for final sequence
  97. pool_include_prefix: bool = False # Whether to include class/register prefix tokens in global pooling
  98. attn_pool_num_heads: Optional[int] = None # Override num_heads for attention pool
  99. attn_pool_mlp_ratio: Optional[float] = None # Override mlp_ratio for attention pool
  100. # Weight initialization
  101. weight_init: str = '' # Weight initialization scheme
  102. fix_init: bool = True # Apply weight initialization fix (scaling w/ layer index)
  103. # Embedding configuration
  104. embed_proj_type: str = 'linear' # Type of embedding layer ('conv' or 'linear')
  105. input_norm_layer: Optional[str] = None # Normalization layer for embeddings input (before input projection)
  106. embed_norm_layer: Optional[str] = None # Normalization layer for embeddings (after input projection)
  107. # Layer implementations
  108. norm_layer: Optional[str] = None # Normalization layer for transformer blocks
  109. act_layer: Optional[str] = None # Activation layer for MLP blocks
  110. block_fn: Optional[str] = None # Transformer block implementation class name
  111. mlp_layer: Optional[str] = None # MLP implementation class name
  112. attn_layer: Optional[str] = None # Attention layer implementation (e.g., 'attn', 'diff')
  113. # EVA-specific parameters
  114. attn_type: str = 'standard' # Attention type: 'standard', 'eva', 'rope'
  115. swiglu_mlp: bool = False # Use SwiGLU MLP variant
  116. qkv_fused: bool = True # Whether to use fused QKV projections
  117. # Variable patch size support
  118. enable_patch_interpolator: bool = False # Enable dynamic patch size support
  119. def _overlay_kwargs(cfg: NaFlexVitCfg, **kwargs) -> NaFlexVitCfg:
  120. """Overlay kwargs onto config, replacing config values with provided kwargs."""
  121. # Only update fields that exist in the config
  122. config_fields = set(cfg.__dataclass_fields__.keys())
  123. config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
  124. if config_kwargs:
  125. cfg = replace(cfg, **config_kwargs)
  126. return cfg
  127. def batch_patchify(
  128. x: torch.Tensor,
  129. patch_size: Tuple[int, int],
  130. pad: bool = True,
  131. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  132. """Patchify a batch of images.
  133. Args:
  134. x: Input tensor of shape [B, C, H, W].
  135. patch_size: Patch dimensions (patch_h, patch_w).
  136. pad: Whether to pad images to be divisible by patch size.
  137. Returns:
  138. Tuple of (patches, grid_size) where patches has shape [B, N, P*P*C]
  139. and grid_size is (num_patches_h, num_patches_w).
  140. """
  141. B, C, H, W = x.shape
  142. ph, pw = patch_size
  143. # Ensure the image is divisible by patch size
  144. if pad and (H % ph != 0 or W % pw != 0):
  145. pad_h = (ph - H % ph) % ph
  146. pad_w = (pw - W % pw) % pw
  147. x = F.pad(x, (0, pad_w, 0, pad_h))
  148. nh, nw = H // ph, W // pw
  149. patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C)
  150. # FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw
  151. return patches, (nh, nw)
  152. def calculate_naflex_grid_sizes(_coord: torch.Tensor):
  153. # Calculate the appropriate grid size from coords
  154. max_y = _coord[:, :, 0].amax(dim=1) + 1
  155. max_x = _coord[:, :, 1].amax(dim=1) + 1
  156. return [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)]
  157. class NaFlexRopeIterator:
  158. """Iterator for generating batched ROPE embeddings for mixed mode with multiple grid sizes."""
  159. def __init__(
  160. self,
  161. rope_module,
  162. size_to_indices: Dict[Tuple[int, int], List[int]],
  163. unique_sizes: List[Tuple[int, int]],
  164. batch_size: int,
  165. seq_len: int,
  166. device: torch.device,
  167. dtype: torch.dtype,
  168. ):
  169. self.rope = rope_module
  170. self.size_to_indices = size_to_indices
  171. self.unique_sizes = unique_sizes
  172. self.batch_size = batch_size
  173. self.seq_len = seq_len
  174. self.dtype = dtype
  175. self.device = device
  176. self.depth = rope_module.depth
  177. self.num_heads = rope_module.num_heads
  178. self.head_dim = 2 * rope_module.dim // rope_module.num_heads
  179. self._depth_idx = 0
  180. # Pre-compute embeddings for each unique size
  181. self._embeddings_per_size = {}
  182. for grid_size in unique_sizes:
  183. # get_embed returns all depths at once for mixed mode
  184. rope_embed = rope_module.get_embed(shape=grid_size)
  185. self._embeddings_per_size[grid_size] = rope_embed
  186. def __iter__(self):
  187. self._depth_idx = 0
  188. return self
  189. @disable_compiler
  190. def __next__(self):
  191. if self._depth_idx >= self.depth:
  192. raise StopIteration
  193. # Create batch tensor for current depth
  194. batch_embed = torch.zeros(
  195. self.batch_size, self.num_heads, self.seq_len, self.head_dim,
  196. dtype=self.dtype, device=self.device
  197. )
  198. # Fill in embeddings for each unique grid size
  199. for grid_size in self.unique_sizes:
  200. h, w = grid_size
  201. actual_len = h * w
  202. batch_indices = self.size_to_indices[grid_size]
  203. # Get pre-computed embeddings for this size at current depth
  204. embed = self._embeddings_per_size[grid_size][self._depth_idx] # [num_heads, H*W, dim]
  205. # Assign to batch indices
  206. for bi in batch_indices:
  207. batch_embed[bi, :, :actual_len, :] = embed[:, :actual_len, :]
  208. self._depth_idx += 1
  209. return batch_embed
  210. def get_block_fn(cfg: NaFlexVitCfg) -> Callable:
  211. """Get appropriate block function based on configuration.
  212. Returns a partially applied block constructor with EVA-specific
  213. or conflicting parameters pre-configured if needed.
  214. """
  215. # Check if we need EVA block features
  216. use_eva_features = (
  217. cfg.attn_type in ('eva', 'rope') or
  218. cfg.rope_type not in ('', 'none') or # Any ROPE type requires EVA blocks
  219. cfg.swiglu_mlp
  220. )
  221. if use_eva_features:
  222. # Determine attention type based on rope_type if not explicitly set
  223. attn_type = cfg.attn_type
  224. if attn_type == 'standard' and cfg.rope_type not in ('', 'none'):
  225. attn_type = 'rope'
  226. num_prefix_tokens = (1 if cfg.class_token else 0) + cfg.reg_tokens
  227. return partial(
  228. EvaBlock,
  229. attn_type=attn_type,
  230. swiglu_mlp=cfg.swiglu_mlp,
  231. scale_mlp=cfg.scale_mlp_norm,
  232. scale_attn_inner=cfg.scale_attn_inner_norm,
  233. qkv_fused=cfg.qkv_fused,
  234. num_prefix_tokens=num_prefix_tokens,
  235. )
  236. else:
  237. # Standard ViT block
  238. block_fn = cfg.block_fn or Block
  239. block_kwargs = {}
  240. if cfg.scale_mlp_norm or cfg.scale_attn_inner_norm:
  241. # param names differ between EVA vs non-EVA block types
  242. block_kwargs['scale_mlp_norm'] = cfg.scale_mlp_norm
  243. block_kwargs['scale_attn_norm'] = cfg.scale_attn_inner_norm
  244. if cfg.attn_layer:
  245. block_kwargs['attn_layer'] = cfg.attn_layer
  246. if block_kwargs:
  247. block_fn = partial(block_fn, **block_kwargs)
  248. return block_fn
  249. @register_notrace_module
  250. class NaFlexEmbeds(nn.Module):
  251. """NaFlex Embedding module for Vision Transformers.
  252. This module encapsulates the complete embedding process for Vision Transformers,
  253. supporting both standard and NaFlex (NaViT + FlexiViT) functionality:
  254. 1. Patch embedding (via Conv2d or Linear)
  255. 2. Class and register token preparation
  256. 3. Position embedding addition with interpolation support
  257. 4. Pre-normalization (if requested)
  258. 5. Dropout application
  259. NaFlex capabilities include:
  260. - Variable aspect ratio and resolution via patch coordinates
  261. - Patch type indicators for handling padding tokens in attention
  262. - Flexible position embedding interpolation for arbitrary grid sizes
  263. - Support for factorized position embeddings
  264. The patch embedding can be one of two types:
  265. - Conv2d-based (default): For standard image inputs [B, C, H, W]
  266. - Linear-based: For pre-patchified inputs [B, N, P*P*C]
  267. Args:
  268. patch_size: Size of patches for patch embedding
  269. in_chans: Number of input image channels
  270. embed_dim: Dimensionality of patch embedding
  271. proj_type: Type of embedding projection layer ('conv' or 'linear')
  272. input_norm_layer: Normalization layer applied to input (linear mode only)
  273. proj_norm_layer: Normalization layer applied after projection
  274. pos_embed: Type of position embedding ('learned', 'factorized', 'none')
  275. pos_drop_rate: Dropout rate for position embeddings
  276. class_token: Whether to include a class token
  277. reg_tokens: Number of register tokens to include
  278. bias: Whether to use bias in projection layers
  279. dynamic_img_pad: Whether to enable dynamic padding for variable resolution
  280. pos_embed_grid_size: Grid size for position embedding initialization
  281. pos_embed_interp_mode: Interpolation mode for position embedding resizing
  282. pos_embed_ar_preserving: Whether to preserve aspect ratio during position embedding interpolation
  283. default_img_size: Default image size for position embedding grid calculation
  284. """
  285. def __init__(
  286. self,
  287. patch_size: Union[int, Tuple[int, int]] = 16,
  288. in_chans: int = 3,
  289. embed_dim: int = 768,
  290. proj_type: Optional[str] = None,
  291. proj_bias: bool = True,
  292. class_token: bool = True,
  293. reg_tokens: int = 0,
  294. dynamic_img_pad: bool = False,
  295. default_img_size: Optional[Union[int, Tuple[int, int]]] = None,
  296. pos_embed: str = 'learned',
  297. pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14),
  298. pos_embed_interp_mode: str = 'bicubic',
  299. pos_embed_ar_preserving: bool = False,
  300. pos_embed_use_grid_sample: bool = False,
  301. input_norm_layer: Optional[Type[nn.Module]] = None,
  302. proj_norm_layer: Union[bool, Optional[Type[nn.Module]]] = None,
  303. norm_layer: Optional[Type[nn.Module]] = None,
  304. pos_drop_rate: float = 0.,
  305. enable_patch_interpolator: bool = False,
  306. device=None,
  307. dtype=None,
  308. ) -> None:
  309. """Initialize NaFlexEmbeds module.
  310. Args:
  311. patch_size: Size of patches for patch embedding.
  312. in_chans: Number of input image channels.
  313. embed_dim: Dimensionality of patch embedding.
  314. proj_type: Type of embedding projection layer ('conv' or 'linear').
  315. proj_bias: Whether to use bias in projection layers.
  316. class_token: Whether to include a class token.
  317. reg_tokens: Number of register tokens to include.
  318. dynamic_img_pad: Whether to enable dynamic padding for variable resolution.
  319. default_img_size: Default image size for position embedding grid calculation.
  320. pos_embed: Type of position embedding ('learned', 'factorized', 'none').
  321. pos_embed_grid_size: Grid size for position embedding initialization.
  322. pos_embed_interp_mode: Interpolation mode for position embedding resizing.
  323. pos_embed_ar_preserving: Whether to preserve aspect ratio during interpolation.
  324. input_norm_layer: Normalization layer applied to input (linear mode only).
  325. proj_norm_layer: Normalization layer applied after projection.
  326. norm_layer: Default normalization layer.
  327. pos_drop_rate: Dropout rate for position embeddings.
  328. enable_patch_interpolator: Enable dynamic patch size support.
  329. """
  330. dd = {'device': device, 'dtype': dtype}
  331. super().__init__()
  332. self.has_class_token = class_token
  333. self.num_reg_tokens = reg_tokens
  334. self.pos_embed_interp_mode = pos_embed_interp_mode
  335. self.pos_embed_ar_preserving = pos_embed_ar_preserving
  336. self.pos_embed_use_grid_sample = pos_embed_use_grid_sample
  337. self.patch_size = to_2tuple(patch_size)
  338. self.in_chans = in_chans
  339. self.embed_dim = embed_dim
  340. self.dynamic_img_pad = dynamic_img_pad
  341. self.enable_patch_interpolator = enable_patch_interpolator
  342. # Calculate number of prefix tokens
  343. self.num_prefix_tokens = 1 if class_token else 0
  344. self.num_prefix_tokens += reg_tokens
  345. # Create class and register tokens
  346. self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, **dd)) if class_token else None
  347. self.reg_token = nn.Parameter(torch.empty(1, reg_tokens, embed_dim, **dd)) if reg_tokens else None
  348. # Calculate grid size and number of patches
  349. self.default_img_size: Optional[Tuple[int, int]] = None
  350. self.pos_embed_grid_size: Optional[Tuple[int, int]] = None # Grid size used for learned pos embed init
  351. if pos_embed_grid_size is not None:
  352. # Highest priority, use provided pos_embed_grid_size
  353. self.pos_embed_grid_size = pos_embed_grid_size
  354. elif default_img_size is not None:
  355. # Fallback to calculating grid size from img_size + patch_size if img size provided.
  356. self.default_img_size = to_2tuple(default_img_size)
  357. self.pos_embed_grid_size = tuple([s // p for s, p in zip(self.default_img_size, self.patch_size)])
  358. # Determine patch embedding type (linear or conv2d)
  359. if proj_type == 'linear':
  360. # Create linear projection for pre-patchified inputs
  361. # Input dimension is patch_size^2 * in_chans
  362. patch_dim = self.patch_size[0] * self.patch_size[1] * in_chans
  363. assert not (input_norm_layer is True and norm_layer is None), \
  364. "`norm_layer` must be given when input_norm_layer=True"
  365. input_norm_layer = norm_layer if input_norm_layer is True else (input_norm_layer or None)
  366. self.norm_input = input_norm_layer(patch_dim) if input_norm_layer else None
  367. self.proj = nn.Linear(patch_dim, embed_dim, bias=proj_bias, **dd)
  368. self.flatten = False
  369. self.is_linear = True
  370. else:
  371. # Default to convolutional patch embedding for image inputs
  372. assert not input_norm_layer
  373. self.norm_input = None
  374. self.proj = nn.Conv2d(
  375. in_chans,
  376. embed_dim,
  377. kernel_size=patch_size,
  378. stride=patch_size,
  379. bias=proj_bias,
  380. **dd,
  381. )
  382. self.flatten = True
  383. self.is_linear = False
  384. # Create patch embedding interpolator if enabled
  385. if self.enable_patch_interpolator:
  386. self.patch_interpolator = PatchEmbedInterpolator(
  387. base_patch_size=self.patch_size,
  388. in_chans=in_chans,
  389. embed_dim=embed_dim,
  390. interpolation=pos_embed_interp_mode,
  391. antialias=True,
  392. )
  393. else:
  394. self.patch_interpolator = None
  395. # Create normalization layer after the projection
  396. assert not (proj_norm_layer is True and norm_layer is None), \
  397. "`norm_layer` must be given when proj_norm_layer=True"
  398. proj_norm_layer = norm_layer if proj_norm_layer is True else (proj_norm_layer or None)
  399. self.norm = proj_norm_layer(embed_dim) if proj_norm_layer else nn.Identity()
  400. # Create position embedding if needed - only for patches, never for prefix tokens
  401. if pos_embed in ('factorized', 'learned') and self.pos_embed_grid_size is None:
  402. raise ValueError(
  403. "Cannot initialize position embeddings without grid_size."
  404. "Please provide img_size or pos_embed_grid_size.")
  405. self.pos_embed: Optional[torch.Tensor] = None
  406. self.pos_embed_y: Optional[torch.Tensor] = None
  407. self.pos_embed_x: Optional[torch.Tensor] = None
  408. if not pos_embed or pos_embed == 'none':
  409. self.pos_embed_type = 'none'
  410. elif pos_embed == 'factorized':
  411. assert self.pos_embed_grid_size is not None
  412. h, w = self.pos_embed_grid_size
  413. self.pos_embed_type = 'factorized'
  414. self.pos_embed_y = nn.Parameter(torch.empty(1, h, embed_dim, **dd))
  415. self.pos_embed_x = nn.Parameter(torch.empty(1, w, embed_dim, **dd))
  416. else:
  417. assert self.pos_embed_grid_size is not None
  418. h, w = self.pos_embed_grid_size
  419. self.pos_embed = nn.Parameter(torch.empty(1, h, w, embed_dim, **dd))
  420. self.pos_embed_type = 'learned'
  421. # Dropout layer
  422. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  423. # TODO: skip init when on meta device when safe to do so
  424. self.reset_parameters()
  425. def reset_parameters(self) -> None:
  426. if self.cls_token is not None:
  427. nn.init.normal_(self.cls_token, std=1e-6)
  428. if self.reg_token is not None:
  429. nn.init.normal_(self.reg_token, std=1e-6)
  430. if self.pos_embed is not None:
  431. nn.init.normal_(self.pos_embed, std=.02)
  432. if self.pos_embed_y is not None:
  433. nn.init.normal_(self.pos_embed_y, std=.02)
  434. if self.pos_embed_x is not None:
  435. nn.init.normal_(self.pos_embed_x, std=.02)
  436. def feature_info(self, location) -> Dict[str, Any]:
  437. """Get feature information for feature extraction.
  438. Args:
  439. location: Feature extraction location identifier
  440. Returns:
  441. Dictionary containing feature channel count and reduction factor
  442. """
  443. return dict(num_chs=self.embed_dim, reduction=self.patch_size)
  444. def feat_ratio(self, as_scalar: bool = True) -> Union[int, Tuple[int, int]]:
  445. """Get the feature reduction ratio (stride) of the patch embedding.
  446. Args:
  447. as_scalar: Whether to return the maximum dimension as a scalar
  448. Returns:
  449. Feature reduction ratio as scalar or tuple
  450. """
  451. if as_scalar:
  452. return max(self.patch_size)
  453. else:
  454. return self.patch_size
  455. def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
  456. """Calculate grid (feature) size for given image size.
  457. Takes into account dynamic padding when enabled.
  458. Args:
  459. img_size: Input image size as (height, width)
  460. Returns:
  461. Grid size as (grid_height, grid_width)
  462. """
  463. if self.dynamic_img_pad:
  464. return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
  465. else:
  466. return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
  467. @disable_compiler
  468. def _apply_learned_naflex_pos_embed(
  469. self,
  470. x: torch.Tensor,
  471. patch_coord: torch.Tensor,
  472. ) -> None:
  473. """Apply learned position embeddings to NaFlex batch in-place.
  474. Interpolates learned 2D position embeddings for each sample in the batch
  475. based on their individual grid sizes.
  476. Args:
  477. x: Input tensor to add position embeddings to [B, N, C]
  478. patch_coord: Patch coordinates [B, N, 2] with (y, x) values
  479. """
  480. # Calculate grid sizes from patch coordinates
  481. naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord)
  482. orig_h, orig_w = self.pos_embed.shape[1:3]
  483. pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
  484. def _interp2d(size):
  485. """
  486. Return a flattened positional-embedding grid at an arbitrary spatial resolution.
  487. Converts the learned 2-D table stored in NCHW format (pos_embed_nchw) into
  488. a (1, H*W, C) sequence that matches the requested size.
  489. """
  490. if (size[0] == orig_h) and (size[1] == orig_w):
  491. pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
  492. else:
  493. _interp_size = to_2tuple(max(size)) if self.pos_embed_ar_preserving else size
  494. pos_embed_flat = F.interpolate(
  495. pos_embed_nchw,
  496. size=_interp_size,
  497. mode=self.pos_embed_interp_mode,
  498. align_corners=False,
  499. antialias=True,
  500. )[:, :, :size[0], :size[1]].flatten(2).transpose(1, 2)
  501. return pos_embed_flat.to(dtype=x.dtype)
  502. # Determine unique grid sizes to avoid duplicate interpolation
  503. size_to_indices: Dict[Tuple[int, int], List[int]] = {}
  504. for bi, k in enumerate(naflex_grid_sizes):
  505. # k = h << 16 | w # FIXME can get jit compat with this
  506. size_to_indices.setdefault(k, []).append(bi)
  507. for k, batch_indices in size_to_indices.items():
  508. # h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
  509. # Interpolate only once for this (h, w)
  510. pos_embed_flat = _interp2d(k)
  511. seq_len = min(x.shape[1], pos_embed_flat.shape[1])
  512. x[:, :seq_len].index_add_(
  513. 0,
  514. torch.as_tensor(batch_indices, device=x.device),
  515. pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1)
  516. )
  517. @disable_compiler
  518. def _apply_learned_naflex_pos_embed_grid_sample(
  519. self,
  520. x: torch.Tensor,
  521. patch_coord: torch.Tensor,
  522. ) -> None:
  523. """Apply learned position embeddings to NaFlex batch using grid_sample.
  524. Uses F.grid_sample for efficient interpolation of learned 2D position embeddings
  525. based on patch coordinates. Based on proposal by https://github.com/stas-sl
  526. Args:
  527. x: Input tensor to add position embeddings to [B, N, C]
  528. patch_coord: Patch coordinates [B, N, 2] with (y, x) values
  529. """
  530. device = x.device
  531. B, N, C = x.shape
  532. shapes = patch_coord.max(dim=1).values + 1 # (B, 2) containing [h_i, w_i]
  533. if self.pos_embed_ar_preserving:
  534. L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i)
  535. L_global = L_i.amax()
  536. grid_size_y = grid_size_x = L_global
  537. scale_x = scale_y = L_global / L_i # uniform zoom (B,)
  538. else:
  539. grid_size_y, grid_size_x = shapes.amax(dim=0) # (2,)
  540. scale_y = grid_size_y / shapes[:, 0] # vertical zoom (B,)
  541. scale_x = grid_size_x / shapes[:, 1] # horizontal zoom (B,)
  542. theta = torch.zeros(B, 2, 3, device=device, dtype=torch.float32)
  543. theta[:, 0, 0] = scale_x
  544. theta[:, 1, 1] = scale_y
  545. theta[:, 0, 2] = scale_x - 1 # translate x
  546. theta[:, 1, 2] = scale_y - 1 # translate y
  547. grid = F.affine_grid(theta, (B, C, grid_size_y, grid_size_x), align_corners=False)
  548. pos_embed = F.grid_sample(
  549. self.pos_embed.permute(0, 3, 1, 2).expand(B, -1, -1, -1).float(),
  550. grid,
  551. mode=self.pos_embed_interp_mode,
  552. align_corners=False,
  553. padding_mode='border',
  554. ).to(dtype=x.dtype) # (B, C, H_out, W_out)
  555. bi = torch.arange(B, device=device, dtype=torch.long).unsqueeze(1)
  556. x += pos_embed[bi, :, patch_coord[..., 0], patch_coord[..., 1]] # NOTE leave as '+='
  557. def _apply_learned_pos_embed(
  558. self,
  559. x: torch.Tensor,
  560. grid_size: List[int],
  561. ) -> None:
  562. """Apply learned position embeddings to standard 2D batch in-place.
  563. Interpolates learned 2D position embeddings to match the specified grid size.
  564. Args:
  565. x: Input tensor to add position embeddings to [B, H*W, C]
  566. grid_size: Target grid size as [height, width]
  567. """
  568. orig_h, orig_w = self.pos_embed.shape[1:3]
  569. if grid_size[0] == orig_h and grid_size[1] == orig_w:
  570. # No resize needed, just flatten
  571. pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
  572. else:
  573. # Resize if needed - directly using F.interpolate
  574. if self.pos_embed_ar_preserving:
  575. L = max(grid_size)
  576. _interp_size = L, L
  577. else:
  578. _interp_size = grid_size
  579. pos_embed_flat = F.interpolate(
  580. self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W
  581. size=_interp_size,
  582. mode=self.pos_embed_interp_mode,
  583. align_corners=False,
  584. antialias=True,
  585. )[:, :, :grid_size[0], :grid_size[1]].flatten(2).transpose(1, 2)
  586. pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
  587. x.add_(pos_embed_flat)
  588. @disable_compiler
  589. def _apply_factorized_naflex_pos_embed(
  590. self,
  591. x: torch.Tensor,
  592. patch_coord: torch.Tensor,
  593. ) -> None:
  594. """Apply factorized position embeddings to NaFlex batch in-place.
  595. Uses separate Y and X position embedding tables that are interpolated
  596. and combined for each sample's grid size.
  597. Args:
  598. x: Input tensor to add position embeddings to [B, N, C]
  599. patch_coord: Patch coordinates [B, N, 2] with (y, x) values
  600. """
  601. # Calculate grid sizes from patch coordinates
  602. naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord)
  603. assert len(naflex_grid_sizes) == x.size(0) # one (H,W) per sample
  604. # Handle each batch element separately with its own grid size
  605. orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1]
  606. # bucket samples that share the same (H, W) so we build each grid once
  607. size_to_indices: Dict[Tuple[int, int], List[int]] = {}
  608. for bi, k in enumerate(naflex_grid_sizes):
  609. size_to_indices.setdefault(k, []).append(bi)
  610. def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.Tensor:
  611. """
  612. Resample a 1-D positional-embedding table to specified length
  613. and return it in (1, L, C) layout, dtype matching x.
  614. """
  615. if new_length == orig_length:
  616. return table.to(dtype=x.dtype)
  617. return F.interpolate(
  618. table.permute(0, 2, 1).float(), # (1,C,L) → (1,C,L_out)
  619. size=new_length,
  620. mode='linear',
  621. align_corners=False,
  622. ).permute(0, 2, 1).to(dtype=x.dtype) # → (1,L_out,C)
  623. for k, batch_indices in size_to_indices.items():
  624. target_h, target_w = k
  625. if self.pos_embed_ar_preserving:
  626. len_y = len_x = max(target_h, target_w)
  627. else:
  628. len_y, len_x = target_h, target_w
  629. pe_y = _interp1d(self.pos_embed_y, len_y, orig_h)[:, :target_h] # (1,H,C)
  630. pe_x = _interp1d(self.pos_embed_x, len_x, orig_w)[:, :target_w] # (1,W,C)
  631. # Broadcast, add and flatten to sequence layout (row major)
  632. pos = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1,H,W,C)
  633. pos = pos.flatten(1, 2)
  634. seq_len = min(x.shape[1], pos.shape[1])
  635. x[:, :seq_len].index_add_(
  636. 0,
  637. torch.as_tensor(batch_indices, device=x.device),
  638. pos[:, :seq_len].expand(len(batch_indices), -1, -1)
  639. )
  640. @disable_compiler
  641. def _apply_factorized_naflex_pos_embed_grid_sample(
  642. self,
  643. x: torch.Tensor,
  644. patch_coord: torch.Tensor,
  645. ) -> None:
  646. """Apply factorized position embeddings to NaFlex batch using grid_sample.
  647. Uses F.grid_sample for efficient interpolation of separate Y and X position
  648. embedding tables based on patch coordinates. Based on proposal by https://github.com/stas-sl
  649. Args:
  650. x: Input tensor to add position embeddings to [B, N, C]
  651. patch_coord: Patch coordinates [B, N, 2] with (y, x) values
  652. """
  653. device = x.device
  654. B, _, C = x.shape
  655. shapes = patch_coord.amax(dim=1) + 1
  656. if self.pos_embed_ar_preserving:
  657. # Aspect ratio preserving mode: use square grid with uniform scaling
  658. L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i)
  659. L_global = L_i.amax()
  660. grid_size_y = grid_size_x = L_global
  661. scale_x = scale_y = L_global / L_i # uniform zoom (B,)
  662. else:
  663. # Standard mode: different scaling for x and y
  664. grid_size_y, grid_size_x = shapes.amax(0)
  665. scale_x = grid_size_x / shapes[:, 1] # horizontal zoom (B,)
  666. scale_y = grid_size_y / shapes[:, 0] # vertical zoom (B,)
  667. def _interp1d(table: torch.Tensor, scale: torch.Tensor, out_length: torch.Tensor) -> torch.Tensor:
  668. pe = table.permute(0, 2, 1).unsqueeze(2).expand(B, -1, -1, -1).float() # (1, L, C) -> (B, C, 1, L)
  669. theta = torch.zeros(B, 2, 3, device=x.device)
  670. theta[:, 0, 0] = scale
  671. theta[:, 0, 2] = scale - 1
  672. theta[:, 1, 1] = 1
  673. grid = F.affine_grid(theta, (B, C, 1, out_length), align_corners=False)
  674. pe = F.grid_sample(pe, grid, mode='bilinear', align_corners=False, padding_mode='border')
  675. return pe.to(x.dtype)
  676. # Interpolate along each axis
  677. pe_x = _interp1d(self.pos_embed_x, scale=scale_x, out_length=grid_size_x)
  678. pe_y = _interp1d(self.pos_embed_y, scale=scale_y, out_length=grid_size_y)
  679. bi = torch.arange(B, device=device, dtype=torch.long).unsqueeze(1)
  680. x += pe_x[bi, :, 0, patch_coord[..., 1]] + pe_y[bi, :, 0, patch_coord[..., 0]]
  681. def _apply_factorized_pos_embed(
  682. self,
  683. x: torch.Tensor,
  684. grid_size: List[int],
  685. ) -> None:
  686. """Apply factorized position embeddings to standard 2D batch in-place.
  687. Uses separate Y and X position embedding tables that are interpolated
  688. and combined for the specified grid size.
  689. Args:
  690. x: Input tensor to add position embeddings to [B, H*W, C]
  691. grid_size: Target grid size as [height, width]
  692. """
  693. orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1]
  694. target_h, target_w = grid_size
  695. if self.pos_embed_ar_preserving:
  696. len_y = len_x = max(target_h, target_w)
  697. else:
  698. len_y, len_x = target_h, target_w
  699. def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.Tensor:
  700. if new_length == orig_length:
  701. return table.to(dtype=x.dtype)
  702. return F.interpolate(
  703. table.permute(0, 2, 1).float(), # (1,L,C) -> (1,C,L)
  704. size=new_length,
  705. mode='linear',
  706. align_corners=False,
  707. ).permute(0, 2, 1).to(dtype=x.dtype) # (1,L,C)
  708. # Interpolate embeddings
  709. pe_y = _interp1d(self.pos_embed_y, len_y, orig_h)[:, :target_h] # (1,H,C)
  710. pe_x = _interp1d(self.pos_embed_x, len_x, orig_w)[:, :target_w] # (1,W,C)
  711. # Broadcast, add and flatten to sequence layout (row major)
  712. pos_embed = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1, H, W, C)
  713. pos_embed_flat = pos_embed.flatten(1, 2) # (1, H*W, C)
  714. x.add_(pos_embed_flat)
  715. def forward(
  716. self,
  717. x: torch.Tensor,
  718. patch_coord: Optional[torch.Tensor] = None,
  719. patch_valid: Optional[torch.Tensor] = None,
  720. ) -> Tuple[torch.Tensor, Optional[Tuple[int, int]]]:
  721. """Forward pass for patch embedding with position encoding.
  722. Args:
  723. x: Input tensor. Supported formats:
  724. - [B, C, H, W] for conv mode
  725. - [B, N, P*P*C] for pre-patchified linear mode (normal)
  726. - [B, N, Ph, Pw, C] for pre-patchified linear mode (variable patch size)
  727. patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode.
  728. patch_valid: Optional validity mask for patches [B, N] for NaFlex mode.
  729. Returns:
  730. Tuple of (embedded_tensor, grid_size) where:
  731. - embedded_tensor: [B, num_prefix_tokens + N, embed_dim]
  732. - grid_size: (H, W) tuple for standard mode, None for NaFlex mode
  733. """
  734. grid_size: Optional[Tuple[int, int]] = None
  735. B = x.shape[0]
  736. if self.is_linear:
  737. # Linear embedding path, works with NaFlex mode or standard 2D mode
  738. if patch_coord is None:
  739. # Standard 2D (B, C, H, W) mode
  740. _assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4')
  741. x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad)
  742. else:
  743. # Pre-patchified NaFlex mode
  744. # Variable patch size mode: [B, N, Ph, Pw, C], normal mode: [B, N, P*P*C]
  745. _assert(x.ndim == 5 or x.ndim == 3, 'Expecting patchified input with ndim == 3 or 5.')
  746. # Handle variable patch size projection
  747. if self.enable_patch_interpolator and x.ndim == 5:
  748. _assert(self.norm_input is None, 'input norm not supported with patch resizing')
  749. # Apply projection with interpolation
  750. x = self.patch_interpolator(
  751. x,
  752. self.proj.weight,
  753. self.proj.bias,
  754. patch_size=tuple(x.shape[2:4]), # patch size from [B, N, Ph, Pw, C] shape
  755. is_linear=True,
  756. )
  757. else:
  758. # Standard projection
  759. x = x.flatten(2) # ensure [B, N, P*P*C], flatten Ph*Pw*C if separate
  760. if self.norm_input is not None:
  761. x = self.norm_input(x)
  762. x = self.proj(x)
  763. else:
  764. _assert(x.ndim == 4, 'Convolutional input must be 4D')
  765. if self.dynamic_img_pad:
  766. H, W = x.shape[-2:]
  767. pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
  768. pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
  769. x = F.pad(x, (0, pad_w, 0, pad_h))
  770. x = self.proj(x)
  771. grid_size = x.shape[-2:]
  772. if self.flatten:
  773. x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
  774. # Apply normalization after flattening
  775. x = self.norm(x)
  776. if self.pos_embed_type == 'learned':
  777. if grid_size is not None:
  778. # Standard 2D mode
  779. self._apply_learned_pos_embed(x, grid_size=grid_size)
  780. else:
  781. # NaFlex mode
  782. if self.pos_embed_use_grid_sample:
  783. self._apply_learned_naflex_pos_embed_grid_sample(x, patch_coord=patch_coord)
  784. else:
  785. self._apply_learned_naflex_pos_embed(x, patch_coord=patch_coord)
  786. elif self.pos_embed_type == 'factorized':
  787. if grid_size is not None:
  788. # Standard 2D mode
  789. self._apply_factorized_pos_embed(x, grid_size=grid_size)
  790. else:
  791. # NaFlex mode
  792. if self.pos_embed_use_grid_sample:
  793. self._apply_factorized_naflex_pos_embed_grid_sample(x, patch_coord=patch_coord)
  794. else:
  795. self._apply_factorized_naflex_pos_embed(x, patch_coord=patch_coord)
  796. # Prepare and add class and register tokens
  797. to_cat = []
  798. if self.cls_token is not None:
  799. to_cat.append(self.cls_token.expand(B, -1, -1))
  800. if self.reg_token is not None:
  801. to_cat.append(self.reg_token.expand(B, -1, -1))
  802. # Add tokens to the beginning
  803. if to_cat:
  804. x = torch.cat(to_cat + [x], dim=1)
  805. # Apply dropout
  806. x = self.pos_drop(x)
  807. return x, grid_size
  808. @register_notrace_function
  809. def create_attention_mask(
  810. patch_valid: torch.Tensor,
  811. num_prefix_tokens: int = 0,
  812. symmetric: bool = True,
  813. q_len: Optional[int] = None,
  814. dtype: torch.dtype = torch.float32,
  815. ) -> Optional[torch.Tensor]:
  816. """Creates an attention mask from patch validity information.
  817. Supports two modes controlled by `symmetric`:
  818. 1. `symmetric=True` (default): Creates a symmetric mask of shape
  819. [B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if
  820. both token i and token j are valid. Suitable for standard self-attention.
  821. 2. `symmetric=False`: Creates a potentially non-square mask of shape
  822. [B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if
  823. the key/value token k is valid. Query token validity is not checked
  824. in the mask itself. Useful for cross-attention or specific self-attention
  825. implementations `q_len` can be specified.
  826. Used for NaFlex mode to handle variable token counts and padding tokens.
  827. Args:
  828. patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding.
  829. num_prefix_tokens: Number of prefix tokens (class token, register tokens)
  830. to prepend, which are always considered valid.
  831. symmetric: If True, create a symmetric mask.
  832. If False, create an expanded mask based only on key/value validity.
  833. q_len: Query sequence length override. Only used when `symmetric` is False.
  834. Defaults to the key/value sequence length (`kv_len`) if None.
  835. dtype: Dtype of the output attention mask (e.g., torch.float32).
  836. Returns:
  837. Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked).
  838. Shape is [B, 1, seq_len, seq_len] if symmetric=True,
  839. or [B, 1, q_len, kv_len] if symmetric=False.
  840. """
  841. if patch_valid is None:
  842. return None
  843. patch_valid = patch_valid.bool() # Ensure boolean type
  844. B, N = patch_valid.shape
  845. kv_len = N # Initial key/value length is the number of patches
  846. # Prepend prefix tokens if any
  847. if num_prefix_tokens > 0:
  848. # Create prefix validity tensor on the same device/dtype base as patch_valid
  849. prefix_valid = patch_valid.new_ones((B, num_prefix_tokens), dtype=torch.bool)
  850. # Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N]
  851. patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
  852. kv_len += num_prefix_tokens # Update total key/value sequence length
  853. if symmetric:
  854. # Symmetric mask is True where BOTH query and key are valid
  855. mask_bool = patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)
  856. mask_bool = mask_bool.unsqueeze(1) # Add head dimension: [B, 1, seq_len, seq_len]
  857. else:
  858. # Expanded mask
  859. q_len = q_len or kv_len
  860. mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len)
  861. # Create the float mask and apply masking using additive mask convention
  862. mask_float = torch.zeros_like(mask_bool, dtype=dtype)
  863. # Fill with negative infinity where mask_bool is False (masked positions)
  864. mask_float.masked_fill_(~mask_bool, torch.finfo(dtype).min)
  865. return mask_float
  866. @register_notrace_function
  867. def global_pool_naflex(
  868. x: torch.Tensor,
  869. patch_valid: Optional[torch.Tensor] = None,
  870. pool_type: str = 'token',
  871. num_prefix_tokens: int = 1,
  872. reduce_include_prefix: bool = False,
  873. ) -> torch.Tensor:
  874. """Global pooling with NaFlex support for masked tokens.
  875. Applies global pooling while respecting patch validity masks to exclude
  876. padding tokens from pooling operations.
  877. Args:
  878. x: Input tensor with shape [B, N, C]
  879. patch_valid: Optional validity mask for patches [B, N-num_prefix_tokens]
  880. pool_type: Type of pooling ('token', 'avg', 'avgmax', 'max')
  881. num_prefix_tokens: Number of prefix tokens (class/register)
  882. reduce_include_prefix: Whether to include prefix tokens in pooling reduction
  883. Returns:
  884. Pooled tensor with shape [B, C]
  885. """
  886. if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'):
  887. # Fall back to standard pooling
  888. x = global_pool_nlc(
  889. x,
  890. pool_type=pool_type,
  891. num_prefix_tokens=num_prefix_tokens,
  892. reduce_include_prefix=reduce_include_prefix,
  893. )
  894. return x
  895. # For NaFlex mode, we need to apply masked pooling to exclude padding tokens
  896. if num_prefix_tokens > 0:
  897. if reduce_include_prefix:
  898. # Include prefix tokens in pooling - they are always considered valid
  899. # patch_valid only covers patch tokens, so create combined validity mask
  900. prefix_valid = patch_valid.new_ones(x.shape[0], num_prefix_tokens)
  901. patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
  902. else:
  903. # Exclude prefix tokens from pooling (default behavior)
  904. x = x[:, num_prefix_tokens:]
  905. patch_valid_float = patch_valid.to(x.dtype)
  906. if pool_type == 'avg':
  907. # Compute masked average pooling, sum valid tokens and divide by count of valid tokens
  908. masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
  909. valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
  910. pooled = masked_sums / valid_counts
  911. return pooled
  912. elif pool_type == 'avgmax':
  913. # For avgmax, compute masked average and masked max
  914. masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
  915. valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
  916. masked_avg = masked_sums / valid_counts
  917. # For max pooling we set masked positions to large negative value
  918. masked_x = x.clone()
  919. masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
  920. masked_max = masked_x.amax(dim=1)
  921. # Combine average and max
  922. return 0.5 * (masked_avg + masked_max)
  923. elif pool_type == 'max':
  924. # For max pooling we set masked positions to large negative value
  925. masked_x = x.clone()
  926. masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
  927. return masked_x.amax(dim=1)
  928. else:
  929. assert False
  930. class NaFlexVit(nn.Module):
  931. """NaFlexVit: Vision Transformer with NaFlex support for flexible input handling.
  932. A flexible implementation of Vision Transformer that supports:
  933. - Standard image classification with various pooling strategies
  934. - NaFlex functionality for variable aspect ratios and resolutions
  935. - Linear patch embedding for pre-patchified inputs
  936. - Multiple position embedding strategies (learned, factorized, rope)
  937. - Comprehensive attention masking for efficient batch processing
  938. - Encapsulated embedding and position encoding in FlexEmbeds module
  939. - Compatible with standard ViT checkpoints through checkpoint filtering
  940. """
  941. def __init__(
  942. self,
  943. cfg: Optional[NaFlexVitCfg] = None,
  944. in_chans: int = 3,
  945. num_classes: int = 1000,
  946. img_size: Optional[Union[int, Tuple[int, int]]] = None,
  947. device=None,
  948. dtype=None,
  949. **kwargs,
  950. ) -> None:
  951. """Initialize NaFlexVit model.
  952. Args:
  953. cfg: Model configuration. If None, uses default NaFlexVitCfg.
  954. in_chans: Number of input image channels.
  955. num_classes: Number of classification classes.
  956. img_size: Input image size (for backwards compatibility with classic vit).
  957. **kwargs: Additional config parameters to override cfg values.
  958. """
  959. super().__init__()
  960. dd = {'device': device, 'dtype': dtype}
  961. # Initialize config
  962. cfg = cfg or NaFlexVitCfg()
  963. if kwargs:
  964. cfg = _overlay_kwargs(cfg, **kwargs)
  965. # Validate configuration
  966. assert cfg.global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
  967. assert cfg.class_token or cfg.global_pool != 'token'
  968. assert cfg.pos_embed in ('', 'none', 'learned', 'factorized')
  969. # Resolve layer implementations
  970. norm_layer = get_norm_layer(cfg.norm_layer) or LayerNorm
  971. embed_norm_layer = get_norm_layer(cfg.embed_norm_layer)
  972. act_layer = get_act_layer(cfg.act_layer) or nn.GELU
  973. block_fn = get_block_fn(cfg)
  974. mlp_layer = cfg.mlp_layer or Mlp # TODO: Support configurable mlp_layer via string lookup
  975. # Store instance variables
  976. self.num_classes = num_classes
  977. self.in_chans = in_chans
  978. self.global_pool = cfg.global_pool
  979. self.num_features = self.head_hidden_size = self.embed_dim = cfg.embed_dim # for consistency with other models
  980. self.num_prefix_tokens = 1 if cfg.class_token else 0
  981. self.num_prefix_tokens += cfg.reg_tokens
  982. self.num_reg_tokens = cfg.reg_tokens
  983. self.has_class_token = cfg.class_token
  984. self.pool_include_prefix = cfg.pool_include_prefix
  985. self.grad_checkpointing = False
  986. # Initialize embedding module (includes patch, position embedding, and class/reg tokens)
  987. # FlexEmbeds is always used - handles both linear and conv embedding
  988. self.embeds = NaFlexEmbeds(
  989. patch_size=cfg.patch_size,
  990. in_chans=in_chans,
  991. embed_dim=cfg.embed_dim,
  992. proj_type=cfg.embed_proj_type,
  993. proj_bias=not cfg.pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
  994. class_token=cfg.class_token,
  995. reg_tokens=cfg.reg_tokens,
  996. default_img_size=img_size,
  997. dynamic_img_pad=cfg.dynamic_img_pad,
  998. pos_embed=cfg.pos_embed,
  999. pos_embed_grid_size=cfg.pos_embed_grid_size,
  1000. pos_embed_interp_mode=cfg.pos_embed_interp_mode,
  1001. pos_embed_ar_preserving=cfg.pos_embed_ar_preserving,
  1002. pos_embed_use_grid_sample=cfg.pos_embed_use_grid_sample,
  1003. proj_norm_layer=embed_norm_layer,
  1004. pos_drop_rate=cfg.pos_drop_rate,
  1005. enable_patch_interpolator=getattr(cfg, 'enable_patch_interpolator', False),
  1006. **dd,
  1007. )
  1008. self.norm_pre = norm_layer(cfg.embed_dim, **dd) if cfg.pre_norm else nn.Identity()
  1009. # ROPE position embeddings at model level
  1010. self.rope: Optional[nn.Module] = None
  1011. self.rope_is_mixed = False
  1012. if cfg.rope_type and cfg.rope_type != 'none':
  1013. from timm.layers.pos_embed_sincos import RotaryEmbeddingCat, RotaryEmbeddingMixed
  1014. if cfg.rope_type == 'mixed':
  1015. self.rope = RotaryEmbeddingMixed(
  1016. cfg.embed_dim,
  1017. depth=cfg.depth,
  1018. num_heads=cfg.num_heads,
  1019. temperature=cfg.rope_temperature,
  1020. feat_shape=None, # Dynamic shapes for NaFlex
  1021. grid_indexing=cfg.rope_grid_indexing,
  1022. **dd,
  1023. )
  1024. self.rope_is_mixed = True
  1025. elif cfg.rope_type == 'axial':
  1026. self.rope = RotaryEmbeddingCat(
  1027. cfg.embed_dim // cfg.num_heads,
  1028. temperature=cfg.rope_temperature,
  1029. in_pixels=False,
  1030. feat_shape=None, # Dynamic shapes for NaFlex
  1031. ref_feat_shape=cfg.rope_ref_feat_shape,
  1032. grid_offset=cfg.rope_grid_offset,
  1033. grid_indexing=cfg.rope_grid_indexing,
  1034. **dd,
  1035. )
  1036. self.rope_is_mixed = False
  1037. else:
  1038. raise ValueError(f"Unknown rope_type: {cfg.rope_type}")
  1039. # Patch dropout
  1040. if cfg.patch_drop_rate > 0:
  1041. self.patch_drop = PatchDropoutWithIndices(
  1042. cfg.patch_drop_rate,
  1043. num_prefix_tokens=self.num_prefix_tokens,
  1044. )
  1045. else:
  1046. self.patch_drop = None
  1047. # Transformer blocks
  1048. dpr = calculate_drop_path_rates(cfg.drop_path_rate, cfg.depth) # stochastic depth decay rule
  1049. # Create transformer blocks
  1050. self.blocks = nn.Sequential(*[
  1051. block_fn(
  1052. dim=cfg.embed_dim,
  1053. num_heads=cfg.num_heads,
  1054. mlp_ratio=cfg.mlp_ratio,
  1055. qkv_bias=cfg.qkv_bias,
  1056. qk_norm=cfg.qk_norm,
  1057. proj_bias=cfg.proj_bias,
  1058. init_values=cfg.init_values,
  1059. proj_drop=cfg.proj_drop_rate,
  1060. attn_drop=cfg.attn_drop_rate,
  1061. drop_path=dpr[i],
  1062. norm_layer=norm_layer,
  1063. act_layer=act_layer,
  1064. mlp_layer=mlp_layer,
  1065. depth=i,
  1066. **dd,
  1067. )
  1068. for i in range(cfg.depth)
  1069. ])
  1070. # Feature info for downstream tasks
  1071. patch_reduction = self.embeds.feat_ratio(as_scalar=True)
  1072. self.feature_info = [
  1073. dict(module=f'blocks.{i}', num_chs=cfg.embed_dim, reduction=patch_reduction)
  1074. for i in range(cfg.depth)
  1075. ]
  1076. self.norm = norm_layer(cfg.embed_dim, **dd) if cfg.final_norm and not cfg.fc_norm else nn.Identity()
  1077. # Classifier Head
  1078. if cfg.global_pool == 'map':
  1079. self.attn_pool = AttentionPoolLatent(
  1080. self.embed_dim,
  1081. num_heads=cfg.attn_pool_num_heads or cfg.num_heads,
  1082. mlp_ratio=cfg.attn_pool_mlp_ratio or cfg.mlp_ratio,
  1083. norm_layer=norm_layer,
  1084. act_layer=act_layer,
  1085. **dd,
  1086. )
  1087. else:
  1088. self.attn_pool = None
  1089. # Handle fc_norm default value
  1090. fc_norm = cfg.fc_norm
  1091. if fc_norm is None:
  1092. fc_norm = cfg.global_pool == 'avg'
  1093. self.fc_norm = norm_layer(cfg.embed_dim, **dd) if cfg.final_norm and fc_norm else nn.Identity()
  1094. self.head_drop = nn.Dropout(cfg.drop_rate)
  1095. self.head = nn.Linear(self.embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  1096. self.weight_init_mode = cfg.weight_init
  1097. self.fix_init = cfg.fix_init
  1098. # TODO: skip init when on meta device when safe to do so
  1099. self.init_weights(cfg.weight_init, needs_reset=False)
  1100. def fix_init_weight(self) -> None:
  1101. """Apply initialization weight fix with layer-wise scaling."""
  1102. def rescale(param: torch.Tensor, _layer_id: int) -> None:
  1103. with torch.no_grad():
  1104. param.div_(math.sqrt(2.0 * _layer_id))
  1105. for layer_id, layer in enumerate(self.blocks):
  1106. if hasattr(layer, 'attn'):
  1107. rescale(layer.attn.proj.weight, layer_id + 1)
  1108. if hasattr(layer, 'mlp'):
  1109. rescale(layer.mlp.fc2.weight, layer_id + 1)
  1110. if hasattr(layer, 'attn_out_proj'):
  1111. rescale(layer.attn_out_proj.weight, layer_id + 1)
  1112. if hasattr(layer, 'mlp_out_proj'):
  1113. rescale(layer.mlp_out_proj.weight, layer_id + 1)
  1114. def init_weights(self, mode: str = '', needs_reset: bool = True) -> None:
  1115. """Initialize model weights according to specified scheme.
  1116. Args:
  1117. mode: Initialization mode ('jax', 'jax_nlhb', 'moco', or '')
  1118. needs_reset: If True, call reset_parameters() on modules (default for after to_empty()).
  1119. If False, skip reset_parameters() (for __init__ where modules already self-initialized).
  1120. """
  1121. mode = mode or self.weight_init_mode
  1122. assert mode in ('jax', 'jax_nlhb', 'moco', '')
  1123. head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
  1124. named_apply(get_init_weights_vit(mode, head_bias, needs_reset=needs_reset), self)
  1125. if self.fix_init:
  1126. self.fix_init_weight()
  1127. @torch.jit.ignore()
  1128. def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None:
  1129. # Custom loading for the new model structure
  1130. from .vision_transformer import _load_weights as _orig_load_weights
  1131. def _load_weights_adapter(model, checkpoint_path, prefix=''):
  1132. """Adapter function to handle the different model structure"""
  1133. state_dict = torch.load(checkpoint_path, map_location='cpu')
  1134. if isinstance(state_dict, dict) and 'state_dict' in state_dict:
  1135. state_dict = state_dict['state_dict']
  1136. # Map original keys to new structure
  1137. for k in list(state_dict.keys()):
  1138. if k.startswith('cls_token'):
  1139. state_dict['embeds.' + k] = state_dict.pop(k)
  1140. elif k.startswith('reg_token'):
  1141. state_dict['embeds.' + k] = state_dict.pop(k)
  1142. elif k.startswith('pos_embed'):
  1143. state_dict['embeds.' + k] = state_dict.pop(k)
  1144. elif k.startswith('patch_embed'):
  1145. state_dict['embeds.' + k[12:]] = state_dict.pop(k)
  1146. return _orig_load_weights(model, state_dict, prefix)
  1147. _load_weights_adapter(self, checkpoint_path, prefix)
  1148. @torch.jit.ignore
  1149. def no_weight_decay(self) -> Set:
  1150. """Get set of parameter names that should not have weight decay applied.
  1151. Returns:
  1152. Set of parameter names to skip during weight decay
  1153. """
  1154. skip_list = {'embeds.pos_embed', 'embeds.cls_token', 'embeds.reg_token'}
  1155. if self.rope and hasattr(self.rope, 'no_weight_decay'):
  1156. skip_list.update(self.rope.no_weight_decay())
  1157. return skip_list
  1158. @torch.jit.ignore
  1159. def group_matcher(self, coarse: bool = False) -> Dict:
  1160. """Get parameter group matcher for optimizer parameter grouping.
  1161. Args:
  1162. coarse: Whether to use coarse-grained grouping
  1163. Returns:
  1164. Dictionary mapping group names to regex patterns
  1165. """
  1166. return dict(
  1167. stem=r'^embeds', # stem and embed
  1168. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
  1169. )
  1170. @torch.jit.ignore
  1171. def set_grad_checkpointing(self, enable: bool = True) -> None:
  1172. """Enable or disable gradient checkpointing for memory efficiency.
  1173. Args:
  1174. enable: Whether to enable gradient checkpointing
  1175. """
  1176. self.grad_checkpointing = enable
  1177. if hasattr(self.embeds, 'patch_embed') and hasattr(self.embeds.patch_embed, 'set_grad_checkpointing'):
  1178. self.embeds.patch_embed.set_grad_checkpointing(enable)
  1179. @torch.jit.ignore
  1180. def get_classifier(self) -> nn.Module:
  1181. """Get the classification head module.
  1182. Returns:
  1183. Classification head module
  1184. """
  1185. return self.head
  1186. @disable_compiler
  1187. def _generate_rope_naflex(
  1188. self,
  1189. x: torch.Tensor,
  1190. patch_coord: torch.Tensor,
  1191. ) -> Union[torch.Tensor, List[torch.Tensor], Any]:
  1192. """Generate ROPE position embeddings for NaFlex batch with variable grid sizes.
  1193. Args:
  1194. x: Input tensor [B, N, C]
  1195. patch_coord: Patch coordinates [B, N, 2] with (y, x) values
  1196. Returns:
  1197. ROPE embeddings:
  1198. - Axial mode: Tensor of shape [B, 1, N, dim*2]
  1199. - Mixed mode: List of tensors, each of shape [B, num_heads, N, dim], one per depth layer
  1200. - Mixed mode with iterator: Iterator yielding tensors per depth
  1201. """
  1202. # Calculate grid sizes for each sample
  1203. naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord)
  1204. # Build ROPE embeddings for each unique grid size
  1205. size_to_indices = {}
  1206. unique_sizes = []
  1207. for bi, grid_size in enumerate(naflex_grid_sizes):
  1208. if grid_size not in size_to_indices:
  1209. size_to_indices[grid_size] = []
  1210. unique_sizes.append(grid_size)
  1211. size_to_indices[grid_size].append(bi)
  1212. B, N, C = x.shape
  1213. seq_len = N - self.num_prefix_tokens
  1214. if self.rope_is_mixed:
  1215. # Use an iterator for Mixed mode, returns [batch_size, depth, num_heads, seq_len, dim]
  1216. return NaFlexRopeIterator(
  1217. self.rope,
  1218. size_to_indices,
  1219. unique_sizes,
  1220. B,
  1221. seq_len,
  1222. x.dtype,
  1223. x.device
  1224. )
  1225. # Axial mode: [batch_size, seq_len, dim*2]
  1226. rope_embeds = torch.zeros(B, seq_len, self.rope.dim * 2, dtype=x.dtype, device=x.device)
  1227. if hasattr(self.rope, 'get_batch_embeds'):
  1228. # Batch mode - generate unique embeds from one grid and then assign
  1229. unique_embeds = self.rope.get_batch_embeds(unique_sizes)
  1230. for grid_size, embed, batch_indices in zip(unique_sizes, unique_embeds, size_to_indices.values()):
  1231. h, w = grid_size
  1232. actual_len = h * w
  1233. for bi in batch_indices:
  1234. rope_embeds[bi, :actual_len] = embed[:actual_len]
  1235. else:
  1236. # Generate each unique size separately and assign
  1237. for grid_size, bi in size_to_indices.items():
  1238. rope_embed = self.rope.get_embed(shape=grid_size)
  1239. h, w = grid_size
  1240. actual_len = h * w
  1241. rope_embeds[bi, :actual_len] = rope_embed[:actual_len]
  1242. rope_embeds = rope_embeds.unsqueeze(1)
  1243. return rope_embeds
  1244. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  1245. """Reset the classification head with new number of classes and pooling.
  1246. Args:
  1247. num_classes: Number of classes for new classification head
  1248. global_pool: Optional new global pooling type
  1249. """
  1250. self.num_classes = num_classes
  1251. if global_pool is not None:
  1252. assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
  1253. if global_pool == 'map' and self.attn_pool is None:
  1254. assert False, "Cannot currently add attention pooling in reset_classifier()."
  1255. elif global_pool != 'map' and self.attn_pool is not None:
  1256. self.attn_pool = None # remove attention pooling
  1257. self.global_pool = global_pool
  1258. self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  1259. def _forward_embeds(
  1260. self,
  1261. x,
  1262. patch_coord,
  1263. patch_valid,
  1264. attn_mask,
  1265. ) -> Dict[str, torch.Tensor]:
  1266. """ Forward pass through patch / abs pos / rope pos embeds and patch dropout
  1267. """
  1268. naflex_mode = patch_coord is not None
  1269. # patch embed, abs pos embed, returns global grid size as calculated from 'standard' NCHW batches
  1270. x, grid_size = self.embeds(
  1271. x,
  1272. patch_coord=patch_coord,
  1273. patch_valid=patch_valid,
  1274. )
  1275. # Generate ROPE embeddings at model level
  1276. rope_embeds = None
  1277. if self.rope is not None:
  1278. if patch_coord is not None:
  1279. # NaFlex mode - variable grid sizes
  1280. rope_embeds = self._generate_rope_naflex(x, patch_coord)
  1281. elif grid_size is not None:
  1282. # Standard mode - fixed grid size
  1283. rope_embeds = self.rope.get_embed(shape=grid_size)
  1284. else:
  1285. assert False, 'Expected one of patch_coord or grid_size to be valid'
  1286. # Apply patch dropout with coordinated updates
  1287. keep_indices: Optional[torch.Tensor] = None
  1288. if self.training and self.patch_drop is not None:
  1289. x, keep_indices = self.patch_drop(x)
  1290. # keep_indices excludes prefix tokens, can use directly on patch_valid & rope embeds
  1291. if patch_valid is not None:
  1292. patch_valid = patch_valid.gather(1, keep_indices)
  1293. if rope_embeds is not None and not self.rope_is_mixed:
  1294. # Update ROPE embeddings to match dropped tokens (only for axial mode)
  1295. # Batch dim already present in NaFlex mode, but will be added in standard mode.
  1296. rope_embeds = apply_keep_indices_nlc(x, rope_embeds, keep_indices, pos_embed_has_batch=naflex_mode)
  1297. if not naflex_mode:
  1298. # B, N, dim -> B, 1, N, dim. Need head dim added for standard mode, already added in NaFlex.
  1299. rope_embeds = rope_embeds.unsqueeze(1)
  1300. # Create attention mask from patch_valid after patch dropout applied
  1301. if attn_mask is None:
  1302. attn_mask = create_attention_mask(
  1303. patch_valid,
  1304. num_prefix_tokens=self.num_prefix_tokens,
  1305. dtype=x.dtype
  1306. )
  1307. x = self.norm_pre(x)
  1308. return {
  1309. 'patches': x,
  1310. 'patch_valid': patch_valid,
  1311. 'rope_embeds': rope_embeds,
  1312. 'attn_mask': attn_mask,
  1313. 'keep_indices': keep_indices,
  1314. }
  1315. def forward_intermediates(
  1316. self,
  1317. x: Union[torch.Tensor, Dict[str, torch.Tensor]],
  1318. indices: Optional[Union[int, List[int]]] = None,
  1319. return_prefix_tokens: bool = False,
  1320. norm: bool = False,
  1321. stop_early: bool = False,
  1322. output_fmt: str = 'NCHW',
  1323. intermediates_only: bool = False,
  1324. output_dict: bool = False,
  1325. patch_coord: Optional[torch.Tensor] = None,
  1326. patch_valid: Optional[torch.Tensor] = None,
  1327. attn_mask: Optional[torch.Tensor] = None,
  1328. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]:
  1329. """ Forward features that returns intermediates.
  1330. Args:
  1331. x: Input image tensor
  1332. indices: Take last n blocks if int, all if None, select matching indices if sequence
  1333. return_prefix_tokens: Return both prefix and spatial intermediate tokens
  1334. norm: Apply norm layer to all intermediates
  1335. stop_early: Stop iterating over blocks when last desired intermediate hit
  1336. output_fmt: Shape of intermediate feature outputs
  1337. intermediates_only: Only return intermediate features
  1338. output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys
  1339. patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode
  1340. patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex
  1341. attn_mask: Optional attention mask for masked attention
  1342. Returns:
  1343. A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing
  1344. 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix')
  1345. """
  1346. # FIXME unfinished / untested
  1347. assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
  1348. reshape = output_fmt == 'NCHW'
  1349. intermediates = []
  1350. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  1351. if isinstance(x, Dict):
  1352. # Handle dictionary input from NaFlex collator
  1353. patch_coord = x['patch_coord']
  1354. patch_valid = x['patch_valid']
  1355. patches = x['patches']
  1356. assert False, 'WIP, patch mode needs more work'
  1357. else:
  1358. patches = x
  1359. height, width = x.shape[-2:]
  1360. H, W = self.embeds.dynamic_feat_size((height, width))
  1361. # Forward pass through patch and abs position embedding
  1362. embeds = self._forward_embeds(
  1363. patches,
  1364. patch_coord=patch_coord,
  1365. patch_valid=patch_valid,
  1366. attn_mask=attn_mask,
  1367. )
  1368. x = embeds['patches']
  1369. rope_embeds = embeds.get('rope_embeds', None)
  1370. keep_indices = embeds.get('keep_indices', None)
  1371. attn_mask = embeds.get('attn_mask', None)
  1372. # Forward pass through blocks
  1373. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  1374. blocks = self.blocks
  1375. else:
  1376. blocks = self.blocks[:max_index + 1]
  1377. do_checkpointing = self.grad_checkpointing and not torch.jit.is_scripting()
  1378. if self.rope_is_mixed and rope_embeds is not None:
  1379. # Mixed mode with per-layer embeddings (list or iterator)
  1380. for i, (blk, rope_embed) in enumerate(zip(self.blocks, rope_embeds)):
  1381. # Apply patch dropout to rope_embed if needed
  1382. if self.training and self.patch_drop is not None and keep_indices is not None:
  1383. # Apply patch dropout to rope_embed if needed (batch dim already present in naflex mode)
  1384. rope_embed = apply_keep_indices_nlc(
  1385. x,
  1386. rope_embed,
  1387. keep_indices,
  1388. pos_embed_has_batch=embeds.get('naflex_mode', False),
  1389. )
  1390. if do_checkpointing:
  1391. x = checkpoint(blk, x, rope=rope_embed, attn_mask=attn_mask)
  1392. else:
  1393. x = blk(x, rope=rope_embed, attn_mask=attn_mask)
  1394. if i in take_indices:
  1395. # normalize intermediates with final norm layer if enabled
  1396. intermediates.append(self.norm(x) if norm else x)
  1397. else:
  1398. for i, blk in enumerate(blocks):
  1399. # Axial ROPE mode with shared embeddings
  1400. if rope_embeds is not None:
  1401. if do_checkpointing:
  1402. x = checkpoint(blk, x, rope=rope_embeds, attn_mask=attn_mask)
  1403. else:
  1404. x = blk(x, rope=rope_embeds, attn_mask=attn_mask)
  1405. else:
  1406. if do_checkpointing:
  1407. x = checkpoint(blk, x, attn_mask=attn_mask)
  1408. else:
  1409. x = blk(x, attn_mask=attn_mask)
  1410. if i in take_indices:
  1411. # normalize intermediates with final norm layer if enabled
  1412. intermediates.append(self.norm(x) if norm else x)
  1413. # Process intermediates
  1414. if self.num_prefix_tokens:
  1415. # split prefix (e.g. class, distill) and spatial feature tokens
  1416. prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
  1417. intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
  1418. else:
  1419. prefix_tokens = None
  1420. if reshape:
  1421. # reshape to BCHW output format
  1422. intermediates = [
  1423. y.reshape(y.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous()
  1424. for y in intermediates
  1425. ]
  1426. # FIXME always use dict for NaFlex mode to return masks and more?
  1427. # For dictionary output
  1428. if output_dict:
  1429. result_dict = {}
  1430. # Intermediates are always included
  1431. result_dict['image_intermediates'] = intermediates
  1432. if prefix_tokens is not None and return_prefix_tokens:
  1433. result_dict['image_intermediates_prefix'] = prefix_tokens
  1434. # Only include features if not intermediates_only
  1435. if not intermediates_only:
  1436. x_final = self.norm(x)
  1437. result_dict['image_features'] = x_final
  1438. return result_dict
  1439. # For non-dictionary output, maintain the original behavior
  1440. if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
  1441. # return_prefix not support in torchscript due to poor type handling
  1442. intermediates = list(zip(intermediates, prefix_tokens))
  1443. if intermediates_only:
  1444. return intermediates
  1445. x = self.norm(x)
  1446. return x, intermediates
  1447. def forward_features(
  1448. self,
  1449. patches: torch.Tensor,
  1450. patch_coord: Optional[torch.Tensor] = None,
  1451. patch_valid: Optional[torch.Tensor] = None,
  1452. attn_mask: Optional[torch.Tensor] = None,
  1453. ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
  1454. """
  1455. """
  1456. naflex_mode = patch_coord is not None
  1457. # Pass through patch & abs position embedding module with patch coordinate/type support
  1458. embeds = self._forward_embeds(
  1459. patches,
  1460. patch_coord=patch_coord,
  1461. patch_valid=patch_valid,
  1462. attn_mask=attn_mask,
  1463. )
  1464. x = embeds['patches']
  1465. rope_embeds = embeds.get('rope_embeds', None)
  1466. keep_indices = embeds.get('keep_indices', None)
  1467. attn_mask = embeds.get('attn_mask', None)
  1468. # Apply transformer blocks with masked attention and/or ROPE if provided
  1469. do_checkpointing = self.grad_checkpointing and not torch.jit.is_scripting()
  1470. if self.rope_is_mixed and rope_embeds is not None:
  1471. # Mixed mode with per-layer embeddings (list or iterator)
  1472. for i, (blk, rope_embed) in enumerate(zip(self.blocks, rope_embeds)):
  1473. if self.training and self.patch_drop is not None and keep_indices is not None:
  1474. # Apply patch dropout to rope_embed if needed (batch dim already present in naflex mode)
  1475. rope_embed = apply_keep_indices_nlc(
  1476. x,
  1477. rope_embed,
  1478. keep_indices,
  1479. pos_embed_has_batch=naflex_mode,
  1480. )
  1481. if do_checkpointing:
  1482. x = checkpoint(blk, x, rope=rope_embed, attn_mask=attn_mask)
  1483. else:
  1484. x = blk(x, rope=rope_embed, attn_mask=attn_mask)
  1485. elif rope_embeds is not None:
  1486. # Axial ROPE mode with shared embeddings
  1487. for blk in self.blocks:
  1488. if do_checkpointing:
  1489. x = checkpoint(blk, x, rope=rope_embeds, attn_mask=attn_mask)
  1490. else:
  1491. x = blk(x, rope=rope_embeds, attn_mask=attn_mask)
  1492. else:
  1493. for blk in self.blocks:
  1494. if do_checkpointing:
  1495. x = checkpoint(blk, x, attn_mask=attn_mask)
  1496. else:
  1497. x = blk(x, attn_mask=attn_mask)
  1498. x = self.norm(x)
  1499. if naflex_mode:
  1500. return {
  1501. 'patches': x,
  1502. 'patch_valid': embeds.get('patch_valid', None),
  1503. }
  1504. return x
  1505. def _pool(
  1506. self,
  1507. x: torch.Tensor,
  1508. pool_type: Optional[str] = None,
  1509. patch_valid: Optional[torch.Tensor] = None,
  1510. ) -> torch.Tensor:
  1511. if self.attn_pool is not None:
  1512. attn_mask = create_attention_mask(
  1513. patch_valid,
  1514. num_prefix_tokens=self.num_prefix_tokens if self.pool_include_prefix else 0,
  1515. symmetric=False,
  1516. q_len=1,
  1517. dtype=x.dtype,
  1518. )
  1519. if not self.pool_include_prefix:
  1520. x = x[:, self.num_prefix_tokens:]
  1521. x = self.attn_pool(x, attn_mask=attn_mask)
  1522. return x
  1523. pool_type = self.global_pool if pool_type is None else pool_type
  1524. x = global_pool_naflex(
  1525. x,
  1526. patch_valid,
  1527. pool_type=pool_type,
  1528. num_prefix_tokens=self.num_prefix_tokens,
  1529. reduce_include_prefix=self.pool_include_prefix,
  1530. )
  1531. return x
  1532. def forward_head(
  1533. self,
  1534. patches: torch.Tensor,
  1535. pre_logits: bool = False,
  1536. patch_valid: Optional[torch.Tensor] = None,
  1537. ) -> torch.Tensor:
  1538. x = self._pool(patches, patch_valid=patch_valid)
  1539. x = self.fc_norm(x)
  1540. x = self.head_drop(x)
  1541. return x if pre_logits else self.head(x)
  1542. def forward(
  1543. self,
  1544. x: Union[torch.Tensor, Dict[str, torch.Tensor]],
  1545. patch_coord: Optional[torch.Tensor] = None,
  1546. patch_valid: Optional[torch.Tensor] = None,
  1547. attn_mask: Optional[torch.Tensor] = None,
  1548. ) -> torch.Tensor:
  1549. """Forward pass with optional NaFlex support.
  1550. Args:
  1551. x: Input tensor. Supported formats:
  1552. - [B, C, H, W] standard image input
  1553. - [B, N, P*P*C] pre-patchified tensor (flattened patches)
  1554. - [B, N, Ph, Pw, C] pre-patchified tensor (variable patch size)
  1555. - Dict from NaFlex collator
  1556. patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode.
  1557. patch_valid: Optional patch validity indicators for NaFlex.
  1558. attn_mask: Optional attn mask to override defaults generated from patch_valid
  1559. Returns:
  1560. Model output tensor.
  1561. """
  1562. input_is_dict = isinstance(x, Dict)
  1563. naflex_mode = input_is_dict or patch_coord is not None
  1564. if naflex_mode:
  1565. if input_is_dict:
  1566. # Handle dictionary input from NaFlex collator, dict inputs take priority over args
  1567. patches = x['patches']
  1568. patch_valid = x.get('patch_valid', patch_valid)
  1569. patch_coord = x.get('patch_coord', patch_coord)
  1570. attn_mask = x.get('attn_mask', attn_mask)
  1571. else:
  1572. patches = x
  1573. _assert(patch_coord is not None, "patch_coord is required in naflex mode")
  1574. _assert(patch_valid is not None, "patch_valid is required in naflex mode")
  1575. features = self.forward_features(
  1576. patches=patches,
  1577. patch_valid=patch_valid,
  1578. patch_coord=patch_coord,
  1579. attn_mask=attn_mask,
  1580. )
  1581. # Pass patches & patch_valid to forward_head for masked pooling
  1582. x = self.forward_head(**features)
  1583. else:
  1584. x = self.forward_features(x)
  1585. x = self.forward_head(x)
  1586. return x
  1587. def _debug_dump_patches(x):
  1588. # DEBUG, reconstruct patches & save
  1589. patch_coord = x['patch_coord']
  1590. patch_valid = x['patch_valid']
  1591. patches = x['patches']
  1592. for i in range(len(patches)):
  1593. patch = patches[i][patch_valid[i]]
  1594. h = (patch_coord[i, :, 0].max() + 1).item()
  1595. w = (patch_coord[i, :, 1].max() + 1).item()
  1596. patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3)
  1597. patch = patch.reshape(3, h*16, w*16)
  1598. from torchvision.utils import save_image
  1599. save_image(patch, f'patch_{i}.jpg', normalize=True)
  1600. def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0, needs_reset: bool = True) -> Callable:
  1601. """Function imported from vision_transformer.py to maintain compatibility"""
  1602. from .vision_transformer import (
  1603. init_weights_vit_jax,
  1604. init_weights_vit_moco,
  1605. init_weights_vit_timm,
  1606. init_weights_reset_parameters,
  1607. )
  1608. if 'jax' in mode:
  1609. return partial(init_weights_vit_jax, head_bias=head_bias, needs_reset=needs_reset)
  1610. elif 'moco' in mode:
  1611. return partial(init_weights_vit_moco, needs_reset=needs_reset)
  1612. else:
  1613. return partial(init_weights_vit_timm, needs_reset=needs_reset)
  1614. def checkpoint_filter_fn(state_dict: Dict[str, Any], model: NaFlexVit) -> Dict[str, Any]:
  1615. """Handle state dict conversion from original ViT to the new version with combined embedding."""
  1616. # Handle CombinedEmbed module pattern
  1617. out_dict = {}
  1618. for k, v in state_dict.items():
  1619. # Convert tokens and embeddings to combined_embed structure
  1620. if k == 'pos_embed':
  1621. # Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
  1622. if hasattr(model.embeds, 'pos_embed') and v.ndim == 3:
  1623. num_cls_token = 0
  1624. num_reg_token = 0
  1625. if 'reg_token' in state_dict:
  1626. num_reg_token = state_dict['reg_token'].shape[1]
  1627. if 'cls_token' in state_dict:
  1628. num_cls_token = state_dict['cls_token'].shape[1]
  1629. num_prefix_tokens = num_cls_token + num_reg_token
  1630. # Original format is (1, N, C), need to reshape to (1, H, W, C)
  1631. num_patches = v.shape[1]
  1632. num_patches_no_prefix = num_patches - num_prefix_tokens
  1633. grid_size_no_prefix = math.sqrt(num_patches_no_prefix)
  1634. grid_size = math.sqrt(num_patches)
  1635. if (grid_size_no_prefix != grid_size
  1636. and (grid_size_no_prefix.is_integer() and not grid_size.is_integer())
  1637. ):
  1638. # make a decision, did the pos_embed of the original include the prefix tokens?
  1639. num_patches = num_patches_no_prefix
  1640. cls_token_emb = v[:, 0:num_cls_token]
  1641. if cls_token_emb.numel():
  1642. state_dict['cls_token'] += cls_token_emb
  1643. reg_token_emb = v[:, num_cls_token:num_reg_token]
  1644. if reg_token_emb.numel():
  1645. state_dict['reg_token'] += reg_token_emb
  1646. v = v[:, num_prefix_tokens:]
  1647. grid_size = grid_size_no_prefix
  1648. grid_size = int(grid_size)
  1649. # Check if it's a perfect square for a standard grid
  1650. if grid_size * grid_size == num_patches:
  1651. # Reshape from (1, N, C) to (1, H, W, C)
  1652. v = v.reshape(1, grid_size, grid_size, v.shape[2])
  1653. else:
  1654. # Not a square grid, we need to get the actual dimensions
  1655. if hasattr(model.embeds.patch_embed, 'grid_size'):
  1656. h, w = model.embeds.patch_embed.grid_size
  1657. if h * w == num_patches:
  1658. # We have the right dimensions
  1659. v = v.reshape(1, h, w, v.shape[2])
  1660. else:
  1661. # Dimensions don't match, use interpolation
  1662. _logger.warning(
  1663. f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. "
  1664. f"Using default initialization and will resize in forward pass."
  1665. )
  1666. # Keep v as is, the forward pass will handle resizing
  1667. out_dict['embeds.pos_embed'] = v
  1668. elif k == 'cls_token':
  1669. out_dict['embeds.cls_token'] = v
  1670. elif k == 'reg_token':
  1671. out_dict['embeds.reg_token'] = v
  1672. # Convert patch_embed.X to embeds.patch_embed.X
  1673. elif k.startswith('patch_embed.'):
  1674. suffix = k[12:]
  1675. if suffix == 'proj.weight':
  1676. v = v.permute(0, 2, 3, 1).flatten(1)
  1677. new_key = 'embeds.' + suffix
  1678. out_dict[new_key] = v
  1679. else:
  1680. out_dict[k] = v
  1681. return out_dict
  1682. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  1683. return {
  1684. 'url': url,
  1685. 'num_classes': 1000,
  1686. 'input_size': (3, 384, 384),
  1687. 'pool_size': None,
  1688. 'crop_pct': 1.0,
  1689. 'interpolation': 'bicubic',
  1690. 'mean': IMAGENET_INCEPTION_MEAN,
  1691. 'std': IMAGENET_INCEPTION_STD,
  1692. 'first_conv': 'embeds.proj',
  1693. 'classifier': 'head',
  1694. 'license': 'apache-2.0',
  1695. **kwargs,
  1696. }
  1697. default_cfgs = generate_default_cfgs({
  1698. 'naflexvit_base_patch16_gap.e300_s576_in1k': _cfg(
  1699. hf_hub_id='timm/',
  1700. ),
  1701. 'naflexvit_base_patch16_par_gap.e300_s576_in1k': _cfg(
  1702. hf_hub_id='timm/',
  1703. ),
  1704. 'naflexvit_base_patch16_parfac_gap.e300_s576_in1k': _cfg(
  1705. hf_hub_id='timm/',
  1706. ),
  1707. 'naflexvit_base_patch16_map.untrained': _cfg(),
  1708. 'naflexvit_so150m2_patch16_reg1_gap.untrained': _cfg(),
  1709. 'naflexvit_so150m2_patch16_reg1_map.untrained': _cfg(),
  1710. # SigLIP-2 NaFlex vit encoder weights
  1711. 'naflexvit_base_patch16_siglip.v2_webli': _cfg(
  1712. hf_hub_id='timm/',
  1713. num_classes=0),
  1714. 'naflexvit_so400m_patch16_siglip.v2_webli': _cfg(
  1715. hf_hub_id='timm/',
  1716. num_classes=0),
  1717. })
  1718. def _create_naflexvit(variant: str, pretrained: bool = False, **kwargs) -> NaFlexVit:
  1719. out_indices = kwargs.pop('out_indices', 3)
  1720. cfg = kwargs.pop('cfg', NaFlexVitCfg())
  1721. cfg_field_names = {f.name for f in fields(NaFlexVitCfg)}
  1722. # pop in-place so the original kwargs is emptied of cfg-specific keys
  1723. cfg_updates = {k: kwargs.pop(k) for k in list(kwargs) if k in cfg_field_names}
  1724. if cfg_updates:
  1725. cfg = _overlay_kwargs(cfg, **cfg_updates)
  1726. model = build_model_with_cfg(
  1727. NaFlexVit, variant, pretrained,
  1728. pretrained_filter_fn=checkpoint_filter_fn,
  1729. cfg=cfg,
  1730. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  1731. **kwargs,
  1732. )
  1733. return model
  1734. def _create_naflexvit_from_classic(
  1735. variant: str,
  1736. pretrained: bool = False,
  1737. **kwargs,
  1738. ) -> NaFlexVit:
  1739. """Create FlexVit model from classic VisionTransformer configuration.
  1740. This function handles the parameter mapping and configuration logic needed
  1741. to create FlexVit models that are compatible with classic VisionTransformer
  1742. configurations and pretrained weights.
  1743. Args:
  1744. variant: Model variant name
  1745. pretrained: Whether to load pretrained weights
  1746. **kwargs: Classic VisionTransformer parameters
  1747. Returns:
  1748. FlexVit model instance
  1749. """
  1750. # Remove VisionTransformer-specific parameters that don't apply to FlexVit
  1751. kwargs.pop('no_embed_class', None)
  1752. kwargs.pop('dynamic_img_size', None)
  1753. # Handle global pooling and fc_norm defaults that differ between ViT and FlexVit
  1754. gp = kwargs.pop('global_pool', 'token') # Original ViTs default to cls token pooling
  1755. fc_norm = kwargs.pop('fc_norm', None) # Original ViTs used fc_norm when not set and avg pooling used
  1756. if fc_norm is None and gp == 'avg':
  1757. fc_norm = True
  1758. # Set FlexVit-specific defaults that differ from VisionTransformer
  1759. flex_kwargs = {
  1760. 'pos_embed_grid_size': None, # rely on img_size (// patch_size) that will be passed through
  1761. 'class_token': kwargs.get('class_token', True),
  1762. 'global_pool': gp,
  1763. 'fc_norm': fc_norm,
  1764. 'scale_mlp_norm': kwargs.pop('scale_mlp_norm', False),
  1765. 'scale_attn_inner_norm': kwargs.pop('scale_attn_norm', False),
  1766. **kwargs # User overrides take precedence
  1767. }
  1768. return _create_naflexvit(variant, pretrained, **flex_kwargs)
  1769. def _create_naflexvit_from_eva(
  1770. variant: str,
  1771. pretrained: bool = False,
  1772. **kwargs,
  1773. ) -> NaFlexVit:
  1774. """Create NaFlexVit model from EVA configuration.
  1775. This function handles the parameter mapping and configuration logic needed
  1776. to create NaFlexVit models that are compatible with EVA configurations
  1777. and pretrained weights.
  1778. Args:
  1779. variant: Model variant name
  1780. pretrained: Whether to load pretrained weights
  1781. **kwargs: EVA model parameters
  1782. Returns:
  1783. NaFlexVit model instance
  1784. """
  1785. # Handle EVA's unique parameters & block args
  1786. kwargs.pop('no_embed_class', None) # EVA specific, not used in NaFlexVit (always no-embed)
  1787. # Map EVA's rope parameters
  1788. use_rot_pos_emb = kwargs.pop('use_rot_pos_emb', False)
  1789. rope_mixed_mode = kwargs.pop('rope_mixed_mode', False)
  1790. rope_temperature = kwargs.pop('rope_temperature', 10000.)
  1791. rope_grid_offset = kwargs.pop('rope_grid_offset', 0.)
  1792. rope_grid_indexing = kwargs.pop('rope_grid_indexing', 'ij')
  1793. if use_rot_pos_emb:
  1794. rope_type = 'mixed' if rope_mixed_mode else 'axial'
  1795. else:
  1796. rope_type = 'none'
  1797. # Handle norm/pool resolution logic to mirror EVA
  1798. gp = kwargs.pop('global_pool', 'avg')
  1799. use_pre_transformer_norm = kwargs.pop('use_pre_transformer_norm', False)
  1800. use_post_transformer_norm = kwargs.pop('use_post_transformer_norm', True)
  1801. use_fc_norm = kwargs.pop('use_fc_norm', None)
  1802. if use_fc_norm is None:
  1803. use_fc_norm = gp == 'avg' # default on if avg pool used
  1804. # Set NaFlexVit-specific parameters
  1805. naflex_kwargs = {
  1806. 'pos_embed_grid_size': None, # rely on img_size (// patch_size)
  1807. 'class_token': kwargs.get('class_token', True),
  1808. 'reg_tokens': kwargs.pop('num_reg_tokens', kwargs.get('reg_tokens', 0)),
  1809. 'global_pool': gp,
  1810. 'pre_norm': use_pre_transformer_norm,
  1811. 'final_norm': use_post_transformer_norm,
  1812. 'fc_norm': use_fc_norm,
  1813. 'pos_embed': 'learned' if kwargs.pop('use_abs_pos_emb', True) else 'none',
  1814. 'rope_type': rope_type,
  1815. 'rope_temperature': rope_temperature,
  1816. 'rope_grid_offset': rope_grid_offset,
  1817. 'rope_grid_indexing': rope_grid_indexing,
  1818. 'rope_ref_feat_shape': kwargs.get('ref_feat_shape', None),
  1819. 'attn_type': kwargs.pop('attn_type', 'eva'),
  1820. 'swiglu_mlp': kwargs.pop('swiglu_mlp', False),
  1821. 'qkv_fused': kwargs.pop('qkv_fused', True),
  1822. 'scale_mlp_norm': kwargs.pop('scale_mlp', False),
  1823. 'scale_attn_inner_norm': kwargs.pop('scale_attn_inner', False),
  1824. **kwargs # Pass remaining kwargs through
  1825. }
  1826. return _create_naflexvit(variant, pretrained, **naflex_kwargs)
  1827. @register_model
  1828. def naflexvit_base_patch16_gap(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1829. """ViT-Base with NaFlex functionality and global average pooling.
  1830. """
  1831. cfg = NaFlexVitCfg(
  1832. patch_size=16,
  1833. embed_dim=768,
  1834. depth=12,
  1835. num_heads=12,
  1836. init_values=1e-5,
  1837. global_pool='avg',
  1838. reg_tokens=4,
  1839. fc_norm=True,
  1840. )
  1841. model = _create_naflexvit('naflexvit_base_patch16_gap', pretrained=pretrained, cfg=cfg, **kwargs)
  1842. return model
  1843. @register_model
  1844. def naflexvit_base_patch16_par_gap(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1845. """ViT-Base with NaFlex functionality, aspect preserving pos embed, global average pooling.
  1846. """
  1847. cfg = NaFlexVitCfg(
  1848. patch_size=16,
  1849. embed_dim=768,
  1850. depth=12,
  1851. num_heads=12,
  1852. init_values=1e-5,
  1853. pos_embed_ar_preserving=True,
  1854. global_pool='avg',
  1855. reg_tokens=4,
  1856. fc_norm=True,
  1857. )
  1858. model = _create_naflexvit('naflexvit_base_patch16_par_gap', pretrained=pretrained, cfg=cfg, **kwargs)
  1859. return model
  1860. @register_model
  1861. def naflexvit_base_patch16_parfac_gap(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1862. """ViT-Base with NaFlex functionality, aspect preserving & factorized pos embed, global average pooling.
  1863. """
  1864. cfg = NaFlexVitCfg(
  1865. patch_size=16,
  1866. embed_dim=768,
  1867. depth=12,
  1868. num_heads=12,
  1869. init_values=1e-5,
  1870. pos_embed_ar_preserving=True,
  1871. pos_embed='factorized',
  1872. global_pool='avg',
  1873. reg_tokens=4,
  1874. fc_norm=True,
  1875. )
  1876. model = _create_naflexvit('naflexvit_base_patch16_parfac_gap', pretrained=pretrained, cfg=cfg, **kwargs)
  1877. return model
  1878. @register_model
  1879. def naflexvit_base_patch16_map(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1880. """ViT-Base with NaFlex functionality and MAP attention pooling.
  1881. """
  1882. cfg = NaFlexVitCfg(
  1883. patch_size=16,
  1884. embed_dim=768,
  1885. depth=12,
  1886. num_heads=12,
  1887. init_values=1e-5,
  1888. global_pool='map',
  1889. reg_tokens=1,
  1890. )
  1891. model = _create_naflexvit('naflexvit_base_patch16_map', pretrained=pretrained, cfg=cfg, **kwargs)
  1892. return model
  1893. @register_model
  1894. def naflexvit_so150m2_patch16_reg1_gap(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1895. """ViT-SO150M2 with NaFlex functionality for variable aspect ratios and resolutions.
  1896. This model supports:
  1897. 1. Variable aspect ratios and resolutions via patch coordinates
  1898. 2. Position embedding interpolation for arbitrary grid sizes
  1899. 3. Explicit patch coordinates and valid token masking
  1900. """
  1901. cfg = NaFlexVitCfg(
  1902. patch_size=16,
  1903. embed_dim=832,
  1904. depth=21,
  1905. num_heads=13,
  1906. mlp_ratio=34/13,
  1907. init_values=1e-5,
  1908. qkv_bias=False,
  1909. reg_tokens=1,
  1910. global_pool='avg',
  1911. fc_norm=True,
  1912. )
  1913. model = _create_naflexvit('naflexvit_so150m2_patch16_reg1_gap', pretrained=pretrained, cfg=cfg, **kwargs)
  1914. return model
  1915. @register_model
  1916. def naflexvit_so150m2_patch16_reg1_map(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1917. """ViT-SO150M2 with NaFlex functionality for variable aspect ratios and resolutions.
  1918. This model supports:
  1919. 1. Variable aspect ratios and resolutions via patch coordinates
  1920. 2. Position embedding interpolation for arbitrary grid sizes
  1921. 3. Explicit patch coordinates and valid token masking
  1922. """
  1923. cfg = NaFlexVitCfg(
  1924. patch_size=16,
  1925. embed_dim=832,
  1926. depth=21,
  1927. num_heads=13,
  1928. mlp_ratio=34/13,
  1929. init_values=1e-5,
  1930. qkv_bias=False,
  1931. reg_tokens=1,
  1932. global_pool='map',
  1933. )
  1934. model = _create_naflexvit('naflexvit_so150m2_patch16_reg1_map', pretrained=pretrained, cfg=cfg, **kwargs)
  1935. return model
  1936. @register_model
  1937. def naflexvit_base_patch16_siglip(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1938. """ViT-Base with NaFlex functionality and SigLIP-style configuration.
  1939. """
  1940. cfg = NaFlexVitCfg(
  1941. patch_size=16,
  1942. embed_dim=768,
  1943. depth=12,
  1944. num_heads=12,
  1945. act_layer='gelu_tanh',
  1946. global_pool='map',
  1947. )
  1948. model = _create_naflexvit('naflexvit_base_patch16_siglip', pretrained=pretrained, cfg=cfg, **kwargs)
  1949. return model
  1950. @register_model
  1951. def naflexvit_so400m_patch16_siglip(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1952. """ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions.
  1953. """
  1954. cfg = NaFlexVitCfg(
  1955. patch_size=16,
  1956. embed_dim=1152,
  1957. depth=27,
  1958. num_heads=16,
  1959. mlp_ratio=3.7362,
  1960. act_layer='gelu_tanh',
  1961. global_pool='map',
  1962. )
  1963. model = _create_naflexvit('naflexvit_so400m_patch16_siglip', pretrained=pretrained, cfg=cfg, **kwargs)
  1964. return model