nest.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699
  1. """ Nested Transformer (NesT) in PyTorch
  2. A PyTorch implement of Aggregating Nested Transformers as described in:
  3. 'Aggregating Nested Transformers'
  4. - https://arxiv.org/abs/2105.12723
  5. The official Jax code is released and available at https://github.com/google-research/nested-transformer. The weights
  6. have been converted with convert/convert_nest_flax.py
  7. Acknowledgments:
  8. * The paper authors for sharing their research, code, and model weights
  9. * Ross Wightman's existing code off which I based this
  10. Copyright 2021 Alexander Soare
  11. """
  12. import collections.abc
  13. import logging
  14. import math
  15. from functools import partial
  16. from typing import List, Optional, Tuple, Type, Union
  17. import torch
  18. import torch.nn.functional as F
  19. from torch import nn
  20. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  21. from timm.layers import (
  22. PatchEmbed,
  23. Mlp,
  24. DropPath,
  25. calculate_drop_path_rates,
  26. create_classifier,
  27. trunc_normal_,
  28. _assert,
  29. create_conv2d,
  30. create_pool2d,
  31. to_ntuple,
  32. use_fused_attn,
  33. LayerNorm,
  34. )
  35. from ._builder import build_model_with_cfg
  36. from ._features import feature_take_indices
  37. from ._features_fx import register_notrace_function
  38. from ._manipulate import checkpoint_seq, named_apply
  39. from ._registry import register_model, generate_default_cfgs, register_model_deprecations
  40. __all__ = ['Nest'] # model_registry will add each entrypoint fn to this
  41. _logger = logging.getLogger(__name__)
  42. class Attention(nn.Module):
  43. """
  44. This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with
  45. an extra "image block" dim
  46. """
  47. fused_attn: torch.jit.Final[bool]
  48. def __init__(
  49. self,
  50. dim: int,
  51. num_heads: int = 8,
  52. qkv_bias: bool = False,
  53. attn_drop: float = 0.,
  54. proj_drop: float = 0.,
  55. device=None,
  56. dtype=None,
  57. ):
  58. dd = {'device': device, 'dtype': dtype}
  59. super().__init__()
  60. self.num_heads = num_heads
  61. head_dim = dim // num_heads
  62. self.scale = head_dim ** -0.5
  63. self.fused_attn = use_fused_attn()
  64. self.qkv = nn.Linear(dim, 3*dim, bias=qkv_bias, **dd)
  65. self.attn_drop = nn.Dropout(attn_drop)
  66. self.proj = nn.Linear(dim, dim, **dd)
  67. self.proj_drop = nn.Dropout(proj_drop)
  68. def forward(self, x):
  69. """
  70. x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim)
  71. """
  72. B, T, N, C = x.shape
  73. # result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
  74. qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
  75. q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
  76. if self.fused_attn:
  77. x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
  78. else:
  79. q = q * self.scale
  80. attn = q @ k.transpose(-2, -1) # (B, H, T, N, N)
  81. attn = attn.softmax(dim=-1)
  82. attn = self.attn_drop(attn)
  83. x = attn @ v
  84. # (B, H, T, N, C'), permute -> (B, T, N, C', H)
  85. x = x.permute(0, 2, 3, 4, 1).reshape(B, T, N, C)
  86. x = self.proj(x)
  87. x = self.proj_drop(x)
  88. return x # (B, T, N, C)
  89. class TransformerLayer(nn.Module):
  90. """
  91. This is much like `.vision_transformer.Block` but:
  92. - Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
  93. - Uses modified Attention layer that handles the "block" dimension
  94. """
  95. def __init__(
  96. self,
  97. dim: int,
  98. num_heads: int,
  99. mlp_ratio: float = 4.,
  100. qkv_bias: bool = False,
  101. proj_drop: float = 0.,
  102. attn_drop: float = 0.,
  103. drop_path: float = 0.,
  104. act_layer: Type[nn.Module] = nn.GELU,
  105. norm_layer: Type[nn.Module] = nn.LayerNorm,
  106. device=None,
  107. dtype=None,
  108. ):
  109. dd = {'device': device, 'dtype': dtype}
  110. super().__init__()
  111. self.norm1 = norm_layer(dim, **dd)
  112. self.attn = Attention(
  113. dim,
  114. num_heads=num_heads,
  115. qkv_bias=qkv_bias,
  116. attn_drop=attn_drop,
  117. proj_drop=proj_drop,
  118. **dd,
  119. )
  120. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  121. self.norm2 = norm_layer(dim, **dd)
  122. mlp_hidden_dim = int(dim * mlp_ratio)
  123. self.mlp = Mlp(
  124. in_features=dim,
  125. hidden_features=mlp_hidden_dim,
  126. act_layer=act_layer,
  127. drop=proj_drop,
  128. **dd,
  129. )
  130. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  131. def forward(self, x):
  132. y = self.norm1(x)
  133. x = x + self.drop_path1(self.attn(y))
  134. x = x + self.drop_path2(self.mlp(self.norm2(x)))
  135. return x
  136. class ConvPool(nn.Module):
  137. def __init__(
  138. self,
  139. in_channels: int,
  140. out_channels: int,
  141. norm_layer: Type[nn.Module],
  142. pad_type: str = '',
  143. device=None,
  144. dtype=None,
  145. ):
  146. dd = {'device': device, 'dtype': dtype}
  147. super().__init__()
  148. self.conv = create_conv2d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True, **dd)
  149. self.norm = norm_layer(out_channels, **dd)
  150. self.pool = create_pool2d('max', kernel_size=3, stride=2, padding=pad_type)
  151. def forward(self, x):
  152. """
  153. x is expected to have shape (B, C, H, W)
  154. """
  155. _assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims')
  156. _assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims')
  157. x = self.conv(x)
  158. # Layer norm done over channel dim only
  159. x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
  160. x = self.pool(x)
  161. return x # (B, C, H//2, W//2)
  162. def blockify(x, block_size: int):
  163. """image to blocks
  164. Args:
  165. x (Tensor): with shape (B, H, W, C)
  166. block_size (int): edge length of a single square block in units of H, W
  167. """
  168. B, H, W, C = x.shape
  169. _assert(H % block_size == 0, '`block_size` must divide input height evenly')
  170. _assert(W % block_size == 0, '`block_size` must divide input width evenly')
  171. grid_height = H // block_size
  172. grid_width = W // block_size
  173. x = x.reshape(B, grid_height, block_size, grid_width, block_size, C)
  174. x = x.transpose(2, 3).reshape(B, grid_height * grid_width, -1, C)
  175. return x # (B, T, N, C)
  176. @register_notrace_function # reason: int receives Proxy
  177. def deblockify(x, block_size: int):
  178. """blocks to image
  179. Args:
  180. x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block
  181. block_size (int): edge length of a single square block in units of desired H, W
  182. """
  183. B, T, _, C = x.shape
  184. grid_size = int(math.sqrt(T))
  185. height = width = grid_size * block_size
  186. x = x.reshape(B, grid_size, grid_size, block_size, block_size, C)
  187. x = x.transpose(2, 3).reshape(B, height, width, C)
  188. return x # (B, H, W, C)
  189. class NestLevel(nn.Module):
  190. """ Single hierarchical level of a Nested Transformer
  191. """
  192. def __init__(
  193. self,
  194. num_blocks: int,
  195. block_size: int,
  196. seq_length: int,
  197. num_heads: int,
  198. depth: int,
  199. embed_dim: int,
  200. prev_embed_dim: Optional[int] = None,
  201. mlp_ratio: float = 4.,
  202. qkv_bias: bool = True,
  203. proj_drop: float = 0.,
  204. attn_drop: float = 0.,
  205. drop_path: Optional[List[float]] = None,
  206. norm_layer: Optional[Type[nn.Module]] = None,
  207. act_layer: Optional[Type[nn.Module]] = None,
  208. pad_type: str = '',
  209. device=None,
  210. dtype=None,
  211. ):
  212. dd = {'device': device, 'dtype': dtype}
  213. super().__init__()
  214. self.block_size = block_size
  215. self.grad_checkpointing = False
  216. self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim, **dd))
  217. if prev_embed_dim is not None:
  218. self.pool = ConvPool(prev_embed_dim, embed_dim, norm_layer=norm_layer, pad_type=pad_type, **dd)
  219. else:
  220. self.pool = nn.Identity()
  221. # Transformer encoder
  222. if len(drop_path):
  223. assert len(drop_path) == depth, 'Must provide as many drop path rates as there are transformer layers'
  224. self.transformer_encoder = nn.Sequential(*[
  225. TransformerLayer(
  226. dim=embed_dim,
  227. num_heads=num_heads,
  228. mlp_ratio=mlp_ratio,
  229. qkv_bias=qkv_bias,
  230. proj_drop=proj_drop,
  231. attn_drop=attn_drop,
  232. drop_path=drop_path[i] if drop_path else None,
  233. norm_layer=norm_layer,
  234. act_layer=act_layer,
  235. **dd,
  236. )
  237. for i in range(depth)])
  238. def forward(self, x):
  239. """
  240. expects x as (B, C, H, W)
  241. """
  242. x = self.pool(x)
  243. x = x.permute(0, 2, 3, 1) # (B, H', W', C), switch to channels last for transformer
  244. x = blockify(x, self.block_size) # (B, T, N, C')
  245. x = x + self.pos_embed
  246. if self.grad_checkpointing and not torch.jit.is_scripting():
  247. x = checkpoint_seq(self.transformer_encoder, x)
  248. else:
  249. x = self.transformer_encoder(x) # (B, T, N, C')
  250. x = deblockify(x, self.block_size) # (B, H', W', C')
  251. # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
  252. return x.permute(0, 3, 1, 2) # (B, C, H', W')
  253. class Nest(nn.Module):
  254. """ Nested Transformer (NesT)
  255. A PyTorch impl of : `Aggregating Nested Transformers`
  256. - https://arxiv.org/abs/2105.12723
  257. """
  258. def __init__(
  259. self,
  260. img_size: int = 224,
  261. in_chans: int = 3,
  262. patch_size: int = 4,
  263. num_levels: int = 3,
  264. embed_dims: Tuple[int, ...] = (128, 256, 512),
  265. num_heads: Tuple[int, ...] = (4, 8, 16),
  266. depths: Tuple[int, ...] = (2, 2, 20),
  267. num_classes: int = 1000,
  268. mlp_ratio: float = 4.,
  269. qkv_bias: bool = True,
  270. drop_rate: float = 0.,
  271. proj_drop_rate: float = 0.,
  272. attn_drop_rate: float = 0.,
  273. drop_path_rate: float = 0.5,
  274. norm_layer: Optional[Type[nn.Module]] = None,
  275. act_layer: Optional[Type[nn.Module]] = None,
  276. pad_type: str = '',
  277. weight_init: str = '',
  278. global_pool: str = 'avg',
  279. device=None,
  280. dtype=None,
  281. ):
  282. """
  283. Args:
  284. img_size (int, tuple): input image size
  285. in_chans (int): number of input channels
  286. patch_size (int): patch size
  287. num_levels (int): number of block hierarchies (T_d in the paper)
  288. embed_dims (int, tuple): embedding dimensions of each level
  289. num_heads (int, tuple): number of attention heads for each level
  290. depths (int, tuple): number of transformer layers for each level
  291. num_classes (int): number of classes for classification head
  292. mlp_ratio (int): ratio of mlp hidden dim to embedding dim for MLP of transformer layers
  293. qkv_bias (bool): enable bias for qkv if True
  294. drop_rate (float): dropout rate for MLP of transformer layers, MSA final projection layer, and classifier
  295. attn_drop_rate (float): attention dropout rate
  296. drop_path_rate (float): stochastic depth rate
  297. norm_layer: (nn.Module): normalization layer for transformer layers
  298. act_layer: (nn.Module): activation layer in MLP of transformer layers
  299. pad_type: str: Type of padding to use '' for PyTorch symmetric, 'same' for TF SAME
  300. weight_init: (str): weight init scheme
  301. global_pool: (str): type of pooling operation to apply to final feature map
  302. Notes:
  303. - Default values follow NesT-B from the original Jax code.
  304. - `embed_dims`, `num_heads`, `depths` should be ints or tuples with length `num_levels`.
  305. - For those following the paper, Table A1 may have errors!
  306. - https://github.com/google-research/nested-transformer/issues/2
  307. """
  308. super().__init__()
  309. dd = {'device': device, 'dtype': dtype}
  310. for param_name in ['embed_dims', 'num_heads', 'depths']:
  311. param_value = locals()[param_name]
  312. if isinstance(param_value, collections.abc.Sequence):
  313. assert len(param_value) == num_levels, f'Require `len({param_name}) == num_levels`'
  314. embed_dims = to_ntuple(num_levels)(embed_dims)
  315. num_heads = to_ntuple(num_levels)(num_heads)
  316. depths = to_ntuple(num_levels)(depths)
  317. self.num_classes = num_classes
  318. self.in_chans = in_chans
  319. self.num_features = self.head_hidden_size = embed_dims[-1]
  320. self.feature_info = []
  321. norm_layer = norm_layer or LayerNorm
  322. act_layer = act_layer or nn.GELU
  323. self.drop_rate = drop_rate
  324. self.num_levels = num_levels
  325. if isinstance(img_size, collections.abc.Sequence):
  326. assert img_size[0] == img_size[1], 'Model only handles square inputs'
  327. img_size = img_size[0]
  328. assert img_size % patch_size == 0, '`patch_size` must divide `img_size` evenly'
  329. self.patch_size = patch_size
  330. # Number of blocks at each level
  331. self.num_blocks = (4 ** torch.arange(num_levels, device='cpu', dtype=torch.long)).flip(0).tolist()
  332. assert (img_size // patch_size) % math.sqrt(self.num_blocks[0]) == 0, \
  333. 'First level blocks don\'t fit evenly. Check `img_size`, `patch_size`, and `num_levels`'
  334. # Block edge size in units of patches
  335. # Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
  336. # number of blocks along edge of image
  337. self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0]))
  338. # Patch embedding
  339. self.patch_embed = PatchEmbed(
  340. img_size=img_size,
  341. patch_size=patch_size,
  342. in_chans=in_chans,
  343. embed_dim=embed_dims[0],
  344. flatten=False,
  345. **dd,
  346. )
  347. self.num_patches = self.patch_embed.num_patches
  348. self.seq_length = self.num_patches // self.num_blocks[0]
  349. # Build up each hierarchical level
  350. levels = []
  351. dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  352. prev_dim = None
  353. curr_stride = 4
  354. for i in range(len(self.num_blocks)):
  355. dim = embed_dims[i]
  356. levels.append(NestLevel(
  357. self.num_blocks[i],
  358. self.block_size,
  359. self.seq_length,
  360. num_heads[i],
  361. depths[i],
  362. dim,
  363. prev_dim,
  364. mlp_ratio=mlp_ratio,
  365. qkv_bias=qkv_bias,
  366. proj_drop=proj_drop_rate,
  367. attn_drop=attn_drop_rate,
  368. drop_path=dp_rates[i],
  369. norm_layer=norm_layer,
  370. act_layer=act_layer,
  371. pad_type=pad_type,
  372. **dd,
  373. ))
  374. self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f'levels.{i}')]
  375. prev_dim = dim
  376. curr_stride *= 2
  377. self.levels = nn.Sequential(*levels)
  378. # Final normalization layer
  379. self.norm = norm_layer(embed_dims[-1], **dd)
  380. # Classifier
  381. global_pool, head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool, **dd)
  382. self.global_pool = global_pool
  383. self.head_drop = nn.Dropout(drop_rate)
  384. self.head = head
  385. self.init_weights(weight_init)
  386. @torch.jit.ignore
  387. def init_weights(self, mode=''):
  388. assert mode in ('nlhb', '')
  389. head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
  390. for level in self.levels:
  391. trunc_normal_(level.pos_embed, std=.02, a=-2, b=2)
  392. named_apply(partial(_init_nest_weights, head_bias=head_bias), self)
  393. @torch.jit.ignore
  394. def no_weight_decay(self):
  395. return {f'level.{i}.pos_embed' for i in range(len(self.levels))}
  396. @torch.jit.ignore
  397. def group_matcher(self, coarse=False):
  398. matcher = dict(
  399. stem=r'^patch_embed', # stem and embed
  400. blocks=[
  401. (r'^levels\.(\d+)' if coarse else r'^levels\.(\d+)\.transformer_encoder\.(\d+)', None),
  402. (r'^levels\.(\d+)\.(?:pool|pos_embed)', (0,)),
  403. (r'^norm', (99999,))
  404. ]
  405. )
  406. return matcher
  407. @torch.jit.ignore
  408. def set_grad_checkpointing(self, enable=True):
  409. for l in self.levels:
  410. l.grad_checkpointing = enable
  411. @torch.jit.ignore
  412. def get_classifier(self) -> nn.Module:
  413. return self.head
  414. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  415. self.num_classes = num_classes
  416. self.global_pool, self.head = create_classifier(
  417. self.num_features, self.num_classes, pool_type=global_pool)
  418. def forward_intermediates(
  419. self,
  420. x: torch.Tensor,
  421. indices: Optional[Union[int, List[int]]] = None,
  422. norm: bool = False,
  423. stop_early: bool = False,
  424. output_fmt: str = 'NCHW',
  425. intermediates_only: bool = False,
  426. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  427. """ Forward features that returns intermediates.
  428. Args:
  429. x: Input image tensor
  430. indices: Take last n blocks if int, all if None, select matching indices if sequence
  431. norm: Apply norm layer to compatible intermediates
  432. stop_early: Stop iterating over blocks when last desired intermediate hit
  433. output_fmt: Shape of intermediate feature outputs
  434. intermediates_only: Only return intermediate features
  435. Returns:
  436. """
  437. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  438. intermediates = []
  439. take_indices, max_index = feature_take_indices(len(self.levels), indices)
  440. # forward pass
  441. x = self.patch_embed(x)
  442. last_idx = len(self.num_blocks) - 1
  443. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  444. stages = self.levels
  445. else:
  446. stages = self.levels[:max_index + 1]
  447. for feat_idx, stage in enumerate(stages):
  448. x = stage(x)
  449. if feat_idx in take_indices:
  450. if norm and feat_idx == last_idx:
  451. x_inter = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
  452. intermediates.append(x_inter)
  453. else:
  454. intermediates.append(x)
  455. if intermediates_only:
  456. return intermediates
  457. if feat_idx == last_idx:
  458. # Layer norm done over channel dim only (to NHWC and back)
  459. x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
  460. return x, intermediates
  461. def prune_intermediate_layers(
  462. self,
  463. indices: Union[int, List[int]] = 1,
  464. prune_norm: bool = False,
  465. prune_head: bool = True,
  466. ):
  467. """ Prune layers not required for specified intermediates.
  468. """
  469. take_indices, max_index = feature_take_indices(len(self.levels), indices)
  470. self.levels = self.levels[:max_index + 1] # truncate blocks w/ stem as idx 0
  471. if prune_norm:
  472. self.norm = nn.Identity()
  473. if prune_head:
  474. self.reset_classifier(0, '')
  475. return take_indices
  476. def forward_features(self, x):
  477. x = self.patch_embed(x)
  478. x = self.levels(x)
  479. # Layer norm done over channel dim only (to NHWC and back)
  480. x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
  481. return x
  482. def forward_head(self, x, pre_logits: bool = False):
  483. x = self.global_pool(x)
  484. x = self.head_drop(x)
  485. return x if pre_logits else self.head(x)
  486. def forward(self, x):
  487. x = self.forward_features(x)
  488. x = self.forward_head(x)
  489. return x
  490. def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0.):
  491. """ NesT weight initialization
  492. Can replicate Jax implementation. Otherwise follows vision_transformer.py
  493. """
  494. if isinstance(module, nn.Linear):
  495. if name.startswith('head'):
  496. trunc_normal_(module.weight, std=.02, a=-2, b=2)
  497. nn.init.constant_(module.bias, head_bias)
  498. else:
  499. trunc_normal_(module.weight, std=.02, a=-2, b=2)
  500. if module.bias is not None:
  501. nn.init.zeros_(module.bias)
  502. elif isinstance(module, nn.Conv2d):
  503. trunc_normal_(module.weight, std=.02, a=-2, b=2)
  504. if module.bias is not None:
  505. nn.init.zeros_(module.bias)
  506. def resize_pos_embed(posemb, posemb_new):
  507. """
  508. Rescale the grid of position embeddings when loading from state_dict
  509. Expected shape of position embeddings is (1, T, N, C), and considers only square images
  510. """
  511. _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
  512. seq_length_old = posemb.shape[2]
  513. num_blocks_new, seq_length_new = posemb_new.shape[1:3]
  514. size_new = int(math.sqrt(num_blocks_new*seq_length_new))
  515. # First change to (1, C, H, W)
  516. posemb = deblockify(posemb, int(math.sqrt(seq_length_old))).permute(0, 3, 1, 2)
  517. posemb = F.interpolate(posemb, size=[size_new, size_new], mode='bicubic', align_corners=False)
  518. # Now change to new (1, T, N, C)
  519. posemb = blockify(posemb.permute(0, 2, 3, 1), int(math.sqrt(seq_length_new)))
  520. return posemb
  521. def checkpoint_filter_fn(state_dict, model):
  522. """ resize positional embeddings of pretrained weights """
  523. pos_embed_keys = [k for k in state_dict.keys() if k.startswith('pos_embed_')]
  524. for k in pos_embed_keys:
  525. if state_dict[k].shape != getattr(model, k).shape:
  526. state_dict[k] = resize_pos_embed(state_dict[k], getattr(model, k))
  527. return state_dict
  528. def _create_nest(variant, pretrained=False, **kwargs):
  529. model = build_model_with_cfg(
  530. Nest,
  531. variant,
  532. pretrained,
  533. feature_cfg=dict(out_indices=(0, 1, 2), flatten_sequential=True),
  534. pretrained_filter_fn=checkpoint_filter_fn,
  535. **kwargs,
  536. )
  537. return model
  538. def _cfg(url='', **kwargs):
  539. return {
  540. 'url': url,
  541. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': [14, 14],
  542. 'crop_pct': .875, 'interpolation': 'bicubic', 'fixed_input_size': True,
  543. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  544. 'first_conv': 'patch_embed.proj', 'classifier': 'head',
  545. 'license': 'apache-2.0',
  546. **kwargs
  547. }
  548. default_cfgs = generate_default_cfgs({
  549. 'nest_base.untrained': _cfg(),
  550. 'nest_small.untrained': _cfg(),
  551. 'nest_tiny.untrained': _cfg(),
  552. # (weights from official Google JAX impl, require 'SAME' padding)
  553. 'nest_base_jx.goog_in1k': _cfg(hf_hub_id='timm/'),
  554. 'nest_small_jx.goog_in1k': _cfg(hf_hub_id='timm/'),
  555. 'nest_tiny_jx.goog_in1k': _cfg(hf_hub_id='timm/'),
  556. })
  557. @register_model
  558. def nest_base(pretrained=False, **kwargs) -> Nest:
  559. """ Nest-B @ 224x224
  560. """
  561. model_kwargs = dict(
  562. embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs)
  563. model = _create_nest('nest_base', pretrained=pretrained, **model_kwargs)
  564. return model
  565. @register_model
  566. def nest_small(pretrained=False, **kwargs) -> Nest:
  567. """ Nest-S @ 224x224
  568. """
  569. model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs)
  570. model = _create_nest('nest_small', pretrained=pretrained, **model_kwargs)
  571. return model
  572. @register_model
  573. def nest_tiny(pretrained=False, **kwargs) -> Nest:
  574. """ Nest-T @ 224x224
  575. """
  576. model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs)
  577. model = _create_nest('nest_tiny', pretrained=pretrained, **model_kwargs)
  578. return model
  579. @register_model
  580. def nest_base_jx(pretrained=False, **kwargs) -> Nest:
  581. """ Nest-B @ 224x224
  582. """
  583. kwargs.setdefault('pad_type', 'same')
  584. model_kwargs = dict(
  585. embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs)
  586. model = _create_nest('nest_base_jx', pretrained=pretrained, **model_kwargs)
  587. return model
  588. @register_model
  589. def nest_small_jx(pretrained=False, **kwargs) -> Nest:
  590. """ Nest-S @ 224x224
  591. """
  592. kwargs.setdefault('pad_type', 'same')
  593. model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs)
  594. model = _create_nest('nest_small_jx', pretrained=pretrained, **model_kwargs)
  595. return model
  596. @register_model
  597. def nest_tiny_jx(pretrained=False, **kwargs) -> Nest:
  598. """ Nest-T @ 224x224
  599. """
  600. kwargs.setdefault('pad_type', 'same')
  601. model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs)
  602. model = _create_nest('nest_tiny_jx', pretrained=pretrained, **model_kwargs)
  603. return model
  604. register_model_deprecations(__name__, {
  605. 'jx_nest_base': 'nest_base_jx',
  606. 'jx_nest_small': 'nest_small_jx',
  607. 'jx_nest_tiny': 'nest_tiny_jx',
  608. })