pit.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. """ Pooling-based Vision Transformer (PiT) in PyTorch
  2. A PyTorch implement of Pooling-based Vision Transformers as described in
  3. 'Rethinking Spatial Dimensions of Vision Transformers' - https://arxiv.org/abs/2103.16302
  4. This code was adapted from the original version at https://github.com/naver-ai/pit, original copyright below.
  5. Modifications for timm by / Copyright 2020 Ross Wightman
  6. """
  7. # PiT
  8. # Copyright 2021-present NAVER Corp.
  9. # Apache License v2.0
  10. import math
  11. import re
  12. from functools import partial
  13. from typing import List, Optional, Sequence, Tuple, Union, Type, Any
  14. import torch
  15. from torch import nn
  16. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  17. from timm.layers import trunc_normal_, to_2tuple, calculate_drop_path_rates
  18. from ._builder import build_model_with_cfg
  19. from ._features import feature_take_indices
  20. from ._registry import register_model, generate_default_cfgs
  21. from .vision_transformer import Block
  22. __all__ = ['PoolingVisionTransformer'] # model_registry will add each entrypoint fn to this
  23. class SequentialTuple(nn.Sequential):
  24. """ This module exists to work around torchscript typing issues list -> list"""
  25. def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
  26. for module in self:
  27. x = module(x)
  28. return x
  29. class Transformer(nn.Module):
  30. def __init__(
  31. self,
  32. base_dim: int,
  33. depth: int,
  34. heads: int,
  35. mlp_ratio: float,
  36. pool: Optional[Any] = None,
  37. proj_drop: float = .0,
  38. attn_drop: float = .0,
  39. drop_path_prob: Optional[List[float]] = None,
  40. norm_layer: Optional[Type[nn.Module]] = None,
  41. device=None,
  42. dtype=None,
  43. ):
  44. dd = {'device': device, 'dtype': dtype}
  45. super().__init__()
  46. embed_dim = base_dim * heads
  47. self.pool = pool
  48. self.norm = norm_layer(embed_dim, **dd) if norm_layer else nn.Identity()
  49. self.blocks = nn.Sequential(*[
  50. Block(
  51. dim=embed_dim,
  52. num_heads=heads,
  53. mlp_ratio=mlp_ratio,
  54. qkv_bias=True,
  55. proj_drop=proj_drop,
  56. attn_drop=attn_drop,
  57. drop_path=drop_path_prob[i],
  58. norm_layer=partial(nn.LayerNorm, eps=1e-6),
  59. **dd,
  60. )
  61. for i in range(depth)])
  62. def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
  63. x, cls_tokens = x
  64. token_length = cls_tokens.shape[1]
  65. if self.pool is not None:
  66. x, cls_tokens = self.pool(x, cls_tokens)
  67. B, C, H, W = x.shape
  68. x = x.flatten(2).transpose(1, 2)
  69. x = torch.cat((cls_tokens, x), dim=1)
  70. x = self.norm(x)
  71. x = self.blocks(x)
  72. cls_tokens = x[:, :token_length]
  73. x = x[:, token_length:]
  74. x = x.transpose(1, 2).reshape(B, C, H, W)
  75. return x, cls_tokens
  76. class Pooling(nn.Module):
  77. def __init__(
  78. self,
  79. in_feature: int,
  80. out_feature: int,
  81. stride: int,
  82. padding_mode: str = 'zeros',
  83. device=None,
  84. dtype=None,
  85. ):
  86. dd = {'device': device, 'dtype': dtype}
  87. super().__init__()
  88. self.conv = nn.Conv2d(
  89. in_feature,
  90. out_feature,
  91. kernel_size=stride + 1,
  92. padding=stride // 2,
  93. stride=stride,
  94. padding_mode=padding_mode,
  95. groups=in_feature,
  96. **dd,
  97. )
  98. self.fc = nn.Linear(in_feature, out_feature, **dd)
  99. def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]:
  100. x = self.conv(x)
  101. cls_token = self.fc(cls_token)
  102. return x, cls_token
  103. class ConvEmbedding(nn.Module):
  104. def __init__(
  105. self,
  106. in_channels: int,
  107. out_channels: int,
  108. img_size: int = 224,
  109. patch_size: int = 16,
  110. stride: int = 8,
  111. padding: int = 0,
  112. device=None,
  113. dtype=None,
  114. ):
  115. dd = {'device': device, 'dtype': dtype}
  116. super().__init__()
  117. padding = padding
  118. self.img_size = to_2tuple(img_size)
  119. self.patch_size = to_2tuple(patch_size)
  120. self.height = math.floor((self.img_size[0] + 2 * padding - self.patch_size[0]) / stride + 1)
  121. self.width = math.floor((self.img_size[1] + 2 * padding - self.patch_size[1]) / stride + 1)
  122. self.grid_size = (self.height, self.width)
  123. self.conv = nn.Conv2d(
  124. in_channels,
  125. out_channels,
  126. kernel_size=patch_size,
  127. stride=stride,
  128. padding=padding,
  129. bias=True,
  130. **dd,
  131. )
  132. def forward(self, x):
  133. x = self.conv(x)
  134. return x
  135. class PoolingVisionTransformer(nn.Module):
  136. """ Pooling-based Vision Transformer
  137. A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers'
  138. - https://arxiv.org/abs/2103.16302
  139. """
  140. def __init__(
  141. self,
  142. img_size: int = 224,
  143. patch_size: int = 16,
  144. stride: int = 8,
  145. stem_type: str = 'overlap',
  146. base_dims: Sequence[int] = (48, 48, 48),
  147. depth: Sequence[int] = (2, 6, 4),
  148. heads: Sequence[int] = (2, 4, 8),
  149. mlp_ratio: float = 4,
  150. num_classes: int = 1000,
  151. in_chans: int = 3,
  152. global_pool: str = 'token',
  153. distilled: bool = False,
  154. drop_rate: float = 0.,
  155. pos_drop_drate: float = 0.,
  156. proj_drop_rate: float = 0.,
  157. attn_drop_rate: float = 0.,
  158. drop_path_rate: float = 0.,
  159. device=None,
  160. dtype=None,
  161. ):
  162. super().__init__()
  163. dd = {'device': device, 'dtype': dtype}
  164. assert global_pool in ('token',)
  165. self.base_dims = base_dims
  166. self.heads = heads
  167. embed_dim = base_dims[0] * heads[0]
  168. self.num_classes = num_classes
  169. self.in_chans = in_chans
  170. self.global_pool = global_pool
  171. self.num_tokens = 2 if distilled else 1
  172. self.feature_info = []
  173. self.patch_embed = ConvEmbedding(in_chans, embed_dim, img_size, patch_size, stride, **dd)
  174. self.pos_embed = nn.Parameter(torch.randn(1, embed_dim, self.patch_embed.height, self.patch_embed.width, **dd))
  175. self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim, **dd))
  176. self.pos_drop = nn.Dropout(p=pos_drop_drate)
  177. transformers = []
  178. # stochastic depth decay rule
  179. dpr = calculate_drop_path_rates(drop_path_rate, depth, stagewise=True)
  180. prev_dim = embed_dim
  181. for i in range(len(depth)):
  182. pool = None
  183. embed_dim = base_dims[i] * heads[i]
  184. if i > 0:
  185. pool = Pooling(
  186. prev_dim,
  187. embed_dim,
  188. stride=2,
  189. **dd,
  190. )
  191. transformers += [Transformer(
  192. base_dims[i],
  193. depth[i],
  194. heads[i],
  195. mlp_ratio,
  196. pool=pool,
  197. proj_drop=proj_drop_rate,
  198. attn_drop=attn_drop_rate,
  199. drop_path_prob=dpr[i],
  200. **dd,
  201. )]
  202. prev_dim = embed_dim
  203. self.feature_info += [dict(num_chs=prev_dim, reduction=(stride - 1) * 2**i, module=f'transformers.{i}')]
  204. self.transformers = SequentialTuple(*transformers)
  205. self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6, **dd)
  206. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim
  207. # Classifier head
  208. self.head_drop = nn.Dropout(drop_rate)
  209. self.head = nn.Linear(self.embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  210. self.head_dist = None
  211. if distilled:
  212. self.head_dist = nn.Linear(self.embed_dim, self.num_classes, **dd) if num_classes > 0 else nn.Identity()
  213. self.distilled_training = False # must set this True to train w/ distillation token
  214. trunc_normal_(self.pos_embed, std=.02)
  215. trunc_normal_(self.cls_token, std=.02)
  216. self.apply(self._init_weights)
  217. def _init_weights(self, m):
  218. if isinstance(m, nn.LayerNorm):
  219. nn.init.constant_(m.bias, 0)
  220. nn.init.constant_(m.weight, 1.0)
  221. @torch.jit.ignore
  222. def no_weight_decay(self):
  223. return {'pos_embed', 'cls_token'}
  224. @torch.jit.ignore
  225. def set_distilled_training(self, enable=True):
  226. self.distilled_training = enable
  227. @torch.jit.ignore
  228. def set_grad_checkpointing(self, enable=True):
  229. assert not enable, 'gradient checkpointing not supported'
  230. def get_classifier(self) -> nn.Module:
  231. if self.head_dist is not None:
  232. return self.head, self.head_dist
  233. else:
  234. return self.head
  235. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  236. self.num_classes = num_classes
  237. if global_pool is not None:
  238. self.global_pool = global_pool
  239. device = self.head.weight.device if hasattr(self.head, 'weight') else None
  240. dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None
  241. self.head = nn.Linear(self.embed_dim, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  242. if self.head_dist is not None:
  243. self.head_dist = nn.Linear(self.embed_dim, self.num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  244. def forward_intermediates(
  245. self,
  246. x: torch.Tensor,
  247. indices: Optional[Union[int, List[int]]] = None,
  248. norm: bool = False,
  249. stop_early: bool = False,
  250. output_fmt: str = 'NCHW',
  251. intermediates_only: bool = False,
  252. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  253. """ Forward features that returns intermediates.
  254. Args:
  255. x: Input image tensor
  256. indices: Take last n blocks if int, all if None, select matching indices if sequence
  257. norm: Apply norm layer to compatible intermediates
  258. stop_early: Stop iterating over blocks when last desired intermediate hit
  259. output_fmt: Shape of intermediate feature outputs
  260. intermediates_only: Only return intermediate features
  261. Returns:
  262. """
  263. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  264. intermediates = []
  265. take_indices, max_index = feature_take_indices(len(self.transformers), indices)
  266. # forward pass
  267. x = self.patch_embed(x)
  268. x = self.pos_drop(x + self.pos_embed)
  269. cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
  270. last_idx = len(self.transformers) - 1
  271. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  272. stages = self.transformers
  273. else:
  274. stages = self.transformers[:max_index + 1]
  275. for feat_idx, stage in enumerate(stages):
  276. x, cls_tokens = stage((x, cls_tokens))
  277. if feat_idx in take_indices:
  278. intermediates.append(x)
  279. if intermediates_only:
  280. return intermediates
  281. if feat_idx == last_idx:
  282. cls_tokens = self.norm(cls_tokens)
  283. return cls_tokens, intermediates
  284. def prune_intermediate_layers(
  285. self,
  286. indices: Union[int, List[int]] = 1,
  287. prune_norm: bool = False,
  288. prune_head: bool = True,
  289. ):
  290. """ Prune layers not required for specified intermediates.
  291. """
  292. take_indices, max_index = feature_take_indices(len(self.transformers), indices)
  293. self.transformers = self.transformers[:max_index + 1] # truncate blocks w/ stem as idx 0
  294. if prune_norm:
  295. self.norm = nn.Identity()
  296. if prune_head:
  297. self.reset_classifier(0, '')
  298. return take_indices
  299. def forward_features(self, x):
  300. x = self.patch_embed(x)
  301. x = self.pos_drop(x + self.pos_embed)
  302. cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
  303. x, cls_tokens = self.transformers((x, cls_tokens))
  304. cls_tokens = self.norm(cls_tokens)
  305. return cls_tokens
  306. def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
  307. if self.head_dist is not None:
  308. assert self.global_pool == 'token'
  309. x, x_dist = x[:, 0], x[:, 1]
  310. x = self.head_drop(x)
  311. x_dist = self.head_drop(x_dist)
  312. if not pre_logits:
  313. x = self.head(x)
  314. x_dist = self.head_dist(x_dist)
  315. if self.distilled_training and self.training and not torch.jit.is_scripting():
  316. # only return separate classification predictions when training in distilled mode
  317. return x, x_dist
  318. else:
  319. # during standard train / finetune, inference average the classifier predictions
  320. return (x + x_dist) / 2
  321. else:
  322. if self.global_pool == 'token':
  323. x = x[:, 0]
  324. x = self.head_drop(x)
  325. if not pre_logits:
  326. x = self.head(x)
  327. return x
  328. def forward(self, x):
  329. x = self.forward_features(x)
  330. x = self.forward_head(x)
  331. return x
  332. def checkpoint_filter_fn(state_dict, model):
  333. """ preprocess checkpoints """
  334. out_dict = {}
  335. p_blocks = re.compile(r'pools\.(\d)\.')
  336. for k, v in state_dict.items():
  337. # FIXME need to update resize for PiT impl
  338. # if k == 'pos_embed' and v.shape != model.pos_embed.shape:
  339. # # To resize pos embedding when using model at different size from pretrained weights
  340. # v = resize_pos_embed(v, model.pos_embed)
  341. k = p_blocks.sub(lambda exp: f'transformers.{int(exp.group(1)) + 1}.pool.', k)
  342. out_dict[k] = v
  343. return out_dict
  344. def _create_pit(variant, pretrained=False, **kwargs):
  345. default_out_indices = tuple(range(3))
  346. out_indices = kwargs.pop('out_indices', default_out_indices)
  347. model = build_model_with_cfg(
  348. PoolingVisionTransformer,
  349. variant,
  350. pretrained,
  351. pretrained_filter_fn=checkpoint_filter_fn,
  352. feature_cfg=dict(feature_cls='hook', out_indices=out_indices),
  353. **kwargs,
  354. )
  355. return model
  356. def _cfg(url='', **kwargs):
  357. return {
  358. 'url': url,
  359. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  360. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  361. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  362. 'first_conv': 'patch_embed.conv', 'classifier': 'head',
  363. 'license': 'apache-2.0',
  364. **kwargs
  365. }
  366. default_cfgs = generate_default_cfgs({
  367. # deit models (FB weights)
  368. 'pit_ti_224.in1k': _cfg(hf_hub_id='timm/'),
  369. 'pit_xs_224.in1k': _cfg(hf_hub_id='timm/'),
  370. 'pit_s_224.in1k': _cfg(hf_hub_id='timm/'),
  371. 'pit_b_224.in1k': _cfg(hf_hub_id='timm/'),
  372. 'pit_ti_distilled_224.in1k': _cfg(
  373. hf_hub_id='timm/',
  374. classifier=('head', 'head_dist')),
  375. 'pit_xs_distilled_224.in1k': _cfg(
  376. hf_hub_id='timm/',
  377. classifier=('head', 'head_dist')),
  378. 'pit_s_distilled_224.in1k': _cfg(
  379. hf_hub_id='timm/',
  380. classifier=('head', 'head_dist')),
  381. 'pit_b_distilled_224.in1k': _cfg(
  382. hf_hub_id='timm/',
  383. classifier=('head', 'head_dist')),
  384. })
  385. @register_model
  386. def pit_b_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  387. model_args = dict(
  388. patch_size=14,
  389. stride=7,
  390. base_dims=[64, 64, 64],
  391. depth=[3, 6, 4],
  392. heads=[4, 8, 16],
  393. mlp_ratio=4,
  394. )
  395. return _create_pit('pit_b_224', pretrained, **dict(model_args, **kwargs))
  396. @register_model
  397. def pit_s_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  398. model_args = dict(
  399. patch_size=16,
  400. stride=8,
  401. base_dims=[48, 48, 48],
  402. depth=[2, 6, 4],
  403. heads=[3, 6, 12],
  404. mlp_ratio=4,
  405. )
  406. return _create_pit('pit_s_224', pretrained, **dict(model_args, **kwargs))
  407. @register_model
  408. def pit_xs_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  409. model_args = dict(
  410. patch_size=16,
  411. stride=8,
  412. base_dims=[48, 48, 48],
  413. depth=[2, 6, 4],
  414. heads=[2, 4, 8],
  415. mlp_ratio=4,
  416. )
  417. return _create_pit('pit_xs_224', pretrained, **dict(model_args, **kwargs))
  418. @register_model
  419. def pit_ti_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  420. model_args = dict(
  421. patch_size=16,
  422. stride=8,
  423. base_dims=[32, 32, 32],
  424. depth=[2, 6, 4],
  425. heads=[2, 4, 8],
  426. mlp_ratio=4,
  427. )
  428. return _create_pit('pit_ti_224', pretrained, **dict(model_args, **kwargs))
  429. @register_model
  430. def pit_b_distilled_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  431. model_args = dict(
  432. patch_size=14,
  433. stride=7,
  434. base_dims=[64, 64, 64],
  435. depth=[3, 6, 4],
  436. heads=[4, 8, 16],
  437. mlp_ratio=4,
  438. distilled=True,
  439. )
  440. return _create_pit('pit_b_distilled_224', pretrained, **dict(model_args, **kwargs))
  441. @register_model
  442. def pit_s_distilled_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  443. model_args = dict(
  444. patch_size=16,
  445. stride=8,
  446. base_dims=[48, 48, 48],
  447. depth=[2, 6, 4],
  448. heads=[3, 6, 12],
  449. mlp_ratio=4,
  450. distilled=True,
  451. )
  452. return _create_pit('pit_s_distilled_224', pretrained, **dict(model_args, **kwargs))
  453. @register_model
  454. def pit_xs_distilled_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  455. model_args = dict(
  456. patch_size=16,
  457. stride=8,
  458. base_dims=[48, 48, 48],
  459. depth=[2, 6, 4],
  460. heads=[2, 4, 8],
  461. mlp_ratio=4,
  462. distilled=True,
  463. )
  464. return _create_pit('pit_xs_distilled_224', pretrained, **dict(model_args, **kwargs))
  465. @register_model
  466. def pit_ti_distilled_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  467. model_args = dict(
  468. patch_size=16,
  469. stride=8,
  470. base_dims=[32, 32, 32],
  471. depth=[2, 6, 4],
  472. heads=[2, 4, 8],
  473. mlp_ratio=4,
  474. distilled=True,
  475. )
  476. return _create_pit('pit_ti_distilled_224', pretrained, **dict(model_args, **kwargs))