cspnet.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208
  1. """PyTorch CspNet
  2. A PyTorch implementation of Cross Stage Partial Networks including:
  3. * CSPResNet50
  4. * CSPResNeXt50
  5. * CSPDarkNet53
  6. * and DarkNet53 for good measure
  7. Based on paper `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
  8. Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStagePartialNetworks
  9. Hacked together by / Copyright 2020 Ross Wightman
  10. """
  11. from dataclasses import dataclass, asdict, replace
  12. from functools import partial
  13. from typing import Any, Dict, List, Optional, Tuple, Type, Union
  14. import torch
  15. import torch.nn as nn
  16. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  17. from timm.layers import ClassifierHead, ConvNormAct, DropPath, calculate_drop_path_rates, get_attn, create_act_layer, make_divisible
  18. from ._builder import build_model_with_cfg
  19. from ._manipulate import named_apply, MATCH_PREV_GROUP
  20. from ._registry import register_model, generate_default_cfgs
  21. __all__ = ['CspNet'] # model_registry will add each entrypoint fn to this
  22. @dataclass
  23. class CspStemCfg:
  24. out_chs: Union[int, Tuple[int, ...]] = 32
  25. stride: Union[int, Tuple[int, ...]] = 2
  26. kernel_size: int = 3
  27. padding: Union[int, str] = ''
  28. pool: Optional[str] = ''
  29. def _pad_arg(x, n):
  30. # pads an argument tuple to specified n by padding with last value
  31. if not isinstance(x, (tuple, list)):
  32. x = (x,)
  33. curr_n = len(x)
  34. pad_n = n - curr_n
  35. if pad_n <= 0:
  36. return x[:n]
  37. return tuple(x + (x[-1],) * pad_n)
  38. @dataclass
  39. class CspStagesCfg:
  40. depth: Tuple[int, ...] = (3, 3, 5, 2) # block depth (number of block repeats in stages)
  41. out_chs: Tuple[int, ...] = (128, 256, 512, 1024) # number of output channels for blocks in stage
  42. stride: Union[int, Tuple[int, ...]] = 2 # stride of stage
  43. groups: Union[int, Tuple[int, ...]] = 1 # num kxk conv groups
  44. block_ratio: Union[float, Tuple[float, ...]] = 1.0
  45. bottle_ratio: Union[float, Tuple[float, ...]] = 1. # bottleneck-ratio of blocks in stage
  46. avg_down: Union[bool, Tuple[bool, ...]] = False
  47. attn_layer: Optional[Union[str, Tuple[str, ...]]] = None
  48. attn_kwargs: Optional[Union[Dict, Tuple[Dict]]] = None
  49. stage_type: Union[str, Tuple[str]] = 'csp' # stage type ('csp', 'cs2', 'dark')
  50. block_type: Union[str, Tuple[str]] = 'bottle' # blocks type for stages ('bottle', 'dark')
  51. # cross-stage only
  52. expand_ratio: Union[float, Tuple[float, ...]] = 1.0
  53. cross_linear: Union[bool, Tuple[bool, ...]] = False
  54. down_growth: Union[bool, Tuple[bool, ...]] = False
  55. def __post_init__(self):
  56. n = len(self.depth)
  57. assert len(self.out_chs) == n
  58. self.stride = _pad_arg(self.stride, n)
  59. self.groups = _pad_arg(self.groups, n)
  60. self.block_ratio = _pad_arg(self.block_ratio, n)
  61. self.bottle_ratio = _pad_arg(self.bottle_ratio, n)
  62. self.avg_down = _pad_arg(self.avg_down, n)
  63. self.attn_layer = _pad_arg(self.attn_layer, n)
  64. self.attn_kwargs = _pad_arg(self.attn_kwargs, n)
  65. self.stage_type = _pad_arg(self.stage_type, n)
  66. self.block_type = _pad_arg(self.block_type, n)
  67. self.expand_ratio = _pad_arg(self.expand_ratio, n)
  68. self.cross_linear = _pad_arg(self.cross_linear, n)
  69. self.down_growth = _pad_arg(self.down_growth, n)
  70. @dataclass
  71. class CspModelCfg:
  72. stem: CspStemCfg
  73. stages: CspStagesCfg
  74. zero_init_last: bool = True # zero init last weight (usually bn) in residual path
  75. act_layer: str = 'leaky_relu'
  76. norm_layer: str = 'batchnorm'
  77. aa_layer: Optional[str] = None # FIXME support string factory for this
  78. def _cs3_cfg(
  79. width_multiplier=1.0,
  80. depth_multiplier=1.0,
  81. avg_down=False,
  82. act_layer='silu',
  83. focus=False,
  84. attn_layer=None,
  85. attn_kwargs=None,
  86. bottle_ratio=1.0,
  87. block_type='dark',
  88. ):
  89. if focus:
  90. stem_cfg = CspStemCfg(
  91. out_chs=make_divisible(64 * width_multiplier),
  92. kernel_size=6, stride=2, padding=2, pool='')
  93. else:
  94. stem_cfg = CspStemCfg(
  95. out_chs=tuple([make_divisible(c * width_multiplier) for c in (32, 64)]),
  96. kernel_size=3, stride=2, pool='')
  97. return CspModelCfg(
  98. stem=stem_cfg,
  99. stages=CspStagesCfg(
  100. out_chs=tuple([make_divisible(c * width_multiplier) for c in (128, 256, 512, 1024)]),
  101. depth=tuple([int(d * depth_multiplier) for d in (3, 6, 9, 3)]),
  102. stride=2,
  103. bottle_ratio=bottle_ratio,
  104. block_ratio=0.5,
  105. avg_down=avg_down,
  106. attn_layer=attn_layer,
  107. attn_kwargs=attn_kwargs,
  108. stage_type='cs3',
  109. block_type=block_type,
  110. ),
  111. act_layer=act_layer,
  112. )
  113. class BottleneckBlock(nn.Module):
  114. """ ResNe(X)t Bottleneck Block
  115. """
  116. def __init__(
  117. self,
  118. in_chs: int,
  119. out_chs: int,
  120. dilation: int = 1,
  121. bottle_ratio: float = 0.25,
  122. groups: int = 1,
  123. act_layer: Type[nn.Module] = nn.ReLU,
  124. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  125. attn_last: bool = False,
  126. attn_layer: Optional[Type[nn.Module]] = None,
  127. drop_block: Optional[Type[nn.Module]] = None,
  128. drop_path: float = 0.,
  129. device=None,
  130. dtype=None,
  131. ):
  132. dd = {'device': device, 'dtype': dtype}
  133. super().__init__()
  134. mid_chs = int(round(out_chs * bottle_ratio))
  135. ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
  136. attn_last = attn_layer is not None and attn_last
  137. attn_first = attn_layer is not None and not attn_last
  138. self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs, **dd)
  139. self.conv2 = ConvNormAct(
  140. mid_chs,
  141. mid_chs,
  142. kernel_size=3,
  143. dilation=dilation,
  144. groups=groups,
  145. drop_layer=drop_block,
  146. **ckwargs,
  147. **dd,
  148. )
  149. self.attn2 = attn_layer(mid_chs, act_layer=act_layer, **dd) if attn_first else nn.Identity()
  150. self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs, **dd)
  151. self.attn3 = attn_layer(out_chs, act_layer=act_layer, **dd) if attn_last else nn.Identity()
  152. self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
  153. self.act3 = create_act_layer(act_layer)
  154. def zero_init_last(self):
  155. nn.init.zeros_(self.conv3.bn.weight)
  156. def forward(self, x):
  157. shortcut = x
  158. x = self.conv1(x)
  159. x = self.conv2(x)
  160. x = self.attn2(x)
  161. x = self.conv3(x)
  162. x = self.attn3(x)
  163. x = self.drop_path(x) + shortcut
  164. # FIXME partial shortcut needed if first block handled as per original, not used for my current impl
  165. #x[:, :shortcut.size(1)] += shortcut
  166. x = self.act3(x)
  167. return x
  168. class DarkBlock(nn.Module):
  169. """ DarkNet Block
  170. """
  171. def __init__(
  172. self,
  173. in_chs: int,
  174. out_chs: int,
  175. dilation: int = 1,
  176. bottle_ratio: float = 0.5,
  177. groups: int = 1,
  178. act_layer: Type[nn.Module] = nn.ReLU,
  179. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  180. attn_layer: Optional[Type[nn.Module]] = None,
  181. drop_block: Optional[Type[nn.Module]] = None,
  182. drop_path: float = 0.,
  183. device=None,
  184. dtype=None,
  185. ):
  186. dd = {'device': device, 'dtype': dtype}
  187. super().__init__()
  188. mid_chs = int(round(out_chs * bottle_ratio))
  189. ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
  190. self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs, **dd)
  191. self.attn = attn_layer(mid_chs, act_layer=act_layer, **dd) if attn_layer is not None else nn.Identity()
  192. self.conv2 = ConvNormAct(
  193. mid_chs,
  194. out_chs,
  195. kernel_size=3,
  196. dilation=dilation,
  197. groups=groups,
  198. drop_layer=drop_block,
  199. **ckwargs,
  200. **dd,
  201. )
  202. self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
  203. def zero_init_last(self):
  204. nn.init.zeros_(self.conv2.bn.weight)
  205. def forward(self, x):
  206. shortcut = x
  207. x = self.conv1(x)
  208. x = self.attn(x)
  209. x = self.conv2(x)
  210. x = self.drop_path(x) + shortcut
  211. return x
  212. class EdgeBlock(nn.Module):
  213. """ EdgeResidual / Fused-MBConv / MobileNetV1-like 3x3 + 1x1 block (w/ activated output)
  214. """
  215. def __init__(
  216. self,
  217. in_chs: int,
  218. out_chs: int,
  219. dilation: int = 1,
  220. bottle_ratio: float = 0.5,
  221. groups: int = 1,
  222. act_layer: Type[nn.Module] = nn.ReLU,
  223. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  224. attn_layer: Optional[Type[nn.Module]] = None,
  225. drop_block: Optional[Type[nn.Module]] = None,
  226. drop_path: float = 0.,
  227. device=None,
  228. dtype=None,
  229. ):
  230. dd = {'device': device, 'dtype': dtype}
  231. super().__init__()
  232. mid_chs = int(round(out_chs * bottle_ratio))
  233. ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
  234. self.conv1 = ConvNormAct(
  235. in_chs,
  236. mid_chs,
  237. kernel_size=3,
  238. dilation=dilation,
  239. groups=groups,
  240. drop_layer=drop_block,
  241. **ckwargs,
  242. **dd,
  243. )
  244. self.attn = attn_layer(mid_chs, act_layer=act_layer, **dd) if attn_layer is not None else nn.Identity()
  245. self.conv2 = ConvNormAct(mid_chs, out_chs, kernel_size=1, **ckwargs, **dd)
  246. self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
  247. def zero_init_last(self):
  248. nn.init.zeros_(self.conv2.bn.weight)
  249. def forward(self, x):
  250. shortcut = x
  251. x = self.conv1(x)
  252. x = self.attn(x)
  253. x = self.conv2(x)
  254. x = self.drop_path(x) + shortcut
  255. return x
  256. class CrossStage(nn.Module):
  257. """Cross Stage."""
  258. def __init__(
  259. self,
  260. in_chs: int,
  261. out_chs: int,
  262. stride: int,
  263. dilation: int,
  264. depth: int,
  265. block_ratio: float = 1.,
  266. bottle_ratio: float = 1.,
  267. expand_ratio: float = 1.,
  268. groups: int = 1,
  269. first_dilation: Optional[int] = None,
  270. avg_down: bool = False,
  271. down_growth: bool = False,
  272. cross_linear: bool = False,
  273. block_dpr: Optional[List[float]] = None,
  274. block_fn: Type[nn.Module] = BottleneckBlock,
  275. device=None,
  276. dtype=None,
  277. **block_kwargs,
  278. ):
  279. dd = {'device': device, 'dtype': dtype}
  280. super().__init__()
  281. first_dilation = first_dilation or dilation
  282. down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
  283. self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
  284. block_out_chs = int(round(out_chs * block_ratio))
  285. conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
  286. aa_layer = block_kwargs.pop('aa_layer', None)
  287. if stride != 1 or first_dilation != dilation:
  288. if avg_down:
  289. self.conv_down = nn.Sequential(
  290. nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
  291. ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs, **dd)
  292. )
  293. else:
  294. self.conv_down = ConvNormAct(
  295. in_chs,
  296. down_chs,
  297. kernel_size=3,
  298. stride=stride,
  299. dilation=first_dilation,
  300. groups=groups,
  301. aa_layer=aa_layer,
  302. **conv_kwargs,
  303. **dd,
  304. )
  305. prev_chs = down_chs
  306. else:
  307. self.conv_down = nn.Identity()
  308. prev_chs = in_chs
  309. # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also,
  310. # there is also special case for the first stage for some of the model that results in uneven split
  311. # across the two paths. I did it this way for simplicity for now.
  312. self.conv_exp = ConvNormAct(
  313. prev_chs,
  314. exp_chs,
  315. kernel_size=1,
  316. apply_act=not cross_linear,
  317. **conv_kwargs,
  318. **dd,
  319. )
  320. prev_chs = exp_chs // 2 # output of conv_exp is always split in two
  321. self.blocks = nn.Sequential()
  322. for i in range(depth):
  323. self.blocks.add_module(str(i), block_fn(
  324. in_chs=prev_chs,
  325. out_chs=block_out_chs,
  326. dilation=dilation,
  327. bottle_ratio=bottle_ratio,
  328. groups=groups,
  329. drop_path=block_dpr[i] if block_dpr is not None else 0.,
  330. **block_kwargs,
  331. **dd,
  332. ))
  333. prev_chs = block_out_chs
  334. # transition convs
  335. self.conv_transition_b = ConvNormAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs, **dd)
  336. self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs, **dd)
  337. def forward(self, x):
  338. x = self.conv_down(x)
  339. x = self.conv_exp(x)
  340. xs, xb = x.split(self.expand_chs // 2, dim=1)
  341. xb = self.blocks(xb)
  342. xb = self.conv_transition_b(xb).contiguous()
  343. out = self.conv_transition(torch.cat([xs, xb], dim=1))
  344. return out
  345. class CrossStage3(nn.Module):
  346. """Cross Stage 3.
  347. Similar to CrossStage, but with only one transition conv for the output.
  348. """
  349. def __init__(
  350. self,
  351. in_chs: int,
  352. out_chs: int,
  353. stride: int,
  354. dilation: int,
  355. depth: int,
  356. block_ratio: float = 1.,
  357. bottle_ratio: float = 1.,
  358. expand_ratio: float = 1.,
  359. groups: int = 1,
  360. first_dilation: Optional[int] = None,
  361. avg_down: bool = False,
  362. down_growth: bool = False,
  363. cross_linear: bool = False,
  364. block_dpr: Optional[List[float]] = None,
  365. block_fn: Type[nn.Module] = BottleneckBlock,
  366. device=None,
  367. dtype=None,
  368. **block_kwargs,
  369. ):
  370. dd = {'device': device, 'dtype': dtype}
  371. super().__init__()
  372. first_dilation = first_dilation or dilation
  373. down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
  374. self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
  375. block_out_chs = int(round(out_chs * block_ratio))
  376. conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
  377. aa_layer = block_kwargs.pop('aa_layer', None)
  378. if stride != 1 or first_dilation != dilation:
  379. if avg_down:
  380. self.conv_down = nn.Sequential(
  381. nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
  382. ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs, **dd)
  383. )
  384. else:
  385. self.conv_down = ConvNormAct(
  386. in_chs,
  387. down_chs,
  388. kernel_size=3,
  389. stride=stride,
  390. dilation=first_dilation,
  391. groups=groups,
  392. aa_layer=aa_layer,
  393. **conv_kwargs,
  394. **dd,
  395. )
  396. prev_chs = down_chs
  397. else:
  398. self.conv_down = None
  399. prev_chs = in_chs
  400. # expansion conv
  401. self.conv_exp = ConvNormAct(
  402. prev_chs,
  403. exp_chs,
  404. kernel_size=1,
  405. apply_act=not cross_linear,
  406. **conv_kwargs,
  407. **dd,
  408. )
  409. prev_chs = exp_chs // 2 # expanded output is split in 2 for blocks and cross stage
  410. self.blocks = nn.Sequential()
  411. for i in range(depth):
  412. self.blocks.add_module(str(i), block_fn(
  413. in_chs=prev_chs,
  414. out_chs=block_out_chs,
  415. dilation=dilation,
  416. bottle_ratio=bottle_ratio,
  417. groups=groups,
  418. drop_path=block_dpr[i] if block_dpr is not None else 0.,
  419. **block_kwargs,
  420. **dd,
  421. ))
  422. prev_chs = block_out_chs
  423. # transition convs
  424. self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs, **dd)
  425. def forward(self, x):
  426. x = self.conv_down(x)
  427. x = self.conv_exp(x)
  428. x1, x2 = x.split(self.expand_chs // 2, dim=1)
  429. x1 = self.blocks(x1)
  430. out = self.conv_transition(torch.cat([x1, x2], dim=1))
  431. return out
  432. class DarkStage(nn.Module):
  433. """DarkNet stage."""
  434. def __init__(
  435. self,
  436. in_chs: int,
  437. out_chs: int,
  438. stride: int,
  439. dilation: int,
  440. depth: int,
  441. block_ratio: float = 1.,
  442. bottle_ratio: float = 1.,
  443. groups: int = 1,
  444. first_dilation: Optional[int] = None,
  445. avg_down: bool = False,
  446. block_fn: Type[nn.Module] = BottleneckBlock,
  447. block_dpr: Optional[List[float]] = None,
  448. device=None,
  449. dtype=None,
  450. **block_kwargs,
  451. ):
  452. dd = {'device': device, 'dtype': dtype}
  453. super().__init__()
  454. first_dilation = first_dilation or dilation
  455. conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
  456. aa_layer = block_kwargs.pop('aa_layer', None)
  457. if avg_down:
  458. self.conv_down = nn.Sequential(
  459. nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
  460. ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs, **dd)
  461. )
  462. else:
  463. self.conv_down = ConvNormAct(
  464. in_chs,
  465. out_chs,
  466. kernel_size=3,
  467. stride=stride,
  468. dilation=first_dilation,
  469. groups=groups,
  470. aa_layer=aa_layer,
  471. **conv_kwargs,
  472. **dd,
  473. )
  474. prev_chs = out_chs
  475. block_out_chs = int(round(out_chs * block_ratio))
  476. self.blocks = nn.Sequential()
  477. for i in range(depth):
  478. self.blocks.add_module(str(i), block_fn(
  479. in_chs=prev_chs,
  480. out_chs=block_out_chs,
  481. dilation=dilation,
  482. bottle_ratio=bottle_ratio,
  483. groups=groups,
  484. drop_path=block_dpr[i] if block_dpr is not None else 0.,
  485. **block_kwargs,
  486. **dd,
  487. ))
  488. prev_chs = block_out_chs
  489. def forward(self, x):
  490. x = self.conv_down(x)
  491. x = self.blocks(x)
  492. return x
  493. def create_csp_stem(
  494. in_chans: int = 3,
  495. out_chs: int = 32,
  496. kernel_size: int = 3,
  497. stride: int = 2,
  498. pool: str = '',
  499. padding: str = '',
  500. act_layer: Type[nn.Module] = nn.ReLU,
  501. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  502. aa_layer: Optional[Type[nn.Module]] = None,
  503. device=None,
  504. dtype=None,
  505. ):
  506. dd = {'device': device, 'dtype': dtype}
  507. stem = nn.Sequential()
  508. feature_info = []
  509. if not isinstance(out_chs, (tuple, list)):
  510. out_chs = [out_chs]
  511. stem_depth = len(out_chs)
  512. assert stem_depth
  513. assert stride in (1, 2, 4)
  514. prev_feat = None
  515. prev_chs = in_chans
  516. last_idx = stem_depth - 1
  517. stem_stride = 1
  518. for i, chs in enumerate(out_chs):
  519. conv_name = f'conv{i + 1}'
  520. conv_stride = 2 if (i == 0 and stride > 1) or (i == last_idx and stride > 2 and not pool) else 1
  521. if conv_stride > 1 and prev_feat is not None:
  522. feature_info.append(prev_feat)
  523. stem.add_module(conv_name, ConvNormAct(
  524. prev_chs, chs, kernel_size,
  525. stride=conv_stride,
  526. padding=padding if i == 0 else '',
  527. act_layer=act_layer,
  528. norm_layer=norm_layer,
  529. **dd,
  530. ))
  531. stem_stride *= conv_stride
  532. prev_chs = chs
  533. prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', conv_name]))
  534. if pool:
  535. assert stride > 2
  536. if prev_feat is not None:
  537. feature_info.append(prev_feat)
  538. if aa_layer is not None:
  539. stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
  540. stem.add_module('aa', aa_layer(channels=prev_chs, stride=2, **dd))
  541. pool_name = 'aa'
  542. else:
  543. stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
  544. pool_name = 'pool'
  545. stem_stride *= 2
  546. prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', pool_name]))
  547. feature_info.append(prev_feat)
  548. return stem, feature_info
  549. def _get_stage_fn(stage_args):
  550. stage_type = stage_args.pop('stage_type')
  551. assert stage_type in ('dark', 'csp', 'cs3')
  552. if stage_type == 'dark':
  553. stage_args.pop('expand_ratio', None)
  554. stage_args.pop('cross_linear', None)
  555. stage_args.pop('down_growth', None)
  556. stage_fn = DarkStage
  557. elif stage_type == 'csp':
  558. stage_fn = CrossStage
  559. else:
  560. stage_fn = CrossStage3
  561. return stage_fn, stage_args
  562. def _get_block_fn(stage_args):
  563. block_type = stage_args.pop('block_type')
  564. assert block_type in ('dark', 'edge', 'bottle')
  565. if block_type == 'dark':
  566. return DarkBlock, stage_args
  567. elif block_type == 'edge':
  568. return EdgeBlock, stage_args
  569. else:
  570. return BottleneckBlock, stage_args
  571. def _get_attn_fn(stage_args):
  572. attn_layer = stage_args.pop('attn_layer')
  573. attn_kwargs = stage_args.pop('attn_kwargs', None) or {}
  574. if attn_layer is not None:
  575. attn_layer = get_attn(attn_layer)
  576. if attn_kwargs:
  577. attn_layer = partial(attn_layer, **attn_kwargs)
  578. return attn_layer, stage_args
  579. def create_csp_stages(
  580. cfg: CspModelCfg,
  581. drop_path_rate: float,
  582. output_stride: int,
  583. stem_feat: Dict[str, Any],
  584. device=None,
  585. dtype=None,
  586. ):
  587. dd = {'device': device, 'dtype': dtype}
  588. cfg_dict = asdict(cfg.stages)
  589. num_stages = len(cfg.stages.depth)
  590. cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \
  591. calculate_drop_path_rates(drop_path_rate, cfg.stages.depth, stagewise=True)
  592. stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())]
  593. block_kwargs = dict(
  594. act_layer=cfg.act_layer,
  595. norm_layer=cfg.norm_layer,
  596. )
  597. dilation = 1
  598. net_stride = stem_feat['reduction']
  599. prev_chs = stem_feat['num_chs']
  600. prev_feat = stem_feat
  601. feature_info = []
  602. stages = []
  603. for stage_idx, stage_args in enumerate(stage_args):
  604. stage_fn, stage_args = _get_stage_fn(stage_args)
  605. block_fn, stage_args = _get_block_fn(stage_args)
  606. attn_fn, stage_args = _get_attn_fn(stage_args)
  607. stride = stage_args.pop('stride')
  608. if stride != 1 and prev_feat:
  609. feature_info.append(prev_feat)
  610. if net_stride >= output_stride and stride > 1:
  611. dilation *= stride
  612. stride = 1
  613. net_stride *= stride
  614. first_dilation = 1 if dilation in (1, 2) else 2
  615. stages += [stage_fn(
  616. prev_chs,
  617. **stage_args,
  618. stride=stride,
  619. first_dilation=first_dilation,
  620. dilation=dilation,
  621. block_fn=block_fn,
  622. aa_layer=cfg.aa_layer,
  623. attn_layer=attn_fn, # will be passed through stage as block_kwargs
  624. **block_kwargs,
  625. **dd,
  626. )]
  627. prev_chs = stage_args['out_chs']
  628. prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
  629. feature_info.append(prev_feat)
  630. return nn.Sequential(*stages), feature_info
  631. class CspNet(nn.Module):
  632. """Cross Stage Partial base model.
  633. Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
  634. Ref Impl: https://github.com/WongKinYiu/CrossStagePartialNetworks
  635. NOTE: There are differences in the way I handle the 1x1 'expansion' conv in this impl vs the
  636. darknet impl. I did it this way for simplicity and less special cases.
  637. """
  638. def __init__(
  639. self,
  640. cfg: CspModelCfg,
  641. in_chans: int = 3,
  642. num_classes: int = 1000,
  643. output_stride: int = 32,
  644. global_pool: str = 'avg',
  645. drop_rate: float = 0.,
  646. drop_path_rate: float = 0.,
  647. zero_init_last: bool = True,
  648. device=None,
  649. dtype=None,
  650. **kwargs,
  651. ):
  652. """
  653. Args:
  654. cfg (CspModelCfg): Model architecture configuration
  655. in_chans (int): Number of input channels (default: 3)
  656. num_classes (int): Number of classifier classes (default: 1000)
  657. output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
  658. global_pool (str): Global pooling type (default: 'avg')
  659. drop_rate (float): Dropout rate (default: 0.)
  660. drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
  661. zero_init_last (bool): Zero-init last weight of residual path
  662. kwargs (dict): Extra kwargs overlayed onto cfg
  663. """
  664. super().__init__()
  665. dd = {'device': device, 'dtype': dtype}
  666. self.num_classes = num_classes
  667. self.in_chans = in_chans
  668. self.drop_rate = drop_rate
  669. assert output_stride in (8, 16, 32)
  670. cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
  671. layer_args = dict(
  672. act_layer=cfg.act_layer,
  673. norm_layer=cfg.norm_layer,
  674. aa_layer=cfg.aa_layer
  675. )
  676. self.feature_info = []
  677. # Construct the stem
  678. self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args, **dd)
  679. self.feature_info.extend(stem_feat_info[:-1])
  680. # Construct the stages
  681. self.stages, stage_feat_info = create_csp_stages(
  682. cfg,
  683. drop_path_rate=drop_path_rate,
  684. output_stride=output_stride,
  685. stem_feat=stem_feat_info[-1],
  686. **dd,
  687. )
  688. prev_chs = stage_feat_info[-1]['num_chs']
  689. self.feature_info.extend(stage_feat_info)
  690. # Construct the head
  691. self.num_features = self.head_hidden_size = prev_chs
  692. self.head = ClassifierHead(
  693. in_features=prev_chs,
  694. num_classes=num_classes,
  695. pool_type=global_pool,
  696. drop_rate=drop_rate,
  697. **dd,
  698. )
  699. named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
  700. @torch.jit.ignore
  701. def group_matcher(self, coarse=False):
  702. matcher = dict(
  703. stem=r'^stem',
  704. blocks=r'^stages\.(\d+)' if coarse else [
  705. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  706. (r'^stages\.(\d+)\..*transition', MATCH_PREV_GROUP), # map to last block in stage
  707. (r'^stages\.(\d+)', (0,)),
  708. ]
  709. )
  710. return matcher
  711. @torch.jit.ignore
  712. def set_grad_checkpointing(self, enable=True):
  713. assert not enable, 'gradient checkpointing not supported'
  714. @torch.jit.ignore
  715. def get_classifier(self) -> nn.Module:
  716. return self.head.fc
  717. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  718. self.num_classes = num_classes
  719. self.head.reset(num_classes, global_pool)
  720. def forward_features(self, x):
  721. x = self.stem(x)
  722. x = self.stages(x)
  723. return x
  724. def forward_head(self, x, pre_logits: bool = False):
  725. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  726. def forward(self, x):
  727. x = self.forward_features(x)
  728. x = self.forward_head(x)
  729. return x
  730. def _init_weights(module, name, zero_init_last=False):
  731. if isinstance(module, nn.Conv2d):
  732. nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
  733. if module.bias is not None:
  734. nn.init.zeros_(module.bias)
  735. elif isinstance(module, nn.Linear):
  736. nn.init.normal_(module.weight, mean=0.0, std=0.01)
  737. if module.bias is not None:
  738. nn.init.zeros_(module.bias)
  739. elif zero_init_last and hasattr(module, 'zero_init_last'):
  740. module.zero_init_last()
  741. model_cfgs = dict(
  742. cspresnet50=CspModelCfg(
  743. stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
  744. stages=CspStagesCfg(
  745. depth=(3, 3, 5, 2),
  746. out_chs=(128, 256, 512, 1024),
  747. stride=(1, 2),
  748. expand_ratio=2.,
  749. bottle_ratio=0.5,
  750. cross_linear=True,
  751. ),
  752. ),
  753. cspresnet50d=CspModelCfg(
  754. stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
  755. stages=CspStagesCfg(
  756. depth=(3, 3, 5, 2),
  757. out_chs=(128, 256, 512, 1024),
  758. stride=(1,) + (2,),
  759. expand_ratio=2.,
  760. bottle_ratio=0.5,
  761. block_ratio=1.,
  762. cross_linear=True,
  763. ),
  764. ),
  765. cspresnet50w=CspModelCfg(
  766. stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
  767. stages=CspStagesCfg(
  768. depth=(3, 3, 5, 2),
  769. out_chs=(256, 512, 1024, 2048),
  770. stride=(1,) + (2,),
  771. expand_ratio=1.,
  772. bottle_ratio=0.25,
  773. block_ratio=0.5,
  774. cross_linear=True,
  775. ),
  776. ),
  777. cspresnext50=CspModelCfg(
  778. stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
  779. stages=CspStagesCfg(
  780. depth=(3, 3, 5, 2),
  781. out_chs=(256, 512, 1024, 2048),
  782. stride=(1,) + (2,),
  783. groups=32,
  784. expand_ratio=1.,
  785. bottle_ratio=1.,
  786. block_ratio=0.5,
  787. cross_linear=True,
  788. ),
  789. ),
  790. cspdarknet53=CspModelCfg(
  791. stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
  792. stages=CspStagesCfg(
  793. depth=(1, 2, 8, 8, 4),
  794. out_chs=(64, 128, 256, 512, 1024),
  795. stride=2,
  796. expand_ratio=(2.,) + (1.,),
  797. bottle_ratio=(0.5,) + (1.,),
  798. block_ratio=(1.,) + (0.5,),
  799. down_growth=True,
  800. block_type='dark',
  801. ),
  802. ),
  803. darknet17=CspModelCfg(
  804. stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
  805. stages=CspStagesCfg(
  806. depth=(1,) * 5,
  807. out_chs=(64, 128, 256, 512, 1024),
  808. stride=(2,),
  809. bottle_ratio=(0.5,),
  810. block_ratio=(1.,),
  811. stage_type='dark',
  812. block_type='dark',
  813. ),
  814. ),
  815. darknet21=CspModelCfg(
  816. stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
  817. stages=CspStagesCfg(
  818. depth=(1, 1, 1, 2, 2),
  819. out_chs=(64, 128, 256, 512, 1024),
  820. stride=(2,),
  821. bottle_ratio=(0.5,),
  822. block_ratio=(1.,),
  823. stage_type='dark',
  824. block_type='dark',
  825. ),
  826. ),
  827. sedarknet21=CspModelCfg(
  828. stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
  829. stages=CspStagesCfg(
  830. depth=(1, 1, 1, 2, 2),
  831. out_chs=(64, 128, 256, 512, 1024),
  832. stride=2,
  833. bottle_ratio=0.5,
  834. block_ratio=1.,
  835. attn_layer='se',
  836. stage_type='dark',
  837. block_type='dark',
  838. ),
  839. ),
  840. darknet53=CspModelCfg(
  841. stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
  842. stages=CspStagesCfg(
  843. depth=(1, 2, 8, 8, 4),
  844. out_chs=(64, 128, 256, 512, 1024),
  845. stride=2,
  846. bottle_ratio=0.5,
  847. block_ratio=1.,
  848. stage_type='dark',
  849. block_type='dark',
  850. ),
  851. ),
  852. darknetaa53=CspModelCfg(
  853. stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
  854. stages=CspStagesCfg(
  855. depth=(1, 2, 8, 8, 4),
  856. out_chs=(64, 128, 256, 512, 1024),
  857. stride=2,
  858. bottle_ratio=0.5,
  859. block_ratio=1.,
  860. avg_down=True,
  861. stage_type='dark',
  862. block_type='dark',
  863. ),
  864. ),
  865. cs3darknet_s=_cs3_cfg(width_multiplier=0.5, depth_multiplier=0.5),
  866. cs3darknet_m=_cs3_cfg(width_multiplier=0.75, depth_multiplier=0.67),
  867. cs3darknet_l=_cs3_cfg(),
  868. cs3darknet_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33),
  869. cs3darknet_focus_s=_cs3_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True),
  870. cs3darknet_focus_m=_cs3_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True),
  871. cs3darknet_focus_l=_cs3_cfg(focus=True),
  872. cs3darknet_focus_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True),
  873. cs3sedarknet_l=_cs3_cfg(attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
  874. cs3sedarknet_x=_cs3_cfg(attn_layer='se', width_multiplier=1.25, depth_multiplier=1.33),
  875. cs3sedarknet_xdw=CspModelCfg(
  876. stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''),
  877. stages=CspStagesCfg(
  878. depth=(3, 6, 12, 4),
  879. out_chs=(256, 512, 1024, 2048),
  880. stride=2,
  881. groups=(1, 1, 256, 512),
  882. bottle_ratio=0.5,
  883. block_ratio=0.5,
  884. attn_layer='se',
  885. ),
  886. act_layer='silu',
  887. ),
  888. cs3edgenet_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33, bottle_ratio=1.5, block_type='edge'),
  889. cs3se_edgenet_x=_cs3_cfg(
  890. width_multiplier=1.25, depth_multiplier=1.33, bottle_ratio=1.5, block_type='edge',
  891. attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
  892. )
  893. def _create_cspnet(variant, pretrained=False, **kwargs):
  894. if variant.startswith('darknet') or variant.startswith('cspdarknet'):
  895. # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5]
  896. default_out_indices = (0, 1, 2, 3, 4, 5)
  897. else:
  898. default_out_indices = (0, 1, 2, 3, 4)
  899. out_indices = kwargs.pop('out_indices', default_out_indices)
  900. return build_model_with_cfg(
  901. CspNet, variant, pretrained,
  902. model_cfg=model_cfgs[variant],
  903. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  904. **kwargs)
  905. def _cfg(url='', **kwargs):
  906. return {
  907. 'url': url,
  908. 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
  909. 'crop_pct': 0.887, 'interpolation': 'bilinear',
  910. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  911. 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc', 'license': 'apache-2.0',
  912. **kwargs
  913. }
  914. default_cfgs = generate_default_cfgs({
  915. 'cspresnet50.ra_in1k': _cfg(
  916. hf_hub_id='timm/',
  917. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnet50_ra-d3e8d487.pth'),
  918. 'cspresnet50d.untrained': _cfg(),
  919. 'cspresnet50w.untrained': _cfg(),
  920. 'cspresnext50.ra_in1k': _cfg(
  921. hf_hub_id='timm/',
  922. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth',
  923. ),
  924. 'cspdarknet53.ra_in1k': _cfg(
  925. hf_hub_id='timm/',
  926. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'),
  927. 'darknet17.untrained': _cfg(),
  928. 'darknet21.untrained': _cfg(),
  929. 'sedarknet21.untrained': _cfg(),
  930. 'darknet53.c2ns_in1k': _cfg(
  931. hf_hub_id='timm/',
  932. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth',
  933. interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
  934. 'darknetaa53.c2ns_in1k': _cfg(
  935. hf_hub_id='timm/',
  936. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknetaa53_c2ns-5c28ec8a.pth',
  937. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  938. 'cs3darknet_s.untrained': _cfg(interpolation='bicubic'),
  939. 'cs3darknet_m.c2ns_in1k': _cfg(
  940. hf_hub_id='timm/',
  941. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_m_c2ns-43f06604.pth',
  942. interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95,
  943. ),
  944. 'cs3darknet_l.c2ns_in1k': _cfg(
  945. hf_hub_id='timm/',
  946. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_l_c2ns-16220c5d.pth',
  947. interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
  948. 'cs3darknet_x.c2ns_in1k': _cfg(
  949. hf_hub_id='timm/',
  950. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_x_c2ns-4e4490aa.pth',
  951. interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  952. 'cs3darknet_focus_s.ra4_e3600_r256_in1k': _cfg(
  953. hf_hub_id='timm/',
  954. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  955. interpolation='bicubic', test_input_size=(3, 320, 320), test_crop_pct=1.0),
  956. 'cs3darknet_focus_m.c2ns_in1k': _cfg(
  957. hf_hub_id='timm/',
  958. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_m_c2ns-e23bed41.pth',
  959. interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
  960. 'cs3darknet_focus_l.c2ns_in1k': _cfg(
  961. hf_hub_id='timm/',
  962. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_l_c2ns-65ef8888.pth',
  963. interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
  964. 'cs3darknet_focus_x.untrained': _cfg(interpolation='bicubic'),
  965. 'cs3sedarknet_l.c2ns_in1k': _cfg(
  966. hf_hub_id='timm/',
  967. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_l_c2ns-e8d1dc13.pth',
  968. interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
  969. 'cs3sedarknet_x.c2ns_in1k': _cfg(
  970. hf_hub_id='timm/',
  971. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_x_c2ns-b4d0abc0.pth',
  972. interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
  973. 'cs3sedarknet_xdw.untrained': _cfg(interpolation='bicubic'),
  974. 'cs3edgenet_x.c2_in1k': _cfg(
  975. hf_hub_id='timm/',
  976. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3edgenet_x_c2-2e1610a9.pth',
  977. interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
  978. 'cs3se_edgenet_x.c2ns_in1k': _cfg(
  979. hf_hub_id='timm/',
  980. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3se_edgenet_x_c2ns-76f8e3ac.pth',
  981. interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
  982. })
  983. @register_model
  984. def cspresnet50(pretrained=False, **kwargs) -> CspNet:
  985. return _create_cspnet('cspresnet50', pretrained=pretrained, **kwargs)
  986. @register_model
  987. def cspresnet50d(pretrained=False, **kwargs) -> CspNet:
  988. return _create_cspnet('cspresnet50d', pretrained=pretrained, **kwargs)
  989. @register_model
  990. def cspresnet50w(pretrained=False, **kwargs) -> CspNet:
  991. return _create_cspnet('cspresnet50w', pretrained=pretrained, **kwargs)
  992. @register_model
  993. def cspresnext50(pretrained=False, **kwargs) -> CspNet:
  994. return _create_cspnet('cspresnext50', pretrained=pretrained, **kwargs)
  995. @register_model
  996. def cspdarknet53(pretrained=False, **kwargs) -> CspNet:
  997. return _create_cspnet('cspdarknet53', pretrained=pretrained, **kwargs)
  998. @register_model
  999. def darknet17(pretrained=False, **kwargs) -> CspNet:
  1000. return _create_cspnet('darknet17', pretrained=pretrained, **kwargs)
  1001. @register_model
  1002. def darknet21(pretrained=False, **kwargs) -> CspNet:
  1003. return _create_cspnet('darknet21', pretrained=pretrained, **kwargs)
  1004. @register_model
  1005. def sedarknet21(pretrained=False, **kwargs) -> CspNet:
  1006. return _create_cspnet('sedarknet21', pretrained=pretrained, **kwargs)
  1007. @register_model
  1008. def darknet53(pretrained=False, **kwargs) -> CspNet:
  1009. return _create_cspnet('darknet53', pretrained=pretrained, **kwargs)
  1010. @register_model
  1011. def darknetaa53(pretrained=False, **kwargs) -> CspNet:
  1012. return _create_cspnet('darknetaa53', pretrained=pretrained, **kwargs)
  1013. @register_model
  1014. def cs3darknet_s(pretrained=False, **kwargs) -> CspNet:
  1015. return _create_cspnet('cs3darknet_s', pretrained=pretrained, **kwargs)
  1016. @register_model
  1017. def cs3darknet_m(pretrained=False, **kwargs) -> CspNet:
  1018. return _create_cspnet('cs3darknet_m', pretrained=pretrained, **kwargs)
  1019. @register_model
  1020. def cs3darknet_l(pretrained=False, **kwargs) -> CspNet:
  1021. return _create_cspnet('cs3darknet_l', pretrained=pretrained, **kwargs)
  1022. @register_model
  1023. def cs3darknet_x(pretrained=False, **kwargs) -> CspNet:
  1024. return _create_cspnet('cs3darknet_x', pretrained=pretrained, **kwargs)
  1025. @register_model
  1026. def cs3darknet_focus_s(pretrained=False, **kwargs) -> CspNet:
  1027. return _create_cspnet('cs3darknet_focus_s', pretrained=pretrained, **kwargs)
  1028. @register_model
  1029. def cs3darknet_focus_m(pretrained=False, **kwargs) -> CspNet:
  1030. return _create_cspnet('cs3darknet_focus_m', pretrained=pretrained, **kwargs)
  1031. @register_model
  1032. def cs3darknet_focus_l(pretrained=False, **kwargs) -> CspNet:
  1033. return _create_cspnet('cs3darknet_focus_l', pretrained=pretrained, **kwargs)
  1034. @register_model
  1035. def cs3darknet_focus_x(pretrained=False, **kwargs) -> CspNet:
  1036. return _create_cspnet('cs3darknet_focus_x', pretrained=pretrained, **kwargs)
  1037. @register_model
  1038. def cs3sedarknet_l(pretrained=False, **kwargs) -> CspNet:
  1039. return _create_cspnet('cs3sedarknet_l', pretrained=pretrained, **kwargs)
  1040. @register_model
  1041. def cs3sedarknet_x(pretrained=False, **kwargs) -> CspNet:
  1042. return _create_cspnet('cs3sedarknet_x', pretrained=pretrained, **kwargs)
  1043. @register_model
  1044. def cs3sedarknet_xdw(pretrained=False, **kwargs) -> CspNet:
  1045. return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs)
  1046. @register_model
  1047. def cs3edgenet_x(pretrained=False, **kwargs) -> CspNet:
  1048. return _create_cspnet('cs3edgenet_x', pretrained=pretrained, **kwargs)
  1049. @register_model
  1050. def cs3se_edgenet_x(pretrained=False, **kwargs) -> CspNet:
  1051. return _create_cspnet('cs3se_edgenet_x', pretrained=pretrained, **kwargs)