tnt.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. """ Transformer in Transformer (TNT) in PyTorch
  2. A PyTorch implement of TNT as described in
  3. 'Transformer in Transformer' - https://arxiv.org/abs/2103.00112
  4. The official mindspore code is released and available at
  5. https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
  6. The official pytorch code is released and available at
  7. https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch
  8. """
  9. import math
  10. from typing import List, Optional, Tuple, Union, Type, Any
  11. import torch
  12. import torch.nn as nn
  13. from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
  14. from timm.layers import Mlp, DropPath, calculate_drop_path_rates, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed
  15. from ._builder import build_model_with_cfg
  16. from ._features import feature_take_indices
  17. from ._manipulate import checkpoint
  18. from ._registry import generate_default_cfgs, register_model
  19. __all__ = ['TNT'] # model_registry will add each entrypoint fn to this
  20. class Attention(nn.Module):
  21. """ Multi-Head Attention
  22. """
  23. def __init__(
  24. self,
  25. dim: int,
  26. hidden_dim: int,
  27. num_heads: int = 8,
  28. qkv_bias: bool = False,
  29. attn_drop: float = 0.,
  30. proj_drop: float = 0.,
  31. device=None,
  32. dtype=None,
  33. ):
  34. dd = {'device': device, 'dtype': dtype}
  35. super().__init__()
  36. self.hidden_dim = hidden_dim
  37. self.num_heads = num_heads
  38. head_dim = hidden_dim // num_heads
  39. self.head_dim = head_dim
  40. self.scale = head_dim ** -0.5
  41. self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias, **dd)
  42. self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  43. self.attn_drop = nn.Dropout(attn_drop, inplace=True)
  44. self.proj = nn.Linear(dim, dim, **dd)
  45. self.proj_drop = nn.Dropout(proj_drop, inplace=True)
  46. def forward(self, x):
  47. B, N, C = x.shape
  48. qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  49. q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
  50. v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
  51. attn = (q @ k.transpose(-2, -1)) * self.scale
  52. attn = attn.softmax(dim=-1)
  53. attn = self.attn_drop(attn)
  54. x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
  55. x = self.proj(x)
  56. x = self.proj_drop(x)
  57. return x
  58. class Block(nn.Module):
  59. """ TNT Block
  60. """
  61. def __init__(
  62. self,
  63. dim: int,
  64. dim_out: int,
  65. num_pixel: int,
  66. num_heads_in: int = 4,
  67. num_heads_out: int = 12,
  68. mlp_ratio: float = 4.,
  69. qkv_bias: bool = False,
  70. proj_drop: float = 0.,
  71. attn_drop: float = 0.,
  72. drop_path: float = 0.,
  73. act_layer: Type[nn.Module] = nn.GELU,
  74. norm_layer: Type[nn.Module] = nn.LayerNorm,
  75. legacy: bool = False,
  76. device=None,
  77. dtype=None,
  78. ):
  79. dd = {'device': device, 'dtype': dtype}
  80. super().__init__()
  81. # Inner transformer
  82. self.norm_in = norm_layer(dim, **dd)
  83. self.attn_in = Attention(
  84. dim,
  85. dim,
  86. num_heads=num_heads_in,
  87. qkv_bias=qkv_bias,
  88. attn_drop=attn_drop,
  89. proj_drop=proj_drop,
  90. **dd,
  91. )
  92. self.norm_mlp_in = norm_layer(dim, **dd)
  93. self.mlp_in = Mlp(
  94. in_features=dim,
  95. hidden_features=int(dim * 4),
  96. out_features=dim,
  97. act_layer=act_layer,
  98. drop=proj_drop,
  99. **dd,
  100. )
  101. self.legacy = legacy
  102. if self.legacy:
  103. self.norm1_proj = norm_layer(dim, **dd)
  104. self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True, **dd)
  105. self.norm2_proj = None
  106. else:
  107. self.norm1_proj = norm_layer(dim * num_pixel, **dd)
  108. self.proj = nn.Linear(dim * num_pixel, dim_out, bias=False, **dd)
  109. self.norm2_proj = norm_layer(dim_out, **dd)
  110. # Outer transformer
  111. self.norm_out = norm_layer(dim_out, **dd)
  112. self.attn_out = Attention(
  113. dim_out,
  114. dim_out,
  115. num_heads=num_heads_out,
  116. qkv_bias=qkv_bias,
  117. attn_drop=attn_drop,
  118. proj_drop=proj_drop,
  119. **dd,
  120. )
  121. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  122. self.norm_mlp = norm_layer(dim_out, **dd)
  123. self.mlp = Mlp(
  124. in_features=dim_out,
  125. hidden_features=int(dim_out * mlp_ratio),
  126. out_features=dim_out,
  127. act_layer=act_layer,
  128. drop=proj_drop,
  129. **dd,
  130. )
  131. def forward(self, pixel_embed, patch_embed):
  132. # inner
  133. pixel_embed = pixel_embed + self.drop_path(self.attn_in(self.norm_in(pixel_embed)))
  134. pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
  135. # outer
  136. B, N, C = patch_embed.size()
  137. if self.norm2_proj is None:
  138. patch_embed = torch.cat([
  139. patch_embed[:, 0:1],
  140. patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),
  141. ], dim=1)
  142. else:
  143. patch_embed = torch.cat([
  144. patch_embed[:, 0:1],
  145. patch_embed[:, 1:] + self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))),
  146. ], dim=1)
  147. patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
  148. patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
  149. return pixel_embed, patch_embed
  150. class PixelEmbed(nn.Module):
  151. """ Image to Pixel Embedding
  152. """
  153. def __init__(
  154. self,
  155. img_size: Union[int, Tuple[int, int]] = 224,
  156. patch_size: Union[int, Tuple[int, int]] = 16,
  157. in_chans: int = 3,
  158. in_dim: int = 48,
  159. stride: int = 4,
  160. legacy: bool = False,
  161. device=None,
  162. dtype=None,
  163. ):
  164. dd = {'device': device, 'dtype': dtype}
  165. super().__init__()
  166. img_size = to_2tuple(img_size)
  167. patch_size = to_2tuple(patch_size)
  168. # grid_size property necessary for resizing positional embedding
  169. self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
  170. num_patches = (self.grid_size[0]) * (self.grid_size[1])
  171. self.img_size = img_size
  172. self.patch_size = patch_size
  173. self.legacy = legacy
  174. self.num_patches = num_patches
  175. self.in_dim = in_dim
  176. new_patch_size = [math.ceil(ps / stride) for ps in patch_size]
  177. self.new_patch_size = new_patch_size
  178. self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride, **dd)
  179. if self.legacy:
  180. self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size)
  181. else:
  182. self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
  183. def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
  184. if as_scalar:
  185. return max(self.patch_size)
  186. else:
  187. return self.patch_size
  188. def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
  189. return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
  190. def forward(self, x: torch.Tensor, pixel_pos: torch.Tensor) -> torch.Tensor:
  191. B, C, H, W = x.shape
  192. _assert(
  193. H == self.img_size[0],
  194. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
  195. _assert(
  196. W == self.img_size[1],
  197. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
  198. if self.legacy:
  199. x = self.proj(x)
  200. x = self.unfold(x)
  201. x = x.transpose(1, 2).reshape(
  202. B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
  203. else:
  204. x = self.unfold(x)
  205. x = x.transpose(1, 2).reshape(B * self.num_patches, C, self.patch_size[0], self.patch_size[1])
  206. x = self.proj(x)
  207. x = x + pixel_pos
  208. x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2)
  209. return x
  210. class TNT(nn.Module):
  211. """ Transformer in Transformer - https://arxiv.org/abs/2103.00112
  212. """
  213. def __init__(
  214. self,
  215. img_size: Union[int, Tuple[int, int]] = 224,
  216. patch_size: Union[int, Tuple[int, int]] = 16,
  217. in_chans: int = 3,
  218. num_classes: int = 1000,
  219. global_pool: str = 'token',
  220. embed_dim: int = 768,
  221. inner_dim: int = 48,
  222. depth: int = 12,
  223. num_heads_inner: int = 4,
  224. num_heads_outer: int = 12,
  225. mlp_ratio: float = 4.,
  226. qkv_bias: bool = False,
  227. drop_rate: float = 0.,
  228. pos_drop_rate: float = 0.,
  229. proj_drop_rate: float = 0.,
  230. attn_drop_rate: float = 0.,
  231. drop_path_rate: float = 0.,
  232. norm_layer: Type[nn.Module] = nn.LayerNorm,
  233. first_stride: int = 4,
  234. legacy: bool = False,
  235. device=None,
  236. dtype=None,
  237. ):
  238. super().__init__()
  239. dd = {'device': device, 'dtype': dtype}
  240. assert global_pool in ('', 'token', 'avg')
  241. self.num_classes = num_classes
  242. self.in_chans = in_chans
  243. self.global_pool = global_pool
  244. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
  245. self.num_prefix_tokens = 1
  246. self.grad_checkpointing = False
  247. self.pixel_embed = PixelEmbed(
  248. img_size=img_size,
  249. patch_size=patch_size,
  250. in_chans=in_chans,
  251. in_dim=inner_dim,
  252. stride=first_stride,
  253. legacy=legacy,
  254. **dd,
  255. )
  256. num_patches = self.pixel_embed.num_patches
  257. r = self.pixel_embed.feat_ratio() if hasattr(self.pixel_embed, 'feat_ratio') else patch_size
  258. self.num_patches = num_patches
  259. new_patch_size = self.pixel_embed.new_patch_size
  260. num_pixel = new_patch_size[0] * new_patch_size[1]
  261. self.norm1_proj = norm_layer(num_pixel * inner_dim, **dd)
  262. self.proj = nn.Linear(num_pixel * inner_dim, embed_dim, **dd)
  263. self.norm2_proj = norm_layer(embed_dim, **dd)
  264. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
  265. self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim, **dd))
  266. self.pixel_pos = nn.Parameter(torch.zeros(1, inner_dim, new_patch_size[0], new_patch_size[1], **dd))
  267. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  268. dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
  269. blocks = []
  270. for i in range(depth):
  271. blocks.append(Block(
  272. dim=inner_dim,
  273. dim_out=embed_dim,
  274. num_pixel=num_pixel,
  275. num_heads_in=num_heads_inner,
  276. num_heads_out=num_heads_outer,
  277. mlp_ratio=mlp_ratio,
  278. qkv_bias=qkv_bias,
  279. proj_drop=proj_drop_rate,
  280. attn_drop=attn_drop_rate,
  281. drop_path=dpr[i],
  282. norm_layer=norm_layer,
  283. legacy=legacy,
  284. **dd,
  285. ))
  286. self.blocks = nn.ModuleList(blocks)
  287. self.feature_info = [
  288. dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
  289. self.norm = norm_layer(embed_dim, **dd)
  290. self.head_drop = nn.Dropout(drop_rate)
  291. self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  292. trunc_normal_(self.cls_token, std=.02)
  293. trunc_normal_(self.patch_pos, std=.02)
  294. trunc_normal_(self.pixel_pos, std=.02)
  295. self.apply(self._init_weights)
  296. def _init_weights(self, m):
  297. if isinstance(m, nn.Linear):
  298. trunc_normal_(m.weight, std=.02)
  299. if isinstance(m, nn.Linear) and m.bias is not None:
  300. nn.init.constant_(m.bias, 0)
  301. elif isinstance(m, nn.LayerNorm):
  302. nn.init.constant_(m.bias, 0)
  303. nn.init.constant_(m.weight, 1.0)
  304. @torch.jit.ignore
  305. def no_weight_decay(self):
  306. return {'patch_pos', 'pixel_pos', 'cls_token'}
  307. @torch.jit.ignore
  308. def group_matcher(self, coarse=False):
  309. matcher = dict(
  310. stem=r'^cls_token|patch_pos|pixel_pos|pixel_embed|norm[12]_proj|proj', # stem and embed / pos
  311. blocks=[
  312. (r'^blocks\.(\d+)', None),
  313. (r'^norm', (99999,)),
  314. ]
  315. )
  316. return matcher
  317. @torch.jit.ignore
  318. def set_grad_checkpointing(self, enable=True):
  319. self.grad_checkpointing = enable
  320. @torch.jit.ignore
  321. def get_classifier(self) -> nn.Module:
  322. return self.head
  323. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  324. self.num_classes = num_classes
  325. if global_pool is not None:
  326. assert global_pool in ('', 'token', 'avg')
  327. self.global_pool = global_pool
  328. device = self.head.weight.device if hasattr(self.head, 'weight') else None
  329. dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None
  330. self.head = nn.Linear(self.embed_dim, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  331. def forward_intermediates(
  332. self,
  333. x: torch.Tensor,
  334. indices: Optional[Union[int, List[int]]] = None,
  335. return_prefix_tokens: bool = False,
  336. norm: bool = False,
  337. stop_early: bool = False,
  338. output_fmt: str = 'NCHW',
  339. intermediates_only: bool = False,
  340. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  341. """ Forward features that returns intermediates.
  342. Args:
  343. x: Input image tensor
  344. indices: Take last n blocks if an int, if is a sequence, select by matching indices
  345. return_prefix_tokens: Return both prefix and spatial intermediate tokens
  346. norm: Apply norm layer to all intermediates
  347. stop_early: Stop iterating over blocks when last desired intermediate hit
  348. output_fmt: Shape of intermediate feature outputs
  349. intermediates_only: Only return intermediate features
  350. Returns:
  351. """
  352. assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
  353. reshape = output_fmt == 'NCHW'
  354. intermediates = []
  355. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  356. # forward pass
  357. B, _, height, width = x.shape
  358. pixel_embed = self.pixel_embed(x, self.pixel_pos)
  359. patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
  360. patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
  361. patch_embed = patch_embed + self.patch_pos
  362. patch_embed = self.pos_drop(patch_embed)
  363. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  364. blocks = self.blocks
  365. else:
  366. blocks = self.blocks[:max_index + 1]
  367. for i, blk in enumerate(blocks):
  368. if self.grad_checkpointing and not torch.jit.is_scripting():
  369. pixel_embed, patch_embed = checkpoint(blk, pixel_embed, patch_embed)
  370. else:
  371. pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
  372. if i in take_indices:
  373. # normalize intermediates with final norm layer if enabled
  374. intermediates.append(self.norm(patch_embed) if norm else patch_embed)
  375. # process intermediates
  376. if self.num_prefix_tokens:
  377. # split prefix (e.g. class, distill) and spatial feature tokens
  378. prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
  379. intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
  380. if reshape:
  381. # reshape to BCHW output format
  382. H, W = self.pixel_embed.dynamic_feat_size((height, width))
  383. intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
  384. if not torch.jit.is_scripting() and return_prefix_tokens:
  385. # return_prefix not support in torchscript due to poor type handling
  386. intermediates = list(zip(intermediates, prefix_tokens))
  387. if intermediates_only:
  388. return intermediates
  389. patch_embed = self.norm(patch_embed)
  390. return patch_embed, intermediates
  391. def prune_intermediate_layers(
  392. self,
  393. indices: Union[int, List[int]] = 1,
  394. prune_norm: bool = False,
  395. prune_head: bool = True,
  396. ):
  397. """ Prune layers not required for specified intermediates.
  398. """
  399. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  400. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  401. if prune_norm:
  402. self.norm = nn.Identity()
  403. if prune_head:
  404. self.reset_classifier(0, '')
  405. return take_indices
  406. def forward_features(self, x):
  407. B = x.shape[0]
  408. pixel_embed = self.pixel_embed(x, self.pixel_pos)
  409. patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
  410. patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
  411. patch_embed = patch_embed + self.patch_pos
  412. patch_embed = self.pos_drop(patch_embed)
  413. for blk in self.blocks:
  414. if self.grad_checkpointing and not torch.jit.is_scripting():
  415. pixel_embed, patch_embed = checkpoint(blk, pixel_embed, patch_embed)
  416. else:
  417. pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
  418. patch_embed = self.norm(patch_embed)
  419. return patch_embed
  420. def forward_head(self, x, pre_logits: bool = False):
  421. if self.global_pool:
  422. x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
  423. x = self.head_drop(x)
  424. return x if pre_logits else self.head(x)
  425. def forward(self, x):
  426. x = self.forward_features(x)
  427. x = self.forward_head(x)
  428. return x
  429. def _cfg(url='', **kwargs):
  430. return {
  431. 'url': url,
  432. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  433. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  434. 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
  435. 'first_conv': 'pixel_embed.proj', 'classifier': 'head',
  436. 'paper_ids': 'arXiv:2103.00112',
  437. 'paper_name': 'Transformer in Transformer',
  438. 'origin_url': 'https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch',
  439. 'license': 'apache-2.0',
  440. **kwargs
  441. }
  442. default_cfgs = generate_default_cfgs({
  443. 'tnt_s_legacy_patch16_224.in1k': _cfg(
  444. hf_hub_id='timm/',
  445. #url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
  446. ),
  447. 'tnt_s_patch16_224.in1k': _cfg(
  448. hf_hub_id='timm/',
  449. #url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
  450. ),
  451. 'tnt_b_patch16_224.in1k': _cfg(
  452. hf_hub_id='timm/',
  453. #url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
  454. ),
  455. })
  456. def checkpoint_filter_fn(state_dict, model):
  457. state_dict.pop('outer_tokens', None)
  458. if 'patch_pos' in state_dict:
  459. out_dict = state_dict
  460. else:
  461. out_dict = {}
  462. for k, v in state_dict.items():
  463. k = k.replace('outer_pos', 'patch_pos')
  464. k = k.replace('inner_pos', 'pixel_pos')
  465. k = k.replace('patch_embed', 'pixel_embed')
  466. k = k.replace('proj_norm1', 'norm1_proj')
  467. k = k.replace('proj_norm2', 'norm2_proj')
  468. k = k.replace('inner_norm1', 'norm_in')
  469. k = k.replace('inner_attn', 'attn_in')
  470. k = k.replace('inner_norm2', 'norm_mlp_in')
  471. k = k.replace('inner_mlp', 'mlp_in')
  472. k = k.replace('outer_norm1', 'norm_out')
  473. k = k.replace('outer_attn', 'attn_out')
  474. k = k.replace('outer_norm2', 'norm_mlp')
  475. k = k.replace('outer_mlp', 'mlp')
  476. if k == 'pixel_pos' and model.pixel_embed.legacy == False:
  477. B, N, C = v.shape
  478. H = W = int(N ** 0.5)
  479. assert H * W == N
  480. v = v.permute(0, 2, 1).reshape(B, C, H, W)
  481. out_dict[k] = v
  482. """ convert patch embedding weight from manual patchify + linear proj to conv"""
  483. if out_dict['patch_pos'].shape != model.patch_pos.shape:
  484. out_dict['patch_pos'] = resample_abs_pos_embed(
  485. out_dict['patch_pos'],
  486. new_size=model.pixel_embed.grid_size,
  487. num_prefix_tokens=1,
  488. )
  489. return out_dict
  490. def _create_tnt(variant, pretrained=False, **kwargs):
  491. out_indices = kwargs.pop('out_indices', 3)
  492. model = build_model_with_cfg(
  493. TNT, variant, pretrained,
  494. pretrained_filter_fn=checkpoint_filter_fn,
  495. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  496. **kwargs)
  497. return model
  498. @register_model
  499. def tnt_s_legacy_patch16_224(pretrained=False, **kwargs) -> TNT:
  500. model_cfg = dict(
  501. patch_size=16, embed_dim=384, inner_dim=24, depth=12, num_heads_outer=6,
  502. qkv_bias=False, legacy=True)
  503. model = _create_tnt('tnt_s_legacy_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs))
  504. return model
  505. @register_model
  506. def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT:
  507. model_cfg = dict(
  508. patch_size=16, embed_dim=384, inner_dim=24, depth=12, num_heads_outer=6,
  509. qkv_bias=False)
  510. model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs))
  511. return model
  512. @register_model
  513. def tnt_b_patch16_224(pretrained=False, **kwargs) -> TNT:
  514. model_cfg = dict(
  515. patch_size=16, embed_dim=640, inner_dim=40, depth=12, num_heads_outer=10,
  516. qkv_bias=False)
  517. model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs))
  518. return model