sequencer.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. """ Sequencer
  2. Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2205.01972.pdf
  3. """
  4. # Copyright (c) 2022. Yuki Tatsunami
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. import math
  7. from functools import partial
  8. from itertools import accumulate
  9. from typing import List, Optional, Tuple, Type, Union
  10. import torch
  11. import torch.nn as nn
  12. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
  13. from timm.layers import lecun_normal_, DropPath, Mlp, PatchEmbed, ClassifierHead
  14. from ._builder import build_model_with_cfg
  15. from ._manipulate import named_apply
  16. from ._registry import register_model, generate_default_cfgs
  17. __all__ = ['Sequencer2d'] # model_registry will add each entrypoint fn to this
  18. def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
  19. if isinstance(module, nn.Linear):
  20. if name.startswith('head'):
  21. nn.init.zeros_(module.weight)
  22. nn.init.constant_(module.bias, head_bias)
  23. else:
  24. if flax:
  25. # Flax defaults
  26. lecun_normal_(module.weight)
  27. if module.bias is not None:
  28. nn.init.zeros_(module.bias)
  29. else:
  30. nn.init.xavier_uniform_(module.weight)
  31. if module.bias is not None:
  32. if 'mlp' in name:
  33. nn.init.normal_(module.bias, std=1e-6)
  34. else:
  35. nn.init.zeros_(module.bias)
  36. elif isinstance(module, nn.Conv2d):
  37. lecun_normal_(module.weight)
  38. if module.bias is not None:
  39. nn.init.zeros_(module.bias)
  40. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
  41. nn.init.ones_(module.weight)
  42. nn.init.zeros_(module.bias)
  43. elif isinstance(module, (nn.RNN, nn.GRU, nn.LSTM)):
  44. stdv = 1.0 / math.sqrt(module.hidden_size)
  45. for weight in module.parameters():
  46. nn.init.uniform_(weight, -stdv, stdv)
  47. elif hasattr(module, 'init_weights'):
  48. module.init_weights()
  49. class RNNIdentity(nn.Module):
  50. def __init__(self, *args, **kwargs):
  51. super().__init__()
  52. def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]:
  53. return x, None
  54. class RNN2dBase(nn.Module):
  55. def __init__(
  56. self,
  57. input_size: int,
  58. hidden_size: int,
  59. num_layers: int = 1,
  60. bias: bool = True,
  61. bidirectional: bool = True,
  62. union: str = "cat",
  63. with_fc: bool = True,
  64. device=None,
  65. dtype=None,
  66. ):
  67. dd = {'device': device, 'dtype': dtype}
  68. super().__init__()
  69. self.input_size = input_size
  70. self.hidden_size = hidden_size
  71. self.output_size = 2 * hidden_size if bidirectional else hidden_size
  72. self.union = union
  73. self.with_vertical = True
  74. self.with_horizontal = True
  75. self.with_fc = with_fc
  76. self.fc = None
  77. if with_fc:
  78. if union == "cat":
  79. self.fc = nn.Linear(2 * self.output_size, input_size, **dd)
  80. elif union == "add":
  81. self.fc = nn.Linear(self.output_size, input_size, **dd)
  82. elif union == "vertical":
  83. self.fc = nn.Linear(self.output_size, input_size, **dd)
  84. self.with_horizontal = False
  85. elif union == "horizontal":
  86. self.fc = nn.Linear(self.output_size, input_size, **dd)
  87. self.with_vertical = False
  88. else:
  89. raise ValueError("Unrecognized union: " + union)
  90. elif union == "cat":
  91. pass
  92. if 2 * self.output_size != input_size:
  93. raise ValueError(f"The output channel {2 * self.output_size} is different from the input channel {input_size}.")
  94. elif union == "add":
  95. pass
  96. if self.output_size != input_size:
  97. raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
  98. elif union == "vertical":
  99. if self.output_size != input_size:
  100. raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
  101. self.with_horizontal = False
  102. elif union == "horizontal":
  103. if self.output_size != input_size:
  104. raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
  105. self.with_vertical = False
  106. else:
  107. raise ValueError("Unrecognized union: " + union)
  108. self.rnn_v = RNNIdentity()
  109. self.rnn_h = RNNIdentity()
  110. def forward(self, x):
  111. B, H, W, C = x.shape
  112. if self.with_vertical:
  113. v = x.permute(0, 2, 1, 3)
  114. v = v.reshape(-1, H, C)
  115. v, _ = self.rnn_v(v)
  116. v = v.reshape(B, W, H, -1)
  117. v = v.permute(0, 2, 1, 3)
  118. else:
  119. v = None
  120. if self.with_horizontal:
  121. h = x.reshape(-1, W, C)
  122. h, _ = self.rnn_h(h)
  123. h = h.reshape(B, H, W, -1)
  124. else:
  125. h = None
  126. if v is not None and h is not None:
  127. if self.union == "cat":
  128. x = torch.cat([v, h], dim=-1)
  129. else:
  130. x = v + h
  131. elif v is not None:
  132. x = v
  133. elif h is not None:
  134. x = h
  135. if self.fc is not None:
  136. x = self.fc(x)
  137. return x
  138. class LSTM2d(RNN2dBase):
  139. def __init__(
  140. self,
  141. input_size: int,
  142. hidden_size: int,
  143. num_layers: int = 1,
  144. bias: bool = True,
  145. bidirectional: bool = True,
  146. union: str = "cat",
  147. with_fc: bool = True,
  148. device=None,
  149. dtype=None,
  150. ):
  151. dd = {'device': device, 'dtype': dtype}
  152. super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc, device, dtype)
  153. if self.with_vertical:
  154. self.rnn_v = nn.LSTM(
  155. input_size,
  156. hidden_size,
  157. num_layers,
  158. batch_first=True,
  159. bias=bias,
  160. bidirectional=bidirectional,
  161. **dd,
  162. )
  163. if self.with_horizontal:
  164. self.rnn_h = nn.LSTM(
  165. input_size,
  166. hidden_size,
  167. num_layers,
  168. batch_first=True,
  169. bias=bias,
  170. bidirectional=bidirectional,
  171. **dd,
  172. )
  173. class Sequencer2dBlock(nn.Module):
  174. def __init__(
  175. self,
  176. dim: int,
  177. hidden_size: int,
  178. mlp_ratio: float = 3.0,
  179. rnn_layer: Type[nn.Module] = LSTM2d,
  180. mlp_layer: Type[nn.Module] = Mlp,
  181. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  182. act_layer: Type[nn.Module] = nn.GELU,
  183. num_layers: int = 1,
  184. bidirectional: bool = True,
  185. union: str = "cat",
  186. with_fc: bool = True,
  187. drop: float = 0.,
  188. drop_path: float = 0.,
  189. device=None,
  190. dtype=None,
  191. ):
  192. dd = {'device': device, 'dtype': dtype}
  193. super().__init__()
  194. channels_dim = int(mlp_ratio * dim)
  195. self.norm1 = norm_layer(dim, **dd)
  196. self.rnn_tokens = rnn_layer(
  197. dim,
  198. hidden_size,
  199. num_layers=num_layers,
  200. bidirectional=bidirectional,
  201. union=union,
  202. with_fc=with_fc,
  203. **dd,
  204. )
  205. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  206. self.norm2 = norm_layer(dim, **dd)
  207. self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop, **dd)
  208. def forward(self, x):
  209. x = x + self.drop_path(self.rnn_tokens(self.norm1(x)))
  210. x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
  211. return x
  212. class Shuffle(nn.Module):
  213. def __init__(self):
  214. super().__init__()
  215. def forward(self, x):
  216. if self.training:
  217. B, H, W, C = x.shape
  218. r = torch.randperm(H * W)
  219. x = x.reshape(B, -1, C)
  220. x = x[:, r, :].reshape(B, H, W, -1)
  221. return x
  222. class Downsample2d(nn.Module):
  223. def __init__(
  224. self,
  225. input_dim: int,
  226. output_dim: int,
  227. patch_size: int,
  228. device=None,
  229. dtype=None,
  230. ):
  231. dd = {'device': device, 'dtype': dtype}
  232. super().__init__()
  233. self.down = nn.Conv2d(input_dim, output_dim, kernel_size=patch_size, stride=patch_size, **dd)
  234. def forward(self, x):
  235. x = x.permute(0, 3, 1, 2)
  236. x = self.down(x)
  237. x = x.permute(0, 2, 3, 1)
  238. return x
  239. class Sequencer2dStage(nn.Module):
  240. def __init__(
  241. self,
  242. dim: int,
  243. dim_out: int,
  244. depth: int,
  245. patch_size: int,
  246. hidden_size: int,
  247. mlp_ratio: float,
  248. downsample: bool = False,
  249. block_layer: Type[nn.Module] = Sequencer2dBlock,
  250. rnn_layer: Type[nn.Module] = LSTM2d,
  251. mlp_layer: Type[nn.Module] = Mlp,
  252. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  253. act_layer: Type[nn.Module] = nn.GELU,
  254. num_layers: int = 1,
  255. bidirectional: bool = True,
  256. union: str = "cat",
  257. with_fc: bool = True,
  258. drop: float = 0.,
  259. drop_path: Union[float, List[float]] = 0.,
  260. device=None,
  261. dtype=None,
  262. ):
  263. super().__init__()
  264. dd = {'device': device, 'dtype': dtype}
  265. if downsample:
  266. self.downsample = Downsample2d(dim, dim_out, patch_size, **dd)
  267. else:
  268. assert dim == dim_out
  269. self.downsample = nn.Identity()
  270. blocks = []
  271. for block_idx in range(depth):
  272. blocks.append(block_layer(
  273. dim_out,
  274. hidden_size,
  275. mlp_ratio=mlp_ratio,
  276. rnn_layer=rnn_layer,
  277. mlp_layer=mlp_layer,
  278. norm_layer=norm_layer,
  279. act_layer=act_layer,
  280. num_layers=num_layers,
  281. bidirectional=bidirectional,
  282. union=union,
  283. with_fc=with_fc,
  284. drop=drop,
  285. drop_path=drop_path[block_idx] if isinstance(drop_path, (list, tuple)) else drop_path,
  286. **dd,
  287. ))
  288. self.blocks = nn.Sequential(*blocks)
  289. def forward(self, x):
  290. x = self.downsample(x)
  291. x = self.blocks(x)
  292. return x
  293. class Sequencer2d(nn.Module):
  294. def __init__(
  295. self,
  296. num_classes: int = 1000,
  297. img_size: int = 224,
  298. in_chans: int = 3,
  299. global_pool: str = 'avg',
  300. layers: Tuple[int, ...] = (4, 3, 8, 3),
  301. patch_sizes: Tuple[int, ...] = (7, 2, 2, 1),
  302. embed_dims: Tuple[int, ...] = (192, 384, 384, 384),
  303. hidden_sizes: Tuple[int, ...] = (48, 96, 96, 96),
  304. mlp_ratios: Tuple[float, ...] = (3.0, 3.0, 3.0, 3.0),
  305. block_layer: Type[nn.Module] = Sequencer2dBlock,
  306. rnn_layer: Type[nn.Module] = LSTM2d,
  307. mlp_layer: Type[nn.Module] = Mlp,
  308. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  309. act_layer: Type[nn.Module] = nn.GELU,
  310. num_rnn_layers: int = 1,
  311. bidirectional: bool = True,
  312. union: str = "cat",
  313. with_fc: bool = True,
  314. drop_rate: float = 0.,
  315. drop_path_rate: float = 0.,
  316. nlhb: bool = False,
  317. stem_norm: bool = False,
  318. device=None,
  319. dtype=None,
  320. ):
  321. super().__init__()
  322. dd = {'device': device, 'dtype': dtype}
  323. assert global_pool in ('', 'avg')
  324. self.num_classes = num_classes
  325. self.in_chans = in_chans
  326. self.global_pool = global_pool
  327. self.num_features = self.head_hidden_size = embed_dims[-1] # for consistency with other models
  328. self.feature_dim = -1 # channel dim index for feature outputs (rank 4, NHWC)
  329. self.output_fmt = 'NHWC'
  330. self.feature_info = []
  331. self.stem = PatchEmbed(
  332. img_size=None,
  333. patch_size=patch_sizes[0],
  334. in_chans=in_chans,
  335. embed_dim=embed_dims[0],
  336. norm_layer=norm_layer if stem_norm else None,
  337. flatten=False,
  338. output_fmt='NHWC',
  339. **dd,
  340. )
  341. assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios)
  342. reductions = list(accumulate(patch_sizes, lambda x, y: x * y))
  343. stages = []
  344. prev_dim = embed_dims[0]
  345. for i, _ in enumerate(embed_dims):
  346. stages += [Sequencer2dStage(
  347. prev_dim,
  348. embed_dims[i],
  349. depth=layers[i],
  350. downsample=i > 0,
  351. patch_size=patch_sizes[i],
  352. hidden_size=hidden_sizes[i],
  353. mlp_ratio=mlp_ratios[i],
  354. block_layer=block_layer,
  355. rnn_layer=rnn_layer,
  356. mlp_layer=mlp_layer,
  357. norm_layer=norm_layer,
  358. act_layer=act_layer,
  359. num_layers=num_rnn_layers,
  360. bidirectional=bidirectional,
  361. union=union,
  362. with_fc=with_fc,
  363. drop=drop_rate,
  364. drop_path=drop_path_rate,
  365. **dd,
  366. )]
  367. prev_dim = embed_dims[i]
  368. self.feature_info += [dict(num_chs=prev_dim, reduction=reductions[i], module=f'stages.{i}')]
  369. self.stages = nn.Sequential(*stages)
  370. self.norm = norm_layer(embed_dims[-1], **dd)
  371. self.head = ClassifierHead(
  372. self.num_features,
  373. num_classes,
  374. pool_type=global_pool,
  375. drop_rate=drop_rate,
  376. input_fmt=self.output_fmt,
  377. **dd,
  378. )
  379. self.init_weights(nlhb=nlhb)
  380. def init_weights(self, nlhb=False):
  381. head_bias = -math.log(self.num_classes) if nlhb else 0.
  382. named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
  383. @torch.jit.ignore
  384. def group_matcher(self, coarse=False):
  385. return dict(
  386. stem=r'^stem',
  387. blocks=[
  388. (r'^stages\.(\d+)', None),
  389. (r'^norm', (99999,))
  390. ] if coarse else [
  391. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  392. (r'^stages\.(\d+)\.downsample', (0,)),
  393. (r'^norm', (99999,))
  394. ]
  395. )
  396. @torch.jit.ignore
  397. def set_grad_checkpointing(self, enable=True):
  398. assert not enable, 'gradient checkpointing not supported'
  399. @torch.jit.ignore
  400. def get_classifier(self) -> nn.Module:
  401. return self.head
  402. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  403. self.num_classes = num_classes
  404. self.head.reset(num_classes, pool_type=global_pool)
  405. def forward_features(self, x):
  406. x = self.stem(x)
  407. x = self.stages(x)
  408. x = self.norm(x)
  409. return x
  410. def forward_head(self, x, pre_logits: bool = False):
  411. return self.head(x, pre_logits=True) if pre_logits else self.head(x)
  412. def forward(self, x):
  413. x = self.forward_features(x)
  414. x = self.forward_head(x)
  415. return x
  416. def checkpoint_filter_fn(state_dict, model):
  417. """ Remap original checkpoints -> timm """
  418. if 'stages.0.blocks.0.norm1.weight' in state_dict:
  419. return state_dict # already translated checkpoint
  420. if 'model' in state_dict:
  421. state_dict = state_dict['model']
  422. import re
  423. out_dict = {}
  424. for k, v in state_dict.items():
  425. k = re.sub(r'blocks.([0-9]+).([0-9]+).down', lambda x: f'stages.{int(x.group(1)) + 1}.downsample.down', k)
  426. k = re.sub(r'blocks.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
  427. k = k.replace('head.', 'head.fc.')
  428. out_dict[k] = v
  429. return out_dict
  430. def _create_sequencer2d(variant, pretrained=False, **kwargs):
  431. default_out_indices = tuple(range(3))
  432. out_indices = kwargs.pop('out_indices', default_out_indices)
  433. model = build_model_with_cfg(
  434. Sequencer2d,
  435. variant,
  436. pretrained,
  437. pretrained_filter_fn=checkpoint_filter_fn,
  438. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  439. **kwargs,
  440. )
  441. return model
  442. def _cfg(url='', **kwargs):
  443. return {
  444. 'url': url,
  445. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  446. 'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True,
  447. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  448. 'first_conv': 'stem.proj', 'classifier': 'head.fc',
  449. 'license': 'apache-2.0',
  450. **kwargs
  451. }
  452. default_cfgs = generate_default_cfgs({
  453. 'sequencer2d_s.in1k': _cfg(hf_hub_id='timm/'),
  454. 'sequencer2d_m.in1k': _cfg(hf_hub_id='timm/'),
  455. 'sequencer2d_l.in1k': _cfg(hf_hub_id='timm/'),
  456. })
  457. @register_model
  458. def sequencer2d_s(pretrained=False, **kwargs) -> Sequencer2d:
  459. model_args = dict(
  460. layers=[4, 3, 8, 3],
  461. patch_sizes=[7, 2, 1, 1],
  462. embed_dims=[192, 384, 384, 384],
  463. hidden_sizes=[48, 96, 96, 96],
  464. mlp_ratios=[3.0, 3.0, 3.0, 3.0],
  465. rnn_layer=LSTM2d,
  466. bidirectional=True,
  467. union="cat",
  468. with_fc=True,
  469. )
  470. model = _create_sequencer2d('sequencer2d_s', pretrained=pretrained, **dict(model_args, **kwargs))
  471. return model
  472. @register_model
  473. def sequencer2d_m(pretrained=False, **kwargs) -> Sequencer2d:
  474. model_args = dict(
  475. layers=[4, 3, 14, 3],
  476. patch_sizes=[7, 2, 1, 1],
  477. embed_dims=[192, 384, 384, 384],
  478. hidden_sizes=[48, 96, 96, 96],
  479. mlp_ratios=[3.0, 3.0, 3.0, 3.0],
  480. rnn_layer=LSTM2d,
  481. bidirectional=True,
  482. union="cat",
  483. with_fc=True,
  484. **kwargs)
  485. model = _create_sequencer2d('sequencer2d_m', pretrained=pretrained, **dict(model_args, **kwargs))
  486. return model
  487. @register_model
  488. def sequencer2d_l(pretrained=False, **kwargs) -> Sequencer2d:
  489. model_args = dict(
  490. layers=[8, 8, 16, 4],
  491. patch_sizes=[7, 2, 1, 1],
  492. embed_dims=[192, 384, 384, 384],
  493. hidden_sizes=[48, 96, 96, 96],
  494. mlp_ratios=[3.0, 3.0, 3.0, 3.0],
  495. rnn_layer=LSTM2d,
  496. bidirectional=True,
  497. union="cat",
  498. with_fc=True,
  499. **kwargs)
  500. model = _create_sequencer2d('sequencer2d_l', pretrained=pretrained, **dict(model_args, **kwargs))
  501. return model