csatv2.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871
  1. """CSATv2
  2. A frequency-domain vision model using DCT transforms with spatial attention.
  3. Paper: TBD
  4. This model created by members of MLPA Lab. Welcome feedback and suggestion, questions.
  5. gusdlf93@naver.com
  6. juno.demie.oh@gmail.com
  7. Refined for timm by Ross Wightman
  8. """
  9. import math
  10. import warnings
  11. from functools import partial, reduce
  12. from typing import List, Optional, Tuple, Union
  13. import numpy as np
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from timm.layers import trunc_normal_, DropPath, Mlp, LayerNorm2d, Attention, NormMlpClassifierHead, LayerScale, LayerScale2d
  18. from timm.layers.grn import GlobalResponseNorm
  19. from timm.models._builder import build_model_with_cfg
  20. from timm.models._features import feature_take_indices
  21. from timm.models._manipulate import checkpoint, checkpoint_seq
  22. from ._registry import register_model, generate_default_cfgs
  23. __all__ = ['CSATv2', 'csatv2']
  24. # DCT frequency normalization statistics (Y, Cb, Cr channels x 64 coefficients)
  25. _DCT_MEAN = (
  26. (932.42657, -0.00260, 0.33415, -0.02840, 0.00003, -0.02792, -0.00183, 0.00006,
  27. 0.00032, 0.03402, -0.00571, 0.00020, 0.00006, -0.00038, -0.00558, -0.00116,
  28. -0.00000, -0.00047, -0.00008, -0.00030, 0.00942, 0.00161, -0.00009, -0.00006,
  29. -0.00014, -0.00035, 0.00001, -0.00220, 0.00033, -0.00002, -0.00003, -0.00020,
  30. 0.00007, -0.00000, 0.00005, 0.00293, -0.00004, 0.00006, 0.00019, 0.00004,
  31. 0.00006, -0.00015, -0.00002, 0.00007, 0.00010, -0.00004, 0.00008, 0.00000,
  32. 0.00008, -0.00001, 0.00015, 0.00002, 0.00007, 0.00003, 0.00004, -0.00001,
  33. 0.00004, -0.00000, 0.00002, -0.00000, -0.00008, -0.00000, -0.00003, 0.00003),
  34. (962.34735, -0.00428, 0.09835, 0.00152, -0.00009, 0.00312, -0.00141, -0.00001,
  35. -0.00013, 0.01050, 0.00065, 0.00006, -0.00000, 0.00003, 0.00264, 0.00000,
  36. 0.00001, 0.00007, -0.00006, 0.00003, 0.00341, 0.00163, 0.00004, 0.00003,
  37. -0.00001, 0.00008, -0.00000, 0.00090, 0.00018, -0.00006, -0.00001, 0.00007,
  38. -0.00003, -0.00001, 0.00006, 0.00084, -0.00000, -0.00001, 0.00000, 0.00004,
  39. -0.00001, -0.00002, 0.00000, 0.00001, 0.00002, 0.00001, 0.00004, 0.00011,
  40. 0.00000, -0.00003, 0.00011, -0.00002, 0.00001, 0.00001, 0.00001, 0.00001,
  41. -0.00007, -0.00003, 0.00001, 0.00000, 0.00001, 0.00002, 0.00001, 0.00000),
  42. (1053.16101, -0.00213, -0.09207, 0.00186, 0.00013, 0.00034, -0.00119, 0.00002,
  43. 0.00011, -0.00984, 0.00046, -0.00007, -0.00001, -0.00005, 0.00180, 0.00042,
  44. 0.00002, -0.00010, 0.00004, 0.00003, -0.00301, 0.00125, -0.00002, -0.00003,
  45. -0.00001, -0.00001, -0.00001, 0.00056, 0.00021, 0.00001, -0.00001, 0.00002,
  46. -0.00001, -0.00001, 0.00005, -0.00070, -0.00002, -0.00002, 0.00005, -0.00004,
  47. -0.00000, 0.00002, -0.00002, 0.00001, 0.00000, -0.00003, 0.00004, 0.00007,
  48. 0.00001, 0.00000, 0.00013, -0.00000, 0.00000, 0.00002, -0.00000, -0.00001,
  49. -0.00004, -0.00003, 0.00000, 0.00001, -0.00001, 0.00001, -0.00000, 0.00000),
  50. )
  51. _DCT_VAR = (
  52. (270372.37500, 6287.10645, 5974.94043, 1653.10889, 1463.91748, 1832.58997, 755.92468, 692.41528,
  53. 648.57184, 641.46881, 285.79288, 301.62100, 380.43405, 349.84027, 374.15891, 190.30960,
  54. 190.76746, 221.64578, 200.82646, 145.87979, 126.92046, 62.14622, 67.75562, 102.42001,
  55. 129.74922, 130.04631, 103.12189, 97.76417, 53.17402, 54.81048, 73.48712, 81.04342,
  56. 69.35100, 49.06024, 33.96053, 37.03279, 20.48858, 24.94830, 33.90822, 44.54912,
  57. 47.56363, 40.03160, 30.43313, 22.63899, 26.53739, 26.57114, 21.84404, 17.41557,
  58. 15.18253, 10.69678, 11.24111, 12.97229, 15.08971, 15.31646, 8.90409, 7.44213,
  59. 6.66096, 6.97719, 4.17834, 3.83882, 4.51073, 2.36646, 2.41363, 1.48266),
  60. (18839.21094, 321.70932, 300.15259, 77.47830, 76.02293, 89.04748, 33.99642, 34.74807,
  61. 32.12333, 28.19588, 12.04675, 14.26871, 18.45779, 16.59588, 15.67892, 7.37718,
  62. 8.56312, 10.28946, 9.41013, 6.69090, 5.16453, 2.55186, 3.03073, 4.66765,
  63. 5.85418, 5.74644, 4.33702, 3.66948, 1.95107, 2.26034, 3.06380, 3.50705,
  64. 3.06359, 2.19284, 1.54454, 1.57860, 0.97078, 1.13941, 1.48653, 1.89996,
  65. 1.95544, 1.64950, 1.24754, 0.93677, 1.09267, 1.09516, 0.94163, 0.78966,
  66. 0.72489, 0.50841, 0.50909, 0.55664, 0.63111, 0.64125, 0.38847, 0.33378,
  67. 0.30918, 0.33463, 0.20875, 0.19298, 0.21903, 0.13380, 0.13444, 0.09554),
  68. (17127.39844, 292.81421, 271.45209, 66.64056, 63.60253, 76.35437, 28.06587, 27.84831,
  69. 25.96656, 23.60370, 9.99173, 11.34992, 14.46955, 12.92553, 12.69353, 5.91537,
  70. 6.60187, 7.90891, 7.32825, 5.32785, 4.29660, 2.13459, 2.44135, 3.66021,
  71. 4.50335, 4.38959, 3.34888, 2.97181, 1.60633, 1.77010, 2.35118, 2.69018,
  72. 2.38189, 1.74596, 1.26014, 1.31684, 0.79327, 0.92046, 1.17670, 1.47609,
  73. 1.50914, 1.28725, 0.99898, 0.74832, 0.85736, 0.85800, 0.74663, 0.63508,
  74. 0.58748, 0.41098, 0.41121, 0.44663, 0.50277, 0.51519, 0.31729, 0.27336,
  75. 0.25399, 0.27241, 0.17353, 0.16255, 0.18440, 0.11602, 0.11511, 0.08450),
  76. )
  77. def _zigzag_permutation(rows: int, cols: int) -> List[int]:
  78. """Generate zigzag scan order for DCT coefficients."""
  79. idx_matrix = np.arange(0, rows * cols, 1).reshape(rows, cols).tolist()
  80. dia = [[] for _ in range(rows + cols - 1)]
  81. zigzag = []
  82. for i in range(rows):
  83. for j in range(cols):
  84. s = i + j
  85. if s % 2 == 0:
  86. dia[s].insert(0, idx_matrix[i][j])
  87. else:
  88. dia[s].append(idx_matrix[i][j])
  89. for d in dia:
  90. zigzag.extend(d)
  91. return zigzag
  92. def _dct_kernel_type_2(
  93. kernel_size: int,
  94. orthonormal: bool,
  95. device=None,
  96. dtype=None,
  97. ) -> torch.Tensor:
  98. """Generate Type-II DCT kernel matrix."""
  99. dd = dict(device=device, dtype=dtype)
  100. x = torch.eye(kernel_size, **dd)
  101. v = x.clone().contiguous().view(-1, kernel_size)
  102. v = torch.cat([v, v.flip([1])], dim=-1)
  103. v = torch.fft.fft(v, dim=-1)[:, :kernel_size]
  104. k = (
  105. torch.tensor(-1j, device=device, dtype=torch.complex64) * torch.pi
  106. * torch.arange(kernel_size, device=device, dtype=torch.long)[None, :]
  107. )
  108. k = torch.exp(k / (kernel_size * 2))
  109. v = v * k
  110. v = v.real
  111. if orthonormal:
  112. v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **dd))
  113. v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **dd))
  114. v = v.contiguous().view(*x.shape)
  115. return v
  116. def _dct_kernel_type_3(
  117. kernel_size: int,
  118. orthonormal: bool,
  119. device=None,
  120. dtype=None,
  121. ) -> torch.Tensor:
  122. """Generate Type-III DCT kernel matrix (inverse of Type-II)."""
  123. return torch.linalg.inv(_dct_kernel_type_2(kernel_size, orthonormal, device, dtype))
  124. class Dct1d(nn.Module):
  125. """1D Discrete Cosine Transform layer."""
  126. def __init__(
  127. self,
  128. kernel_size: int,
  129. kernel_type: int = 2,
  130. orthonormal: bool = True,
  131. device=None,
  132. dtype=None,
  133. ) -> None:
  134. dd = dict(device=device, dtype=dtype)
  135. super().__init__()
  136. kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3}
  137. dct_weights = kernel[f'{kernel_type}'](kernel_size, orthonormal, **dd).T
  138. self.register_buffer('weights', dct_weights.contiguous())
  139. self.register_parameter('bias', None)
  140. def forward(self, x: torch.Tensor) -> torch.Tensor:
  141. return F.linear(x, self.weights, self.bias)
  142. class Dct2d(nn.Module):
  143. """2D Discrete Cosine Transform layer."""
  144. def __init__(
  145. self,
  146. kernel_size: int,
  147. kernel_type: int = 2,
  148. orthonormal: bool = True,
  149. device=None,
  150. dtype=None,
  151. ) -> None:
  152. dd = dict(device=device, dtype=dtype)
  153. super().__init__()
  154. self.transform = Dct1d(kernel_size, kernel_type, orthonormal, **dd)
  155. def forward(self, x: torch.Tensor) -> torch.Tensor:
  156. return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2)
  157. def _split_out_chs(out_chs: int, ratio=(24, 4, 4)):
  158. # reduce ratio to smallest integers (24,4,4) -> (6,1,1)
  159. g = reduce(math.gcd, ratio)
  160. r = tuple(x // g for x in ratio)
  161. denom = sum(r)
  162. assert out_chs % denom == 0 and out_chs >= denom, (
  163. f"out_chs={out_chs} can't be split into Y/Cb/Cr with ratio {ratio} "
  164. f"(reduced {r}); out_chs must be a multiple of {denom}."
  165. )
  166. unit = out_chs // denom
  167. y, cb, cr = (ri * unit for ri in r)
  168. assert y + cb + cr == out_chs and min(y, cb, cr) > 0
  169. return y, cb, cr
  170. class LearnableDct2d(nn.Module):
  171. """Learnable 2D DCT stem with RGB to YCbCr conversion and frequency selection."""
  172. def __init__(
  173. self,
  174. kernel_size: int,
  175. kernel_type: int = 2,
  176. orthonormal: bool = True,
  177. out_chs: int = 32,
  178. device=None,
  179. dtype=None,
  180. ) -> None:
  181. dd = dict(device=device, dtype=dtype)
  182. super().__init__()
  183. self.k = kernel_size
  184. self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **dd)
  185. self.permutation = _zigzag_permutation(kernel_size, kernel_size)
  186. y_ch, cb_ch, cr_ch = _split_out_chs(out_chs, ratio=(24, 4, 4))
  187. self.conv_y = nn.Conv2d(kernel_size ** 2, y_ch, kernel_size=1, padding=0, **dd)
  188. self.conv_cb = nn.Conv2d(kernel_size ** 2, cb_ch, kernel_size=1, padding=0, **dd)
  189. self.conv_cr = nn.Conv2d(kernel_size ** 2, cr_ch, kernel_size=1, padding=0, **dd)
  190. # Register empty buffers for DCT normalization statistics
  191. self.register_buffer('mean', torch.empty(3, 64, device=device, dtype=dtype), persistent=False)
  192. self.register_buffer('var', torch.empty(3, 64, device=device, dtype=dtype), persistent=False)
  193. # Shape (3, 1, 1) for BCHW broadcasting
  194. self.register_buffer('imagenet_mean', torch.empty(3, 1, 1, device=device, dtype=dtype), persistent=False)
  195. self.register_buffer('imagenet_std', torch.empty(3, 1, 1, device=device, dtype=dtype), persistent=False)
  196. # TODO: skip init when on meta device when safe to do so
  197. self.reset_parameters()
  198. def reset_parameters(self) -> None:
  199. """Initialize buffers."""
  200. self._init_buffers()
  201. def _init_buffers(self) -> None:
  202. """Compute and fill non-persistent buffer values."""
  203. self.mean.copy_(torch.tensor(_DCT_MEAN))
  204. self.var.copy_(torch.tensor(_DCT_VAR))
  205. self.imagenet_mean.copy_(torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1))
  206. self.imagenet_std.copy_(torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1))
  207. def init_non_persistent_buffers(self) -> None:
  208. """Initialize non-persistent buffers."""
  209. self._init_buffers()
  210. def _denormalize(self, x: torch.Tensor) -> torch.Tensor:
  211. """Convert from ImageNet normalized to [0, 255] range."""
  212. return x.mul(self.imagenet_std).add_(self.imagenet_mean) * 255
  213. def _rgb_to_ycbcr(self, x: torch.Tensor) -> torch.Tensor:
  214. """Convert RGB to YCbCr color space (BCHW input/output)."""
  215. r, g, b = x[:, 0], x[:, 1], x[:, 2]
  216. y = r * 0.299 + g * 0.587 + b * 0.114
  217. cb = 0.564 * (b - y) + 128
  218. cr = 0.713 * (r - y) + 128
  219. return torch.stack([y, cb, cr], dim=1)
  220. def _frequency_normalize(self, x: torch.Tensor) -> torch.Tensor:
  221. """Normalize DCT coefficients using precomputed statistics."""
  222. std = self.var ** 0.5 + 1e-8
  223. return (x - self.mean) / std
  224. def forward(self, x: torch.Tensor) -> torch.Tensor:
  225. b, c, h, w = x.shape
  226. x = self._denormalize(x)
  227. x = self._rgb_to_ycbcr(x)
  228. # Extract non-overlapping k x k patches
  229. x = x.reshape(b, c, h // self.k, self.k, w // self.k, self.k) # (B, C, H//k, k, W//k, k)
  230. x = x.permute(0, 2, 4, 1, 3, 5) # (B, H//k, W//k, C, k, k)
  231. x = self.transform(x)
  232. x = x.reshape(-1, c, self.k * self.k)
  233. x = x[:, :, self.permutation]
  234. x = self._frequency_normalize(x)
  235. x = x.reshape(b, h // self.k, w // self.k, c, -1)
  236. x = x.permute(0, 3, 4, 1, 2).contiguous()
  237. x_y = self.conv_y(x[:, 0])
  238. x_cb = self.conv_cb(x[:, 1])
  239. x_cr = self.conv_cr(x[:, 2])
  240. return torch.cat([x_y, x_cb, x_cr], dim=1)
  241. class Dct2dStats(nn.Module):
  242. """Utility module to compute DCT coefficient statistics."""
  243. def __init__(
  244. self,
  245. kernel_size: int,
  246. kernel_type: int = 2,
  247. orthonormal: bool = True,
  248. device=None,
  249. dtype=None,
  250. ) -> None:
  251. dd = dict(device=device, dtype=dtype)
  252. super().__init__()
  253. self.k = kernel_size
  254. self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **dd)
  255. self.permutation = _zigzag_permutation(kernel_size, kernel_size)
  256. def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  257. b, c, h, w = x.shape
  258. # Extract non-overlapping k x k patches
  259. x = x.reshape(b, c, h // self.k, self.k, w // self.k, self.k) # (B, C, H//k, k, W//k, k)
  260. x = x.permute(0, 2, 4, 1, 3, 5) # (B, H//k, W//k, C, k, k)
  261. x = self.transform(x)
  262. x = x.reshape(-1, c, self.k * self.k)
  263. x = x[:, :, self.permutation]
  264. x = x.reshape(b * (h // self.k) * (w // self.k), c, -1)
  265. mean_list = torch.zeros([3, 64])
  266. var_list = torch.zeros([3, 64])
  267. for i in range(3):
  268. mean_list[i] = torch.mean(x[:, i], dim=0)
  269. var_list[i] = torch.var(x[:, i], dim=0)
  270. return mean_list, var_list
  271. class Block(nn.Module):
  272. """ConvNeXt-style block with spatial attention."""
  273. def __init__(
  274. self,
  275. dim: int,
  276. drop_path: float = 0.,
  277. ls_init_value: Optional[float] = None,
  278. device=None,
  279. dtype=None,
  280. ) -> None:
  281. dd = dict(device=device, dtype=dtype)
  282. super().__init__()
  283. self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, **dd)
  284. self.norm = nn.LayerNorm(dim, eps=1e-6, **dd)
  285. self.pwconv1 = nn.Linear(dim, 4 * dim, **dd)
  286. self.act = nn.GELU()
  287. self.grn = GlobalResponseNorm(4 * dim, channels_last=True, **dd)
  288. self.pwconv2 = nn.Linear(4 * dim, dim, **dd)
  289. self.ls = LayerScale2d(dim, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity()
  290. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  291. self.attn = SpatialAttention(**dd)
  292. def forward(self, x: torch.Tensor) -> torch.Tensor:
  293. shortcut = x
  294. x = self.dwconv(x)
  295. x = x.permute(0, 2, 3, 1)
  296. x = self.norm(x)
  297. x = self.pwconv1(x)
  298. x = self.act(x)
  299. x = self.grn(x)
  300. x = self.pwconv2(x)
  301. x = x.permute(0, 3, 1, 2)
  302. attn = self.attn(x)
  303. attn = F.interpolate(attn, size=x.shape[2:], mode='bilinear', align_corners=True)
  304. x = x * attn
  305. x = self.ls(x)
  306. return shortcut + self.drop_path(x)
  307. class SpatialTransformerBlock(nn.Module):
  308. """Lightweight transformer block for spatial attention (1-channel, 7x7 grid).
  309. This is a simplified transformer with single-head, 1-dim attention over spatial
  310. positions. Used inside SpatialAttention where input is 1 channel at 7x7 resolution.
  311. """
  312. def __init__(
  313. self,
  314. device=None,
  315. dtype=None,
  316. ) -> None:
  317. dd = dict(device=device, dtype=dtype)
  318. super().__init__()
  319. # Single-head attention with 1-dim q/k/v (no output projection needed)
  320. self.pos_embed = PosConv(in_chans=1, **dd)
  321. self.norm1 = nn.LayerNorm(1, **dd)
  322. self.qkv = nn.Linear(1, 3, bias=False, **dd)
  323. # Feedforward: 1 -> 4 -> 1
  324. self.norm2 = nn.LayerNorm(1, **dd)
  325. self.mlp = Mlp(1, 4, 1, act_layer=nn.GELU, **dd)
  326. def forward(self, x: torch.Tensor) -> torch.Tensor:
  327. B, C, H, W = x.shape
  328. # Attention block
  329. shortcut = x
  330. x_t = x.flatten(2).transpose(1, 2) # (B, N, 1)
  331. x_t = self.norm1(x_t)
  332. x_t = self.pos_embed(x_t, (H, W))
  333. # Simple single-head attention with scalar q/k/v
  334. qkv = self.qkv(x_t) # (B, N, 3)
  335. q, k, v = qkv.unbind(-1) # each (B, N)
  336. attn = (q @ k.transpose(-1, -2)).softmax(dim=-1) # (B, N, N)
  337. x_t = (attn @ v).unsqueeze(-1) # (B, N, 1)
  338. x_t = x_t.transpose(1, 2).reshape(B, C, H, W)
  339. x = shortcut + x_t
  340. # Feedforward block
  341. shortcut = x
  342. x_t = x.flatten(2).transpose(1, 2)
  343. x_t = self.mlp(self.norm2(x_t))
  344. x_t = x_t.transpose(1, 2).reshape(B, C, H, W)
  345. x = shortcut + x_t
  346. return x
  347. class SpatialAttention(nn.Module):
  348. """Spatial attention module using channel statistics and transformer."""
  349. def __init__(
  350. self,
  351. device=None,
  352. dtype=None,
  353. ) -> None:
  354. dd = dict(device=device, dtype=dtype)
  355. super().__init__()
  356. self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
  357. self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, **dd)
  358. self.attn = SpatialTransformerBlock(**dd)
  359. def forward(self, x: torch.Tensor) -> torch.Tensor:
  360. x_avg = x.mean(dim=1, keepdim=True)
  361. x_max = x.amax(dim=1, keepdim=True)
  362. x = torch.cat([x_avg, x_max], dim=1)
  363. x = self.avgpool(x)
  364. x = self.conv(x)
  365. x = self.attn(x)
  366. return x
  367. class TransformerBlock(nn.Module):
  368. """Transformer block with optional downsampling and convolutional position encoding."""
  369. def __init__(
  370. self,
  371. inp: int,
  372. oup: int,
  373. num_heads: int = 8,
  374. attn_head_dim: int = 32,
  375. downsample: bool = False,
  376. attn_drop: float = 0.,
  377. proj_drop: float = 0.,
  378. drop_path: float = 0.,
  379. ls_init_value: Optional[float] = None,
  380. device=None,
  381. dtype=None,
  382. ) -> None:
  383. dd = dict(device=device, dtype=dtype)
  384. super().__init__()
  385. hidden_dim = int(inp * 4)
  386. self.downsample = downsample
  387. if self.downsample:
  388. self.pool1 = nn.MaxPool2d(3, 2, 1)
  389. self.pool2 = nn.MaxPool2d(3, 2, 1)
  390. self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False, **dd)
  391. else:
  392. self.pool1 = nn.Identity()
  393. self.pool2 = nn.Identity()
  394. self.proj = nn.Identity()
  395. self.pos_embed = PosConv(in_chans=inp, **dd)
  396. self.norm1 = nn.LayerNorm(inp, **dd)
  397. self.attn = Attention(
  398. dim=inp,
  399. num_heads=num_heads,
  400. attn_head_dim=attn_head_dim,
  401. dim_out=oup,
  402. attn_drop=attn_drop,
  403. proj_drop=proj_drop,
  404. **dd,
  405. )
  406. self.ls1 = LayerScale(oup, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity()
  407. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  408. self.norm2 = nn.LayerNorm(oup, **dd)
  409. self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop, **dd)
  410. self.ls2 = LayerScale(oup, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity()
  411. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  412. def forward(self, x: torch.Tensor) -> torch.Tensor:
  413. if self.downsample:
  414. shortcut = self.proj(self.pool1(x))
  415. x_t = self.pool2(x)
  416. B, C, H, W = x_t.shape
  417. x_t = x_t.flatten(2).transpose(1, 2)
  418. x_t = self.norm1(x_t)
  419. x_t = self.pos_embed(x_t, (H, W))
  420. x_t = self.ls1(self.attn(x_t))
  421. x_t = x_t.transpose(1, 2).reshape(B, -1, H, W)
  422. x = shortcut + self.drop_path1(x_t)
  423. else:
  424. B, C, H, W = x.shape
  425. shortcut = x
  426. x_t = x.flatten(2).transpose(1, 2)
  427. x_t = self.norm1(x_t)
  428. x_t = self.pos_embed(x_t, (H, W))
  429. x_t = self.ls1(self.attn(x_t))
  430. x_t = x_t.transpose(1, 2).reshape(B, -1, H, W)
  431. x = shortcut + self.drop_path1(x_t)
  432. # MLP block
  433. B, C, H, W = x.shape
  434. shortcut = x
  435. x_t = x.flatten(2).transpose(1, 2)
  436. x_t = self.ls2(self.mlp(self.norm2(x_t)))
  437. x_t = x_t.transpose(1, 2).reshape(B, C, H, W)
  438. x = shortcut + self.drop_path2(x_t)
  439. return x
  440. class PosConv(nn.Module):
  441. """Convolutional position encoding."""
  442. def __init__(
  443. self,
  444. in_chans: int,
  445. device=None,
  446. dtype=None,
  447. ) -> None:
  448. dd = dict(device=device, dtype=dtype)
  449. super().__init__()
  450. self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans, **dd)
  451. def forward(self, x: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
  452. B, N, C = x.shape
  453. H, W = size
  454. cnn_feat = x.transpose(1, 2).view(B, C, H, W)
  455. x = self.proj(cnn_feat) + cnn_feat
  456. return x.flatten(2).transpose(1, 2)
  457. class CSATv2(nn.Module):
  458. """CSATv2: Frequency-domain vision model with spatial attention.
  459. A hybrid architecture that processes images in the DCT frequency domain
  460. with ConvNeXt-style blocks and transformer attention.
  461. """
  462. def __init__(
  463. self,
  464. num_classes: int = 1000,
  465. in_chans: int = 3,
  466. dims: Tuple[int, ...] = (32, 72, 168, 386),
  467. depths: Tuple[int, ...] = (2, 2, 8, 6),
  468. transformer_depths: Tuple[int, ...] = (0, 0, 2, 2),
  469. drop_path_rate: float = 0.0,
  470. transformer_drop_path: bool = False,
  471. ls_init_value: Optional[float] = None,
  472. global_pool: str = 'avg',
  473. device=None,
  474. dtype=None,
  475. **kwargs,
  476. ) -> None:
  477. dd = dict(device=device, dtype=dtype)
  478. super().__init__()
  479. if in_chans != 3:
  480. warnings.warn(
  481. f'CSATv2 is designed for 3-channel RGB input. '
  482. f'in_chans={in_chans} may not work correctly with the DCT stem.'
  483. )
  484. self.num_classes = num_classes
  485. self.in_chans = in_chans
  486. self.global_pool = global_pool
  487. self.grad_checkpointing = False
  488. self.num_features = dims[-1]
  489. self.head_hidden_size = self.num_features
  490. # Build feature_info dynamically
  491. self.feature_info = [dict(num_chs=dims[0], reduction=8, module='stem_dct')]
  492. reduction = 8
  493. for i, dim in enumerate(dims):
  494. if i > 0:
  495. reduction *= 2
  496. self.feature_info.append(dict(num_chs=dim, reduction=reduction, module=f'stages.{i}'))
  497. # Build drop path rates for all blocks (0 for transformer blocks when transformer_drop_path=False)
  498. total_blocks = sum(depths) if transformer_drop_path else sum(d - t for d, t in zip(depths, transformer_depths))
  499. dp_iter = iter(torch.linspace(0, drop_path_rate, total_blocks).tolist())
  500. dp_rates = []
  501. for depth, t_depth in zip(depths, transformer_depths):
  502. dp_rates += [next(dp_iter) for _ in range(depth - t_depth)]
  503. dp_rates += [next(dp_iter) if transformer_drop_path else 0. for _ in range(t_depth)]
  504. self.stem_dct = LearnableDct2d(8, out_chs=dims[0], **dd)
  505. # Build stages dynamically
  506. dp_iter = iter(dp_rates)
  507. stages = []
  508. for i, (dim, depth, t_depth) in enumerate(zip(dims, depths, transformer_depths)):
  509. layers = (
  510. # Downsample at start of stage (except first stage)
  511. ([nn.Conv2d(dims[i - 1], dim, kernel_size=2, stride=2, **dd)] if i > 0 else []) +
  512. # Conv blocks
  513. [Block(dim=dim, drop_path=next(dp_iter), ls_init_value=ls_init_value, **dd) for _ in range(depth - t_depth)] +
  514. # Transformer blocks at end of stage
  515. [TransformerBlock(inp=dim, oup=dim, drop_path=next(dp_iter), ls_init_value=ls_init_value, **dd) for _ in range(t_depth)] +
  516. # Trailing LayerNorm (except last stage)
  517. ([LayerNorm2d(dim, eps=1e-6, **dd)] if i < len(depths) - 1 else [])
  518. )
  519. stages.append(nn.Sequential(*layers))
  520. self.stages = nn.Sequential(*stages)
  521. self.head = NormMlpClassifierHead(dims[-1], num_classes, pool_type=global_pool, **dd)
  522. # TODO: skip init when on meta device when safe to do so
  523. self.init_weights(needs_reset=False)
  524. def init_weights(self, needs_reset: bool = True):
  525. self.apply(partial(self._init_weights, needs_reset=needs_reset))
  526. def _init_weights(self, m: nn.Module, needs_reset: bool = True) -> None:
  527. if isinstance(m, (nn.Conv2d, nn.Linear)):
  528. trunc_normal_(m.weight, std=0.02)
  529. if m.bias is not None:
  530. nn.init.constant_(m.bias, 0)
  531. elif needs_reset and hasattr(m, 'reset_parameters'):
  532. m.reset_parameters()
  533. @torch.jit.ignore
  534. def get_classifier(self) -> nn.Module:
  535. return self.head.fc
  536. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  537. self.num_classes = num_classes
  538. if global_pool is not None:
  539. self.global_pool = global_pool
  540. self.head.reset(num_classes, pool_type=global_pool)
  541. @torch.jit.ignore
  542. def set_grad_checkpointing(self, enable: bool = True) -> None:
  543. self.grad_checkpointing = enable
  544. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  545. x = self.stem_dct(x)
  546. if self.grad_checkpointing and not torch.jit.is_scripting():
  547. x = checkpoint_seq(self.stages, x)
  548. else:
  549. x = self.stages(x)
  550. return x
  551. def forward_intermediates(
  552. self,
  553. x: torch.Tensor,
  554. indices: Optional[Union[int, List[int]]] = None,
  555. norm: bool = False,
  556. stop_early: bool = False,
  557. output_fmt: str = 'NCHW',
  558. intermediates_only: bool = False,
  559. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  560. """Forward pass returning intermediate features.
  561. Args:
  562. x: Input image tensor.
  563. indices: Indices of features to return (0=stem_dct, 1-4=stages). None returns all.
  564. norm: Apply norm layer to final intermediate (unused, for API compat).
  565. stop_early: Stop iterating when last desired intermediate is reached.
  566. output_fmt: Output format, must be 'NCHW'.
  567. intermediates_only: Only return intermediate features.
  568. Returns:
  569. List of intermediate features or tuple of (final features, intermediates).
  570. """
  571. assert output_fmt == 'NCHW', 'Output format must be NCHW.'
  572. intermediates = []
  573. # 5 feature levels: stem_dct (0) + stages 0-3 (1-4)
  574. take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
  575. x = self.stem_dct(x)
  576. if 0 in take_indices:
  577. intermediates.append(x)
  578. if torch.jit.is_scripting() or not stop_early:
  579. stages = self.stages
  580. else:
  581. # max_index is 0-4, stages are 1-4, so we need max_index stages
  582. stages = self.stages[:max_index] if max_index > 0 else []
  583. for feat_idx, stage in enumerate(stages):
  584. if self.grad_checkpointing and not torch.jit.is_scripting():
  585. x = checkpoint(stage, x)
  586. else:
  587. x = stage(x)
  588. if feat_idx + 1 in take_indices: # +1 because stem is index 0
  589. intermediates.append(x)
  590. if intermediates_only:
  591. return intermediates
  592. return x, intermediates
  593. def prune_intermediate_layers(
  594. self,
  595. indices: Union[int, List[int]] = 1,
  596. prune_norm: bool = False,
  597. prune_head: bool = True,
  598. ) -> List[int]:
  599. """Prune layers not required for specified intermediates.
  600. Args:
  601. indices: Indices of intermediate layers to keep (0=stem_dct, 1-4=stages).
  602. prune_norm: Whether to prune the final norm layer.
  603. prune_head: Whether to prune the classifier head.
  604. Returns:
  605. List of indices that were kept.
  606. """
  607. # 5 feature levels: stem_dct (0) + stages 0-3 (1-4)
  608. take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
  609. # max_index is 0-4, stages are 1-4, so we keep max_index stages
  610. self.stages = self.stages[:max_index] if max_index > 0 else nn.Sequential()
  611. if prune_norm:
  612. self.head.norm = nn.Identity()
  613. if prune_head:
  614. self.reset_classifier(0, '')
  615. return take_indices
  616. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  617. return self.head(x, pre_logits=pre_logits)
  618. def forward(self, x: torch.Tensor) -> torch.Tensor:
  619. x = self.forward_features(x)
  620. return self.forward_head(x)
  621. def _cfg(url='', **kwargs):
  622. return {
  623. 'url': url,
  624. 'num_classes': 1000, 'input_size': (3, 512, 512), 'pool_size': (8, 8),
  625. 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
  626. 'interpolation': 'bilinear', 'crop_pct': 1.0,
  627. 'classifier': 'head.fc', 'first_conv': [],
  628. **kwargs,
  629. }
  630. default_cfgs = generate_default_cfgs({
  631. 'csatv2.r512_in1k': _cfg(
  632. hf_hub_id='timm/',
  633. ),
  634. 'csatv2_21m.sw_r640_in1k': _cfg(
  635. hf_hub_id='timm/',
  636. input_size=(3, 640, 640),
  637. interpolation='bicubic',
  638. ),
  639. 'csatv2_21m.sw_r512_in1k': _cfg(
  640. hf_hub_id='timm/',
  641. pool_size=(10, 10),
  642. interpolation='bicubic',
  643. ),
  644. })
  645. def checkpoint_filter_fn(state_dict: dict, model: nn.Module) -> dict:
  646. """Remap original CSATv2 checkpoint to timm format.
  647. Handles two key structural changes:
  648. 1) Stage naming: stages1/2/3/4 -> stages.0/1/2/3
  649. 2) Downsample position: moved from end of stage N to start of stage N+1
  650. """
  651. if "stages.0.0.grn.weight" in state_dict:
  652. return state_dict # already in timm format
  653. import re
  654. # FIXME this downsample idx is wired to the original 'csatv2' model size
  655. downsample_idx = {1: 3, 2: 3, 3: 9} # original stage -> downsample index
  656. dct_re = re.compile(r"^dct\.")
  657. stage_re = re.compile(r"^stages([1-4])\.(\d+)\.(.*)$")
  658. head_re = re.compile(r"^head\.")
  659. norm_re = re.compile(r"^norm\.")
  660. def remap_stage(m: re.Match) -> str:
  661. stage, idx, rest = int(m.group(1)), int(m.group(2)), m.group(3)
  662. if stage in downsample_idx and idx == downsample_idx[stage]:
  663. return f"stages.{stage}.0.{rest}" # move downsample to next stage @0
  664. if stage == 1:
  665. return f"stages.0.{idx}.{rest}" # stage1 -> stages.0
  666. return f"stages.{stage - 1}.{idx + 1}.{rest}" # stage2-4 -> stages.1-3, shift +1
  667. out = {}
  668. for k, v in state_dict.items():
  669. # dct -> stem_dct, and Y/Cb/Cr conv names
  670. k = dct_re.sub("stem_dct.", k)
  671. k = (k.replace(".Y_Conv.", ".conv_y.")
  672. .replace(".Cb_Conv.", ".conv_cb.")
  673. .replace(".Cr_Conv.", ".conv_cr."))
  674. # stage remap + downsample relocation
  675. k = stage_re.sub(remap_stage, k)
  676. # GRN: gamma/beta -> weight/bias (reshape)
  677. if "grn.gamma" in k:
  678. k, v = k.replace("grn.gamma", "grn.weight"), v.reshape(-1)
  679. elif "grn.beta" in k:
  680. k, v = k.replace("grn.beta", "grn.bias"), v.reshape(-1)
  681. # FeedForward(nn.Sequential) -> Mlp + norm renames
  682. if ".ff.net.0." in k:
  683. k = k.replace(".ff.net.0.", ".mlp.fc1.")
  684. elif ".ff.net.3." in k:
  685. k = k.replace(".ff.net.3.", ".mlp.fc2.")
  686. elif ".ff_norm." in k:
  687. k = k.replace(".ff_norm.", ".norm2.")
  688. elif ".attn_norm." in k:
  689. k = k.replace(".attn_norm.", ".norm1.")
  690. # attention -> attn (handle nested first)
  691. if ".attention.attention." in k:
  692. k = (k.replace(".attention.attention.attn.to_qkv.", ".attn.attn.qkv.")
  693. .replace(".attention.attention.attn.", ".attn.attn.")
  694. .replace(".attention.attention.", ".attn.attn."))
  695. elif ".attention." in k:
  696. k = k.replace(".attention.", ".attn.")
  697. # TransformerBlock attention name remaps
  698. if ".attn.to_qkv." in k:
  699. k = k.replace(".attn.to_qkv.", ".attn.qkv.")
  700. elif ".attn.to_out.0." in k:
  701. k = k.replace(".attn.to_out.0.", ".attn.proj.")
  702. # .attn.pos_embed -> .pos_embed (but not SpatialTransformerBlock's .attn.attn.pos_embed)
  703. if ".attn.pos_embed." in k and ".attn.attn." not in k:
  704. k = k.replace(".attn.pos_embed.", ".pos_embed.")
  705. # head -> head.fc, norm -> head.norm (order matters)
  706. k = head_re.sub("head.fc.", k)
  707. k = norm_re.sub("head.norm.", k)
  708. out[k] = v
  709. return out
  710. def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2:
  711. out_indices = kwargs.pop('out_indices', (1, 2, 3, 4))
  712. return build_model_with_cfg(
  713. CSATv2,
  714. variant,
  715. pretrained,
  716. pretrained_filter_fn=checkpoint_filter_fn,
  717. feature_cfg=dict(out_indices=out_indices, flatten_sequential=True),
  718. default_cfg=default_cfgs[variant],
  719. **kwargs,
  720. )
  721. @register_model
  722. def csatv2(pretrained: bool = False, **kwargs) -> CSATv2:
  723. return _create_csatv2('csatv2', pretrained, **kwargs)
  724. @register_model
  725. def csatv2_21m(pretrained: bool = False, **kwargs) -> CSATv2:
  726. # experimental ~20-21M param larger model to validate flexible arch spec
  727. model_args = dict(
  728. dims = (48, 96, 224, 448),
  729. depths = (3, 3, 10, 8),
  730. transformer_depths = (0, 0, 4, 3)
  731. )
  732. return _create_csatv2('csatv2_21m', pretrained, **dict(model_args, **kwargs))