byoanet.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. """ Bring-Your-Own-Attention Network
  2. A flexible network w/ dataclass based config for stacking NN blocks including
  3. self-attention (or similar) layers.
  4. Currently used to implement experimental variants of:
  5. * Bottleneck Transformers
  6. * Lambda ResNets
  7. * HaloNets
  8. Consider all of the models definitions here as experimental WIP and likely to change.
  9. Hacked together by / copyright Ross Wightman, 2021.
  10. """
  11. from typing import Any, Dict, Optional
  12. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  13. from ._builder import build_model_with_cfg
  14. from ._registry import register_model, generate_default_cfgs
  15. from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks
  16. __all__ = []
  17. model_cfgs = dict(
  18. botnet26t=ByoModelCfg(
  19. blocks=(
  20. ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
  21. ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
  22. interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
  23. ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
  24. ),
  25. stem_chs=64,
  26. stem_type='tiered',
  27. stem_pool='maxpool',
  28. fixed_input_size=True,
  29. self_attn_layer='bottleneck',
  30. self_attn_kwargs=dict()
  31. ),
  32. sebotnet33ts=ByoModelCfg(
  33. blocks=(
  34. ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
  35. interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25),
  36. interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25),
  37. ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333),
  38. ),
  39. stem_chs=64,
  40. stem_type='tiered',
  41. stem_pool='',
  42. act_layer='silu',
  43. num_features=1280,
  44. attn_layer='se',
  45. self_attn_layer='bottleneck',
  46. self_attn_kwargs=dict()
  47. ),
  48. botnet50ts=ByoModelCfg(
  49. blocks=(
  50. ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
  51. interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25),
  52. interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
  53. interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
  54. ),
  55. stem_chs=64,
  56. stem_type='tiered',
  57. stem_pool='maxpool',
  58. act_layer='silu',
  59. fixed_input_size=True,
  60. self_attn_layer='bottleneck',
  61. self_attn_kwargs=dict()
  62. ),
  63. eca_botnext26ts=ByoModelCfg(
  64. blocks=(
  65. ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
  66. ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
  67. interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
  68. ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
  69. ),
  70. stem_chs=64,
  71. stem_type='tiered',
  72. stem_pool='maxpool',
  73. fixed_input_size=True,
  74. act_layer='silu',
  75. attn_layer='eca',
  76. self_attn_layer='bottleneck',
  77. self_attn_kwargs=dict(dim_head=16)
  78. ),
  79. halonet_h1=ByoModelCfg(
  80. blocks=(
  81. ByoBlockCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
  82. ByoBlockCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
  83. ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
  84. ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
  85. ),
  86. stem_chs=64,
  87. stem_type='7x7',
  88. stem_pool='maxpool',
  89. self_attn_layer='halo',
  90. self_attn_kwargs=dict(block_size=8, halo_size=3),
  91. ),
  92. halonet26t=ByoModelCfg(
  93. blocks=(
  94. ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
  95. ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
  96. interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
  97. ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
  98. ),
  99. stem_chs=64,
  100. stem_type='tiered',
  101. stem_pool='maxpool',
  102. self_attn_layer='halo',
  103. self_attn_kwargs=dict(block_size=8, halo_size=2)
  104. ),
  105. sehalonet33ts=ByoModelCfg(
  106. blocks=(
  107. ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
  108. interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25),
  109. interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25),
  110. ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333),
  111. ),
  112. stem_chs=64,
  113. stem_type='tiered',
  114. stem_pool='',
  115. act_layer='silu',
  116. num_features=1280,
  117. attn_layer='se',
  118. self_attn_layer='halo',
  119. self_attn_kwargs=dict(block_size=8, halo_size=3)
  120. ),
  121. halonet50ts=ByoModelCfg(
  122. blocks=(
  123. ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
  124. interleave_blocks(
  125. types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25,
  126. self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3, num_heads=4)),
  127. interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
  128. interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
  129. ),
  130. stem_chs=64,
  131. stem_type='tiered',
  132. stem_pool='maxpool',
  133. act_layer='silu',
  134. self_attn_layer='halo',
  135. self_attn_kwargs=dict(block_size=8, halo_size=3)
  136. ),
  137. eca_halonext26ts=ByoModelCfg(
  138. blocks=(
  139. ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
  140. ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
  141. interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
  142. ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
  143. ),
  144. stem_chs=64,
  145. stem_type='tiered',
  146. stem_pool='maxpool',
  147. act_layer='silu',
  148. attn_layer='eca',
  149. self_attn_layer='halo',
  150. self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16)
  151. ),
  152. lambda_resnet26t=ByoModelCfg(
  153. blocks=(
  154. ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
  155. ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
  156. interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
  157. ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
  158. ),
  159. stem_chs=64,
  160. stem_type='tiered',
  161. stem_pool='maxpool',
  162. self_attn_layer='lambda',
  163. self_attn_kwargs=dict(r=9)
  164. ),
  165. lambda_resnet50ts=ByoModelCfg(
  166. blocks=(
  167. ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
  168. interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25),
  169. interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
  170. interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
  171. ),
  172. stem_chs=64,
  173. stem_type='tiered',
  174. stem_pool='maxpool',
  175. act_layer='silu',
  176. self_attn_layer='lambda',
  177. self_attn_kwargs=dict(r=9)
  178. ),
  179. lambda_resnet26rpt_256=ByoModelCfg(
  180. blocks=(
  181. ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
  182. ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
  183. interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
  184. ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
  185. ),
  186. stem_chs=64,
  187. stem_type='tiered',
  188. stem_pool='maxpool',
  189. self_attn_layer='lambda',
  190. self_attn_kwargs=dict(r=None)
  191. ),
  192. # experimental
  193. haloregnetz_b=ByoModelCfg(
  194. blocks=(
  195. ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
  196. ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
  197. interleave_blocks(types=('bottle', 'self_attn'), every=3, d=12, c=192, s=2, gs=16, br=3),
  198. ByoBlockCfg('self_attn', d=2, c=288, s=2, gs=16, br=3),
  199. ),
  200. stem_chs=32,
  201. stem_pool='',
  202. downsample='',
  203. num_features=1536,
  204. act_layer='silu',
  205. attn_layer='se',
  206. attn_kwargs=dict(rd_ratio=0.25),
  207. block_kwargs=dict(bottle_in=True, linear_out=True),
  208. self_attn_layer='halo',
  209. self_attn_kwargs=dict(block_size=7, halo_size=2, qk_ratio=0.33)
  210. ),
  211. # experimental
  212. lamhalobotnet50ts=ByoModelCfg(
  213. blocks=(
  214. ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
  215. interleave_blocks(
  216. types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25,
  217. self_attn_layer='lambda', self_attn_kwargs=dict(r=13)),
  218. interleave_blocks(
  219. types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25,
  220. self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
  221. interleave_blocks(
  222. types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25,
  223. self_attn_layer='bottleneck', self_attn_kwargs=dict()),
  224. ),
  225. stem_chs=64,
  226. stem_type='tiered',
  227. stem_pool='',
  228. act_layer='silu',
  229. ),
  230. halo2botnet50ts=ByoModelCfg(
  231. blocks=(
  232. ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
  233. interleave_blocks(
  234. types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25,
  235. self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
  236. interleave_blocks(
  237. types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25,
  238. self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
  239. interleave_blocks(
  240. types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25,
  241. self_attn_layer='bottleneck', self_attn_kwargs=dict()),
  242. ),
  243. stem_chs=64,
  244. stem_type='tiered',
  245. stem_pool='',
  246. act_layer='silu',
  247. ),
  248. )
  249. def _create_byoanet(variant: str, cfg_variant: Optional[str] = None, pretrained: bool = False, **kwargs) -> ByobNet:
  250. """Create a Bring-Your-Own-Attention network model.
  251. Args:
  252. variant: Model variant name.
  253. cfg_variant: Config variant name if different from model variant.
  254. pretrained: Load pretrained weights.
  255. **kwargs: Additional model arguments.
  256. Returns:
  257. Instantiated ByobNet model.
  258. """
  259. return build_model_with_cfg(
  260. ByobNet, variant, pretrained,
  261. model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
  262. feature_cfg=dict(flatten_sequential=True),
  263. **kwargs,
  264. )
  265. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  266. """Generate default model configuration.
  267. Args:
  268. url: URL for pretrained weights.
  269. **kwargs: Override default configuration values.
  270. Returns:
  271. Model configuration dictionary.
  272. """
  273. return {
  274. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  275. 'crop_pct': 0.95, 'interpolation': 'bicubic',
  276. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  277. 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
  278. 'fixed_input_size': False, 'min_input_size': (3, 224, 224), 'license': 'apache-2.0',
  279. **kwargs
  280. }
  281. default_cfgs = generate_default_cfgs({
  282. # GPU-Efficient (ResNet) weights
  283. 'botnet26t_256.c1_in1k': _cfg(
  284. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth',
  285. hf_hub_id='timm/',
  286. fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
  287. 'sebotnet33ts_256.a1h_in1k': _cfg(
  288. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sebotnet33ts_a1h2_256-957e3c3e.pth',
  289. hf_hub_id='timm/',
  290. fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
  291. 'botnet50ts_256.untrained': _cfg(
  292. fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
  293. 'eca_botnext26ts_256.c1_in1k': _cfg(
  294. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_botnext26ts_c_256-95a898f6.pth',
  295. hf_hub_id='timm/',
  296. fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
  297. 'halonet_h1.untrained': _cfg(input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
  298. 'halonet26t.a1h_in1k': _cfg(
  299. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_a1h_256-3083328c.pth',
  300. hf_hub_id='timm/',
  301. input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
  302. 'sehalonet33ts.ra2_in1k': _cfg(
  303. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
  304. hf_hub_id='timm/',
  305. input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
  306. 'halonet50ts.a1h_in1k': _cfg(
  307. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h2_256-f3a3daee.pth',
  308. hf_hub_id='timm/',
  309. input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
  310. 'eca_halonext26ts.c1_in1k': _cfg(
  311. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth',
  312. hf_hub_id='timm/',
  313. input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
  314. 'lambda_resnet26t.c1_in1k': _cfg(
  315. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_c_256-e5a5c857.pth',
  316. hf_hub_id='timm/',
  317. min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
  318. 'lambda_resnet50ts.a1h_in1k': _cfg(
  319. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet50ts_a1h_256-b87370f7.pth',
  320. hf_hub_id='timm/',
  321. min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
  322. 'lambda_resnet26rpt_256.c1_in1k': _cfg(
  323. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_c_256-ab00292d.pth',
  324. hf_hub_id='timm/',
  325. fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
  326. 'haloregnetz_b.ra3_in1k': _cfg(
  327. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/haloregnetz_c_raa_256-c8ad7616.pth',
  328. hf_hub_id='timm/',
  329. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  330. first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94),
  331. 'lamhalobotnet50ts_256.a1h_in1k': _cfg(
  332. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lamhalobotnet50ts_a1h2_256-fe3d9445.pth',
  333. hf_hub_id='timm/',
  334. fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
  335. 'halo2botnet50ts_256.a1h_in1k': _cfg(
  336. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halo2botnet50ts_a1h2_256-fd9c11a3.pth',
  337. hf_hub_id='timm/',
  338. fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
  339. })
  340. @register_model
  341. def botnet26t_256(pretrained: bool = False, **kwargs) -> ByobNet:
  342. """ Bottleneck Transformer w/ ResNet26-T backbone.
  343. """
  344. kwargs.setdefault('img_size', 256)
  345. return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
  346. @register_model
  347. def sebotnet33ts_256(pretrained: bool = False, **kwargs) -> ByobNet:
  348. """ Bottleneck Transformer w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU,
  349. """
  350. return _create_byoanet('sebotnet33ts_256', 'sebotnet33ts', pretrained=pretrained, **kwargs)
  351. @register_model
  352. def botnet50ts_256(pretrained: bool = False, **kwargs) -> ByobNet:
  353. """ Bottleneck Transformer w/ ResNet50-T backbone, silu act.
  354. """
  355. kwargs.setdefault('img_size', 256)
  356. return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs)
  357. @register_model
  358. def eca_botnext26ts_256(pretrained: bool = False, **kwargs) -> ByobNet:
  359. """ Bottleneck Transformer w/ ResNet26-T backbone, silu act.
  360. """
  361. kwargs.setdefault('img_size', 256)
  362. return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs)
  363. @register_model
  364. def halonet_h1(pretrained: bool = False, **kwargs) -> ByobNet:
  365. """ HaloNet-H1. Halo attention in all stages as per the paper.
  366. NOTE: This runs very slowly!
  367. """
  368. return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs)
  369. @register_model
  370. def halonet26t(pretrained: bool = False, **kwargs) -> ByobNet:
  371. """ HaloNet w/ a ResNet26-t backbone. Halo attention in final two stages
  372. """
  373. return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs)
  374. @register_model
  375. def sehalonet33ts(pretrained: bool = False, **kwargs) -> ByobNet:
  376. """ HaloNet w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU, 1-2 Halo in stage 2,3,4.
  377. """
  378. return _create_byoanet('sehalonet33ts', pretrained=pretrained, **kwargs)
  379. @register_model
  380. def halonet50ts(pretrained: bool = False, **kwargs) -> ByobNet:
  381. """ HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages
  382. """
  383. return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs)
  384. @register_model
  385. def eca_halonext26ts(pretrained: bool = False, **kwargs) -> ByobNet:
  386. """ HaloNet w/ a ResNet26-t backbone, silu act. Halo attention in final two stages
  387. """
  388. return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs)
  389. @register_model
  390. def lambda_resnet26t(pretrained: bool = False, **kwargs) -> ByobNet:
  391. """ Lambda-ResNet-26-T. Lambda layers w/ conv pos in last two stages.
  392. """
  393. return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)
  394. @register_model
  395. def lambda_resnet50ts(pretrained: bool = False, **kwargs) -> ByobNet:
  396. """ Lambda-ResNet-50-TS. SiLU act. Lambda layers w/ conv pos in last two stages.
  397. """
  398. return _create_byoanet('lambda_resnet50ts', pretrained=pretrained, **kwargs)
  399. @register_model
  400. def lambda_resnet26rpt_256(pretrained: bool = False, **kwargs) -> ByobNet:
  401. """ Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages.
  402. """
  403. kwargs.setdefault('img_size', 256)
  404. return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs)
  405. @register_model
  406. def haloregnetz_b(pretrained: bool = False, **kwargs) -> ByobNet:
  407. """ Halo + RegNetZ
  408. """
  409. return _create_byoanet('haloregnetz_b', pretrained=pretrained, **kwargs)
  410. @register_model
  411. def lamhalobotnet50ts_256(pretrained: bool = False, **kwargs) -> ByobNet:
  412. """ Combo Attention (Lambda + Halo + Bot) Network
  413. """
  414. return _create_byoanet('lamhalobotnet50ts_256', 'lamhalobotnet50ts', pretrained=pretrained, **kwargs)
  415. @register_model
  416. def halo2botnet50ts_256(pretrained: bool = False, **kwargs) -> ByobNet:
  417. """ Combo Attention (Halo + Halo + Bot) Network
  418. """
  419. return _create_byoanet('halo2botnet50ts_256', 'halo2botnet50ts', pretrained=pretrained, **kwargs)