coat.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844
  1. """
  2. CoaT architecture.
  3. Paper: Co-Scale Conv-Attentional Image Transformers - https://arxiv.org/abs/2104.06399
  4. Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
  5. Modified from timm/models/vision_transformer.py
  6. """
  7. from typing import List, Optional, Tuple, Union, Type, Any
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  12. from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, LayerNorm
  13. from ._builder import build_model_with_cfg
  14. from ._registry import register_model, generate_default_cfgs
  15. __all__ = ['CoaT']
  16. class ConvRelPosEnc(nn.Module):
  17. """ Convolutional relative position encoding. """
  18. def __init__(
  19. self,
  20. head_chs: int,
  21. num_heads: int,
  22. window: Union[int, dict],
  23. device=None,
  24. dtype=None,
  25. ):
  26. """
  27. Initialization.
  28. Ch: Channels per head.
  29. h: Number of heads.
  30. window: Window size(s) in convolutional relative positional encoding. It can have two forms:
  31. 1. An integer of window size, which assigns all attention heads with the same window s
  32. size in ConvRelPosEnc.
  33. 2. A dict mapping window size to #attention head splits (
  34. e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2})
  35. It will apply different window size to the attention head splits.
  36. """
  37. dd = {'device': device, 'dtype': dtype}
  38. super().__init__()
  39. if isinstance(window, int):
  40. # Set the same window size for all attention heads.
  41. window = {window: num_heads}
  42. self.window = window
  43. elif isinstance(window, dict):
  44. self.window = window
  45. else:
  46. raise ValueError()
  47. self.conv_list = nn.ModuleList()
  48. self.head_splits = []
  49. for cur_window, cur_head_split in window.items():
  50. dilation = 1
  51. # Determine padding size.
  52. # Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338
  53. padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2
  54. cur_conv = nn.Conv2d(
  55. cur_head_split * head_chs,
  56. cur_head_split * head_chs,
  57. kernel_size=(cur_window, cur_window),
  58. padding=(padding_size, padding_size),
  59. dilation=(dilation, dilation),
  60. groups=cur_head_split * head_chs,
  61. **dd,
  62. )
  63. self.conv_list.append(cur_conv)
  64. self.head_splits.append(cur_head_split)
  65. self.channel_splits = [x * head_chs for x in self.head_splits]
  66. def forward(self, q, v, size: Tuple[int, int]):
  67. B, num_heads, N, C = q.shape
  68. H, W = size
  69. _assert(N == 1 + H * W, '')
  70. # Convolutional relative position encoding.
  71. q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
  72. v_img = v[:, :, 1:, :] # [B, h, H*W, Ch]
  73. v_img = v_img.transpose(-1, -2).reshape(B, num_heads * C, H, W)
  74. v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels
  75. conv_v_img_list = []
  76. for i, conv in enumerate(self.conv_list):
  77. conv_v_img_list.append(conv(v_img_list[i]))
  78. conv_v_img = torch.cat(conv_v_img_list, dim=1)
  79. conv_v_img = conv_v_img.reshape(B, num_heads, C, H * W).transpose(-1, -2)
  80. EV_hat = q_img * conv_v_img
  81. EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch].
  82. return EV_hat
  83. class FactorAttnConvRelPosEnc(nn.Module):
  84. """ Factorized attention with convolutional relative position encoding class. """
  85. def __init__(
  86. self,
  87. dim: int,
  88. num_heads: int = 8,
  89. qkv_bias: bool = False,
  90. attn_drop: float = 0.,
  91. proj_drop: float = 0.,
  92. shared_crpe: Optional[Any] = None,
  93. device=None,
  94. dtype=None,
  95. ):
  96. dd = {'device': device, 'dtype': dtype}
  97. super().__init__()
  98. self.num_heads = num_heads
  99. head_dim = dim // num_heads
  100. self.scale = head_dim ** -0.5
  101. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  102. self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
  103. self.proj = nn.Linear(dim, dim, **dd)
  104. self.proj_drop = nn.Dropout(proj_drop)
  105. # Shared convolutional relative position encoding.
  106. self.crpe = shared_crpe
  107. def forward(self, x, size: Tuple[int, int]):
  108. B, N, C = x.shape
  109. # Generate Q, K, V.
  110. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  111. q, k, v = qkv.unbind(0) # [B, h, N, Ch]
  112. # Factorized attention.
  113. k_softmax = k.softmax(dim=2)
  114. factor_att = k_softmax.transpose(-1, -2) @ v
  115. factor_att = q @ factor_att
  116. # Convolutional relative position encoding.
  117. crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
  118. # Merge and reshape.
  119. x = self.scale * factor_att + crpe
  120. x = x.transpose(1, 2).reshape(B, N, C) # [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]
  121. # Output projection.
  122. x = self.proj(x)
  123. x = self.proj_drop(x)
  124. return x
  125. class ConvPosEnc(nn.Module):
  126. """ Convolutional Position Encoding.
  127. Note: This module is similar to the conditional position encoding in CPVT.
  128. """
  129. def __init__(
  130. self,
  131. dim: int,
  132. k: int = 3,
  133. device=None,
  134. dtype=None,
  135. ):
  136. dd = {'device': device, 'dtype': dtype}
  137. super().__init__()
  138. self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim, **dd)
  139. def forward(self, x, size: Tuple[int, int]):
  140. B, N, C = x.shape
  141. H, W = size
  142. _assert(N == 1 + H * W, '')
  143. # Extract CLS token and image tokens.
  144. cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
  145. # Depthwise convolution.
  146. feat = img_tokens.transpose(1, 2).view(B, C, H, W)
  147. x = self.proj(feat) + feat
  148. x = x.flatten(2).transpose(1, 2)
  149. # Combine with CLS token.
  150. x = torch.cat((cls_token, x), dim=1)
  151. return x
  152. class SerialBlock(nn.Module):
  153. """ Serial block class.
  154. Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """
  155. def __init__(
  156. self,
  157. dim: int,
  158. num_heads: int,
  159. mlp_ratio: float = 4.,
  160. qkv_bias: bool = False,
  161. proj_drop: float = 0.,
  162. attn_drop: float = 0.,
  163. drop_path: float = 0.,
  164. act_layer: Type[nn.Module] = nn.GELU,
  165. norm_layer: Type[nn.Module] = nn.LayerNorm,
  166. shared_cpe: Optional[Any] = None,
  167. shared_crpe: Optional[Any] = None,
  168. device=None,
  169. dtype=None,
  170. ):
  171. dd = {'device': device, 'dtype': dtype}
  172. super().__init__()
  173. # Conv-Attention.
  174. self.cpe = shared_cpe
  175. self.norm1 = norm_layer(dim, **dd)
  176. self.factoratt_crpe = FactorAttnConvRelPosEnc(
  177. dim,
  178. num_heads=num_heads,
  179. qkv_bias=qkv_bias,
  180. attn_drop=attn_drop,
  181. proj_drop=proj_drop,
  182. shared_crpe=shared_crpe,
  183. **dd,
  184. )
  185. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  186. # MLP.
  187. self.norm2 = norm_layer(dim, **dd)
  188. mlp_hidden_dim = int(dim * mlp_ratio)
  189. self.mlp = Mlp(
  190. in_features=dim,
  191. hidden_features=mlp_hidden_dim,
  192. act_layer=act_layer,
  193. drop=proj_drop,
  194. **dd,
  195. )
  196. def forward(self, x, size: Tuple[int, int]):
  197. # Conv-Attention.
  198. x = self.cpe(x, size)
  199. cur = self.norm1(x)
  200. cur = self.factoratt_crpe(cur, size)
  201. x = x + self.drop_path(cur)
  202. # MLP.
  203. cur = self.norm2(x)
  204. cur = self.mlp(cur)
  205. x = x + self.drop_path(cur)
  206. return x
  207. class ParallelBlock(nn.Module):
  208. """ Parallel block class. """
  209. def __init__(
  210. self,
  211. dims: List[int],
  212. num_heads: int,
  213. mlp_ratios: List[float] = None,
  214. qkv_bias: bool = False,
  215. proj_drop: float = 0.,
  216. attn_drop: float = 0.,
  217. drop_path: float = 0.,
  218. act_layer: Type[nn.Module] = nn.GELU,
  219. norm_layer: Type[nn.Module] = nn.LayerNorm,
  220. shared_crpes: Optional[List[Any]] = None,
  221. device=None,
  222. dtype=None,
  223. ):
  224. dd = {'device': device, 'dtype': dtype}
  225. super().__init__()
  226. if mlp_ratios is None:
  227. mlp_ratios = []
  228. # Conv-Attention.
  229. self.norm12 = norm_layer(dims[1], **dd)
  230. self.norm13 = norm_layer(dims[2], **dd)
  231. self.norm14 = norm_layer(dims[3], **dd)
  232. self.factoratt_crpe2 = FactorAttnConvRelPosEnc(
  233. dims[1],
  234. num_heads=num_heads,
  235. qkv_bias=qkv_bias,
  236. attn_drop=attn_drop,
  237. proj_drop=proj_drop,
  238. shared_crpe=shared_crpes[1],
  239. **dd,
  240. )
  241. self.factoratt_crpe3 = FactorAttnConvRelPosEnc(
  242. dims[2],
  243. num_heads=num_heads,
  244. qkv_bias=qkv_bias,
  245. attn_drop=attn_drop,
  246. proj_drop=proj_drop,
  247. shared_crpe=shared_crpes[2],
  248. **dd,
  249. )
  250. self.factoratt_crpe4 = FactorAttnConvRelPosEnc(
  251. dims[3],
  252. num_heads=num_heads,
  253. qkv_bias=qkv_bias,
  254. attn_drop=attn_drop,
  255. proj_drop=proj_drop,
  256. shared_crpe=shared_crpes[3],
  257. **dd,
  258. )
  259. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  260. # MLP.
  261. self.norm22 = norm_layer(dims[1], **dd)
  262. self.norm23 = norm_layer(dims[2], **dd)
  263. self.norm24 = norm_layer(dims[3], **dd)
  264. # In parallel block, we assume dimensions are the same and share the linear transformation.
  265. assert dims[1] == dims[2] == dims[3]
  266. assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]
  267. mlp_hidden_dim = int(dims[1] * mlp_ratios[1])
  268. self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
  269. in_features=dims[1],
  270. hidden_features=mlp_hidden_dim,
  271. act_layer=act_layer,
  272. drop=proj_drop,
  273. **dd,
  274. )
  275. def upsample(self, x, factor: float, size: Tuple[int, int]):
  276. """ Feature map up-sampling. """
  277. return self.interpolate(x, scale_factor=factor, size=size)
  278. def downsample(self, x, factor: float, size: Tuple[int, int]):
  279. """ Feature map down-sampling. """
  280. return self.interpolate(x, scale_factor=1.0/factor, size=size)
  281. def interpolate(self, x, scale_factor: float, size: Tuple[int, int]):
  282. """ Feature map interpolation. """
  283. B, N, C = x.shape
  284. H, W = size
  285. _assert(N == 1 + H * W, '')
  286. cls_token = x[:, :1, :]
  287. img_tokens = x[:, 1:, :]
  288. img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
  289. img_tokens = F.interpolate(
  290. img_tokens,
  291. scale_factor=scale_factor,
  292. recompute_scale_factor=False,
  293. mode='bilinear',
  294. align_corners=False,
  295. )
  296. img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)
  297. out = torch.cat((cls_token, img_tokens), dim=1)
  298. return out
  299. def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]):
  300. _, S2, S3, S4 = sizes
  301. cur2 = self.norm12(x2)
  302. cur3 = self.norm13(x3)
  303. cur4 = self.norm14(x4)
  304. cur2 = self.factoratt_crpe2(cur2, size=S2)
  305. cur3 = self.factoratt_crpe3(cur3, size=S3)
  306. cur4 = self.factoratt_crpe4(cur4, size=S4)
  307. upsample3_2 = self.upsample(cur3, factor=2., size=S3)
  308. upsample4_3 = self.upsample(cur4, factor=2., size=S4)
  309. upsample4_2 = self.upsample(cur4, factor=4., size=S4)
  310. downsample2_3 = self.downsample(cur2, factor=2., size=S2)
  311. downsample3_4 = self.downsample(cur3, factor=2., size=S3)
  312. downsample2_4 = self.downsample(cur2, factor=4., size=S2)
  313. cur2 = cur2 + upsample3_2 + upsample4_2
  314. cur3 = cur3 + upsample4_3 + downsample2_3
  315. cur4 = cur4 + downsample3_4 + downsample2_4
  316. x2 = x2 + self.drop_path(cur2)
  317. x3 = x3 + self.drop_path(cur3)
  318. x4 = x4 + self.drop_path(cur4)
  319. # MLP.
  320. cur2 = self.norm22(x2)
  321. cur3 = self.norm23(x3)
  322. cur4 = self.norm24(x4)
  323. cur2 = self.mlp2(cur2)
  324. cur3 = self.mlp3(cur3)
  325. cur4 = self.mlp4(cur4)
  326. x2 = x2 + self.drop_path(cur2)
  327. x3 = x3 + self.drop_path(cur3)
  328. x4 = x4 + self.drop_path(cur4)
  329. return x1, x2, x3, x4
  330. class CoaT(nn.Module):
  331. """ CoaT class. """
  332. def __init__(
  333. self,
  334. img_size: int = 224,
  335. patch_size: int = 16,
  336. in_chans: int = 3,
  337. num_classes: int = 1000,
  338. embed_dims: Tuple[int, int, int, int] = (64, 128, 320, 512),
  339. serial_depths: Tuple[int, int, int, int] = (3, 4, 6, 3),
  340. parallel_depth: int = 0,
  341. num_heads: int = 8,
  342. mlp_ratios: Tuple[float, float, float, float] = (4, 4, 4, 4),
  343. qkv_bias: bool = True,
  344. drop_rate: float = 0.,
  345. proj_drop_rate: float = 0.,
  346. attn_drop_rate: float = 0.,
  347. drop_path_rate: float = 0.,
  348. norm_layer: Type[nn.Module] = LayerNorm,
  349. return_interm_layers: bool = False,
  350. out_features: Optional[List[str]] = None,
  351. crpe_window: Optional[dict] = None,
  352. global_pool: str = 'token',
  353. device=None,
  354. dtype=None,
  355. ):
  356. super().__init__()
  357. dd = {'device': device, 'dtype': dtype}
  358. assert global_pool in ('token', 'avg')
  359. crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
  360. self.return_interm_layers = return_interm_layers
  361. self.out_features = out_features
  362. self.embed_dims = embed_dims
  363. self.num_features = self.head_hidden_size = embed_dims[-1]
  364. self.num_classes = num_classes
  365. self.in_chans = in_chans
  366. self.global_pool = global_pool
  367. # Patch embeddings.
  368. img_size = to_2tuple(img_size)
  369. self.patch_embed1 = PatchEmbed(
  370. img_size=img_size, patch_size=patch_size, in_chans=in_chans,
  371. embed_dim=embed_dims[0], norm_layer=nn.LayerNorm, **dd)
  372. self.patch_embed2 = PatchEmbed(
  373. img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0],
  374. embed_dim=embed_dims[1], norm_layer=nn.LayerNorm, **dd)
  375. self.patch_embed3 = PatchEmbed(
  376. img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1],
  377. embed_dim=embed_dims[2], norm_layer=nn.LayerNorm, **dd)
  378. self.patch_embed4 = PatchEmbed(
  379. img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2],
  380. embed_dim=embed_dims[3], norm_layer=nn.LayerNorm, **dd)
  381. # Class tokens.
  382. self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0], **dd))
  383. self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1], **dd))
  384. self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2], **dd))
  385. self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3], **dd))
  386. # Convolutional position encodings.
  387. self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3, **dd)
  388. self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3, **dd)
  389. self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3, **dd)
  390. self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3, **dd)
  391. # Convolutional relative position encodings.
  392. self.crpe1 = ConvRelPosEnc(head_chs=embed_dims[0] // num_heads, num_heads=num_heads, window=crpe_window, **dd)
  393. self.crpe2 = ConvRelPosEnc(head_chs=embed_dims[1] // num_heads, num_heads=num_heads, window=crpe_window, **dd)
  394. self.crpe3 = ConvRelPosEnc(head_chs=embed_dims[2] // num_heads, num_heads=num_heads, window=crpe_window, **dd)
  395. self.crpe4 = ConvRelPosEnc(head_chs=embed_dims[3] // num_heads, num_heads=num_heads, window=crpe_window, **dd)
  396. dpr = drop_path_rate
  397. skwargs = dict(
  398. num_heads=num_heads,
  399. qkv_bias=qkv_bias,
  400. proj_drop=proj_drop_rate,
  401. attn_drop=attn_drop_rate,
  402. drop_path=dpr,
  403. norm_layer=norm_layer,
  404. )
  405. # Serial blocks 1.
  406. self.serial_blocks1 = nn.ModuleList([
  407. SerialBlock(
  408. dim=embed_dims[0],
  409. mlp_ratio=mlp_ratios[0],
  410. shared_cpe=self.cpe1,
  411. shared_crpe=self.crpe1,
  412. **skwargs,
  413. **dd,
  414. )
  415. for _ in range(serial_depths[0])]
  416. )
  417. # Serial blocks 2.
  418. self.serial_blocks2 = nn.ModuleList([
  419. SerialBlock(
  420. dim=embed_dims[1],
  421. mlp_ratio=mlp_ratios[1],
  422. shared_cpe=self.cpe2,
  423. shared_crpe=self.crpe2,
  424. **skwargs,
  425. **dd,
  426. )
  427. for _ in range(serial_depths[1])]
  428. )
  429. # Serial blocks 3.
  430. self.serial_blocks3 = nn.ModuleList([
  431. SerialBlock(
  432. dim=embed_dims[2],
  433. mlp_ratio=mlp_ratios[2],
  434. shared_cpe=self.cpe3,
  435. shared_crpe=self.crpe3,
  436. **skwargs,
  437. **dd,
  438. )
  439. for _ in range(serial_depths[2])]
  440. )
  441. # Serial blocks 4.
  442. self.serial_blocks4 = nn.ModuleList([
  443. SerialBlock(
  444. dim=embed_dims[3],
  445. mlp_ratio=mlp_ratios[3],
  446. shared_cpe=self.cpe4,
  447. shared_crpe=self.crpe4,
  448. **skwargs,
  449. **dd,
  450. )
  451. for _ in range(serial_depths[3])]
  452. )
  453. # Parallel blocks.
  454. self.parallel_depth = parallel_depth
  455. if self.parallel_depth > 0:
  456. self.parallel_blocks = nn.ModuleList([
  457. ParallelBlock(
  458. dims=embed_dims,
  459. mlp_ratios=mlp_ratios,
  460. shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4),
  461. **skwargs,
  462. **dd,
  463. )
  464. for _ in range(parallel_depth)]
  465. )
  466. else:
  467. self.parallel_blocks = None
  468. # Classification head(s).
  469. if not self.return_interm_layers:
  470. if self.parallel_blocks is not None:
  471. self.norm2 = norm_layer(embed_dims[1], **dd)
  472. self.norm3 = norm_layer(embed_dims[2], **dd)
  473. else:
  474. self.norm2 = self.norm3 = None
  475. self.norm4 = norm_layer(embed_dims[3], **dd)
  476. if self.parallel_depth > 0:
  477. # CoaT series: Aggregate features of last three scales for classification.
  478. assert embed_dims[1] == embed_dims[2] == embed_dims[3]
  479. self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1, **dd)
  480. self.head_drop = nn.Dropout(drop_rate)
  481. self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  482. else:
  483. # CoaT-Lite series: Use feature of last scale for classification.
  484. self.aggregate = None
  485. self.head_drop = nn.Dropout(drop_rate)
  486. self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  487. # Initialize weights.
  488. trunc_normal_(self.cls_token1, std=.02)
  489. trunc_normal_(self.cls_token2, std=.02)
  490. trunc_normal_(self.cls_token3, std=.02)
  491. trunc_normal_(self.cls_token4, std=.02)
  492. self.apply(self._init_weights)
  493. def _init_weights(self, m):
  494. if isinstance(m, nn.Linear):
  495. trunc_normal_(m.weight, std=.02)
  496. if isinstance(m, nn.Linear) and m.bias is not None:
  497. nn.init.constant_(m.bias, 0)
  498. elif isinstance(m, nn.LayerNorm):
  499. nn.init.constant_(m.bias, 0)
  500. nn.init.constant_(m.weight, 1.0)
  501. @torch.jit.ignore
  502. def no_weight_decay(self):
  503. return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'}
  504. @torch.jit.ignore
  505. def set_grad_checkpointing(self, enable=True):
  506. assert not enable, 'gradient checkpointing not supported'
  507. @torch.jit.ignore
  508. def group_matcher(self, coarse=False):
  509. matcher = dict(
  510. stem1=r'^cls_token1|patch_embed1|crpe1|cpe1',
  511. serial_blocks1=r'^serial_blocks1\.(\d+)',
  512. stem2=r'^cls_token2|patch_embed2|crpe2|cpe2',
  513. serial_blocks2=r'^serial_blocks2\.(\d+)',
  514. stem3=r'^cls_token3|patch_embed3|crpe3|cpe3',
  515. serial_blocks3=r'^serial_blocks3\.(\d+)',
  516. stem4=r'^cls_token4|patch_embed4|crpe4|cpe4',
  517. serial_blocks4=r'^serial_blocks4\.(\d+)',
  518. parallel_blocks=[ # FIXME (partially?) overlap parallel w/ serial blocks??
  519. (r'^parallel_blocks\.(\d+)', None),
  520. (r'^norm|aggregate', (99999,)),
  521. ]
  522. )
  523. return matcher
  524. @torch.jit.ignore
  525. def get_classifier(self) -> nn.Module:
  526. return self.head
  527. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  528. self.num_classes = num_classes
  529. if global_pool is not None:
  530. assert global_pool in ('token', 'avg')
  531. self.global_pool = global_pool
  532. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  533. def forward_features(self, x0):
  534. B = x0.shape[0]
  535. # Serial blocks 1.
  536. x1 = self.patch_embed1(x0)
  537. H1, W1 = self.patch_embed1.grid_size
  538. x1 = insert_cls(x1, self.cls_token1)
  539. for blk in self.serial_blocks1:
  540. x1 = blk(x1, size=(H1, W1))
  541. x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
  542. # Serial blocks 2.
  543. x2 = self.patch_embed2(x1_nocls)
  544. H2, W2 = self.patch_embed2.grid_size
  545. x2 = insert_cls(x2, self.cls_token2)
  546. for blk in self.serial_blocks2:
  547. x2 = blk(x2, size=(H2, W2))
  548. x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
  549. # Serial blocks 3.
  550. x3 = self.patch_embed3(x2_nocls)
  551. H3, W3 = self.patch_embed3.grid_size
  552. x3 = insert_cls(x3, self.cls_token3)
  553. for blk in self.serial_blocks3:
  554. x3 = blk(x3, size=(H3, W3))
  555. x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
  556. # Serial blocks 4.
  557. x4 = self.patch_embed4(x3_nocls)
  558. H4, W4 = self.patch_embed4.grid_size
  559. x4 = insert_cls(x4, self.cls_token4)
  560. for blk in self.serial_blocks4:
  561. x4 = blk(x4, size=(H4, W4))
  562. x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
  563. # Only serial blocks: Early return.
  564. if self.parallel_blocks is None:
  565. if not torch.jit.is_scripting() and self.return_interm_layers:
  566. # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
  567. feat_out = {}
  568. if 'x1_nocls' in self.out_features:
  569. feat_out['x1_nocls'] = x1_nocls
  570. if 'x2_nocls' in self.out_features:
  571. feat_out['x2_nocls'] = x2_nocls
  572. if 'x3_nocls' in self.out_features:
  573. feat_out['x3_nocls'] = x3_nocls
  574. if 'x4_nocls' in self.out_features:
  575. feat_out['x4_nocls'] = x4_nocls
  576. return feat_out
  577. else:
  578. # Return features for classification.
  579. x4 = self.norm4(x4)
  580. return x4
  581. # Parallel blocks.
  582. for blk in self.parallel_blocks:
  583. x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4))
  584. x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])
  585. if not torch.jit.is_scripting() and self.return_interm_layers:
  586. # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
  587. feat_out = {}
  588. if 'x1_nocls' in self.out_features:
  589. x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
  590. feat_out['x1_nocls'] = x1_nocls
  591. if 'x2_nocls' in self.out_features:
  592. x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
  593. feat_out['x2_nocls'] = x2_nocls
  594. if 'x3_nocls' in self.out_features:
  595. x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
  596. feat_out['x3_nocls'] = x3_nocls
  597. if 'x4_nocls' in self.out_features:
  598. x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
  599. feat_out['x4_nocls'] = x4_nocls
  600. return feat_out
  601. else:
  602. x2 = self.norm2(x2)
  603. x3 = self.norm3(x3)
  604. x4 = self.norm4(x4)
  605. return [x2, x3, x4]
  606. def forward_head(self, x_feat: Union[torch.Tensor, List[torch.Tensor]], pre_logits: bool = False):
  607. if isinstance(x_feat, list):
  608. assert self.aggregate is not None
  609. if self.global_pool == 'avg':
  610. x = torch.cat([xl[:, 1:].mean(dim=1, keepdim=True) for xl in x_feat], dim=1) # [B, 3, C]
  611. else:
  612. x = torch.stack([xl[:, 0] for xl in x_feat], dim=1) # [B, 3, C]
  613. x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C]
  614. else:
  615. x = x_feat[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x_feat[:, 0]
  616. x = self.head_drop(x)
  617. return x if pre_logits else self.head(x)
  618. def forward(self, x) -> torch.Tensor:
  619. if not torch.jit.is_scripting() and self.return_interm_layers:
  620. # Return intermediate features (for down-stream tasks).
  621. return self.forward_features(x)
  622. else:
  623. # Return features for classification.
  624. x_feat = self.forward_features(x)
  625. x = self.forward_head(x_feat)
  626. return x
  627. def insert_cls(x, cls_token):
  628. """ Insert CLS token. """
  629. cls_tokens = cls_token.expand(x.shape[0], -1, -1)
  630. x = torch.cat((cls_tokens, x), dim=1)
  631. return x
  632. def remove_cls(x):
  633. """ Remove CLS token. """
  634. return x[:, 1:, :]
  635. def checkpoint_filter_fn(state_dict, model):
  636. out_dict = {}
  637. state_dict = state_dict.get('model', state_dict)
  638. for k, v in state_dict.items():
  639. # original model had unused norm layers, removing them requires filtering pretrained checkpoints
  640. if k.startswith('norm1') or \
  641. (k.startswith('norm2') and getattr(model, 'norm2', None) is None) or \
  642. (k.startswith('norm3') and getattr(model, 'norm3', None) is None) or \
  643. (k.startswith('norm4') and getattr(model, 'norm4', None) is None) or \
  644. (k.startswith('aggregate') and getattr(model, 'aggregate', None) is None) or \
  645. (k.startswith('head') and getattr(model, 'head', None) is None):
  646. continue
  647. out_dict[k] = v
  648. return out_dict
  649. def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs):
  650. if kwargs.get('features_only', None):
  651. raise RuntimeError('features_only not implemented for Vision Transformer models.')
  652. model = build_model_with_cfg(
  653. CoaT,
  654. variant,
  655. pretrained,
  656. pretrained_filter_fn=checkpoint_filter_fn,
  657. **kwargs,
  658. )
  659. return model
  660. def _cfg_coat(url='', **kwargs):
  661. return {
  662. 'url': url,
  663. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  664. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  665. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  666. 'first_conv': 'patch_embed1.proj', 'classifier': 'head',
  667. 'license': 'apache-2.0',
  668. **kwargs
  669. }
  670. default_cfgs = generate_default_cfgs({
  671. 'coat_tiny.in1k': _cfg_coat(hf_hub_id='timm/'),
  672. 'coat_mini.in1k': _cfg_coat(hf_hub_id='timm/'),
  673. 'coat_small.in1k': _cfg_coat(hf_hub_id='timm/'),
  674. 'coat_lite_tiny.in1k': _cfg_coat(hf_hub_id='timm/'),
  675. 'coat_lite_mini.in1k': _cfg_coat(hf_hub_id='timm/'),
  676. 'coat_lite_small.in1k': _cfg_coat(hf_hub_id='timm/'),
  677. 'coat_lite_medium.in1k': _cfg_coat(hf_hub_id='timm/'),
  678. 'coat_lite_medium_384.in1k': _cfg_coat(
  679. hf_hub_id='timm/',
  680. input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash',
  681. ),
  682. })
  683. @register_model
  684. def coat_tiny(pretrained=False, **kwargs) -> CoaT:
  685. model_cfg = dict(
  686. patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6)
  687. model = _create_coat('coat_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
  688. return model
  689. @register_model
  690. def coat_mini(pretrained=False, **kwargs) -> CoaT:
  691. model_cfg = dict(
  692. patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6)
  693. model = _create_coat('coat_mini', pretrained=pretrained, **dict(model_cfg, **kwargs))
  694. return model
  695. @register_model
  696. def coat_small(pretrained=False, **kwargs) -> CoaT:
  697. model_cfg = dict(
  698. patch_size=4, embed_dims=[152, 320, 320, 320], serial_depths=[2, 2, 2, 2], parallel_depth=6, **kwargs)
  699. model = _create_coat('coat_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
  700. return model
  701. @register_model
  702. def coat_lite_tiny(pretrained=False, **kwargs) -> CoaT:
  703. model_cfg = dict(
  704. patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], mlp_ratios=[8, 8, 4, 4])
  705. model = _create_coat('coat_lite_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
  706. return model
  707. @register_model
  708. def coat_lite_mini(pretrained=False, **kwargs) -> CoaT:
  709. model_cfg = dict(
  710. patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], mlp_ratios=[8, 8, 4, 4])
  711. model = _create_coat('coat_lite_mini', pretrained=pretrained, **dict(model_cfg, **kwargs))
  712. return model
  713. @register_model
  714. def coat_lite_small(pretrained=False, **kwargs) -> CoaT:
  715. model_cfg = dict(
  716. patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], mlp_ratios=[8, 8, 4, 4])
  717. model = _create_coat('coat_lite_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
  718. return model
  719. @register_model
  720. def coat_lite_medium(pretrained=False, **kwargs) -> CoaT:
  721. model_cfg = dict(
  722. patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8])
  723. model = _create_coat('coat_lite_medium', pretrained=pretrained, **dict(model_cfg, **kwargs))
  724. return model
  725. @register_model
  726. def coat_lite_medium_384(pretrained=False, **kwargs) -> CoaT:
  727. model_cfg = dict(
  728. img_size=384, patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8])
  729. model = _create_coat('coat_lite_medium_384', pretrained=pretrained, **dict(model_cfg, **kwargs))
  730. return model