volo.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404
  1. """ Vision OutLOoker (VOLO) implementation
  2. Paper: `VOLO: Vision Outlooker for Visual Recognition` - https://arxiv.org/abs/2106.13112
  3. Code adapted from official impl at https://github.com/sail-sg/volo, original copyright in comment below
  4. Modifications and additions for timm by / Copyright 2022, Ross Wightman
  5. """
  6. # Copyright 2021 Sea Limited.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. import math
  20. from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type
  21. import torch
  22. import torch.nn as nn
  23. import torch.nn.functional as F
  24. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  25. from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_, use_fused_attn
  26. from ._builder import build_model_with_cfg
  27. from ._features import feature_take_indices
  28. from ._manipulate import checkpoint
  29. from ._registry import register_model, generate_default_cfgs
  30. __all__ = ['VOLO'] # model_registry will add each entrypoint fn to this
  31. class OutlookAttention(nn.Module):
  32. """Outlook attention mechanism for VOLO models."""
  33. def __init__(
  34. self,
  35. dim: int,
  36. num_heads: int,
  37. kernel_size: int = 3,
  38. padding: int = 1,
  39. stride: int = 1,
  40. qkv_bias: bool = False,
  41. attn_drop: float = 0.,
  42. proj_drop: float = 0.,
  43. device=None,
  44. dtype=None,
  45. ):
  46. """Initialize OutlookAttention.
  47. Args:
  48. dim: Input feature dimension.
  49. num_heads: Number of attention heads.
  50. kernel_size: Kernel size for attention computation.
  51. padding: Padding for attention computation.
  52. stride: Stride for attention computation.
  53. qkv_bias: Whether to use bias in linear layers.
  54. attn_drop: Attention dropout rate.
  55. proj_drop: Projection dropout rate.
  56. """
  57. dd = {'device': device, 'dtype': dtype}
  58. super().__init__()
  59. head_dim = dim // num_heads
  60. self.num_heads = num_heads
  61. self.kernel_size = kernel_size
  62. self.padding = padding
  63. self.stride = stride
  64. self.scale = head_dim ** -0.5
  65. self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  66. self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads, **dd)
  67. self.attn_drop = nn.Dropout(attn_drop)
  68. self.proj = nn.Linear(dim, dim, **dd)
  69. self.proj_drop = nn.Dropout(proj_drop)
  70. self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride)
  71. self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)
  72. def forward(self, x: torch.Tensor) -> torch.Tensor:
  73. """Forward pass.
  74. Args:
  75. x: Input tensor of shape (B, H, W, C).
  76. Returns:
  77. Output tensor of shape (B, H, W, C).
  78. """
  79. B, H, W, C = x.shape
  80. v = self.v(x).permute(0, 3, 1, 2) # B, C, H, W
  81. h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)
  82. v = self.unfold(v).reshape(
  83. B, self.num_heads, C // self.num_heads,
  84. self.kernel_size * self.kernel_size, h * w).permute(0, 1, 4, 3, 2) # B,H,N,kxk,C/H
  85. attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
  86. attn = self.attn(attn).reshape(
  87. B, h * w, self.num_heads, self.kernel_size * self.kernel_size,
  88. self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4) # B,H,N,kxk,kxk
  89. attn = attn * self.scale
  90. attn = attn.softmax(dim=-1)
  91. attn = self.attn_drop(attn)
  92. x = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size, h * w)
  93. x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size, padding=self.padding, stride=self.stride)
  94. x = self.proj(x.permute(0, 2, 3, 1))
  95. x = self.proj_drop(x)
  96. return x
  97. class Outlooker(nn.Module):
  98. """Outlooker block that combines outlook attention with MLP."""
  99. def __init__(
  100. self,
  101. dim: int,
  102. kernel_size: int,
  103. padding: int,
  104. stride: int = 1,
  105. num_heads: int = 1,
  106. mlp_ratio: float = 3.,
  107. attn_drop: float = 0.,
  108. drop_path: float = 0.,
  109. act_layer: Type[nn.Module] = nn.GELU,
  110. norm_layer: Type[nn.Module] = nn.LayerNorm,
  111. qkv_bias: bool = False,
  112. device=None,
  113. dtype=None,
  114. ):
  115. """Initialize Outlooker block.
  116. Args:
  117. dim: Input feature dimension.
  118. kernel_size: Kernel size for outlook attention.
  119. padding: Padding for outlook attention.
  120. stride: Stride for outlook attention.
  121. num_heads: Number of attention heads.
  122. mlp_ratio: Ratio for MLP hidden dimension.
  123. attn_drop: Attention dropout rate.
  124. drop_path: Stochastic depth drop rate.
  125. act_layer: Activation layer type.
  126. norm_layer: Normalization layer type.
  127. qkv_bias: Whether to use bias in linear layers.
  128. """
  129. dd = {'device': device, 'dtype': dtype}
  130. super().__init__()
  131. self.norm1 = norm_layer(dim, **dd)
  132. self.attn = OutlookAttention(
  133. dim,
  134. num_heads,
  135. kernel_size=kernel_size,
  136. padding=padding,
  137. stride=stride,
  138. qkv_bias=qkv_bias,
  139. attn_drop=attn_drop,
  140. **dd,
  141. )
  142. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  143. self.norm2 = norm_layer(dim, **dd)
  144. self.mlp = Mlp(
  145. in_features=dim,
  146. hidden_features=int(dim * mlp_ratio),
  147. act_layer=act_layer,
  148. **dd,
  149. )
  150. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  151. def forward(self, x: torch.Tensor) -> torch.Tensor:
  152. """Forward pass.
  153. Args:
  154. x: Input tensor.
  155. Returns:
  156. Output tensor.
  157. """
  158. x = x + self.drop_path1(self.attn(self.norm1(x)))
  159. x = x + self.drop_path2(self.mlp(self.norm2(x)))
  160. return x
  161. class Attention(nn.Module):
  162. """Multi-head self-attention module."""
  163. fused_attn: torch.jit.Final[bool]
  164. def __init__(
  165. self,
  166. dim: int,
  167. num_heads: int = 8,
  168. qkv_bias: bool = False,
  169. attn_drop: float = 0.,
  170. proj_drop: float = 0.,
  171. device=None,
  172. dtype=None,
  173. ):
  174. """Initialize Attention module.
  175. Args:
  176. dim: Input feature dimension.
  177. num_heads: Number of attention heads.
  178. qkv_bias: Whether to use bias in QKV projection.
  179. attn_drop: Attention dropout rate.
  180. proj_drop: Projection dropout rate.
  181. """
  182. dd = {'device': device, 'dtype': dtype}
  183. super().__init__()
  184. self.num_heads = num_heads
  185. head_dim = dim // num_heads
  186. self.scale = head_dim ** -0.5
  187. self.fused_attn = use_fused_attn()
  188. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  189. self.attn_drop = nn.Dropout(attn_drop)
  190. self.proj = nn.Linear(dim, dim, **dd)
  191. self.proj_drop = nn.Dropout(proj_drop)
  192. def forward(self, x: torch.Tensor) -> torch.Tensor:
  193. """Forward pass.
  194. Args:
  195. x: Input tensor of shape (B, H, W, C).
  196. Returns:
  197. Output tensor of shape (B, H, W, C).
  198. """
  199. B, H, W, C = x.shape
  200. qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  201. q, k, v = qkv.unbind(0)
  202. if self.fused_attn:
  203. x = F.scaled_dot_product_attention(
  204. q, k, v,
  205. dropout_p=self.attn_drop.p if self.training else 0.,
  206. )
  207. else:
  208. q = q * self.scale
  209. attn = q @ k.transpose(-2, -1)
  210. attn = attn.softmax(dim=-1)
  211. attn = self.attn_drop(attn)
  212. x = attn @ v
  213. x = x.transpose(1, 2).reshape(B, H, W, C)
  214. x = self.proj(x)
  215. x = self.proj_drop(x)
  216. return x
  217. class Transformer(nn.Module):
  218. """Transformer block with multi-head self-attention and MLP."""
  219. def __init__(
  220. self,
  221. dim: int,
  222. num_heads: int,
  223. mlp_ratio: float = 4.,
  224. qkv_bias: bool = False,
  225. attn_drop: float = 0.,
  226. drop_path: float = 0.,
  227. act_layer: Type[nn.Module] = nn.GELU,
  228. norm_layer: Type[nn.Module] = nn.LayerNorm,
  229. device=None,
  230. dtype=None,
  231. ):
  232. """Initialize Transformer block.
  233. Args:
  234. dim: Input feature dimension.
  235. num_heads: Number of attention heads.
  236. mlp_ratio: Ratio for MLP hidden dimension.
  237. qkv_bias: Whether to use bias in QKV projection.
  238. attn_drop: Attention dropout rate.
  239. drop_path: Stochastic depth drop rate.
  240. act_layer: Activation layer type.
  241. norm_layer: Normalization layer type.
  242. """
  243. dd = {'device': device, 'dtype': dtype}
  244. super().__init__()
  245. self.norm1 = norm_layer(dim, **dd)
  246. self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, **dd)
  247. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  248. self.norm2 = norm_layer(dim, **dd)
  249. self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, **dd)
  250. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  251. def forward(self, x: torch.Tensor) -> torch.Tensor:
  252. """Forward pass.
  253. Args:
  254. x: Input tensor.
  255. Returns:
  256. Output tensor.
  257. """
  258. x = x + self.drop_path1(self.attn(self.norm1(x)))
  259. x = x + self.drop_path2(self.mlp(self.norm2(x)))
  260. return x
  261. class ClassAttention(nn.Module):
  262. """Class attention mechanism for class token interaction."""
  263. def __init__(
  264. self,
  265. dim: int,
  266. num_heads: int = 8,
  267. head_dim: Optional[int] = None,
  268. qkv_bias: bool = False,
  269. attn_drop: float = 0.,
  270. proj_drop: float = 0.,
  271. device=None,
  272. dtype=None,
  273. ):
  274. """Initialize ClassAttention.
  275. Args:
  276. dim: Input feature dimension.
  277. num_heads: Number of attention heads.
  278. head_dim: Dimension per head. If None, computed as dim // num_heads.
  279. qkv_bias: Whether to use bias in QKV projection.
  280. attn_drop: Attention dropout rate.
  281. proj_drop: Projection dropout rate.
  282. """
  283. dd = {'device': device, 'dtype': dtype}
  284. super().__init__()
  285. self.num_heads = num_heads
  286. if head_dim is not None:
  287. self.head_dim = head_dim
  288. else:
  289. head_dim = dim // num_heads
  290. self.head_dim = head_dim
  291. self.scale = head_dim ** -0.5
  292. self.kv = nn.Linear(dim, self.head_dim * self.num_heads * 2, bias=qkv_bias, **dd)
  293. self.q = nn.Linear(dim, self.head_dim * self.num_heads, bias=qkv_bias, **dd)
  294. self.attn_drop = nn.Dropout(attn_drop)
  295. self.proj = nn.Linear(self.head_dim * self.num_heads, dim, **dd)
  296. self.proj_drop = nn.Dropout(proj_drop)
  297. def forward(self, x: torch.Tensor) -> torch.Tensor:
  298. """Forward pass.
  299. Args:
  300. x: Input tensor of shape (B, N, C) where first token is class token.
  301. Returns:
  302. Class token output of shape (B, 1, C).
  303. """
  304. B, N, C = x.shape
  305. kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  306. k, v = kv.unbind(0)
  307. q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim) * self.scale
  308. attn = q @ k.transpose(-2, -1)
  309. attn = attn.softmax(dim=-1)
  310. attn = self.attn_drop(attn)
  311. cls_embed = (attn @ v).transpose(1, 2).reshape(B, 1, self.head_dim * self.num_heads)
  312. cls_embed = self.proj(cls_embed)
  313. cls_embed = self.proj_drop(cls_embed)
  314. return cls_embed
  315. class ClassBlock(nn.Module):
  316. """Class block that combines class attention with MLP."""
  317. def __init__(
  318. self,
  319. dim: int,
  320. num_heads: int,
  321. head_dim: Optional[int] = None,
  322. mlp_ratio: float = 4.,
  323. qkv_bias: bool = False,
  324. drop: float = 0.,
  325. attn_drop: float = 0.,
  326. drop_path: float = 0.,
  327. act_layer: Type[nn.Module] = nn.GELU,
  328. norm_layer: Type[nn.Module] = nn.LayerNorm,
  329. device=None,
  330. dtype=None,
  331. ):
  332. """Initialize ClassBlock.
  333. Args:
  334. dim: Input feature dimension.
  335. num_heads: Number of attention heads.
  336. head_dim: Dimension per head. If None, computed as dim // num_heads.
  337. mlp_ratio: Ratio for MLP hidden dimension.
  338. qkv_bias: Whether to use bias in QKV projection.
  339. drop: Dropout rate.
  340. attn_drop: Attention dropout rate.
  341. drop_path: Stochastic depth drop rate.
  342. act_layer: Activation layer type.
  343. norm_layer: Normalization layer type.
  344. """
  345. dd = {'device': device, 'dtype': dtype}
  346. super().__init__()
  347. self.norm1 = norm_layer(dim, **dd)
  348. self.attn = ClassAttention(
  349. dim,
  350. num_heads=num_heads,
  351. head_dim=head_dim,
  352. qkv_bias=qkv_bias,
  353. attn_drop=attn_drop,
  354. proj_drop=drop,
  355. **dd,
  356. )
  357. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  358. self.norm2 = norm_layer(dim, **dd)
  359. self.mlp = Mlp(
  360. in_features=dim,
  361. hidden_features=int(dim * mlp_ratio),
  362. act_layer=act_layer,
  363. drop=drop,
  364. **dd,
  365. )
  366. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  367. def forward(self, x: torch.Tensor) -> torch.Tensor:
  368. """Forward pass.
  369. Args:
  370. x: Input tensor of shape (B, N, C) where first token is class token.
  371. Returns:
  372. Output tensor with updated class token.
  373. """
  374. cls_embed = x[:, :1]
  375. cls_embed = cls_embed + self.drop_path1(self.attn(self.norm1(x)))
  376. cls_embed = cls_embed + self.drop_path2(self.mlp(self.norm2(cls_embed)))
  377. return torch.cat([cls_embed, x[:, 1:]], dim=1)
  378. def get_block(block_type: str, **kwargs: Any) -> nn.Module:
  379. """Get block based on type.
  380. Args:
  381. block_type: Type of block ('ca' for ClassBlock).
  382. **kwargs: Additional keyword arguments for block.
  383. Returns:
  384. The requested block module.
  385. """
  386. if block_type == 'ca':
  387. return ClassBlock(**kwargs)
  388. else:
  389. assert False, f'Invalid block type: {block_type}'
  390. def rand_bbox(size: Tuple[int, ...], lam: float, scale: int = 1) -> Tuple[int, int, int, int]:
  391. """Get random bounding box for token labeling.
  392. Reference: https://github.com/zihangJiang/TokenLabeling
  393. Args:
  394. size: Input tensor size tuple.
  395. lam: Lambda parameter for cutmix.
  396. scale: Scaling factor.
  397. Returns:
  398. Bounding box coordinates (bbx1, bby1, bbx2, bby2).
  399. """
  400. W = size[1] // scale
  401. H = size[2] // scale
  402. W_t = torch.tensor(W, dtype=torch.float32)
  403. H_t = torch.tensor(H, dtype=torch.float32)
  404. cut_rat = torch.sqrt(1. - lam)
  405. cut_w = (W_t * cut_rat).int()
  406. cut_h = (H_t * cut_rat).int()
  407. # uniform
  408. cx = torch.randint(0, W, (1,))
  409. cy = torch.randint(0, H, (1,))
  410. bbx1 = torch.clamp(cx - cut_w // 2, 0, W)
  411. bby1 = torch.clamp(cy - cut_h // 2, 0, H)
  412. bbx2 = torch.clamp(cx + cut_w // 2, 0, W)
  413. bby2 = torch.clamp(cy + cut_h // 2, 0, H)
  414. return bbx1.item(), bby1.item(), bbx2.item(), bby2.item()
  415. class PatchEmbed(nn.Module):
  416. """Image to patch embedding with multi-layer convolution."""
  417. def __init__(
  418. self,
  419. img_size: int = 224,
  420. stem_conv: bool = False,
  421. stem_stride: int = 1,
  422. patch_size: int = 8,
  423. in_chans: int = 3,
  424. hidden_dim: int = 64,
  425. embed_dim: int = 384,
  426. device=None,
  427. dtype=None,
  428. ):
  429. """Initialize PatchEmbed.
  430. Different from ViT which uses 1 conv layer, VOLO uses multiple conv layers for patch embedding.
  431. Args:
  432. img_size: Input image size.
  433. stem_conv: Whether to use stem convolution layers.
  434. stem_stride: Stride for stem convolution.
  435. patch_size: Patch size (must be 4, 8, or 16).
  436. in_chans: Number of input channels.
  437. hidden_dim: Hidden dimension for stem convolution.
  438. embed_dim: Output embedding dimension.
  439. """
  440. dd = {'device': device, 'dtype': dtype}
  441. super().__init__()
  442. assert patch_size in [4, 8, 16]
  443. if stem_conv:
  444. self.conv = nn.Sequential(
  445. nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False, **dd),
  446. nn.BatchNorm2d(hidden_dim, **dd),
  447. nn.ReLU(inplace=True),
  448. nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False, **dd),
  449. nn.BatchNorm2d(hidden_dim, **dd),
  450. nn.ReLU(inplace=True),
  451. nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False, **dd),
  452. nn.BatchNorm2d(hidden_dim, **dd),
  453. nn.ReLU(inplace=True),
  454. )
  455. else:
  456. self.conv = None
  457. self.proj = nn.Conv2d(
  458. hidden_dim,
  459. embed_dim,
  460. kernel_size=patch_size // stem_stride,
  461. stride=patch_size // stem_stride,
  462. **dd,
  463. )
  464. self.num_patches = (img_size // patch_size) * (img_size // patch_size)
  465. def forward(self, x: torch.Tensor) -> torch.Tensor:
  466. """Forward pass.
  467. Args:
  468. x: Input tensor of shape (B, C, H, W).
  469. Returns:
  470. Output tensor of shape (B, embed_dim, H', W').
  471. """
  472. if self.conv is not None:
  473. x = self.conv(x)
  474. x = self.proj(x) # B, C, H, W
  475. return x
  476. class Downsample(nn.Module):
  477. """Downsampling module between stages."""
  478. def __init__(
  479. self,
  480. in_embed_dim: int,
  481. out_embed_dim: int,
  482. patch_size: int = 2,
  483. device=None,
  484. dtype=None,
  485. ):
  486. """Initialize Downsample.
  487. Args:
  488. in_embed_dim: Input embedding dimension.
  489. out_embed_dim: Output embedding dimension.
  490. patch_size: Patch size for downsampling.
  491. """
  492. super().__init__()
  493. dd = {'device': device, 'dtype': dtype}
  494. self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size, **dd)
  495. def forward(self, x: torch.Tensor) -> torch.Tensor:
  496. """Forward pass.
  497. Args:
  498. x: Input tensor of shape (B, H, W, C).
  499. Returns:
  500. Output tensor of shape (B, H', W', C').
  501. """
  502. x = x.permute(0, 3, 1, 2)
  503. x = self.proj(x) # B, C, H, W
  504. x = x.permute(0, 2, 3, 1)
  505. return x
  506. def outlooker_blocks(
  507. block_fn: Callable,
  508. index: int,
  509. dim: int,
  510. layers: List[int],
  511. num_heads: int = 1,
  512. kernel_size: int = 3,
  513. padding: int = 1,
  514. stride: int = 2,
  515. mlp_ratio: float = 3.,
  516. qkv_bias: bool = False,
  517. attn_drop: float = 0,
  518. drop_path_rate: float = 0.,
  519. device=None,
  520. dtype=None,
  521. **kwargs: Any,
  522. ) -> nn.Sequential:
  523. """Generate outlooker layers for stage 1.
  524. Args:
  525. block_fn: Block function to use (typically Outlooker).
  526. index: Index of current stage.
  527. dim: Feature dimension.
  528. layers: List of layer counts for each stage.
  529. num_heads: Number of attention heads.
  530. kernel_size: Kernel size for outlook attention.
  531. padding: Padding for outlook attention.
  532. stride: Stride for outlook attention.
  533. mlp_ratio: Ratio for MLP hidden dimension.
  534. qkv_bias: Whether to use bias in QKV projection.
  535. attn_drop: Attention dropout rate.
  536. drop_path_rate: Stochastic depth drop rate.
  537. **kwargs: Additional keyword arguments.
  538. Returns:
  539. Sequential module containing outlooker blocks.
  540. """
  541. blocks = []
  542. for block_idx in range(layers[index]):
  543. block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
  544. blocks.append(block_fn(
  545. dim,
  546. kernel_size=kernel_size,
  547. padding=padding,
  548. stride=stride,
  549. num_heads=num_heads,
  550. mlp_ratio=mlp_ratio,
  551. qkv_bias=qkv_bias,
  552. attn_drop=attn_drop,
  553. drop_path=block_dpr,
  554. device=device,
  555. dtype=dtype,
  556. **kwargs,
  557. ))
  558. blocks = nn.Sequential(*blocks)
  559. return blocks
  560. def transformer_blocks(
  561. block_fn: Callable,
  562. index: int,
  563. dim: int,
  564. layers: List[int],
  565. num_heads: int,
  566. mlp_ratio: float = 3.,
  567. qkv_bias: bool = False,
  568. attn_drop: float = 0,
  569. drop_path_rate: float = 0.,
  570. **kwargs: Any,
  571. ) -> nn.Sequential:
  572. """Generate transformer layers for stage 2.
  573. Args:
  574. block_fn: Block function to use (typically Transformer).
  575. index: Index of current stage.
  576. dim: Feature dimension.
  577. layers: List of layer counts for each stage.
  578. num_heads: Number of attention heads.
  579. mlp_ratio: Ratio for MLP hidden dimension.
  580. qkv_bias: Whether to use bias in QKV projection.
  581. attn_drop: Attention dropout rate.
  582. drop_path_rate: Stochastic depth drop rate.
  583. **kwargs: Additional keyword arguments.
  584. Returns:
  585. Sequential module containing transformer blocks.
  586. """
  587. blocks = []
  588. for block_idx in range(layers[index]):
  589. block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
  590. blocks.append(block_fn(
  591. dim,
  592. num_heads,
  593. mlp_ratio=mlp_ratio,
  594. qkv_bias=qkv_bias,
  595. attn_drop=attn_drop,
  596. drop_path=block_dpr,
  597. **kwargs,
  598. ))
  599. blocks = nn.Sequential(*blocks)
  600. return blocks
  601. class VOLO(nn.Module):
  602. """Vision Outlooker (VOLO) model."""
  603. def __init__(
  604. self,
  605. layers: List[int],
  606. img_size: int = 224,
  607. in_chans: int = 3,
  608. num_classes: int = 1000,
  609. global_pool: str = 'token',
  610. patch_size: int = 8,
  611. stem_hidden_dim: int = 64,
  612. embed_dims: Optional[List[int]] = None,
  613. num_heads: Optional[List[int]] = None,
  614. downsamples: Tuple[bool, ...] = (True, False, False, False),
  615. outlook_attention: Tuple[bool, ...] = (True, False, False, False),
  616. mlp_ratio: float = 3.0,
  617. qkv_bias: bool = False,
  618. drop_rate: float = 0.,
  619. pos_drop_rate: float = 0.,
  620. attn_drop_rate: float = 0.,
  621. drop_path_rate: float = 0.,
  622. norm_layer: Type[nn.Module] = nn.LayerNorm,
  623. post_layers: Optional[Tuple[str, ...]] = ('ca', 'ca'),
  624. use_aux_head: bool = True,
  625. use_mix_token: bool = False,
  626. pooling_scale: int = 2,
  627. device=None,
  628. dtype=None,
  629. ):
  630. """Initialize VOLO model.
  631. Args:
  632. layers: Number of blocks in each stage.
  633. img_size: Input image size.
  634. in_chans: Number of input channels.
  635. num_classes: Number of classes for classification.
  636. global_pool: Global pooling type ('token', 'avg', or '').
  637. patch_size: Patch size for patch embedding.
  638. stem_hidden_dim: Hidden dimension for stem convolution.
  639. embed_dims: List of embedding dimensions for each stage.
  640. num_heads: List of number of attention heads for each stage.
  641. downsamples: Whether to downsample between stages.
  642. outlook_attention: Whether to use outlook attention in each stage.
  643. mlp_ratio: Ratio for MLP hidden dimension.
  644. qkv_bias: Whether to use bias in QKV projection.
  645. drop_rate: Dropout rate.
  646. pos_drop_rate: Position embedding dropout rate.
  647. attn_drop_rate: Attention dropout rate.
  648. drop_path_rate: Stochastic depth drop rate.
  649. norm_layer: Normalization layer type.
  650. post_layers: Post-processing layer types.
  651. use_aux_head: Whether to use auxiliary head.
  652. use_mix_token: Whether to use token mixing for training.
  653. pooling_scale: Pooling scale factor.
  654. """
  655. super().__init__()
  656. dd = {'device': device, 'dtype': dtype}
  657. num_layers = len(layers)
  658. mlp_ratio = to_ntuple(num_layers)(mlp_ratio)
  659. img_size = to_2tuple(img_size)
  660. self.num_classes = num_classes
  661. self.in_chans = in_chans
  662. self.global_pool = global_pool
  663. self.mix_token = use_mix_token
  664. self.pooling_scale = pooling_scale
  665. self.num_features = self.head_hidden_size = embed_dims[-1]
  666. if use_mix_token: # enable token mixing, see token labeling for details.
  667. self.beta = 1.0
  668. assert global_pool == 'token', "return all tokens if mix_token is enabled"
  669. self.grad_checkpointing = False
  670. self.patch_embed = PatchEmbed(
  671. stem_conv=True,
  672. stem_stride=2,
  673. patch_size=patch_size,
  674. in_chans=in_chans,
  675. hidden_dim=stem_hidden_dim,
  676. embed_dim=embed_dims[0],
  677. **dd,
  678. )
  679. r = patch_size
  680. # initial positional encoding, we add positional encoding after outlooker blocks
  681. patch_grid = (img_size[0] // patch_size // pooling_scale, img_size[1] // patch_size // pooling_scale)
  682. self.pos_embed = nn.Parameter(torch.zeros(1, patch_grid[0], patch_grid[1], embed_dims[-1], **dd))
  683. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  684. # set the main block in network
  685. self.stage_ends = []
  686. self.feature_info = []
  687. network = []
  688. block_idx = 0
  689. for i in range(len(layers)):
  690. if outlook_attention[i]:
  691. # stage 1
  692. stage = outlooker_blocks(
  693. Outlooker,
  694. i,
  695. embed_dims[i],
  696. layers,
  697. num_heads[i],
  698. mlp_ratio=mlp_ratio[i],
  699. qkv_bias=qkv_bias,
  700. attn_drop=attn_drop_rate,
  701. norm_layer=norm_layer,
  702. **dd,
  703. )
  704. else:
  705. # stage 2
  706. stage = transformer_blocks(
  707. Transformer,
  708. i,
  709. embed_dims[i],
  710. layers,
  711. num_heads[i],
  712. mlp_ratio=mlp_ratio[i],
  713. qkv_bias=qkv_bias,
  714. drop_path_rate=drop_path_rate,
  715. attn_drop=attn_drop_rate,
  716. norm_layer=norm_layer,
  717. **dd,
  718. )
  719. network.append(stage)
  720. self.stage_ends.append(block_idx)
  721. self.feature_info.append(dict(num_chs=embed_dims[i], reduction=r, module=f'network.{block_idx}'))
  722. block_idx += 1
  723. if downsamples[i]:
  724. # downsampling between two stages
  725. network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2, **dd))
  726. r *= 2
  727. block_idx += 1
  728. self.network = nn.ModuleList(network)
  729. # set post block, for example, class attention layers
  730. self.post_network = None
  731. if post_layers is not None:
  732. self.post_network = nn.ModuleList([
  733. get_block(
  734. post_layers[i],
  735. dim=embed_dims[-1],
  736. num_heads=num_heads[-1],
  737. mlp_ratio=mlp_ratio[-1],
  738. qkv_bias=qkv_bias,
  739. attn_drop=attn_drop_rate,
  740. drop_path=0.,
  741. norm_layer=norm_layer,
  742. **dd,
  743. )
  744. for i in range(len(post_layers))
  745. ])
  746. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1], **dd))
  747. trunc_normal_(self.cls_token, std=.02)
  748. # set output type
  749. if use_aux_head:
  750. self.aux_head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  751. else:
  752. self.aux_head = None
  753. self.norm = norm_layer(self.num_features, **dd)
  754. # Classifier head
  755. self.head_drop = nn.Dropout(drop_rate)
  756. self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  757. trunc_normal_(self.pos_embed, std=.02)
  758. self.apply(self._init_weights)
  759. def _init_weights(self, m: nn.Module) -> None:
  760. """Initialize weights for modules.
  761. Args:
  762. m: Module to initialize.
  763. """
  764. if isinstance(m, nn.Linear):
  765. trunc_normal_(m.weight, std=.02)
  766. if isinstance(m, nn.Linear) and m.bias is not None:
  767. nn.init.constant_(m.bias, 0)
  768. @torch.jit.ignore
  769. def no_weight_decay(self) -> set:
  770. """Get set of parameters that should not have weight decay.
  771. Returns:
  772. Set of parameter names.
  773. """
  774. return {'pos_embed', 'cls_token'}
  775. @torch.jit.ignore
  776. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  777. """Get parameter grouping for optimizer.
  778. Args:
  779. coarse: Whether to use coarse grouping.
  780. Returns:
  781. Parameter grouping dictionary.
  782. """
  783. return dict(
  784. stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
  785. blocks=[
  786. (r'^network\.(\d+)\.(\d+)', None),
  787. (r'^network\.(\d+)', (0,)),
  788. ],
  789. blocks2=[
  790. (r'^cls_token', (0,)),
  791. (r'^post_network\.(\d+)', None),
  792. (r'^norm', (99999,))
  793. ],
  794. )
  795. @torch.jit.ignore
  796. def set_grad_checkpointing(self, enable: bool = True) -> None:
  797. """Set gradient checkpointing.
  798. Args:
  799. enable: Whether to enable gradient checkpointing.
  800. """
  801. self.grad_checkpointing = enable
  802. @torch.jit.ignore
  803. def get_classifier(self) -> nn.Module:
  804. """Get classifier module.
  805. Returns:
  806. The classifier head module.
  807. """
  808. return self.head
  809. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  810. """Reset classifier head.
  811. Args:
  812. num_classes: Number of classes for new classifier.
  813. global_pool: Global pooling type.
  814. """
  815. self.num_classes = num_classes
  816. if global_pool is not None:
  817. self.global_pool = global_pool
  818. device = self.head.weight.device if hasattr(self.head, 'weight') else None
  819. dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None
  820. self.head = nn.Linear(
  821. self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  822. if self.aux_head is not None:
  823. self.aux_head = nn.Linear(
  824. self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  825. def forward_tokens(self, x: torch.Tensor) -> torch.Tensor:
  826. """Forward pass through token processing stages.
  827. Args:
  828. x: Input tensor of shape (B, H, W, C).
  829. Returns:
  830. Token tensor of shape (B, N, C).
  831. """
  832. for idx, block in enumerate(self.network):
  833. if idx == 2:
  834. # add positional encoding after outlooker blocks
  835. x = x + self.pos_embed
  836. x = self.pos_drop(x)
  837. if self.grad_checkpointing and not torch.jit.is_scripting():
  838. x = checkpoint(block, x)
  839. else:
  840. x = block(x)
  841. B, H, W, C = x.shape
  842. x = x.reshape(B, -1, C)
  843. return x
  844. def forward_cls(self, x: torch.Tensor) -> torch.Tensor:
  845. """Forward pass through class attention blocks.
  846. Args:
  847. x: Input token tensor of shape (B, N, C).
  848. Returns:
  849. Output tensor with class token of shape (B, N+1, C).
  850. """
  851. B, N, C = x.shape
  852. cls_tokens = self.cls_token.expand(B, -1, -1)
  853. x = torch.cat([cls_tokens, x], dim=1)
  854. for block in self.post_network:
  855. if self.grad_checkpointing and not torch.jit.is_scripting():
  856. x = checkpoint(block, x)
  857. else:
  858. x = block(x)
  859. return x
  860. def forward_train(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, Tuple[int, int, int, int]]]:
  861. """Forward pass for training with mix token support.
  862. Args:
  863. x: Input tensor of shape (B, C, H, W).
  864. Returns:
  865. If training with mix_token: tuple of (class_token, aux_tokens, bbox).
  866. Otherwise: class_token tensor.
  867. """
  868. """ A separate forward fn for training with mix_token (if a train script supports).
  869. Combining multiple modes in as single forward with different return types is torchscript hell.
  870. """
  871. x = self.patch_embed(x)
  872. x = x.permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
  873. # mix token, see token labeling for details.
  874. if self.mix_token and self.training:
  875. lam = torch.distributions.Beta(self.beta, self.beta).sample()
  876. patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[2] // self.pooling_scale
  877. bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)
  878. temp_x = x.clone()
  879. sbbx1, sbby1 = self.pooling_scale * bbx1, self.pooling_scale * bby1
  880. sbbx2, sbby2 = self.pooling_scale * bbx2, self.pooling_scale * bby2
  881. temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]
  882. x = temp_x
  883. else:
  884. bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0
  885. # step2: tokens learning in the two stages
  886. x = self.forward_tokens(x)
  887. # step3: post network, apply class attention or not
  888. if self.post_network is not None:
  889. x = self.forward_cls(x)
  890. x = self.norm(x)
  891. if self.global_pool == 'avg':
  892. x_cls = x.mean(dim=1)
  893. elif self.global_pool == 'token':
  894. x_cls = x[:, 0]
  895. else:
  896. x_cls = x
  897. if self.aux_head is None:
  898. return x_cls
  899. x_aux = self.aux_head(x[:, 1:]) # generate classes in all feature tokens, see token labeling
  900. if not self.training:
  901. return x_cls + 0.5 * x_aux.max(1)[0]
  902. if self.mix_token and self.training: # reverse "mix token", see token labeling for details.
  903. x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1])
  904. temp_x = x_aux.clone()
  905. temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :]
  906. x_aux = temp_x
  907. x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1])
  908. # return these: 1. class token, 2. classes from all feature tokens, 3. bounding box
  909. return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)
  910. def forward_intermediates(
  911. self,
  912. x: torch.Tensor,
  913. indices: Optional[Union[int, List[int]]] = None,
  914. norm: bool = False,
  915. stop_early: bool = False,
  916. output_fmt: str = 'NCHW',
  917. intermediates_only: bool = False,
  918. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  919. """ Forward features that returns intermediates.
  920. Args:
  921. x: Input image tensor
  922. indices: Take last n blocks if int, all if None, select matching indices if sequence
  923. norm: Apply norm layer to all intermediates
  924. stop_early: Stop iterating over blocks when last desired intermediate hit
  925. output_fmt: Shape of intermediate feature outputs
  926. intermediates_only: Only return intermediate features
  927. Returns:
  928. """
  929. assert output_fmt in ('NCHW',), 'Output format must be NCHW.'
  930. intermediates = []
  931. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  932. take_indices = [self.stage_ends[i] for i in take_indices]
  933. max_index = self.stage_ends[max_index]
  934. # forward pass
  935. B, _, height, width = x.shape
  936. x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
  937. # step2: tokens learning in the two stages
  938. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  939. network = self.network
  940. else:
  941. network = self.network[:max_index + 1]
  942. for idx, block in enumerate(network):
  943. if idx == 2:
  944. # add positional encoding after outlooker blocks
  945. x = x + self.pos_embed
  946. x = self.pos_drop(x)
  947. if self.grad_checkpointing and not torch.jit.is_scripting():
  948. x = checkpoint(block, x)
  949. else:
  950. x = block(x)
  951. if idx in take_indices:
  952. if norm and idx >= 2:
  953. x_inter = self.norm(x)
  954. else:
  955. x_inter = x
  956. intermediates.append(x_inter.permute(0, 3, 1, 2))
  957. if intermediates_only:
  958. return intermediates
  959. # NOTE not supporting return of class tokens
  960. # step3: post network, apply class attention or not
  961. B, H, W, C = x.shape
  962. x = x.reshape(B, -1, C)
  963. if self.post_network is not None:
  964. x = self.forward_cls(x)
  965. x = self.norm(x)
  966. return x, intermediates
  967. def prune_intermediate_layers(
  968. self,
  969. indices: Union[int, List[int]] = 1,
  970. prune_norm: bool = False,
  971. prune_head: bool = True,
  972. ) -> List[int]:
  973. """Prune layers not required for specified intermediates.
  974. Args:
  975. indices: Indices of intermediate layers to keep.
  976. prune_norm: Whether to prune normalization layer.
  977. prune_head: Whether to prune classification head.
  978. Returns:
  979. List of kept intermediate indices.
  980. """
  981. """ Prune layers not required for specified intermediates.
  982. """
  983. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  984. max_index = self.stage_ends[max_index]
  985. self.network = self.network[:max_index + 1] # truncate blocks
  986. if prune_norm:
  987. self.norm = nn.Identity()
  988. if prune_head:
  989. self.post_network = nn.ModuleList() # prune token blocks with head
  990. self.reset_classifier(0, '')
  991. return take_indices
  992. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  993. """Forward pass through feature extraction.
  994. Args:
  995. x: Input tensor of shape (B, C, H, W).
  996. Returns:
  997. Feature tensor.
  998. """
  999. x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
  1000. # step2: tokens learning in the two stages
  1001. x = self.forward_tokens(x)
  1002. # step3: post network, apply class attention or not
  1003. if self.post_network is not None:
  1004. x = self.forward_cls(x)
  1005. x = self.norm(x)
  1006. return x
  1007. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  1008. """Forward pass through classification head.
  1009. Args:
  1010. x: Input feature tensor.
  1011. pre_logits: Whether to return pre-logits features.
  1012. Returns:
  1013. Classification logits or pre-logits features.
  1014. """
  1015. if self.global_pool == 'avg':
  1016. out = x.mean(dim=1)
  1017. elif self.global_pool == 'token':
  1018. out = x[:, 0]
  1019. else:
  1020. out = x
  1021. x = self.head_drop(x)
  1022. if pre_logits:
  1023. return out
  1024. out = self.head(out)
  1025. if self.aux_head is not None:
  1026. # generate classes in all feature tokens, see token labeling
  1027. aux = self.aux_head(x[:, 1:])
  1028. out = out + 0.5 * aux.max(1)[0]
  1029. return out
  1030. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1031. """Forward pass (simplified, without mix token training).
  1032. Args:
  1033. x: Input tensor of shape (B, C, H, W).
  1034. Returns:
  1035. Classification logits.
  1036. """
  1037. """ simplified forward (without mix token training) """
  1038. x = self.forward_features(x)
  1039. x = self.forward_head(x)
  1040. return x
  1041. def _create_volo(variant: str, pretrained: bool = False, **kwargs: Any) -> VOLO:
  1042. """Create VOLO model.
  1043. Args:
  1044. variant: Model variant name.
  1045. pretrained: Whether to load pretrained weights.
  1046. **kwargs: Additional model arguments.
  1047. Returns:
  1048. VOLO model instance.
  1049. """
  1050. out_indices = kwargs.pop('out_indices', 3)
  1051. return build_model_with_cfg(
  1052. VOLO,
  1053. variant,
  1054. pretrained,
  1055. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  1056. **kwargs,
  1057. )
  1058. def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
  1059. """Create model configuration.
  1060. Args:
  1061. url: URL for pretrained weights.
  1062. **kwargs: Additional configuration options.
  1063. Returns:
  1064. Model configuration dictionary.
  1065. """
  1066. return {
  1067. 'url': url,
  1068. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  1069. 'crop_pct': .96, 'interpolation': 'bicubic', 'fixed_input_size': True,
  1070. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  1071. 'first_conv': 'patch_embed.conv.0', 'classifier': ('head', 'aux_head'),
  1072. 'license': 'apache-2.0',
  1073. **kwargs
  1074. }
  1075. default_cfgs = generate_default_cfgs({
  1076. 'volo_d1_224.sail_in1k': _cfg(
  1077. hf_hub_id='timm/',
  1078. url='https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar',
  1079. crop_pct=0.96),
  1080. 'volo_d1_384.sail_in1k': _cfg(
  1081. hf_hub_id='timm/',
  1082. url='https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar',
  1083. crop_pct=1.0, input_size=(3, 384, 384)),
  1084. 'volo_d2_224.sail_in1k': _cfg(
  1085. hf_hub_id='timm/',
  1086. url='https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar',
  1087. crop_pct=0.96),
  1088. 'volo_d2_384.sail_in1k': _cfg(
  1089. hf_hub_id='timm/',
  1090. url='https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar',
  1091. crop_pct=1.0, input_size=(3, 384, 384)),
  1092. 'volo_d3_224.sail_in1k': _cfg(
  1093. hf_hub_id='timm/',
  1094. url='https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar',
  1095. crop_pct=0.96),
  1096. 'volo_d3_448.sail_in1k': _cfg(
  1097. hf_hub_id='timm/',
  1098. url='https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar',
  1099. crop_pct=1.0, input_size=(3, 448, 448)),
  1100. 'volo_d4_224.sail_in1k': _cfg(
  1101. hf_hub_id='timm/',
  1102. url='https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar',
  1103. crop_pct=0.96),
  1104. 'volo_d4_448.sail_in1k': _cfg(
  1105. hf_hub_id='timm/',
  1106. url='https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar',
  1107. crop_pct=1.15, input_size=(3, 448, 448)),
  1108. 'volo_d5_224.sail_in1k': _cfg(
  1109. hf_hub_id='timm/',
  1110. url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar',
  1111. crop_pct=0.96),
  1112. 'volo_d5_448.sail_in1k': _cfg(
  1113. hf_hub_id='timm/',
  1114. url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar',
  1115. crop_pct=1.15, input_size=(3, 448, 448)),
  1116. 'volo_d5_512.sail_in1k': _cfg(
  1117. hf_hub_id='timm/',
  1118. url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar',
  1119. crop_pct=1.15, input_size=(3, 512, 512)),
  1120. })
  1121. @register_model
  1122. def volo_d1_224(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1123. """VOLO-D1 model, Params: 27M."""
  1124. model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
  1125. model = _create_volo('volo_d1_224', pretrained=pretrained, **model_args)
  1126. return model
  1127. @register_model
  1128. def volo_d1_384(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1129. """VOLO-D1 model, Params: 27M."""
  1130. model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
  1131. model = _create_volo('volo_d1_384', pretrained=pretrained, **model_args)
  1132. return model
  1133. @register_model
  1134. def volo_d2_224(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1135. """VOLO-D2 model, Params: 59M."""
  1136. model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
  1137. model = _create_volo('volo_d2_224', pretrained=pretrained, **model_args)
  1138. return model
  1139. @register_model
  1140. def volo_d2_384(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1141. """VOLO-D2 model, Params: 59M."""
  1142. model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
  1143. model = _create_volo('volo_d2_384', pretrained=pretrained, **model_args)
  1144. return model
  1145. @register_model
  1146. def volo_d3_224(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1147. """VOLO-D3 model, Params: 86M."""
  1148. model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
  1149. model = _create_volo('volo_d3_224', pretrained=pretrained, **model_args)
  1150. return model
  1151. @register_model
  1152. def volo_d3_448(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1153. """VOLO-D3 model, Params: 86M."""
  1154. model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
  1155. model = _create_volo('volo_d3_448', pretrained=pretrained, **model_args)
  1156. return model
  1157. @register_model
  1158. def volo_d4_224(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1159. """VOLO-D4 model, Params: 193M."""
  1160. model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
  1161. model = _create_volo('volo_d4_224', pretrained=pretrained, **model_args)
  1162. return model
  1163. @register_model
  1164. def volo_d4_448(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1165. """VOLO-D4 model, Params: 193M."""
  1166. model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
  1167. model = _create_volo('volo_d4_448', pretrained=pretrained, **model_args)
  1168. return model
  1169. @register_model
  1170. def volo_d5_224(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1171. """VOLO-D5 model, Params: 296M.
  1172. stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5.
  1173. """
  1174. model_args = dict(
  1175. layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16),
  1176. mlp_ratio=4, stem_hidden_dim=128, **kwargs)
  1177. model = _create_volo('volo_d5_224', pretrained=pretrained, **model_args)
  1178. return model
  1179. @register_model
  1180. def volo_d5_448(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1181. """VOLO-D5 model, Params: 296M.
  1182. stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5.
  1183. """
  1184. model_args = dict(
  1185. layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16),
  1186. mlp_ratio=4, stem_hidden_dim=128, **kwargs)
  1187. model = _create_volo('volo_d5_448', pretrained=pretrained, **model_args)
  1188. return model
  1189. @register_model
  1190. def volo_d5_512(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1191. """VOLO-D5 model, Params: 296M.
  1192. stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5.
  1193. """
  1194. model_args = dict(
  1195. layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16),
  1196. mlp_ratio=4, stem_hidden_dim=128, **kwargs)
  1197. model = _create_volo('volo_d5_512', pretrained=pretrained, **model_args)
  1198. return model