starnet.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. """
  2. Implementation of Prof-of-Concept Network: StarNet.
  3. We make StarNet as simple as possible [to show the key contribution of element-wise multiplication]:
  4. - like NO layer-scale in network design,
  5. - and NO EMA during training,
  6. - which would improve the performance further.
  7. Created by: Xu Ma (Email: ma.xu1@northeastern.edu)
  8. Modified Date: Mar/29/2024
  9. """
  10. from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  15. from timm.layers import DropPath, SelectAdaptivePool2d, Linear, LayerType, trunc_normal_, calculate_drop_path_rates
  16. from ._builder import build_model_with_cfg
  17. from ._features import feature_take_indices
  18. from ._manipulate import checkpoint_seq
  19. from ._registry import register_model, generate_default_cfgs
  20. __all__ = ['StarNet']
  21. class ConvBN(nn.Sequential):
  22. def __init__(
  23. self,
  24. in_channels: int,
  25. out_channels: int,
  26. kernel_size: int = 1,
  27. stride: int = 1,
  28. padding: int = 0,
  29. with_bn: bool = True,
  30. device=None,
  31. dtype=None,
  32. **kwargs,
  33. ):
  34. dd = {'device': device, 'dtype': dtype}
  35. super().__init__()
  36. self.add_module('conv', nn.Conv2d(
  37. in_channels, out_channels, kernel_size, stride=stride, padding=padding, **dd, **kwargs))
  38. if with_bn:
  39. self.add_module('bn', nn.BatchNorm2d(out_channels, **dd))
  40. nn.init.constant_(self.bn.weight, 1)
  41. nn.init.constant_(self.bn.bias, 0)
  42. class Block(nn.Module):
  43. def __init__(
  44. self,
  45. dim: int,
  46. mlp_ratio: int = 3,
  47. drop_path: float = 0.,
  48. act_layer: Type[nn.Module] = nn.ReLU6,
  49. device=None,
  50. dtype=None,
  51. ):
  52. dd = {'device': device, 'dtype': dtype}
  53. super().__init__()
  54. self.dwconv = ConvBN(dim, dim, 7, 1, 3, groups=dim, with_bn=True, **dd)
  55. self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False, **dd)
  56. self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False, **dd)
  57. self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True, **dd)
  58. self.dwconv2 = ConvBN(dim, dim, 7, 1, 3, groups=dim, with_bn=False, **dd)
  59. self.act = act_layer()
  60. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  61. def forward(self, x: torch.Tensor) -> torch.Tensor:
  62. residual = x
  63. x = self.dwconv(x)
  64. x1, x2 = self.f1(x), self.f2(x)
  65. x = self.act(x1) * x2
  66. x = self.dwconv2(self.g(x))
  67. x = residual + self.drop_path(x)
  68. return x
  69. class StarNet(nn.Module):
  70. def __init__(
  71. self,
  72. base_dim: int = 32,
  73. depths: List[int] = [3, 3, 12, 5],
  74. mlp_ratio: int = 4,
  75. drop_rate: float = 0.,
  76. drop_path_rate: float = 0.,
  77. act_layer: Type[nn.Module] = nn.ReLU6,
  78. num_classes: int = 1000,
  79. in_chans: int = 3,
  80. global_pool: str = 'avg',
  81. output_stride: int = 32,
  82. device=None,
  83. dtype=None,
  84. **kwargs,
  85. ):
  86. dd = {'device': device, 'dtype': dtype}
  87. super().__init__()
  88. assert output_stride == 32
  89. self.num_classes = num_classes
  90. self.in_chans = in_chans
  91. self.drop_rate = drop_rate
  92. self.grad_checkpointing = False
  93. self.feature_info = []
  94. stem_chs = 32
  95. # stem layer
  96. self.stem = nn.Sequential(
  97. ConvBN(in_chans, stem_chs, kernel_size=3, stride=2, padding=1, **dd),
  98. act_layer(),
  99. )
  100. prev_chs = stem_chs
  101. # build stages
  102. dpr = calculate_drop_path_rates(drop_path_rate, sum(depths)) # stochastic depth
  103. stages = []
  104. cur = 0
  105. for i_layer in range(len(depths)):
  106. embed_dim = base_dim * 2 ** i_layer
  107. down_sampler = ConvBN(prev_chs, embed_dim, 3, stride=2, padding=1, **dd)
  108. blocks = [Block(embed_dim, mlp_ratio, dpr[cur + i], act_layer, **dd) for i in range(depths[i_layer])]
  109. cur += depths[i_layer]
  110. prev_chs = embed_dim
  111. stages.append(nn.Sequential(down_sampler, *blocks))
  112. self.feature_info.append(dict(
  113. num_chs=prev_chs, reduction=2**(i_layer+2), module=f'stages.{i_layer}'))
  114. self.stages = nn.Sequential(*stages)
  115. # head
  116. self.num_features = self.head_hidden_size = prev_chs
  117. self.norm = nn.BatchNorm2d(self.num_features, **dd)
  118. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  119. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  120. self.head = Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  121. self.apply(self._init_weights)
  122. def _init_weights(self, m):
  123. if isinstance(m, (nn.Linear, nn.Conv2d)):
  124. trunc_normal_(m.weight, std=.02)
  125. if isinstance(m, nn.Linear) and m.bias is not None:
  126. nn.init.constant_(m.bias, 0)
  127. elif isinstance(m, nn.BatchNorm2d):
  128. nn.init.constant_(m.bias, 0)
  129. nn.init.constant_(m.weight, 1.0)
  130. @torch.jit.ignore
  131. def no_weight_decay(self) -> Set:
  132. return set()
  133. @torch.jit.ignore
  134. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  135. matcher = dict(
  136. stem=r'^stem\.\d+',
  137. blocks=[
  138. (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
  139. (r'norm', (99999,))
  140. ]
  141. )
  142. return matcher
  143. @torch.jit.ignore
  144. def set_grad_checkpointing(self, enable: bool = True):
  145. self.grad_checkpointing = enable
  146. @torch.jit.ignore
  147. def get_classifier(self) -> nn.Module:
  148. return self.head
  149. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  150. self.num_classes = num_classes
  151. if global_pool is not None:
  152. # NOTE: cannot meaningfully change pooling of efficient head after creation
  153. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  154. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  155. self.head = Linear(
  156. self.head_hidden_size, num_classes,
  157. device=self.head.weight.device if isinstance(self.head, nn.Linear) else None,
  158. dtype=self.head.weight.dtype if isinstance(self.head, nn.Linear) else None,
  159. ) if num_classes > 0 else nn.Identity()
  160. def forward_intermediates(
  161. self,
  162. x: torch.Tensor,
  163. indices: Optional[Union[int, List[int]]] = None,
  164. norm: bool = False,
  165. stop_early: bool = False,
  166. output_fmt: str = 'NCHW',
  167. intermediates_only: bool = False,
  168. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  169. """ Forward features that returns intermediates.
  170. Args:
  171. x: Input image tensor
  172. indices: Take last n blocks if int, all if None, select matching indices if sequence
  173. norm: Apply norm layer to compatible intermediates
  174. stop_early: Stop iterating over blocks when last desired intermediate hit
  175. output_fmt: Shape of intermediate feature outputs
  176. intermediates_only: Only return intermediate features
  177. Returns:
  178. """
  179. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  180. intermediates = []
  181. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  182. last_idx = len(self.stages) - 1
  183. # forward pass
  184. x = self.stem(x)
  185. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  186. stages = self.stages
  187. else:
  188. stages = self.stages[:max_index + 1]
  189. for feat_idx, stage in enumerate(stages):
  190. if self.grad_checkpointing and not torch.jit.is_scripting():
  191. x = checkpoint_seq(stage, x)
  192. else:
  193. x = stage(x)
  194. if feat_idx in take_indices:
  195. if norm and feat_idx == last_idx:
  196. x_inter = self.norm(x) # applying final norm last intermediate
  197. else:
  198. x_inter = x
  199. intermediates.append(x_inter)
  200. if intermediates_only:
  201. return intermediates
  202. if feat_idx == last_idx:
  203. x = self.norm(x)
  204. return x, intermediates
  205. def prune_intermediate_layers(
  206. self,
  207. indices: Union[int, List[int]] = 1,
  208. prune_norm: bool = False,
  209. prune_head: bool = True,
  210. ):
  211. """ Prune layers not required for specified intermediates.
  212. """
  213. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  214. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  215. if prune_norm:
  216. self.norm = nn.Identity()
  217. if prune_head:
  218. self.reset_classifier(0, '')
  219. return take_indices
  220. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  221. x = self.stem(x)
  222. if self.grad_checkpointing and not torch.jit.is_scripting():
  223. x = checkpoint_seq(self.stages, x)
  224. else:
  225. x = self.stages(x)
  226. x = self.norm(x)
  227. return x
  228. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  229. x = self.global_pool(x)
  230. x = self.flatten(x)
  231. if self.drop_rate > 0.:
  232. x = F.dropout(x, p=self.drop_rate, training=self.training)
  233. return x if pre_logits else self.head(x)
  234. def forward(self, x: torch.Tensor) -> torch.Tensor:
  235. x = self.forward_features(x)
  236. x = self.forward_head(x)
  237. return x
  238. def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
  239. return state_dict.get('state_dict', state_dict)
  240. def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
  241. return {
  242. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  243. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  244. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  245. 'first_conv': 'stem.0.conv', 'classifier': 'head',
  246. 'paper_ids': 'arXiv:2403.19967',
  247. 'paper_name': 'Rewrite the Stars',
  248. 'origin_url': 'https://github.com/ma-xu/Rewrite-the-Stars', 'license': 'apache-2.0',
  249. **kwargs
  250. }
  251. default_cfgs = generate_default_cfgs({
  252. 'starnet_s1.in1k': _cfg(
  253. hf_hub_id='timm/',
  254. #url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s1.pth.tar',
  255. ),
  256. 'starnet_s2.in1k': _cfg(
  257. hf_hub_id='timm/',
  258. #url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s2.pth.tar',
  259. ),
  260. 'starnet_s3.in1k': _cfg(
  261. hf_hub_id='timm/',
  262. #url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s3.pth.tar',
  263. ),
  264. 'starnet_s4.in1k': _cfg(
  265. hf_hub_id='timm/',
  266. #url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s4.pth.tar',
  267. ),
  268. 'starnet_s050.untrained': _cfg(),
  269. 'starnet_s100.untrained': _cfg(),
  270. 'starnet_s150.untrained': _cfg(),
  271. })
  272. def _create_starnet(variant: str, pretrained: bool = False, **kwargs: Any) -> StarNet:
  273. model = build_model_with_cfg(
  274. StarNet, variant, pretrained,
  275. pretrained_filter_fn=checkpoint_filter_fn,
  276. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  277. **kwargs,
  278. )
  279. return model
  280. @register_model
  281. def starnet_s1(pretrained: bool = False, **kwargs: Any) -> StarNet:
  282. model_args = dict(base_dim=24, depths=[2, 2, 8, 3])
  283. return _create_starnet('starnet_s1', pretrained=pretrained, **dict(model_args, **kwargs))
  284. @register_model
  285. def starnet_s2(pretrained: bool = False, **kwargs: Any) -> StarNet:
  286. model_args = dict(base_dim=32, depths=[1, 2, 6, 2])
  287. return _create_starnet('starnet_s2', pretrained=pretrained, **dict(model_args, **kwargs))
  288. @register_model
  289. def starnet_s3(pretrained: bool = False, **kwargs: Any) -> StarNet:
  290. model_args = dict(base_dim=32, depths=[2, 2, 8, 4])
  291. return _create_starnet('starnet_s3', pretrained=pretrained, **dict(model_args, **kwargs))
  292. @register_model
  293. def starnet_s4(pretrained: bool = False, **kwargs: Any) -> StarNet:
  294. model_args = dict(base_dim=32, depths=[3, 3, 12, 5])
  295. return _create_starnet('starnet_s4', pretrained=pretrained, **dict(model_args, **kwargs))
  296. # very small networks #
  297. @register_model
  298. def starnet_s050(pretrained: bool = False, **kwargs: Any) -> StarNet:
  299. model_args = dict(base_dim=16, depths=[1, 1, 3, 1], mlp_ratio=3)
  300. return _create_starnet('starnet_s050', pretrained=pretrained, **dict(model_args, **kwargs))
  301. @register_model
  302. def starnet_s100(pretrained: bool = False, **kwargs: Any) -> StarNet:
  303. model_args = dict(base_dim=20, depths=[1, 2, 4, 1], mlp_ratio=4)
  304. return _create_starnet('starnet_s100', pretrained=pretrained, **dict(model_args, **kwargs))
  305. @register_model
  306. def starnet_s150(pretrained: bool = False, **kwargs: Any) -> StarNet:
  307. model_args = dict(base_dim=24, depths=[1, 2, 4, 2], mlp_ratio=3)
  308. return _create_starnet('starnet_s150', pretrained=pretrained, **dict(model_args, **kwargs))