pvt_v2.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594
  1. """ Pyramid Vision Transformer v2
  2. @misc{wang2021pvtv2,
  3. title={PVTv2: Improved Baselines with Pyramid Vision Transformer},
  4. author={Wenhai Wang and Enze Xie and Xiang Li and Deng-Ping Fan and Kaitao Song and Ding Liang and
  5. Tong Lu and Ping Luo and Ling Shao},
  6. year={2021},
  7. eprint={2106.13797},
  8. archivePrefix={arXiv},
  9. primaryClass={cs.CV}
  10. }
  11. Based on Apache 2.0 licensed code at https://github.com/whai362/PVT
  12. Modifications and timm support by / Copyright 2022, Ross Wightman
  13. """
  14. import math
  15. from typing import Callable, List, Optional, Tuple, Union, Type, Any
  16. import torch
  17. import torch.nn as nn
  18. import torch.nn.functional as F
  19. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  20. from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
  21. from ._builder import build_model_with_cfg
  22. from ._features import feature_take_indices
  23. from ._manipulate import checkpoint
  24. from ._registry import register_model, generate_default_cfgs
  25. __all__ = ['PyramidVisionTransformerV2']
  26. class MlpWithDepthwiseConv(nn.Module):
  27. def __init__(
  28. self,
  29. in_features: int,
  30. hidden_features: Optional[int] = None,
  31. out_features: Optional[int] = None,
  32. act_layer: Type[nn.Module] = nn.GELU,
  33. drop: float = 0.,
  34. extra_relu: bool = False,
  35. device=None,
  36. dtype=None,
  37. ):
  38. super().__init__()
  39. dd = {'device': device, 'dtype': dtype}
  40. out_features = out_features or in_features
  41. hidden_features = hidden_features or in_features
  42. self.fc1 = nn.Linear(in_features, hidden_features, **dd)
  43. self.relu = nn.ReLU() if extra_relu else nn.Identity()
  44. self.dwconv = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True, groups=hidden_features, **dd)
  45. self.act = act_layer()
  46. self.fc2 = nn.Linear(hidden_features, out_features, **dd)
  47. self.drop = nn.Dropout(drop)
  48. def forward(self, x, feat_size: List[int]):
  49. x = self.fc1(x)
  50. B, N, C = x.shape
  51. x = x.transpose(1, 2).view(B, C, feat_size[0], feat_size[1])
  52. x = self.relu(x)
  53. x = self.dwconv(x)
  54. x = x.flatten(2).transpose(1, 2)
  55. x = self.act(x)
  56. x = self.drop(x)
  57. x = self.fc2(x)
  58. x = self.drop(x)
  59. return x
  60. class Attention(nn.Module):
  61. fused_attn: torch.jit.Final[bool]
  62. def __init__(
  63. self,
  64. dim: int,
  65. num_heads: int = 8,
  66. sr_ratio: int = 1,
  67. linear_attn: bool = False,
  68. qkv_bias: bool = True,
  69. attn_drop: float = 0.,
  70. proj_drop: float = 0.,
  71. device=None,
  72. dtype=None,
  73. ):
  74. super().__init__()
  75. dd = {'device': device, 'dtype': dtype}
  76. assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
  77. self.dim = dim
  78. self.num_heads = num_heads
  79. self.head_dim = dim // num_heads
  80. self.scale = self.head_dim ** -0.5
  81. self.fused_attn = use_fused_attn()
  82. self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  83. self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias, **dd)
  84. self.attn_drop = nn.Dropout(attn_drop)
  85. self.proj = nn.Linear(dim, dim, **dd)
  86. self.proj_drop = nn.Dropout(proj_drop)
  87. if not linear_attn:
  88. self.pool = None
  89. if sr_ratio > 1:
  90. self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio, **dd)
  91. self.norm = nn.LayerNorm(dim, **dd)
  92. else:
  93. self.sr = None
  94. self.norm = None
  95. self.act = None
  96. else:
  97. self.pool = nn.AdaptiveAvgPool2d(7)
  98. self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1, **dd)
  99. self.norm = nn.LayerNorm(dim, **dd)
  100. self.act = nn.GELU()
  101. def forward(self, x, feat_size: List[int]):
  102. B, N, C = x.shape
  103. H, W = feat_size
  104. q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
  105. if self.pool is not None:
  106. x = x.permute(0, 2, 1).reshape(B, C, H, W)
  107. x = self.sr(self.pool(x)).reshape(B, C, -1).permute(0, 2, 1)
  108. x = self.norm(x)
  109. x = self.act(x)
  110. kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  111. else:
  112. if self.sr is not None:
  113. x = x.permute(0, 2, 1).reshape(B, C, H, W)
  114. x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
  115. x = self.norm(x)
  116. kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  117. else:
  118. kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  119. k, v = kv.unbind(0)
  120. if self.fused_attn:
  121. x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
  122. else:
  123. q = q * self.scale
  124. attn = q @ k.transpose(-2, -1)
  125. attn = attn.softmax(dim=-1)
  126. attn = self.attn_drop(attn)
  127. x = attn @ v
  128. x = x.transpose(1, 2).reshape(B, N, C)
  129. x = self.proj(x)
  130. x = self.proj_drop(x)
  131. return x
  132. class Block(nn.Module):
  133. def __init__(
  134. self,
  135. dim: int,
  136. num_heads: int,
  137. mlp_ratio: float = 4.,
  138. sr_ratio: int = 1,
  139. linear_attn: bool = False,
  140. qkv_bias: bool = False,
  141. proj_drop: float = 0.,
  142. attn_drop: float = 0.,
  143. drop_path: float = 0.,
  144. act_layer: Type[nn.Module] = nn.GELU,
  145. norm_layer: Type[nn.Module] = LayerNorm,
  146. device=None,
  147. dtype=None,
  148. ):
  149. super().__init__()
  150. dd = {'device': device, 'dtype': dtype}
  151. self.norm1 = norm_layer(dim, **dd)
  152. self.attn = Attention(
  153. dim,
  154. num_heads=num_heads,
  155. sr_ratio=sr_ratio,
  156. linear_attn=linear_attn,
  157. qkv_bias=qkv_bias,
  158. attn_drop=attn_drop,
  159. proj_drop=proj_drop,
  160. **dd,
  161. )
  162. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  163. self.norm2 = norm_layer(dim, **dd)
  164. self.mlp = MlpWithDepthwiseConv(
  165. in_features=dim,
  166. hidden_features=int(dim * mlp_ratio),
  167. act_layer=act_layer,
  168. drop=proj_drop,
  169. extra_relu=linear_attn,
  170. **dd,
  171. )
  172. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  173. def forward(self, x, feat_size: List[int]):
  174. x = x + self.drop_path1(self.attn(self.norm1(x), feat_size))
  175. x = x + self.drop_path2(self.mlp(self.norm2(x), feat_size))
  176. return x
  177. class OverlapPatchEmbed(nn.Module):
  178. """ Image to Patch Embedding
  179. """
  180. def __init__(
  181. self,
  182. patch_size: Union[int, Tuple[int, int]] = 7,
  183. stride: int = 4,
  184. in_chans: int = 3,
  185. embed_dim: int = 768,
  186. device=None,
  187. dtype=None,
  188. ):
  189. super().__init__()
  190. dd = {'device': device, 'dtype': dtype}
  191. patch_size = to_2tuple(patch_size)
  192. assert max(patch_size) > stride, "Set larger patch_size than stride"
  193. self.patch_size = patch_size
  194. self.proj = nn.Conv2d(
  195. in_chans, embed_dim, patch_size,
  196. stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2), **dd)
  197. self.norm = nn.LayerNorm(embed_dim, **dd)
  198. def forward(self, x):
  199. x = self.proj(x)
  200. x = x.permute(0, 2, 3, 1)
  201. x = self.norm(x)
  202. return x
  203. class PyramidVisionTransformerStage(nn.Module):
  204. def __init__(
  205. self,
  206. dim: int,
  207. dim_out: int,
  208. depth: int,
  209. downsample: bool = True,
  210. num_heads: int = 8,
  211. sr_ratio: int = 1,
  212. linear_attn: bool = False,
  213. mlp_ratio: float = 4.0,
  214. qkv_bias: bool = True,
  215. proj_drop: float = 0.,
  216. attn_drop: float = 0.,
  217. drop_path: Union[List[float], float] = 0.0,
  218. norm_layer: Callable = LayerNorm,
  219. device=None,
  220. dtype=None,
  221. ):
  222. super().__init__()
  223. dd = {'device': device, 'dtype': dtype}
  224. self.grad_checkpointing = False
  225. if downsample:
  226. self.downsample = OverlapPatchEmbed(
  227. patch_size=3,
  228. stride=2,
  229. in_chans=dim,
  230. embed_dim=dim_out,
  231. **dd,
  232. )
  233. else:
  234. assert dim == dim_out
  235. self.downsample = None
  236. self.blocks = nn.ModuleList([Block(
  237. dim=dim_out,
  238. num_heads=num_heads,
  239. sr_ratio=sr_ratio,
  240. linear_attn=linear_attn,
  241. mlp_ratio=mlp_ratio,
  242. qkv_bias=qkv_bias,
  243. proj_drop=proj_drop,
  244. attn_drop=attn_drop,
  245. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  246. norm_layer=norm_layer,
  247. **dd,
  248. ) for i in range(depth)])
  249. self.norm = norm_layer(dim_out, **dd)
  250. def forward(self, x):
  251. # x is either B, C, H, W (if downsample) or B, H, W, C if not
  252. if self.downsample is not None:
  253. # input to downsample is B, C, H, W
  254. x = self.downsample(x) # output B, H, W, C
  255. B, H, W, C = x.shape
  256. feat_size = (H, W)
  257. x = x.reshape(B, -1, C)
  258. for blk in self.blocks:
  259. if self.grad_checkpointing and not torch.jit.is_scripting():
  260. x = checkpoint(blk, x, feat_size)
  261. else:
  262. x = blk(x, feat_size)
  263. x = self.norm(x)
  264. x = x.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2).contiguous()
  265. return x
  266. class PyramidVisionTransformerV2(nn.Module):
  267. def __init__(
  268. self,
  269. in_chans: int = 3,
  270. num_classes: int = 1000,
  271. global_pool: str = 'avg',
  272. depths: Tuple[int, ...] = (3, 4, 6, 3),
  273. embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
  274. num_heads: Tuple[int, ...] = (1, 2, 4, 8),
  275. sr_ratios: Tuple[int, ...] = (8, 4, 2, 1),
  276. mlp_ratios: Tuple[float, ...] = (8., 8., 4., 4.),
  277. qkv_bias: bool = True,
  278. linear: bool = False,
  279. drop_rate: float = 0.,
  280. proj_drop_rate: float = 0.,
  281. attn_drop_rate: float = 0.,
  282. drop_path_rate: float = 0.,
  283. norm_layer: Type[nn.Module] = LayerNorm,
  284. device=None,
  285. dtype=None,
  286. ):
  287. super().__init__()
  288. dd = {'device': device, 'dtype': dtype}
  289. self.num_classes = num_classes
  290. self.in_chans = in_chans
  291. assert global_pool in ('avg', '')
  292. self.global_pool = global_pool
  293. self.depths = depths
  294. num_stages = len(depths)
  295. mlp_ratios = to_ntuple(num_stages)(mlp_ratios)
  296. num_heads = to_ntuple(num_stages)(num_heads)
  297. sr_ratios = to_ntuple(num_stages)(sr_ratios)
  298. assert(len(embed_dims)) == num_stages
  299. self.feature_info = []
  300. self.patch_embed = OverlapPatchEmbed(
  301. patch_size=7,
  302. stride=4,
  303. in_chans=in_chans,
  304. embed_dim=embed_dims[0],
  305. **dd,
  306. )
  307. dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  308. cur = 0
  309. prev_dim = embed_dims[0]
  310. stages = []
  311. for i in range(num_stages):
  312. stages += [PyramidVisionTransformerStage(
  313. dim=prev_dim,
  314. dim_out=embed_dims[i],
  315. depth=depths[i],
  316. downsample=i > 0,
  317. num_heads=num_heads[i],
  318. sr_ratio=sr_ratios[i],
  319. mlp_ratio=mlp_ratios[i],
  320. linear_attn=linear,
  321. qkv_bias=qkv_bias,
  322. proj_drop=proj_drop_rate,
  323. attn_drop=attn_drop_rate,
  324. drop_path=dpr[i],
  325. norm_layer=norm_layer,
  326. **dd,
  327. )]
  328. prev_dim = embed_dims[i]
  329. cur += depths[i]
  330. self.feature_info += [dict(num_chs=prev_dim, reduction=4 * 2**i, module=f'stages.{i}')]
  331. self.stages = nn.Sequential(*stages)
  332. # classification head
  333. self.num_features = self.head_hidden_size = embed_dims[-1]
  334. self.head_drop = nn.Dropout(drop_rate)
  335. self.head = nn.Linear(embed_dims[-1], num_classes, **dd) if num_classes > 0 else nn.Identity()
  336. self.apply(self._init_weights)
  337. def _init_weights(self, m):
  338. if isinstance(m, nn.Linear):
  339. trunc_normal_(m.weight, std=.02)
  340. if isinstance(m, nn.Linear) and m.bias is not None:
  341. nn.init.constant_(m.bias, 0)
  342. elif isinstance(m, nn.Conv2d):
  343. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  344. fan_out //= m.groups
  345. m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  346. if m.bias is not None:
  347. m.bias.data.zero_()
  348. def freeze_patch_emb(self):
  349. self.patch_embed.requires_grad = False
  350. @torch.jit.ignore
  351. def no_weight_decay(self):
  352. return {}
  353. @torch.jit.ignore
  354. def group_matcher(self, coarse=False):
  355. matcher = dict(
  356. stem=r'^patch_embed', # stem and embed
  357. blocks=r'^stages\.(\d+)'
  358. )
  359. return matcher
  360. @torch.jit.ignore
  361. def set_grad_checkpointing(self, enable=True):
  362. for s in self.stages:
  363. s.grad_checkpointing = enable
  364. def get_classifier(self) -> nn.Module:
  365. return self.head
  366. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  367. self.num_classes = num_classes
  368. if global_pool is not None:
  369. assert global_pool in ('avg', '')
  370. self.global_pool = global_pool
  371. device = self.head.weight.device if hasattr(self.head, 'weight') else None
  372. dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None
  373. self.head = nn.Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  374. def forward_intermediates(
  375. self,
  376. x: torch.Tensor,
  377. indices: Optional[Union[int, List[int]]] = None,
  378. norm: bool = False,
  379. stop_early: bool = False,
  380. output_fmt: str = 'NCHW',
  381. intermediates_only: bool = False,
  382. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  383. """ Forward features that returns intermediates.
  384. Args:
  385. x: Input image tensor
  386. indices: Take last n blocks if int, all if None, select matching indices if sequence
  387. norm: Apply norm layer to compatible intermediates
  388. stop_early: Stop iterating over blocks when last desired intermediate hit
  389. output_fmt: Shape of intermediate feature outputs
  390. intermediates_only: Only return intermediate features
  391. Returns:
  392. """
  393. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  394. intermediates = []
  395. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  396. # forward pass
  397. x = self.patch_embed(x)
  398. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  399. stages = self.stages
  400. else:
  401. stages = self.stages[:max_index + 1]
  402. for feat_idx, stage in enumerate(stages):
  403. x = stage(x)
  404. if feat_idx in take_indices:
  405. intermediates.append(x)
  406. if intermediates_only:
  407. return intermediates
  408. return x, intermediates
  409. def prune_intermediate_layers(
  410. self,
  411. indices: Union[int, List[int]] = 1,
  412. prune_norm: bool = False,
  413. prune_head: bool = True,
  414. ):
  415. """ Prune layers not required for specified intermediates.
  416. """
  417. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  418. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  419. if prune_head:
  420. self.reset_classifier(0, '')
  421. return take_indices
  422. def forward_features(self, x):
  423. x = self.patch_embed(x)
  424. x = self.stages(x)
  425. return x
  426. def forward_head(self, x, pre_logits: bool = False):
  427. if self.global_pool:
  428. x = x.mean(dim=(-1, -2))
  429. x = self.head_drop(x)
  430. return x if pre_logits else self.head(x)
  431. def forward(self, x):
  432. x = self.forward_features(x)
  433. x = self.forward_head(x)
  434. return x
  435. def checkpoint_filter_fn(state_dict, model):
  436. """ Remap original checkpoints -> timm """
  437. if 'patch_embed.proj.weight' in state_dict:
  438. return state_dict # non-original checkpoint, no remapping needed
  439. out_dict = {}
  440. import re
  441. for k, v in state_dict.items():
  442. if k.startswith('patch_embed'):
  443. k = k.replace('patch_embed1', 'patch_embed')
  444. k = k.replace('patch_embed2', 'stages.1.downsample')
  445. k = k.replace('patch_embed3', 'stages.2.downsample')
  446. k = k.replace('patch_embed4', 'stages.3.downsample')
  447. k = k.replace('dwconv.dwconv', 'dwconv')
  448. k = re.sub(r'block(\d+).(\d+)', lambda x: f'stages.{int(x.group(1)) - 1}.blocks.{x.group(2)}', k)
  449. k = re.sub(r'^norm(\d+)', lambda x: f'stages.{int(x.group(1)) - 1}.norm', k)
  450. out_dict[k] = v
  451. return out_dict
  452. def _create_pvt2(variant, pretrained=False, **kwargs):
  453. default_out_indices = tuple(range(4))
  454. out_indices = kwargs.pop('out_indices', default_out_indices)
  455. model = build_model_with_cfg(
  456. PyramidVisionTransformerV2,
  457. variant,
  458. pretrained,
  459. pretrained_filter_fn=checkpoint_filter_fn,
  460. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  461. **kwargs,
  462. )
  463. return model
  464. def _cfg(url='', **kwargs):
  465. return {
  466. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  467. 'crop_pct': 0.9, 'interpolation': 'bicubic',
  468. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  469. 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False,
  470. 'license': 'apache-2.0',
  471. **kwargs
  472. }
  473. default_cfgs = generate_default_cfgs({
  474. 'pvt_v2_b0.in1k': _cfg(hf_hub_id='timm/'),
  475. 'pvt_v2_b1.in1k': _cfg(hf_hub_id='timm/'),
  476. 'pvt_v2_b2.in1k': _cfg(hf_hub_id='timm/'),
  477. 'pvt_v2_b3.in1k': _cfg(hf_hub_id='timm/'),
  478. 'pvt_v2_b4.in1k': _cfg(hf_hub_id='timm/'),
  479. 'pvt_v2_b5.in1k': _cfg(hf_hub_id='timm/'),
  480. 'pvt_v2_b2_li.in1k': _cfg(hf_hub_id='timm/'),
  481. })
  482. @register_model
  483. def pvt_v2_b0(pretrained=False, **kwargs) -> PyramidVisionTransformerV2:
  484. model_args = dict(depths=(2, 2, 2, 2), embed_dims=(32, 64, 160, 256), num_heads=(1, 2, 5, 8))
  485. return _create_pvt2('pvt_v2_b0', pretrained=pretrained, **dict(model_args, **kwargs))
  486. @register_model
  487. def pvt_v2_b1(pretrained=False, **kwargs) -> PyramidVisionTransformerV2:
  488. model_args = dict(depths=(2, 2, 2, 2), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8))
  489. return _create_pvt2('pvt_v2_b1', pretrained=pretrained, **dict(model_args, **kwargs))
  490. @register_model
  491. def pvt_v2_b2(pretrained=False, **kwargs) -> PyramidVisionTransformerV2:
  492. model_args = dict(depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8))
  493. return _create_pvt2('pvt_v2_b2', pretrained=pretrained, **dict(model_args, **kwargs))
  494. @register_model
  495. def pvt_v2_b3(pretrained=False, **kwargs) -> PyramidVisionTransformerV2:
  496. model_args = dict(depths=(3, 4, 18, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8))
  497. return _create_pvt2('pvt_v2_b3', pretrained=pretrained, **dict(model_args, **kwargs))
  498. @register_model
  499. def pvt_v2_b4(pretrained=False, **kwargs) -> PyramidVisionTransformerV2:
  500. model_args = dict(depths=(3, 8, 27, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8))
  501. return _create_pvt2('pvt_v2_b4', pretrained=pretrained, **dict(model_args, **kwargs))
  502. @register_model
  503. def pvt_v2_b5(pretrained=False, **kwargs) -> PyramidVisionTransformerV2:
  504. model_args = dict(
  505. depths=(3, 6, 40, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), mlp_ratios=(4, 4, 4, 4))
  506. return _create_pvt2('pvt_v2_b5', pretrained=pretrained, **dict(model_args, **kwargs))
  507. @register_model
  508. def pvt_v2_b2_li(pretrained=False, **kwargs) -> PyramidVisionTransformerV2:
  509. model_args = dict(
  510. depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), linear=True)
  511. return _create_pvt2('pvt_v2_b2_li', pretrained=pretrained, **dict(model_args, **kwargs))