resnet.py 102 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266
  1. """PyTorch ResNet
  2. This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
  3. additional dropout and dynamic global avg/max pool.
  4. ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
  5. Copyright 2019, Ross Wightman
  6. """
  7. import math
  8. from functools import partial
  9. from typing import Any, Dict, List, Optional, Tuple, Type, Union
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  14. from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, LayerType, create_attn, \
  15. get_attn, get_act_layer, get_norm_layer, create_classifier, create_aa, to_ntuple
  16. from ._builder import build_model_with_cfg
  17. from ._features import feature_take_indices
  18. from ._manipulate import checkpoint_seq
  19. from ._registry import register_model, generate_default_cfgs, register_model_deprecations
  20. __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
  21. def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
  22. padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
  23. return padding
  24. class BasicBlock(nn.Module):
  25. """Basic residual block for ResNet.
  26. This is the standard residual block used in ResNet-18 and ResNet-34.
  27. """
  28. expansion = 1
  29. def __init__(
  30. self,
  31. inplanes: int,
  32. planes: int,
  33. stride: int = 1,
  34. downsample: Optional[nn.Module] = None,
  35. cardinality: int = 1,
  36. base_width: int = 64,
  37. reduce_first: int = 1,
  38. dilation: int = 1,
  39. first_dilation: Optional[int] = None,
  40. act_layer: Type[nn.Module] = nn.ReLU,
  41. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  42. attn_layer: Optional[Type[nn.Module]] = None,
  43. aa_layer: Optional[Type[nn.Module]] = None,
  44. drop_block: Optional[Type[nn.Module]] = None,
  45. drop_path: Optional[nn.Module] = None,
  46. device=None,
  47. dtype=None,
  48. ) -> None:
  49. """
  50. Args:
  51. inplanes: Input channel dimensionality.
  52. planes: Used to determine output channel dimensionalities.
  53. stride: Stride used in convolution layers.
  54. downsample: Optional downsample layer for residual path.
  55. cardinality: Number of convolution groups.
  56. base_width: Base width used to determine output channel dimensionality.
  57. reduce_first: Reduction factor for first convolution output width of residual blocks.
  58. dilation: Dilation rate for convolution layers.
  59. first_dilation: Dilation rate for first convolution layer.
  60. act_layer: Activation layer class.
  61. norm_layer: Normalization layer class.
  62. attn_layer: Attention layer class.
  63. aa_layer: Anti-aliasing layer class.
  64. drop_block: DropBlock layer class.
  65. drop_path: Optional DropPath layer instance.
  66. """
  67. dd = {'device': device, 'dtype': dtype}
  68. super().__init__()
  69. assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
  70. assert base_width == 64, 'BasicBlock does not support changing base width'
  71. first_planes = planes // reduce_first
  72. outplanes = planes * self.expansion
  73. first_dilation = first_dilation or dilation
  74. use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
  75. self.conv1 = nn.Conv2d(
  76. inplanes,
  77. first_planes,
  78. kernel_size=3,
  79. stride=1 if use_aa else stride,
  80. padding=first_dilation,
  81. dilation=first_dilation,
  82. bias=False,
  83. **dd,
  84. )
  85. self.bn1 = norm_layer(first_planes, **dd)
  86. self.drop_block = drop_block() if drop_block is not None else nn.Identity()
  87. self.act1 = act_layer(inplace=True)
  88. self.aa = create_aa(aa_layer, channels=first_planes, stride=stride, enable=use_aa, **dd)
  89. self.conv2 = nn.Conv2d(
  90. first_planes,
  91. outplanes,
  92. kernel_size=3,
  93. padding=dilation,
  94. dilation=dilation,
  95. bias=False,
  96. **dd,
  97. )
  98. self.bn2 = norm_layer(outplanes, **dd)
  99. self.se = create_attn(attn_layer, outplanes, **dd)
  100. self.act2 = act_layer(inplace=True)
  101. self.downsample = downsample
  102. self.stride = stride
  103. self.dilation = dilation
  104. self.drop_path = drop_path
  105. def zero_init_last(self) -> None:
  106. """Initialize the last batch norm layer weights to zero for better convergence."""
  107. if getattr(self.bn2, 'weight', None) is not None:
  108. nn.init.zeros_(self.bn2.weight)
  109. def forward(self, x: torch.Tensor) -> torch.Tensor:
  110. shortcut = x
  111. x = self.conv1(x)
  112. x = self.bn1(x)
  113. x = self.drop_block(x)
  114. x = self.act1(x)
  115. x = self.aa(x)
  116. x = self.conv2(x)
  117. x = self.bn2(x)
  118. if self.se is not None:
  119. x = self.se(x)
  120. if self.drop_path is not None:
  121. x = self.drop_path(x)
  122. if self.downsample is not None:
  123. shortcut = self.downsample(shortcut)
  124. x += shortcut
  125. x = self.act2(x)
  126. return x
  127. class Bottleneck(nn.Module):
  128. """Bottleneck residual block for ResNet.
  129. This is the bottleneck block used in ResNet-50, ResNet-101, and ResNet-152.
  130. """
  131. expansion = 4
  132. def __init__(
  133. self,
  134. inplanes: int,
  135. planes: int,
  136. stride: int = 1,
  137. downsample: Optional[nn.Module] = None,
  138. cardinality: int = 1,
  139. base_width: int = 64,
  140. reduce_first: int = 1,
  141. dilation: int = 1,
  142. first_dilation: Optional[int] = None,
  143. act_layer: Type[nn.Module] = nn.ReLU,
  144. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  145. attn_layer: Optional[Type[nn.Module]] = None,
  146. aa_layer: Optional[Type[nn.Module]] = None,
  147. drop_block: Optional[Type[nn.Module]] = None,
  148. drop_path: Optional[nn.Module] = None,
  149. device=None,
  150. dtype=None,
  151. ) -> None:
  152. """
  153. Args:
  154. inplanes: Input channel dimensionality.
  155. planes: Used to determine output channel dimensionalities.
  156. stride: Stride used in convolution layers.
  157. downsample: Optional downsample layer for residual path.
  158. cardinality: Number of convolution groups.
  159. base_width: Base width used to determine output channel dimensionality.
  160. reduce_first: Reduction factor for first convolution output width of residual blocks.
  161. dilation: Dilation rate for convolution layers.
  162. first_dilation: Dilation rate for first convolution layer.
  163. act_layer: Activation layer class.
  164. norm_layer: Normalization layer class.
  165. attn_layer: Attention layer class.
  166. aa_layer: Anti-aliasing layer class.
  167. drop_block: DropBlock layer class.
  168. drop_path: Optional DropPath layer instance.
  169. """
  170. dd = {'device': device, 'dtype': dtype}
  171. super().__init__()
  172. width = int(math.floor(planes * (base_width / 64)) * cardinality)
  173. first_planes = width // reduce_first
  174. outplanes = planes * self.expansion
  175. first_dilation = first_dilation or dilation
  176. use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
  177. self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False, **dd)
  178. self.bn1 = norm_layer(first_planes, **dd)
  179. self.act1 = act_layer(inplace=True)
  180. self.conv2 = nn.Conv2d(
  181. first_planes,
  182. width,
  183. kernel_size=3,
  184. stride=1 if use_aa else stride,
  185. padding=first_dilation,
  186. dilation=first_dilation,
  187. groups=cardinality,
  188. bias=False,
  189. **dd,
  190. )
  191. self.bn2 = norm_layer(width, **dd)
  192. self.drop_block = drop_block() if drop_block is not None else nn.Identity()
  193. self.act2 = act_layer(inplace=True)
  194. self.aa = create_aa(aa_layer, channels=width, stride=stride, enable=use_aa, **dd)
  195. self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False, **dd)
  196. self.bn3 = norm_layer(outplanes, **dd)
  197. self.se = create_attn(attn_layer, outplanes, **dd)
  198. self.act3 = act_layer(inplace=True)
  199. self.downsample = downsample
  200. self.stride = stride
  201. self.dilation = dilation
  202. self.drop_path = drop_path
  203. def zero_init_last(self) -> None:
  204. """Initialize the last batch norm layer weights to zero for better convergence."""
  205. if getattr(self.bn3, 'weight', None) is not None:
  206. nn.init.zeros_(self.bn3.weight)
  207. def forward(self, x: torch.Tensor) -> torch.Tensor:
  208. shortcut = x
  209. x = self.conv1(x)
  210. x = self.bn1(x)
  211. x = self.act1(x)
  212. x = self.conv2(x)
  213. x = self.bn2(x)
  214. x = self.drop_block(x)
  215. x = self.act2(x)
  216. x = self.aa(x)
  217. x = self.conv3(x)
  218. x = self.bn3(x)
  219. if self.se is not None:
  220. x = self.se(x)
  221. if self.drop_path is not None:
  222. x = self.drop_path(x)
  223. if self.downsample is not None:
  224. shortcut = self.downsample(shortcut)
  225. x += shortcut
  226. x = self.act3(x)
  227. return x
  228. def downsample_conv(
  229. in_channels: int,
  230. out_channels: int,
  231. kernel_size: int,
  232. stride: int = 1,
  233. dilation: int = 1,
  234. first_dilation: Optional[int] = None,
  235. norm_layer: Optional[Type[nn.Module]] = None,
  236. device=None,
  237. dtype=None,
  238. ) -> nn.Module:
  239. dd = {'device': device, 'dtype': dtype}
  240. norm_layer = norm_layer or nn.BatchNorm2d
  241. kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
  242. first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
  243. p = get_padding(kernel_size, stride, first_dilation)
  244. return nn.Sequential(*[
  245. nn.Conv2d(
  246. in_channels,
  247. out_channels,
  248. kernel_size,
  249. stride=stride,
  250. padding=p,
  251. dilation=first_dilation,
  252. bias=False,
  253. **dd
  254. ),
  255. norm_layer(out_channels, **dd)
  256. ])
  257. def downsample_avg(
  258. in_channels: int,
  259. out_channels: int,
  260. kernel_size: int,
  261. stride: int = 1,
  262. dilation: int = 1,
  263. first_dilation: Optional[int] = None,
  264. norm_layer: Optional[Type[nn.Module]] = None,
  265. device=None,
  266. dtype=None,
  267. ) -> nn.Module:
  268. dd = {'device': device, 'dtype': dtype}
  269. norm_layer = norm_layer or nn.BatchNorm2d
  270. avg_stride = stride if dilation == 1 else 1
  271. if stride == 1 and dilation == 1:
  272. pool = nn.Identity()
  273. else:
  274. avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
  275. pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
  276. return nn.Sequential(*[
  277. pool,
  278. nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False, **dd),
  279. norm_layer(out_channels, **dd)
  280. ])
  281. def drop_blocks(drop_prob: float = 0.) -> List[Optional[partial]]:
  282. """Create DropBlock layer instances for each stage.
  283. Args:
  284. drop_prob: Drop probability for DropBlock.
  285. Returns:
  286. List of DropBlock partial instances or None for each stage.
  287. """
  288. return [
  289. None, None,
  290. partial(DropBlock2d, drop_prob=drop_prob, block_size=5, gamma_scale=0.25) if drop_prob else None,
  291. partial(DropBlock2d, drop_prob=drop_prob, block_size=3, gamma_scale=1.00) if drop_prob else None]
  292. def make_blocks(
  293. block_fns: Tuple[Union[Type[BasicBlock], Type[Bottleneck]], ...],
  294. channels: Tuple[int, ...],
  295. block_repeats: Tuple[int, ...],
  296. inplanes: int,
  297. reduce_first: int = 1,
  298. output_stride: int = 32,
  299. down_kernel_size: int = 1,
  300. avg_down: bool = False,
  301. drop_block_rate: float = 0.,
  302. drop_path_rate: float = 0.,
  303. device=None,
  304. dtype=None,
  305. **kwargs,
  306. ) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]:
  307. """Create ResNet stages with specified block configurations.
  308. Args:
  309. block_fns: Block class to use for each stage.
  310. channels: Number of channels for each stage.
  311. block_repeats: Number of blocks to repeat for each stage.
  312. inplanes: Number of input channels.
  313. reduce_first: Reduction factor for first convolution in each stage.
  314. output_stride: Target output stride of network.
  315. down_kernel_size: Kernel size for downsample layers.
  316. avg_down: Use average pooling for downsample.
  317. drop_block_rate: DropBlock drop rate.
  318. drop_path_rate: Drop path rate for stochastic depth.
  319. **kwargs: Additional arguments passed to block constructors.
  320. Returns:
  321. Tuple of stage modules list and feature info list.
  322. """
  323. dd = {'device': device, 'dtype': dtype}
  324. stages = []
  325. feature_info = []
  326. net_num_blocks = sum(block_repeats)
  327. net_block_idx = 0
  328. net_stride = 4
  329. dilation = prev_dilation = 1
  330. for stage_idx, (block_fn, planes, num_blocks, db) in enumerate(zip(block_fns, channels, block_repeats, drop_blocks(drop_block_rate))):
  331. stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
  332. stride = 1 if stage_idx == 0 else 2
  333. if net_stride >= output_stride:
  334. dilation *= stride
  335. stride = 1
  336. else:
  337. net_stride *= stride
  338. downsample = None
  339. if stride != 1 or inplanes != planes * block_fn.expansion:
  340. down_kwargs = dict(
  341. in_channels=inplanes,
  342. out_channels=planes * block_fn.expansion,
  343. kernel_size=down_kernel_size,
  344. stride=stride,
  345. dilation=dilation,
  346. first_dilation=prev_dilation,
  347. norm_layer=kwargs.get('norm_layer'),
  348. **dd,
  349. )
  350. downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)
  351. block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs)
  352. blocks = []
  353. for block_idx in range(num_blocks):
  354. downsample = downsample if block_idx == 0 else None
  355. stride = stride if block_idx == 0 else 1
  356. block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule
  357. blocks.append(block_fn(
  358. inplanes,
  359. planes,
  360. stride,
  361. downsample,
  362. first_dilation=prev_dilation,
  363. drop_path=DropPath(block_dpr) if block_dpr > 0. else None,
  364. **block_kwargs,
  365. **dd,
  366. ))
  367. prev_dilation = dilation
  368. inplanes = planes * block_fn.expansion
  369. net_block_idx += 1
  370. stages.append((stage_name, nn.Sequential(*blocks)))
  371. feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name))
  372. return stages, feature_info
  373. class ResNet(nn.Module):
  374. """ResNet / ResNeXt / SE-ResNeXt / SE-Net
  375. This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that
  376. * have > 1 stride in the 3x3 conv layer of bottleneck
  377. * have conv-bn-act ordering
  378. This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s
  379. variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the
  380. 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default.
  381. ResNet variants (the same modifications can be used in SE/ResNeXt models as well):
  382. * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
  383. * c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64)
  384. * d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample
  385. * e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample
  386. * s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128)
  387. * t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample
  388. * tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample
  389. ResNeXt
  390. * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths
  391. * same c,d, e, s variants as ResNet can be enabled
  392. SE-ResNeXt
  393. * normal - 7x7 stem, stem_width = 64
  394. * same c, d, e, s variants as ResNet can be enabled
  395. SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64,
  396. reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block
  397. """
  398. def __init__(
  399. self,
  400. block: Union[BasicBlock, Bottleneck],
  401. layers: Tuple[int, ...],
  402. num_classes: int = 1000,
  403. in_chans: int = 3,
  404. output_stride: int = 32,
  405. global_pool: str = 'avg',
  406. cardinality: int = 1,
  407. base_width: int = 64,
  408. stem_width: int = 64,
  409. stem_type: str = '',
  410. replace_stem_pool: bool = False,
  411. block_reduce_first: int = 1,
  412. down_kernel_size: int = 1,
  413. avg_down: bool = False,
  414. channels: Optional[Tuple[int, ...]] = (64, 128, 256, 512),
  415. act_layer: LayerType = nn.ReLU,
  416. norm_layer: LayerType = nn.BatchNorm2d,
  417. aa_layer: Optional[Type[nn.Module]] = None,
  418. drop_rate: float = 0.0,
  419. drop_path_rate: float = 0.,
  420. drop_block_rate: float = 0.,
  421. zero_init_last: bool = True,
  422. block_args: Optional[Dict[str, Any]] = None,
  423. device=None,
  424. dtype=None,
  425. ):
  426. """
  427. Args:
  428. block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck.
  429. layers (List[int]) : number of layers in each block
  430. num_classes (int): number of classification classes (default 1000)
  431. in_chans (int): number of input (color) channels. (default 3)
  432. output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
  433. global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg')
  434. cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1)
  435. base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64)
  436. stem_width (int): number of channels in stem convolutions (default 64)
  437. stem_type (str): The type of stem (default ''):
  438. * '', default - a single 7x7 conv with a width of stem_width
  439. * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
  440. * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
  441. replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution
  442. block_reduce_first (int): Reduction factor for first convolution output width of residual blocks,
  443. 1 for all archs except senets, where 2 (default 1)
  444. down_kernel_size (int): kernel size of residual block downsample path,
  445. 1x1 for most, 3x3 for senets (default: 1)
  446. avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False)
  447. act_layer (str, nn.Module): activation layer
  448. norm_layer (str, nn.Module): normalization layer
  449. aa_layer (nn.Module): anti-aliasing layer
  450. drop_rate (float): Dropout probability before classifier, for training (default 0.)
  451. drop_path_rate (float): Stochastic depth drop-path rate (default 0.)
  452. drop_block_rate (float): Drop block rate (default 0.)
  453. zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight)
  454. block_args (dict): Extra kwargs to pass through to block module
  455. """
  456. super().__init__()
  457. dd = {'device': device, 'dtype': dtype}
  458. block_args = block_args or dict()
  459. assert output_stride in (8, 16, 32)
  460. self.num_classes = num_classes
  461. self.in_chans = in_chans
  462. self.drop_rate = drop_rate
  463. self.grad_checkpointing = False
  464. act_layer = get_act_layer(act_layer)
  465. norm_layer = get_norm_layer(norm_layer)
  466. # Stem
  467. deep_stem = 'deep' in stem_type
  468. inplanes = stem_width * 2 if deep_stem else 64
  469. if deep_stem:
  470. stem_chs = (stem_width, stem_width)
  471. if 'tiered' in stem_type:
  472. stem_chs = (3 * (stem_width // 4), stem_width)
  473. self.conv1 = nn.Sequential(*[
  474. nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False, **dd),
  475. norm_layer(stem_chs[0], **dd),
  476. act_layer(inplace=True),
  477. nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False, **dd),
  478. norm_layer(stem_chs[1], **dd),
  479. act_layer(inplace=True),
  480. nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False, **dd)])
  481. else:
  482. self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False, **dd)
  483. self.bn1 = norm_layer(inplanes, **dd)
  484. self.act1 = act_layer(inplace=True)
  485. self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
  486. # Stem pooling. The name 'maxpool' remains for weight compatibility.
  487. if replace_stem_pool:
  488. self.maxpool = nn.Sequential(*filter(None, [
  489. nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False, **dd),
  490. create_aa(aa_layer, channels=inplanes, stride=2, **dd) if aa_layer is not None else None,
  491. norm_layer(inplanes, **dd),
  492. act_layer(inplace=True),
  493. ]))
  494. else:
  495. if aa_layer is not None:
  496. if issubclass(aa_layer, nn.AvgPool2d):
  497. self.maxpool = aa_layer(2)
  498. else:
  499. self.maxpool = nn.Sequential(*[
  500. nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
  501. aa_layer(channels=inplanes, stride=2, **dd)])
  502. else:
  503. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  504. # Feature Blocks
  505. block_fns = to_ntuple(len(channels))(block)
  506. stage_modules, stage_feature_info = make_blocks(
  507. block_fns,
  508. channels,
  509. layers,
  510. inplanes,
  511. cardinality=cardinality,
  512. base_width=base_width,
  513. output_stride=output_stride,
  514. reduce_first=block_reduce_first,
  515. avg_down=avg_down,
  516. down_kernel_size=down_kernel_size,
  517. act_layer=act_layer,
  518. norm_layer=norm_layer,
  519. aa_layer=aa_layer,
  520. drop_block_rate=drop_block_rate,
  521. drop_path_rate=drop_path_rate,
  522. **block_args,
  523. **dd,
  524. )
  525. for stage in stage_modules:
  526. self.add_module(*stage) # layer1, layer2, etc
  527. self.feature_info.extend(stage_feature_info)
  528. # Head (Pooling and Classifier)
  529. self.num_features = self.head_hidden_size = channels[-1] * block_fns[-1].expansion
  530. self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool, **dd)
  531. self.init_weights(zero_init_last=zero_init_last)
  532. @torch.jit.ignore
  533. def init_weights(self, zero_init_last: bool = True) -> None:
  534. """Initialize model weights.
  535. Args:
  536. zero_init_last: Zero-initialize the last BN in each residual branch.
  537. """
  538. for n, m in self.named_modules():
  539. if isinstance(m, nn.Conv2d):
  540. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  541. if zero_init_last:
  542. for m in self.modules():
  543. if hasattr(m, 'zero_init_last'):
  544. m.zero_init_last()
  545. @torch.jit.ignore
  546. def group_matcher(self, coarse: bool = False) -> Dict[str, str]:
  547. """Create regex patterns for parameter grouping.
  548. Args:
  549. coarse: Use coarse (stage-level) or fine (block-level) grouping.
  550. Returns:
  551. Dictionary mapping group names to regex patterns.
  552. """
  553. matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
  554. return matcher
  555. @torch.jit.ignore
  556. def set_grad_checkpointing(self, enable: bool = True) -> None:
  557. """Enable or disable gradient checkpointing.
  558. Args:
  559. enable: Whether to enable gradient checkpointing.
  560. """
  561. self.grad_checkpointing = enable
  562. @torch.jit.ignore
  563. def get_classifier(self, name_only: bool = False) -> Union[str, nn.Module]:
  564. """Get the classifier module.
  565. Args:
  566. name_only: Return classifier module name instead of module.
  567. Returns:
  568. Classifier module or name.
  569. """
  570. return 'fc' if name_only else self.fc
  571. def reset_classifier(self, num_classes: int, global_pool: str = 'avg') -> None:
  572. """Reset the classifier head.
  573. Args:
  574. num_classes: Number of classes for new classifier.
  575. global_pool: Global pooling type.
  576. """
  577. self.num_classes = num_classes
  578. self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
  579. def forward_intermediates(
  580. self,
  581. x: torch.Tensor,
  582. indices: Optional[Union[int, List[int]]] = None,
  583. norm: bool = False,
  584. stop_early: bool = False,
  585. output_fmt: str = 'NCHW',
  586. intermediates_only: bool = False,
  587. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  588. """Forward features that returns intermediates.
  589. Args:
  590. x: Input image tensor.
  591. indices: Take last n blocks if int, all if None, select matching indices if sequence.
  592. norm: Apply norm layer to compatible intermediates.
  593. stop_early: Stop iterating over blocks when last desired intermediate hit.
  594. output_fmt: Shape of intermediate feature outputs.
  595. intermediates_only: Only return intermediate features.
  596. Returns:
  597. Features and list of intermediate features or just intermediate features.
  598. """
  599. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  600. intermediates = []
  601. take_indices, max_index = feature_take_indices(5, indices)
  602. # forward pass
  603. feat_idx = 0
  604. x = self.conv1(x)
  605. x = self.bn1(x)
  606. x = self.act1(x)
  607. if feat_idx in take_indices:
  608. intermediates.append(x)
  609. x = self.maxpool(x)
  610. layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
  611. if stop_early:
  612. layer_names = layer_names[:max_index]
  613. for n in layer_names:
  614. feat_idx += 1
  615. x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML
  616. if feat_idx in take_indices:
  617. intermediates.append(x)
  618. if intermediates_only:
  619. return intermediates
  620. return x, intermediates
  621. def prune_intermediate_layers(
  622. self,
  623. indices: Union[int, List[int]] = 1,
  624. prune_norm: bool = False,
  625. prune_head: bool = True,
  626. ) -> List[int]:
  627. """Prune layers not required for specified intermediates.
  628. Args:
  629. indices: Indices of intermediate layers to keep.
  630. prune_norm: Whether to prune normalization layers.
  631. prune_head: Whether to prune the classifier head.
  632. Returns:
  633. List of indices that were kept.
  634. """
  635. take_indices, max_index = feature_take_indices(5, indices)
  636. layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
  637. layer_names = layer_names[max_index:]
  638. for n in layer_names:
  639. setattr(self, n, nn.Identity())
  640. if prune_head:
  641. self.reset_classifier(0, '')
  642. return take_indices
  643. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  644. """Forward pass through feature extraction layers."""
  645. x = self.conv1(x)
  646. x = self.bn1(x)
  647. x = self.act1(x)
  648. x = self.maxpool(x)
  649. if self.grad_checkpointing and not torch.jit.is_scripting():
  650. x = checkpoint_seq([self.layer1, self.layer2, self.layer3, self.layer4], x, flatten=True)
  651. else:
  652. x = self.layer1(x)
  653. x = self.layer2(x)
  654. x = self.layer3(x)
  655. x = self.layer4(x)
  656. return x
  657. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  658. """Forward pass through classifier head.
  659. Args:
  660. x: Feature tensor.
  661. pre_logits: Return features before final classifier layer.
  662. Returns:
  663. Output tensor.
  664. """
  665. x = self.global_pool(x)
  666. if self.drop_rate:
  667. x = F.dropout(x, p=float(self.drop_rate), training=self.training)
  668. return x if pre_logits else self.fc(x)
  669. def forward(self, x: torch.Tensor) -> torch.Tensor:
  670. """Forward pass."""
  671. x = self.forward_features(x)
  672. x = self.forward_head(x)
  673. return x
  674. def _create_resnet(variant: str, pretrained: bool = False, **kwargs) -> ResNet:
  675. """Create a ResNet model.
  676. Args:
  677. variant: Model variant name.
  678. pretrained: Load pretrained weights.
  679. **kwargs: Additional model arguments.
  680. Returns:
  681. ResNet model instance.
  682. """
  683. return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
  684. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  685. """Create a default configuration for ResNet models."""
  686. return {
  687. 'url': url,
  688. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  689. 'crop_pct': 0.875, 'interpolation': 'bilinear',
  690. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  691. 'first_conv': 'conv1', 'classifier': 'fc',
  692. 'license': 'apache-2.0',
  693. **kwargs
  694. }
  695. def _tcfg(url: str = '', **kwargs) -> Dict[str, Any]:
  696. """Create a configuration with bicubic interpolation."""
  697. return _cfg(url=url, **dict({'interpolation': 'bicubic'}, **kwargs))
  698. def _ttcfg(url: str = '', **kwargs) -> Dict[str, Any]:
  699. """Create a configuration for models trained with timm."""
  700. return _cfg(url=url, **dict({
  701. 'interpolation': 'bicubic', 'test_input_size': (3, 288, 288), 'test_crop_pct': 0.95,
  702. 'origin_url': 'https://github.com/huggingface/pytorch-image-models',
  703. }, **kwargs))
  704. def _rcfg(url: str = '', **kwargs) -> Dict[str, Any]:
  705. """Create a configuration for ResNet-RS models."""
  706. return _cfg(url=url, **dict({
  707. 'interpolation': 'bicubic', 'crop_pct': 0.95, 'test_input_size': (3, 288, 288), 'test_crop_pct': 1.0,
  708. 'origin_url': 'https://github.com/huggingface/pytorch-image-models', 'paper_ids': 'arXiv:2110.00476'
  709. }, **kwargs))
  710. def _r3cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  711. """Create a configuration for ResNet-RS models with 160x160 input."""
  712. return _cfg(url=url, **dict({
  713. 'interpolation': 'bicubic', 'input_size': (3, 160, 160), 'pool_size': (5, 5),
  714. 'crop_pct': 0.95, 'test_input_size': (3, 224, 224), 'test_crop_pct': 0.95,
  715. 'origin_url': 'https://github.com/huggingface/pytorch-image-models', 'paper_ids': 'arXiv:2110.00476',
  716. }, **kwargs))
  717. def _gcfg(url: str = '', **kwargs) -> Dict[str, Any]:
  718. """Create a configuration for Gluon pretrained models."""
  719. return _cfg(url=url, **dict({
  720. 'interpolation': 'bicubic',
  721. 'origin_url': 'https://cv.gluon.ai/model_zoo/classification.html',
  722. }, **kwargs))
  723. default_cfgs = generate_default_cfgs({
  724. # ResNet and Wide ResNet trained w/ timm (RSB paper and others)
  725. 'resnet10t.c3_in1k': _ttcfg(
  726. hf_hub_id='timm/',
  727. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet10t_176_c3-f3215ab1.pth',
  728. input_size=(3, 176, 176), pool_size=(6, 6), test_crop_pct=0.95, test_input_size=(3, 224, 224),
  729. first_conv='conv1.0'),
  730. 'resnet14t.c3_in1k': _ttcfg(
  731. hf_hub_id='timm/',
  732. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet14t_176_c3-c4ed2c37.pth',
  733. input_size=(3, 176, 176), pool_size=(6, 6), test_crop_pct=0.95, test_input_size=(3, 224, 224),
  734. first_conv='conv1.0'),
  735. 'resnet18.a1_in1k': _rcfg(
  736. hf_hub_id='timm/',
  737. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet18_a1_0-d63eafa0.pth'),
  738. 'resnet18.a2_in1k': _rcfg(
  739. hf_hub_id='timm/',
  740. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet18_a2_0-b61bd467.pth'),
  741. 'resnet18.a3_in1k': _r3cfg(
  742. hf_hub_id='timm/',
  743. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet18_a3_0-40c531c8.pth'),
  744. 'resnet18d.ra2_in1k': _ttcfg(
  745. hf_hub_id='timm/',
  746. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth',
  747. first_conv='conv1.0'),
  748. 'resnet18d.ra4_e3600_r224_in1k': _rcfg(
  749. hf_hub_id='timm/',
  750. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='conv1.0'),
  751. 'resnet34.a1_in1k': _rcfg(
  752. hf_hub_id='timm/',
  753. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet34_a1_0-46f8f793.pth'),
  754. 'resnet34.a2_in1k': _rcfg(
  755. hf_hub_id='timm/',
  756. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet34_a2_0-82d47d71.pth'),
  757. 'resnet34.a3_in1k': _r3cfg(
  758. hf_hub_id='timm/',
  759. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet34_a3_0-a20cabb6.pth',
  760. crop_pct=0.95),
  761. 'resnet34.bt_in1k': _ttcfg(
  762. hf_hub_id='timm/',
  763. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'),
  764. 'resnet34.ra4_e3600_r224_in1k': _rcfg(
  765. hf_hub_id='timm/',
  766. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9),
  767. 'resnet34d.ra2_in1k': _ttcfg(
  768. hf_hub_id='timm/',
  769. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34d_ra2-f8dcfcaf.pth',
  770. first_conv='conv1.0'),
  771. 'resnet26.bt_in1k': _ttcfg(
  772. hf_hub_id='timm/',
  773. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26-9aa10e23.pth'),
  774. 'resnet26d.bt_in1k': _ttcfg(
  775. hf_hub_id='timm/',
  776. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth',
  777. first_conv='conv1.0'),
  778. 'resnet26t.ra2_in1k': _ttcfg(
  779. hf_hub_id='timm/',
  780. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth',
  781. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
  782. crop_pct=0.94, test_input_size=(3, 320, 320), test_crop_pct=1.0),
  783. 'resnet50.a1_in1k': _rcfg(
  784. hf_hub_id='timm/',
  785. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth'),
  786. 'resnet50.a1h_in1k': _rcfg(
  787. hf_hub_id='timm/',
  788. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth',
  789. input_size=(3, 176, 176), pool_size=(6, 6), crop_pct=0.9, test_input_size=(3, 224, 224), test_crop_pct=1.0),
  790. 'resnet50.a2_in1k': _rcfg(
  791. hf_hub_id='timm/',
  792. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a2_0-a2746f79.pth'),
  793. 'resnet50.a3_in1k': _r3cfg(
  794. hf_hub_id='timm/',
  795. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a3_0-59cae1ef.pth'),
  796. 'resnet50.b1k_in1k': _rcfg(
  797. hf_hub_id='timm/',
  798. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_b1k-532a802a.pth'),
  799. 'resnet50.b2k_in1k': _rcfg(
  800. hf_hub_id='timm/',
  801. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_b2k-1ba180c1.pth'),
  802. 'resnet50.c1_in1k': _rcfg(
  803. hf_hub_id='timm/',
  804. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_c1-5ba5e060.pth'),
  805. 'resnet50.c2_in1k': _rcfg(
  806. hf_hub_id='timm/',
  807. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_c2-d01e05b2.pth'),
  808. 'resnet50.d_in1k': _rcfg(
  809. hf_hub_id='timm/',
  810. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_d-f39db8af.pth'),
  811. 'resnet50.ram_in1k': _ttcfg(
  812. hf_hub_id='timm/',
  813. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth'),
  814. 'resnet50.am_in1k': _tcfg(
  815. hf_hub_id='timm/',
  816. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/resnet50_am-6c502b37.pth'),
  817. 'resnet50.ra_in1k': _ttcfg(
  818. hf_hub_id='timm/',
  819. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/resnet50_ra-85ebb6e5.pth'),
  820. 'resnet50.bt_in1k': _ttcfg(
  821. hf_hub_id='timm/',
  822. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth'),
  823. 'resnet50d.ra2_in1k': _ttcfg(
  824. hf_hub_id='timm/',
  825. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
  826. first_conv='conv1.0'),
  827. 'resnet50d.ra4_e3600_r224_in1k': _rcfg(
  828. hf_hub_id='timm/',
  829. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  830. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0,
  831. first_conv='conv1.0'),
  832. 'resnet50d.a1_in1k': _rcfg(
  833. hf_hub_id='timm/',
  834. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50d_a1_0-e20cff14.pth',
  835. first_conv='conv1.0'),
  836. 'resnet50d.a2_in1k': _rcfg(
  837. hf_hub_id='timm/',
  838. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50d_a2_0-a3adc64d.pth',
  839. first_conv='conv1.0'),
  840. 'resnet50d.a3_in1k': _r3cfg(
  841. hf_hub_id='timm/',
  842. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50d_a3_0-403fdfad.pth',
  843. first_conv='conv1.0'),
  844. 'resnet50t.untrained': _ttcfg(first_conv='conv1.0'),
  845. 'resnet101.a1h_in1k': _rcfg(
  846. hf_hub_id='timm/',
  847. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth'),
  848. 'resnet101.a1_in1k': _rcfg(
  849. hf_hub_id='timm/',
  850. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1_0-cdcb52a9.pth'),
  851. 'resnet101.a2_in1k': _rcfg(
  852. hf_hub_id='timm/',
  853. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a2_0-6edb36c7.pth'),
  854. 'resnet101.a3_in1k': _r3cfg(
  855. hf_hub_id='timm/',
  856. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a3_0-1db14157.pth'),
  857. 'resnet101d.ra2_in1k': _ttcfg(
  858. hf_hub_id='timm/',
  859. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth',
  860. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95,
  861. test_crop_pct=1.0, test_input_size=(3, 320, 320)),
  862. 'resnet152.a1h_in1k': _rcfg(
  863. hf_hub_id='timm/',
  864. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet152_a1h-dc400468.pth'),
  865. 'resnet152.a1_in1k': _rcfg(
  866. hf_hub_id='timm/',
  867. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet152_a1_0-2eee8a7a.pth'),
  868. 'resnet152.a2_in1k': _rcfg(
  869. hf_hub_id='timm/',
  870. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet152_a2_0-b4c6978f.pth'),
  871. 'resnet152.a3_in1k': _r3cfg(
  872. hf_hub_id='timm/',
  873. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet152_a3_0-134d4688.pth'),
  874. 'resnet152d.ra2_in1k': _ttcfg(
  875. hf_hub_id='timm/',
  876. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth',
  877. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95,
  878. test_crop_pct=1.0, test_input_size=(3, 320, 320)),
  879. 'resnet200.untrained': _ttcfg(),
  880. 'resnet200d.ra2_in1k': _ttcfg(
  881. hf_hub_id='timm/',
  882. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet200d_ra2-bdba9bf9.pth',
  883. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95,
  884. test_crop_pct=1.0, test_input_size=(3, 320, 320)),
  885. 'wide_resnet50_2.racm_in1k': _ttcfg(
  886. hf_hub_id='timm/',
  887. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/wide_resnet50_racm-8234f177.pth'),
  888. # torchvision resnet weights
  889. 'resnet18.tv_in1k': _cfg(
  890. hf_hub_id='timm/',
  891. url='https://download.pytorch.org/models/resnet18-f37072fd.pth',
  892. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  893. 'resnet34.tv_in1k': _cfg(
  894. hf_hub_id='timm/',
  895. url='https://download.pytorch.org/models/resnet34-b627a593.pth',
  896. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  897. 'resnet50.tv_in1k': _cfg(
  898. hf_hub_id='timm/',
  899. url='https://download.pytorch.org/models/resnet50-0676ba61.pth',
  900. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  901. 'resnet50.tv2_in1k': _cfg(
  902. hf_hub_id='timm/',
  903. url='https://download.pytorch.org/models/resnet50-11ad3fa6.pth',
  904. input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965,
  905. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  906. 'resnet101.tv_in1k': _cfg(
  907. hf_hub_id='timm/',
  908. url='https://download.pytorch.org/models/resnet101-63fe2227.pth',
  909. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  910. 'resnet101.tv2_in1k': _cfg(
  911. hf_hub_id='timm/',
  912. url='https://download.pytorch.org/models/resnet101-cd907fc2.pth',
  913. input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965,
  914. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  915. 'resnet152.tv_in1k': _cfg(
  916. hf_hub_id='timm/',
  917. url='https://download.pytorch.org/models/resnet152-394f9c45.pth',
  918. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  919. 'resnet152.tv2_in1k': _cfg(
  920. hf_hub_id='timm/',
  921. url='https://download.pytorch.org/models/resnet152-f82ba261.pth',
  922. input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965,
  923. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  924. 'wide_resnet50_2.tv_in1k': _cfg(
  925. hf_hub_id='timm/',
  926. url='https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
  927. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  928. 'wide_resnet50_2.tv2_in1k': _cfg(
  929. hf_hub_id='timm/',
  930. url='https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth',
  931. input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965,
  932. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  933. 'wide_resnet101_2.tv_in1k': _cfg(
  934. hf_hub_id='timm/',
  935. url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
  936. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  937. 'wide_resnet101_2.tv2_in1k': _cfg(
  938. hf_hub_id='timm/',
  939. url='https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth',
  940. input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965,
  941. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  942. # ResNets w/ alternative norm layers
  943. 'resnet50_gn.a1h_in1k': _ttcfg(
  944. hf_hub_id='timm/',
  945. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_gn_a1h2-8fe6c4d0.pth',
  946. crop_pct=0.94),
  947. # ResNeXt trained in timm (RSB paper and others)
  948. 'resnext50_32x4d.a1h_in1k': _rcfg(
  949. hf_hub_id='timm/',
  950. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a1h-0146ab0a.pth'),
  951. 'resnext50_32x4d.a1_in1k': _rcfg(
  952. hf_hub_id='timm/',
  953. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a1_0-b5a91a1d.pth'),
  954. 'resnext50_32x4d.a2_in1k': _rcfg(
  955. hf_hub_id='timm/',
  956. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a2_0-efc76add.pth'),
  957. 'resnext50_32x4d.a3_in1k': _r3cfg(
  958. hf_hub_id='timm/',
  959. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a3_0-3e450271.pth'),
  960. 'resnext50_32x4d.ra_in1k': _ttcfg(
  961. hf_hub_id='timm/',
  962. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d_ra-d733960d.pth'),
  963. 'resnext50d_32x4d.bt_in1k': _ttcfg(
  964. hf_hub_id='timm/',
  965. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth',
  966. first_conv='conv1.0'),
  967. 'resnext101_32x4d.untrained': _ttcfg(),
  968. 'resnext101_64x4d.c1_in1k': _rcfg(
  969. hf_hub_id='timm/',
  970. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnext101_64x4d_c-0d0e0cc0.pth'),
  971. # torchvision ResNeXt weights
  972. 'resnext50_32x4d.tv_in1k': _cfg(
  973. hf_hub_id='timm/',
  974. url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
  975. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  976. 'resnext101_32x8d.tv_in1k': _cfg(
  977. hf_hub_id='timm/',
  978. url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
  979. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  980. 'resnext101_64x4d.tv_in1k': _cfg(
  981. hf_hub_id='timm/',
  982. url='https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth',
  983. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  984. 'resnext50_32x4d.tv2_in1k': _cfg(
  985. hf_hub_id='timm/',
  986. url='https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth',
  987. input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965,
  988. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  989. 'resnext101_32x8d.tv2_in1k': _cfg(
  990. hf_hub_id='timm/',
  991. url='https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth',
  992. input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965,
  993. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  994. # ResNeXt models - Weakly Supervised Pretraining on Instagram Hashtags
  995. # from https://github.com/facebookresearch/WSL-Images
  996. # Please note the CC-BY-NC 4.0 license on these weights, non-commercial use only.
  997. 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k': _cfg(
  998. hf_hub_id='timm/',
  999. url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth',
  1000. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/WSL-Images'),
  1001. 'resnext101_32x16d.fb_wsl_ig1b_ft_in1k': _cfg(
  1002. hf_hub_id='timm/',
  1003. url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth',
  1004. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/WSL-Images'),
  1005. 'resnext101_32x32d.fb_wsl_ig1b_ft_in1k': _cfg(
  1006. hf_hub_id='timm/',
  1007. url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth',
  1008. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/WSL-Images'),
  1009. 'resnext101_32x48d.fb_wsl_ig1b_ft_in1k': _cfg(
  1010. hf_hub_id='timm/',
  1011. url='https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth',
  1012. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/WSL-Images'),
  1013. # Semi-Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models
  1014. # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only.
  1015. 'resnet18.fb_ssl_yfcc100m_ft_in1k': _cfg(
  1016. hf_hub_id='timm/',
  1017. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth',
  1018. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1019. 'resnet50.fb_ssl_yfcc100m_ft_in1k': _cfg(
  1020. hf_hub_id='timm/',
  1021. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth',
  1022. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1023. 'resnext50_32x4d.fb_ssl_yfcc100m_ft_in1k': _cfg(
  1024. hf_hub_id='timm/',
  1025. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth',
  1026. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1027. 'resnext101_32x4d.fb_ssl_yfcc100m_ft_in1k': _cfg(
  1028. hf_hub_id='timm/',
  1029. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth',
  1030. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1031. 'resnext101_32x8d.fb_ssl_yfcc100m_ft_in1k': _cfg(
  1032. hf_hub_id='timm/',
  1033. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth',
  1034. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1035. 'resnext101_32x16d.fb_ssl_yfcc100m_ft_in1k': _cfg(
  1036. hf_hub_id='timm/',
  1037. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth',
  1038. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1039. # Semi-Weakly Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models
  1040. # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only.
  1041. 'resnet18.fb_swsl_ig1b_ft_in1k': _cfg(
  1042. hf_hub_id='timm/',
  1043. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth',
  1044. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1045. 'resnet50.fb_swsl_ig1b_ft_in1k': _cfg(
  1046. hf_hub_id='timm/',
  1047. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth',
  1048. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1049. 'resnext50_32x4d.fb_swsl_ig1b_ft_in1k': _cfg(
  1050. hf_hub_id='timm/',
  1051. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth',
  1052. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1053. 'resnext101_32x4d.fb_swsl_ig1b_ft_in1k': _cfg(
  1054. hf_hub_id='timm/',
  1055. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth',
  1056. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1057. 'resnext101_32x8d.fb_swsl_ig1b_ft_in1k': _cfg(
  1058. hf_hub_id='timm/',
  1059. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth',
  1060. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1061. 'resnext101_32x16d.fb_swsl_ig1b_ft_in1k': _cfg(
  1062. hf_hub_id='timm/',
  1063. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth',
  1064. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1065. # Efficient Channel Attention ResNets
  1066. 'ecaresnet26t.ra2_in1k': _ttcfg(
  1067. hf_hub_id='timm/',
  1068. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth',
  1069. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
  1070. test_crop_pct=0.95, test_input_size=(3, 320, 320)),
  1071. 'ecaresnetlight.miil_in1k': _tcfg(
  1072. hf_hub_id='timm/',
  1073. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnetlight-75a9c627.pth',
  1074. test_crop_pct=0.95, test_input_size=(3, 288, 288)),
  1075. 'ecaresnet50d.miil_in1k': _tcfg(
  1076. hf_hub_id='timm/',
  1077. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnet50d-93c81e3b.pth',
  1078. first_conv='conv1.0', test_crop_pct=0.95, test_input_size=(3, 288, 288)),
  1079. 'ecaresnet50d_pruned.miil_in1k': _tcfg(
  1080. hf_hub_id='timm/',
  1081. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnet50d_p-e4fa23c2.pth',
  1082. first_conv='conv1.0', test_crop_pct=0.95, test_input_size=(3, 288, 288)),
  1083. 'ecaresnet50t.ra2_in1k': _tcfg(
  1084. hf_hub_id='timm/',
  1085. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth',
  1086. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
  1087. test_crop_pct=0.95, test_input_size=(3, 320, 320)),
  1088. 'ecaresnet50t.a1_in1k': _rcfg(
  1089. hf_hub_id='timm/',
  1090. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/ecaresnet50t_a1_0-99bd76a8.pth',
  1091. first_conv='conv1.0'),
  1092. 'ecaresnet50t.a2_in1k': _rcfg(
  1093. hf_hub_id='timm/',
  1094. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/ecaresnet50t_a2_0-b1c7b745.pth',
  1095. first_conv='conv1.0'),
  1096. 'ecaresnet50t.a3_in1k': _r3cfg(
  1097. hf_hub_id='timm/',
  1098. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/ecaresnet50t_a3_0-8cc311f1.pth',
  1099. first_conv='conv1.0'),
  1100. 'ecaresnet101d.miil_in1k': _tcfg(
  1101. hf_hub_id='timm/',
  1102. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnet101d-153dad65.pth',
  1103. first_conv='conv1.0', test_crop_pct=0.95, test_input_size=(3, 288, 288)),
  1104. 'ecaresnet101d_pruned.miil_in1k': _tcfg(
  1105. hf_hub_id='timm/',
  1106. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnet101d_p-9e74cb91.pth',
  1107. first_conv='conv1.0', test_crop_pct=0.95, test_input_size=(3, 288, 288)),
  1108. 'ecaresnet200d.untrained': _ttcfg(
  1109. first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.95, pool_size=(8, 8)),
  1110. 'ecaresnet269d.ra2_in1k': _ttcfg(
  1111. hf_hub_id='timm/',
  1112. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth',
  1113. first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=0.95,
  1114. test_crop_pct=1.0, test_input_size=(3, 352, 352)),
  1115. # Efficient Channel Attention ResNeXts
  1116. 'ecaresnext26t_32x4d.untrained': _tcfg(first_conv='conv1.0'),
  1117. 'ecaresnext50t_32x4d.untrained': _tcfg(first_conv='conv1.0'),
  1118. # Squeeze-Excitation ResNets, to eventually replace the models in senet.py
  1119. 'seresnet18.untrained': _ttcfg(),
  1120. 'seresnet34.untrained': _ttcfg(),
  1121. 'seresnet50.a1_in1k': _rcfg(
  1122. hf_hub_id='timm/',
  1123. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/seresnet50_a1_0-ffa00869.pth',
  1124. crop_pct=0.95),
  1125. 'seresnet50.a2_in1k': _rcfg(
  1126. hf_hub_id='timm/',
  1127. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/seresnet50_a2_0-850de0d9.pth',
  1128. crop_pct=0.95),
  1129. 'seresnet50.a3_in1k': _r3cfg(
  1130. hf_hub_id='timm/',
  1131. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/seresnet50_a3_0-317ecd56.pth',
  1132. crop_pct=0.95),
  1133. 'seresnet50.ra2_in1k': _ttcfg(
  1134. hf_hub_id='timm/',
  1135. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet50_ra_224-8efdb4bb.pth'),
  1136. 'seresnet50t.untrained': _ttcfg(
  1137. first_conv='conv1.0'),
  1138. 'seresnet101.untrained': _ttcfg(),
  1139. 'seresnet152.untrained': _ttcfg(),
  1140. 'seresnet152d.ra2_in1k': _ttcfg(
  1141. hf_hub_id='timm/',
  1142. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth',
  1143. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95,
  1144. test_crop_pct=1.0, test_input_size=(3, 320, 320)
  1145. ),
  1146. 'seresnet200d.untrained': _ttcfg(
  1147. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)),
  1148. 'seresnet269d.untrained': _ttcfg(
  1149. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)),
  1150. # Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py
  1151. 'seresnext26d_32x4d.bt_in1k': _ttcfg(
  1152. hf_hub_id='timm/',
  1153. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth',
  1154. first_conv='conv1.0'),
  1155. 'seresnext26t_32x4d.bt_in1k': _ttcfg(
  1156. hf_hub_id='timm/',
  1157. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth',
  1158. first_conv='conv1.0'),
  1159. 'seresnext50_32x4d.racm_in1k': _ttcfg(
  1160. hf_hub_id='timm/',
  1161. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext50_32x4d_racm-a304a460.pth'),
  1162. 'seresnext101_32x4d.untrained': _ttcfg(),
  1163. 'seresnext101_32x8d.ah_in1k': _rcfg(
  1164. hf_hub_id='timm/',
  1165. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101_32x8d_ah-e6bc4c0a.pth'),
  1166. 'seresnext101d_32x8d.ah_in1k': _rcfg(
  1167. hf_hub_id='timm/',
  1168. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101d_32x8d_ah-191d7b94.pth',
  1169. first_conv='conv1.0'),
  1170. # ResNets with anti-aliasing / blur pool
  1171. 'resnetaa50d.sw_in12k_ft_in1k': _ttcfg(
  1172. hf_hub_id='timm/',
  1173. first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0),
  1174. 'resnetaa101d.sw_in12k_ft_in1k': _ttcfg(
  1175. hf_hub_id='timm/',
  1176. first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0),
  1177. 'seresnextaa101d_32x8d.sw_in12k_ft_in1k_288': _ttcfg(
  1178. hf_hub_id='timm/',
  1179. crop_pct=0.95, input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), test_crop_pct=1.0,
  1180. first_conv='conv1.0'),
  1181. 'seresnextaa101d_32x8d.sw_in12k_ft_in1k': _ttcfg(
  1182. hf_hub_id='timm/',
  1183. first_conv='conv1.0', test_crop_pct=1.0),
  1184. 'seresnextaa201d_32x8d.sw_in12k_ft_in1k_384': _cfg(
  1185. hf_hub_id='timm/',
  1186. interpolation='bicubic', first_conv='conv1.0', pool_size=(12, 12), input_size=(3, 384, 384), crop_pct=1.0),
  1187. 'seresnextaa201d_32x8d.sw_in12k': _cfg(
  1188. hf_hub_id='timm/',
  1189. num_classes=11821, interpolation='bicubic', first_conv='conv1.0',
  1190. crop_pct=0.95, input_size=(3, 320, 320), pool_size=(10, 10), test_input_size=(3, 384, 384), test_crop_pct=1.0),
  1191. 'resnetaa50d.sw_in12k': _ttcfg(
  1192. hf_hub_id='timm/',
  1193. num_classes=11821, first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0),
  1194. 'resnetaa50d.d_in12k': _ttcfg(
  1195. hf_hub_id='timm/',
  1196. num_classes=11821, first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0),
  1197. 'resnetaa101d.sw_in12k': _ttcfg(
  1198. hf_hub_id='timm/',
  1199. num_classes=11821, first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0),
  1200. 'seresnextaa101d_32x8d.sw_in12k': _ttcfg(
  1201. hf_hub_id='timm/',
  1202. num_classes=11821, first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0),
  1203. 'resnetblur18.untrained': _ttcfg(),
  1204. 'resnetblur50.bt_in1k': _ttcfg(
  1205. hf_hub_id='timm/',
  1206. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth'),
  1207. 'resnetblur50d.untrained': _ttcfg(first_conv='conv1.0'),
  1208. 'resnetblur101d.untrained': _ttcfg(first_conv='conv1.0'),
  1209. 'resnetaa34d.untrained': _ttcfg(first_conv='conv1.0'),
  1210. 'resnetaa50.a1h_in1k': _rcfg(
  1211. hf_hub_id='timm/',
  1212. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetaa50_a1h-4cf422b3.pth'),
  1213. 'seresnetaa50d.untrained': _ttcfg(first_conv='conv1.0'),
  1214. 'seresnextaa101d_32x8d.ah_in1k': _rcfg(
  1215. hf_hub_id='timm/',
  1216. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnextaa101d_32x8d_ah-83c8ae12.pth',
  1217. first_conv='conv1.0'),
  1218. # ResNet-RS models
  1219. 'resnetrs50.tf_in1k': _cfg(
  1220. hf_hub_id='timm/',
  1221. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50_ema-6b53758b.pth',
  1222. input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.91, test_input_size=(3, 224, 224),
  1223. interpolation='bicubic', first_conv='conv1.0'),
  1224. 'resnetrs101.tf_in1k': _cfg(
  1225. hf_hub_id='timm/',
  1226. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101_i192_ema-1509bbf6.pth',
  1227. input_size=(3, 192, 192), pool_size=(6, 6), crop_pct=0.94, test_input_size=(3, 288, 288),
  1228. interpolation='bicubic', first_conv='conv1.0'),
  1229. 'resnetrs152.tf_in1k': _cfg(
  1230. hf_hub_id='timm/',
  1231. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152_i256_ema-a9aff7f9.pth',
  1232. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
  1233. interpolation='bicubic', first_conv='conv1.0'),
  1234. 'resnetrs200.tf_in1k': _cfg(
  1235. hf_hub_id='timm/',
  1236. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetrs200_c-6b698b88.pth',
  1237. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
  1238. interpolation='bicubic', first_conv='conv1.0'),
  1239. 'resnetrs270.tf_in1k': _cfg(
  1240. hf_hub_id='timm/',
  1241. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270_ema-b40e674c.pth',
  1242. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 352, 352),
  1243. interpolation='bicubic', first_conv='conv1.0'),
  1244. 'resnetrs350.tf_in1k': _cfg(
  1245. hf_hub_id='timm/',
  1246. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350_i256_ema-5a1aa8f1.pth',
  1247. input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, test_input_size=(3, 384, 384),
  1248. interpolation='bicubic', first_conv='conv1.0'),
  1249. 'resnetrs420.tf_in1k': _cfg(
  1250. hf_hub_id='timm/',
  1251. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420_ema-972dee69.pth',
  1252. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, test_input_size=(3, 416, 416),
  1253. interpolation='bicubic', first_conv='conv1.0'),
  1254. # gluon resnet weights
  1255. 'resnet18.gluon_in1k': _gcfg(
  1256. hf_hub_id='timm/',
  1257. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet18_v1b-0757602b.pth'),
  1258. 'resnet34.gluon_in1k': _gcfg(
  1259. hf_hub_id='timm/',
  1260. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet34_v1b-c6d82d59.pth'),
  1261. 'resnet50.gluon_in1k': _gcfg(
  1262. hf_hub_id='timm/',
  1263. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1b-0ebe02e2.pth'),
  1264. 'resnet101.gluon_in1k': _gcfg(
  1265. hf_hub_id='timm/',
  1266. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1b-3b017079.pth'),
  1267. 'resnet152.gluon_in1k': _gcfg(
  1268. hf_hub_id='timm/',
  1269. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1b-c1edb0dd.pth'),
  1270. 'resnet50c.gluon_in1k': _gcfg(
  1271. hf_hub_id='timm/',
  1272. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1c-48092f55.pth',
  1273. first_conv='conv1.0'),
  1274. 'resnet101c.gluon_in1k': _gcfg(
  1275. hf_hub_id='timm/',
  1276. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1c-1f26822a.pth',
  1277. first_conv='conv1.0'),
  1278. 'resnet152c.gluon_in1k': _gcfg(
  1279. hf_hub_id='timm/',
  1280. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1c-a3bb0b98.pth',
  1281. first_conv='conv1.0'),
  1282. 'resnet50d.gluon_in1k': _gcfg(
  1283. hf_hub_id='timm/',
  1284. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth',
  1285. first_conv='conv1.0'),
  1286. 'resnet101d.gluon_in1k': _gcfg(
  1287. hf_hub_id='timm/',
  1288. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1d-0f9c8644.pth',
  1289. first_conv='conv1.0'),
  1290. 'resnet152d.gluon_in1k': _gcfg(
  1291. hf_hub_id='timm/',
  1292. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1d-bd354e12.pth',
  1293. first_conv='conv1.0'),
  1294. 'resnet50s.gluon_in1k': _gcfg(
  1295. hf_hub_id='timm/',
  1296. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth',
  1297. first_conv='conv1.0'),
  1298. 'resnet101s.gluon_in1k': _gcfg(
  1299. hf_hub_id='timm/',
  1300. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth',
  1301. first_conv='conv1.0'),
  1302. 'resnet152s.gluon_in1k': _gcfg(
  1303. hf_hub_id='timm/',
  1304. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1s-dcc41b81.pth',
  1305. first_conv='conv1.0'),
  1306. 'resnext50_32x4d.gluon_in1k': _gcfg(
  1307. hf_hub_id='timm/',
  1308. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'),
  1309. 'resnext101_32x4d.gluon_in1k': _gcfg(
  1310. hf_hub_id='timm/',
  1311. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'),
  1312. 'resnext101_64x4d.gluon_in1k': _gcfg(
  1313. hf_hub_id='timm/',
  1314. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'),
  1315. 'seresnext50_32x4d.gluon_in1k': _gcfg(
  1316. hf_hub_id='timm/',
  1317. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'),
  1318. 'seresnext101_32x4d.gluon_in1k': _gcfg(
  1319. hf_hub_id='timm/',
  1320. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'),
  1321. 'seresnext101_64x4d.gluon_in1k': _gcfg(
  1322. hf_hub_id='timm/',
  1323. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'),
  1324. 'senet154.gluon_in1k': _gcfg(
  1325. hf_hub_id='timm/',
  1326. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth',
  1327. first_conv='conv1.0'),
  1328. 'test_resnet.r160_in1k': _cfg(
  1329. hf_hub_id='timm/',
  1330. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95,
  1331. input_size=(3, 160, 160), pool_size=(5, 5), first_conv='conv1.0'),
  1332. })
  1333. @register_model
  1334. def resnet10t(pretrained: bool = False, **kwargs) -> ResNet:
  1335. """Constructs a ResNet-10-T model.
  1336. """
  1337. model_args = dict(block=BasicBlock, layers=(1, 1, 1, 1), stem_width=32, stem_type='deep_tiered', avg_down=True)
  1338. return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs))
  1339. @register_model
  1340. def resnet14t(pretrained: bool = False, **kwargs) -> ResNet:
  1341. """Constructs a ResNet-14-T model.
  1342. """
  1343. model_args = dict(block=Bottleneck, layers=(1, 1, 1, 1), stem_width=32, stem_type='deep_tiered', avg_down=True)
  1344. return _create_resnet('resnet14t', pretrained, **dict(model_args, **kwargs))
  1345. @register_model
  1346. def resnet18(pretrained: bool = False, **kwargs) -> ResNet:
  1347. """Constructs a ResNet-18 model.
  1348. """
  1349. model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2))
  1350. return _create_resnet('resnet18', pretrained, **dict(model_args, **kwargs))
  1351. @register_model
  1352. def resnet18d(pretrained: bool = False, **kwargs) -> ResNet:
  1353. """Constructs a ResNet-18-D model.
  1354. """
  1355. model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep', avg_down=True)
  1356. return _create_resnet('resnet18d', pretrained, **dict(model_args, **kwargs))
  1357. @register_model
  1358. def resnet34(pretrained: bool = False, **kwargs) -> ResNet:
  1359. """Constructs a ResNet-34 model.
  1360. """
  1361. model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3))
  1362. return _create_resnet('resnet34', pretrained, **dict(model_args, **kwargs))
  1363. @register_model
  1364. def resnet34d(pretrained: bool = False, **kwargs) -> ResNet:
  1365. """Constructs a ResNet-34-D model.
  1366. """
  1367. model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True)
  1368. return _create_resnet('resnet34d', pretrained, **dict(model_args, **kwargs))
  1369. @register_model
  1370. def resnet26(pretrained: bool = False, **kwargs) -> ResNet:
  1371. """Constructs a ResNet-26 model.
  1372. """
  1373. model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2))
  1374. return _create_resnet('resnet26', pretrained, **dict(model_args, **kwargs))
  1375. @register_model
  1376. def resnet26t(pretrained: bool = False, **kwargs) -> ResNet:
  1377. """Constructs a ResNet-26-T model.
  1378. """
  1379. model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep_tiered', avg_down=True)
  1380. return _create_resnet('resnet26t', pretrained, **dict(model_args, **kwargs))
  1381. @register_model
  1382. def resnet26d(pretrained: bool = False, **kwargs) -> ResNet:
  1383. """Constructs a ResNet-26-D model.
  1384. """
  1385. model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep', avg_down=True)
  1386. return _create_resnet('resnet26d', pretrained, **dict(model_args, **kwargs))
  1387. @register_model
  1388. def resnet50(pretrained: bool = False, **kwargs) -> ResNet:
  1389. """Constructs a ResNet-50 model.
  1390. """
  1391. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3))
  1392. return _create_resnet('resnet50', pretrained, **dict(model_args, **kwargs))
  1393. @register_model
  1394. def resnet50c(pretrained: bool = False, **kwargs) -> ResNet:
  1395. """Constructs a ResNet-50-C model.
  1396. """
  1397. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep')
  1398. return _create_resnet('resnet50c', pretrained, **dict(model_args, **kwargs))
  1399. @register_model
  1400. def resnet50d(pretrained: bool = False, **kwargs) -> ResNet:
  1401. """Constructs a ResNet-50-D model.
  1402. """
  1403. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True)
  1404. return _create_resnet('resnet50d', pretrained, **dict(model_args, **kwargs))
  1405. @register_model
  1406. def resnet50s(pretrained: bool = False, **kwargs) -> ResNet:
  1407. """Constructs a ResNet-50-S model.
  1408. """
  1409. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=64, stem_type='deep')
  1410. return _create_resnet('resnet50s', pretrained, **dict(model_args, **kwargs))
  1411. @register_model
  1412. def resnet50t(pretrained: bool = False, **kwargs) -> ResNet:
  1413. """Constructs a ResNet-50-T model.
  1414. """
  1415. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep_tiered', avg_down=True)
  1416. return _create_resnet('resnet50t', pretrained, **dict(model_args, **kwargs))
  1417. @register_model
  1418. def resnet101(pretrained: bool = False, **kwargs) -> ResNet:
  1419. """Constructs a ResNet-101 model.
  1420. """
  1421. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3))
  1422. return _create_resnet('resnet101', pretrained, **dict(model_args, **kwargs))
  1423. @register_model
  1424. def resnet101c(pretrained: bool = False, **kwargs) -> ResNet:
  1425. """Constructs a ResNet-101-C model.
  1426. """
  1427. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep')
  1428. return _create_resnet('resnet101c', pretrained, **dict(model_args, **kwargs))
  1429. @register_model
  1430. def resnet101d(pretrained: bool = False, **kwargs) -> ResNet:
  1431. """Constructs a ResNet-101-D model.
  1432. """
  1433. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True)
  1434. return _create_resnet('resnet101d', pretrained, **dict(model_args, **kwargs))
  1435. @register_model
  1436. def resnet101s(pretrained: bool = False, **kwargs) -> ResNet:
  1437. """Constructs a ResNet-101-S model.
  1438. """
  1439. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=64, stem_type='deep')
  1440. return _create_resnet('resnet101s', pretrained, **dict(model_args, **kwargs))
  1441. @register_model
  1442. def resnet152(pretrained: bool = False, **kwargs) -> ResNet:
  1443. """Constructs a ResNet-152 model.
  1444. """
  1445. model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3))
  1446. return _create_resnet('resnet152', pretrained, **dict(model_args, **kwargs))
  1447. @register_model
  1448. def resnet152c(pretrained: bool = False, **kwargs) -> ResNet:
  1449. """Constructs a ResNet-152-C model.
  1450. """
  1451. model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep')
  1452. return _create_resnet('resnet152c', pretrained, **dict(model_args, **kwargs))
  1453. @register_model
  1454. def resnet152d(pretrained: bool = False, **kwargs) -> ResNet:
  1455. """Constructs a ResNet-152-D model.
  1456. """
  1457. model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep', avg_down=True)
  1458. return _create_resnet('resnet152d', pretrained, **dict(model_args, **kwargs))
  1459. @register_model
  1460. def resnet152s(pretrained: bool = False, **kwargs) -> ResNet:
  1461. """Constructs a ResNet-152-S model.
  1462. """
  1463. model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=64, stem_type='deep')
  1464. return _create_resnet('resnet152s', pretrained, **dict(model_args, **kwargs))
  1465. @register_model
  1466. def resnet200(pretrained: bool = False, **kwargs) -> ResNet:
  1467. """Constructs a ResNet-200 model.
  1468. """
  1469. model_args = dict(block=Bottleneck, layers=(3, 24, 36, 3))
  1470. return _create_resnet('resnet200', pretrained, **dict(model_args, **kwargs))
  1471. @register_model
  1472. def resnet200d(pretrained: bool = False, **kwargs) -> ResNet:
  1473. """Constructs a ResNet-200-D model.
  1474. """
  1475. model_args = dict(block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', avg_down=True)
  1476. return _create_resnet('resnet200d', pretrained, **dict(model_args, **kwargs))
  1477. @register_model
  1478. def wide_resnet50_2(pretrained: bool = False, **kwargs) -> ResNet:
  1479. """Constructs a Wide ResNet-50-2 model.
  1480. The model is the same as ResNet except for the bottleneck number of channels
  1481. which is twice larger in every block. The number of channels in outer 1x1
  1482. convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
  1483. channels, and in Wide ResNet-50-2 has 2048-1024-2048.
  1484. """
  1485. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), base_width=128)
  1486. return _create_resnet('wide_resnet50_2', pretrained, **dict(model_args, **kwargs))
  1487. @register_model
  1488. def wide_resnet101_2(pretrained: bool = False, **kwargs) -> ResNet:
  1489. """Constructs a Wide ResNet-101-2 model.
  1490. The model is the same as ResNet except for the bottleneck number of channels
  1491. which is twice larger in every block. The number of channels in outer 1x1
  1492. convolutions is the same.
  1493. """
  1494. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), base_width=128)
  1495. return _create_resnet('wide_resnet101_2', pretrained, **dict(model_args, **kwargs))
  1496. @register_model
  1497. def resnet50_gn(pretrained: bool = False, **kwargs) -> ResNet:
  1498. """Constructs a ResNet-50 model w/ GroupNorm
  1499. """
  1500. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), norm_layer='groupnorm')
  1501. return _create_resnet('resnet50_gn', pretrained, **dict(model_args, **kwargs))
  1502. @register_model
  1503. def resnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1504. """Constructs a ResNeXt50-32x4d model.
  1505. """
  1506. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4)
  1507. return _create_resnet('resnext50_32x4d', pretrained, **dict(model_args, **kwargs))
  1508. @register_model
  1509. def resnext50d_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1510. """Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample
  1511. """
  1512. model_args = dict(
  1513. block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4,
  1514. stem_width=32, stem_type='deep', avg_down=True)
  1515. return _create_resnet('resnext50d_32x4d', pretrained, **dict(model_args, **kwargs))
  1516. @register_model
  1517. def resnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1518. """Constructs a ResNeXt-101 32x4d model.
  1519. """
  1520. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=4)
  1521. return _create_resnet('resnext101_32x4d', pretrained, **dict(model_args, **kwargs))
  1522. @register_model
  1523. def resnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
  1524. """Constructs a ResNeXt-101 32x8d model.
  1525. """
  1526. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8)
  1527. return _create_resnet('resnext101_32x8d', pretrained, **dict(model_args, **kwargs))
  1528. @register_model
  1529. def resnext101_32x16d(pretrained: bool = False, **kwargs) -> ResNet:
  1530. """Constructs a ResNeXt-101 32x16d model
  1531. """
  1532. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=16)
  1533. return _create_resnet('resnext101_32x16d', pretrained, **dict(model_args, **kwargs))
  1534. @register_model
  1535. def resnext101_32x32d(pretrained: bool = False, **kwargs) -> ResNet:
  1536. """Constructs a ResNeXt-101 32x32d model
  1537. """
  1538. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=32)
  1539. return _create_resnet('resnext101_32x32d', pretrained, **dict(model_args, **kwargs))
  1540. @register_model
  1541. def resnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1542. """Constructs a ResNeXt101-64x4d model.
  1543. """
  1544. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=64, base_width=4)
  1545. return _create_resnet('resnext101_64x4d', pretrained, **dict(model_args, **kwargs))
  1546. @register_model
  1547. def ecaresnet26t(pretrained: bool = False, **kwargs) -> ResNet:
  1548. """Constructs an ECA-ResNeXt-26-T model.
  1549. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
  1550. in the deep stem and ECA attn.
  1551. """
  1552. model_args = dict(
  1553. block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32,
  1554. stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
  1555. return _create_resnet('ecaresnet26t', pretrained, **dict(model_args, **kwargs))
  1556. @register_model
  1557. def ecaresnet50d(pretrained: bool = False, **kwargs) -> ResNet:
  1558. """Constructs a ResNet-50-D model with eca.
  1559. """
  1560. model_args = dict(
  1561. block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True,
  1562. block_args=dict(attn_layer='eca'))
  1563. return _create_resnet('ecaresnet50d', pretrained, **dict(model_args, **kwargs))
  1564. @register_model
  1565. def ecaresnet50d_pruned(pretrained: bool = False, **kwargs) -> ResNet:
  1566. """Constructs a ResNet-50-D model pruned with eca.
  1567. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
  1568. """
  1569. model_args = dict(
  1570. block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True,
  1571. block_args=dict(attn_layer='eca'))
  1572. return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs))
  1573. @register_model
  1574. def ecaresnet50t(pretrained: bool = False, **kwargs) -> ResNet:
  1575. """Constructs an ECA-ResNet-50-T model.
  1576. Like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem and ECA attn.
  1577. """
  1578. model_args = dict(
  1579. block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32,
  1580. stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
  1581. return _create_resnet('ecaresnet50t', pretrained, **dict(model_args, **kwargs))
  1582. @register_model
  1583. def ecaresnetlight(pretrained: bool = False, **kwargs) -> ResNet:
  1584. """Constructs a ResNet-50-D light model with eca.
  1585. """
  1586. model_args = dict(
  1587. block=Bottleneck, layers=(1, 1, 11, 3), stem_width=32, avg_down=True,
  1588. block_args=dict(attn_layer='eca'))
  1589. return _create_resnet('ecaresnetlight', pretrained, **dict(model_args, **kwargs))
  1590. @register_model
  1591. def ecaresnet101d(pretrained: bool = False, **kwargs) -> ResNet:
  1592. """Constructs a ResNet-101-D model with eca.
  1593. """
  1594. model_args = dict(
  1595. block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True,
  1596. block_args=dict(attn_layer='eca'))
  1597. return _create_resnet('ecaresnet101d', pretrained, **dict(model_args, **kwargs))
  1598. @register_model
  1599. def ecaresnet101d_pruned(pretrained: bool = False, **kwargs) -> ResNet:
  1600. """Constructs a ResNet-101-D model pruned with eca.
  1601. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
  1602. """
  1603. model_args = dict(
  1604. block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True,
  1605. block_args=dict(attn_layer='eca'))
  1606. return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs))
  1607. @register_model
  1608. def ecaresnet200d(pretrained: bool = False, **kwargs) -> ResNet:
  1609. """Constructs a ResNet-200-D model with ECA.
  1610. """
  1611. model_args = dict(
  1612. block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', avg_down=True,
  1613. block_args=dict(attn_layer='eca'))
  1614. return _create_resnet('ecaresnet200d', pretrained, **dict(model_args, **kwargs))
  1615. @register_model
  1616. def ecaresnet269d(pretrained: bool = False, **kwargs) -> ResNet:
  1617. """Constructs a ResNet-269-D model with ECA.
  1618. """
  1619. model_args = dict(
  1620. block=Bottleneck, layers=(3, 30, 48, 8), stem_width=32, stem_type='deep', avg_down=True,
  1621. block_args=dict(attn_layer='eca'))
  1622. return _create_resnet('ecaresnet269d', pretrained, **dict(model_args, **kwargs))
  1623. @register_model
  1624. def ecaresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1625. """Constructs an ECA-ResNeXt-26-T model.
  1626. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
  1627. in the deep stem. This model replaces SE module with the ECA module
  1628. """
  1629. model_args = dict(
  1630. block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
  1631. stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
  1632. return _create_resnet('ecaresnext26t_32x4d', pretrained, **dict(model_args, **kwargs))
  1633. @register_model
  1634. def ecaresnext50t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1635. """Constructs an ECA-ResNeXt-50-T model.
  1636. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
  1637. in the deep stem. This model replaces SE module with the ECA module
  1638. """
  1639. model_args = dict(
  1640. block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
  1641. stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
  1642. return _create_resnet('ecaresnext50t_32x4d', pretrained, **dict(model_args, **kwargs))
  1643. @register_model
  1644. def seresnet18(pretrained: bool = False, **kwargs) -> ResNet:
  1645. model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), block_args=dict(attn_layer='se'))
  1646. return _create_resnet('seresnet18', pretrained, **dict(model_args, **kwargs))
  1647. @register_model
  1648. def seresnet34(pretrained: bool = False, **kwargs) -> ResNet:
  1649. model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3), block_args=dict(attn_layer='se'))
  1650. return _create_resnet('seresnet34', pretrained, **dict(model_args, **kwargs))
  1651. @register_model
  1652. def seresnet50(pretrained: bool = False, **kwargs) -> ResNet:
  1653. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), block_args=dict(attn_layer='se'))
  1654. return _create_resnet('seresnet50', pretrained, **dict(model_args, **kwargs))
  1655. @register_model
  1656. def seresnet50t(pretrained: bool = False, **kwargs) -> ResNet:
  1657. model_args = dict(
  1658. block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep_tiered',
  1659. avg_down=True, block_args=dict(attn_layer='se'))
  1660. return _create_resnet('seresnet50t', pretrained, **dict(model_args, **kwargs))
  1661. @register_model
  1662. def seresnet101(pretrained: bool = False, **kwargs) -> ResNet:
  1663. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), block_args=dict(attn_layer='se'))
  1664. return _create_resnet('seresnet101', pretrained, **dict(model_args, **kwargs))
  1665. @register_model
  1666. def seresnet152(pretrained: bool = False, **kwargs) -> ResNet:
  1667. model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), block_args=dict(attn_layer='se'))
  1668. return _create_resnet('seresnet152', pretrained, **dict(model_args, **kwargs))
  1669. @register_model
  1670. def seresnet152d(pretrained: bool = False, **kwargs) -> ResNet:
  1671. model_args = dict(
  1672. block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep',
  1673. avg_down=True, block_args=dict(attn_layer='se'))
  1674. return _create_resnet('seresnet152d', pretrained, **dict(model_args, **kwargs))
  1675. @register_model
  1676. def seresnet200d(pretrained: bool = False, **kwargs) -> ResNet:
  1677. """Constructs a ResNet-200-D model with SE attn.
  1678. """
  1679. model_args = dict(
  1680. block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep',
  1681. avg_down=True, block_args=dict(attn_layer='se'))
  1682. return _create_resnet('seresnet200d', pretrained, **dict(model_args, **kwargs))
  1683. @register_model
  1684. def seresnet269d(pretrained: bool = False, **kwargs) -> ResNet:
  1685. """Constructs a ResNet-269-D model with SE attn.
  1686. """
  1687. model_args = dict(
  1688. block=Bottleneck, layers=(3, 30, 48, 8), stem_width=32, stem_type='deep',
  1689. avg_down=True, block_args=dict(attn_layer='se'))
  1690. return _create_resnet('seresnet269d', pretrained, **dict(model_args, **kwargs))
  1691. @register_model
  1692. def seresnext26d_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1693. """Constructs a SE-ResNeXt-26-D model.`
  1694. This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for
  1695. combination of deep stem and avg_pool in downsample.
  1696. """
  1697. model_args = dict(
  1698. block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
  1699. stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'))
  1700. return _create_resnet('seresnext26d_32x4d', pretrained, **dict(model_args, **kwargs))
  1701. @register_model
  1702. def seresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1703. """Constructs a SE-ResNet-26-T model.
  1704. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
  1705. in the deep stem.
  1706. """
  1707. model_args = dict(
  1708. block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
  1709. stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'))
  1710. return _create_resnet('seresnext26t_32x4d', pretrained, **dict(model_args, **kwargs))
  1711. @register_model
  1712. def seresnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1713. model_args = dict(
  1714. block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4,
  1715. block_args=dict(attn_layer='se'))
  1716. return _create_resnet('seresnext50_32x4d', pretrained, **dict(model_args, **kwargs))
  1717. @register_model
  1718. def seresnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1719. model_args = dict(
  1720. block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=4,
  1721. block_args=dict(attn_layer='se'))
  1722. return _create_resnet('seresnext101_32x4d', pretrained, **dict(model_args, **kwargs))
  1723. @register_model
  1724. def seresnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
  1725. model_args = dict(
  1726. block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8,
  1727. block_args=dict(attn_layer='se'))
  1728. return _create_resnet('seresnext101_32x8d', pretrained, **dict(model_args, **kwargs))
  1729. @register_model
  1730. def seresnext101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
  1731. model_args = dict(
  1732. block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8,
  1733. stem_width=32, stem_type='deep', avg_down=True,
  1734. block_args=dict(attn_layer='se'))
  1735. return _create_resnet('seresnext101d_32x8d', pretrained, **dict(model_args, **kwargs))
  1736. @register_model
  1737. def seresnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1738. model_args = dict(
  1739. block=Bottleneck, layers=(3, 4, 23, 3), cardinality=64, base_width=4,
  1740. block_args=dict(attn_layer='se'))
  1741. return _create_resnet('seresnext101_64x4d', pretrained, **dict(model_args, **kwargs))
  1742. @register_model
  1743. def senet154(pretrained: bool = False, **kwargs) -> ResNet:
  1744. model_args = dict(
  1745. block=Bottleneck, layers=(3, 8, 36, 3), cardinality=64, base_width=4, stem_type='deep',
  1746. down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'))
  1747. return _create_resnet('senet154', pretrained, **dict(model_args, **kwargs))
  1748. @register_model
  1749. def resnetblur18(pretrained: bool = False, **kwargs) -> ResNet:
  1750. """Constructs a ResNet-18 model with blur anti-aliasing
  1751. """
  1752. model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), aa_layer=BlurPool2d)
  1753. return _create_resnet('resnetblur18', pretrained, **dict(model_args, **kwargs))
  1754. @register_model
  1755. def resnetblur50(pretrained: bool = False, **kwargs) -> ResNet:
  1756. """Constructs a ResNet-50 model with blur anti-aliasing
  1757. """
  1758. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=BlurPool2d)
  1759. return _create_resnet('resnetblur50', pretrained, **dict(model_args, **kwargs))
  1760. @register_model
  1761. def resnetblur50d(pretrained: bool = False, **kwargs) -> ResNet:
  1762. """Constructs a ResNet-50-D model with blur anti-aliasing
  1763. """
  1764. model_args = dict(
  1765. block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=BlurPool2d,
  1766. stem_width=32, stem_type='deep', avg_down=True)
  1767. return _create_resnet('resnetblur50d', pretrained, **dict(model_args, **kwargs))
  1768. @register_model
  1769. def resnetblur101d(pretrained: bool = False, **kwargs) -> ResNet:
  1770. """Constructs a ResNet-101-D model with blur anti-aliasing
  1771. """
  1772. model_args = dict(
  1773. block=Bottleneck, layers=(3, 4, 23, 3), aa_layer=BlurPool2d,
  1774. stem_width=32, stem_type='deep', avg_down=True)
  1775. return _create_resnet('resnetblur101d', pretrained, **dict(model_args, **kwargs))
  1776. @register_model
  1777. def resnetaa34d(pretrained: bool = False, **kwargs) -> ResNet:
  1778. """Constructs a ResNet-34-D model w/ avgpool anti-aliasing
  1779. """
  1780. model_args = dict(
  1781. block=BasicBlock, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True)
  1782. return _create_resnet('resnetaa34d', pretrained, **dict(model_args, **kwargs))
  1783. @register_model
  1784. def resnetaa50(pretrained: bool = False, **kwargs) -> ResNet:
  1785. """Constructs a ResNet-50 model with avgpool anti-aliasing
  1786. """
  1787. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d)
  1788. return _create_resnet('resnetaa50', pretrained, **dict(model_args, **kwargs))
  1789. @register_model
  1790. def resnetaa50d(pretrained: bool = False, **kwargs) -> ResNet:
  1791. """Constructs a ResNet-50-D model with avgpool anti-aliasing
  1792. """
  1793. model_args = dict(
  1794. block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d,
  1795. stem_width=32, stem_type='deep', avg_down=True)
  1796. return _create_resnet('resnetaa50d', pretrained, **dict(model_args, **kwargs))
  1797. @register_model
  1798. def resnetaa101d(pretrained: bool = False, **kwargs) -> ResNet:
  1799. """Constructs a ResNet-101-D model with avgpool anti-aliasing
  1800. """
  1801. model_args = dict(
  1802. block=Bottleneck, layers=(3, 4, 23, 3), aa_layer=nn.AvgPool2d,
  1803. stem_width=32, stem_type='deep', avg_down=True)
  1804. return _create_resnet('resnetaa101d', pretrained, **dict(model_args, **kwargs))
  1805. @register_model
  1806. def seresnetaa50d(pretrained: bool = False, **kwargs) -> ResNet:
  1807. """Constructs a SE=ResNet-50-D model with avgpool anti-aliasing
  1808. """
  1809. model_args = dict(
  1810. block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d,
  1811. stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'))
  1812. return _create_resnet('seresnetaa50d', pretrained, **dict(model_args, **kwargs))
  1813. @register_model
  1814. def seresnextaa101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
  1815. """Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing
  1816. """
  1817. model_args = dict(
  1818. block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8,
  1819. stem_width=32, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d,
  1820. block_args=dict(attn_layer='se'))
  1821. return _create_resnet('seresnextaa101d_32x8d', pretrained, **dict(model_args, **kwargs))
  1822. @register_model
  1823. def seresnextaa201d_32x8d(pretrained: bool = False, **kwargs):
  1824. """Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing
  1825. """
  1826. model_args = dict(
  1827. block=Bottleneck, layers=(3, 24, 36, 4), cardinality=32, base_width=8,
  1828. stem_width=64, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d,
  1829. block_args=dict(attn_layer='se'))
  1830. return _create_resnet('seresnextaa201d_32x8d', pretrained, **dict(model_args, **kwargs))
  1831. @register_model
  1832. def resnetrs50(pretrained: bool = False, **kwargs) -> ResNet:
  1833. """Constructs a ResNet-RS-50 model.
  1834. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
  1835. Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
  1836. """
  1837. attn_layer = partial(get_attn('se'), rd_ratio=0.25)
  1838. model_args = dict(
  1839. block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
  1840. avg_down=True, block_args=dict(attn_layer=attn_layer))
  1841. return _create_resnet('resnetrs50', pretrained, **dict(model_args, **kwargs))
  1842. @register_model
  1843. def resnetrs101(pretrained: bool = False, **kwargs) -> ResNet:
  1844. """Constructs a ResNet-RS-101 model.
  1845. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
  1846. Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
  1847. """
  1848. attn_layer = partial(get_attn('se'), rd_ratio=0.25)
  1849. model_args = dict(
  1850. block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
  1851. avg_down=True, block_args=dict(attn_layer=attn_layer))
  1852. return _create_resnet('resnetrs101', pretrained, **dict(model_args, **kwargs))
  1853. @register_model
  1854. def resnetrs152(pretrained: bool = False, **kwargs) -> ResNet:
  1855. """Constructs a ResNet-RS-152 model.
  1856. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
  1857. Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
  1858. """
  1859. attn_layer = partial(get_attn('se'), rd_ratio=0.25)
  1860. model_args = dict(
  1861. block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
  1862. avg_down=True, block_args=dict(attn_layer=attn_layer))
  1863. return _create_resnet('resnetrs152', pretrained, **dict(model_args, **kwargs))
  1864. @register_model
  1865. def resnetrs200(pretrained: bool = False, **kwargs) -> ResNet:
  1866. """Constructs a ResNet-RS-200 model.
  1867. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
  1868. Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
  1869. """
  1870. attn_layer = partial(get_attn('se'), rd_ratio=0.25)
  1871. model_args = dict(
  1872. block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
  1873. avg_down=True, block_args=dict(attn_layer=attn_layer))
  1874. return _create_resnet('resnetrs200', pretrained, **dict(model_args, **kwargs))
  1875. @register_model
  1876. def resnetrs270(pretrained: bool = False, **kwargs) -> ResNet:
  1877. """Constructs a ResNet-RS-270 model.
  1878. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
  1879. Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
  1880. """
  1881. attn_layer = partial(get_attn('se'), rd_ratio=0.25)
  1882. model_args = dict(
  1883. block=Bottleneck, layers=(4, 29, 53, 4), stem_width=32, stem_type='deep', replace_stem_pool=True,
  1884. avg_down=True, block_args=dict(attn_layer=attn_layer))
  1885. return _create_resnet('resnetrs270', pretrained, **dict(model_args, **kwargs))
  1886. @register_model
  1887. def resnetrs350(pretrained: bool = False, **kwargs) -> ResNet:
  1888. """Constructs a ResNet-RS-350 model.
  1889. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
  1890. Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
  1891. """
  1892. attn_layer = partial(get_attn('se'), rd_ratio=0.25)
  1893. model_args = dict(
  1894. block=Bottleneck, layers=(4, 36, 72, 4), stem_width=32, stem_type='deep', replace_stem_pool=True,
  1895. avg_down=True, block_args=dict(attn_layer=attn_layer))
  1896. return _create_resnet('resnetrs350', pretrained, **dict(model_args, **kwargs))
  1897. @register_model
  1898. def resnetrs420(pretrained: bool = False, **kwargs) -> ResNet:
  1899. """Constructs a ResNet-RS-420 model
  1900. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
  1901. Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
  1902. """
  1903. attn_layer = partial(get_attn('se'), rd_ratio=0.25)
  1904. model_args = dict(
  1905. block=Bottleneck, layers=(4, 44, 87, 4), stem_width=32, stem_type='deep', replace_stem_pool=True,
  1906. avg_down=True, block_args=dict(attn_layer=attn_layer))
  1907. return _create_resnet('resnetrs420', pretrained, **dict(model_args, **kwargs))
  1908. @register_model
  1909. def test_resnet(pretrained: bool = False, **kwargs) -> ResNet:
  1910. """Constructs a tiny ResNet test model.
  1911. """
  1912. model_args = dict(
  1913. block=[BasicBlock, BasicBlock, Bottleneck, BasicBlock], layers=(1, 1, 1, 1),
  1914. stem_width=16, stem_type='deep', avg_down=True, channels=(32, 48, 48, 96))
  1915. return _create_resnet('test_resnet', pretrained, **dict(model_args, **kwargs))
  1916. register_model_deprecations(__name__, {
  1917. 'tv_resnet34': 'resnet34.tv_in1k',
  1918. 'tv_resnet50': 'resnet50.tv_in1k',
  1919. 'tv_resnet101': 'resnet101.tv_in1k',
  1920. 'tv_resnet152': 'resnet152.tv_in1k',
  1921. 'tv_resnext50_32x4d' : 'resnext50_32x4d.tv_in1k',
  1922. 'ig_resnext101_32x8d': 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k',
  1923. 'ig_resnext101_32x16d': 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k',
  1924. 'ig_resnext101_32x32d': 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k',
  1925. 'ig_resnext101_32x48d': 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k',
  1926. 'ssl_resnet18': 'resnet18.fb_ssl_yfcc100m_ft_in1k',
  1927. 'ssl_resnet50': 'resnet50.fb_ssl_yfcc100m_ft_in1k',
  1928. 'ssl_resnext50_32x4d': 'resnext50_32x4d.fb_ssl_yfcc100m_ft_in1k',
  1929. 'ssl_resnext101_32x4d': 'resnext101_32x4d.fb_ssl_yfcc100m_ft_in1k',
  1930. 'ssl_resnext101_32x8d': 'resnext101_32x8d.fb_ssl_yfcc100m_ft_in1k',
  1931. 'ssl_resnext101_32x16d': 'resnext101_32x16d.fb_ssl_yfcc100m_ft_in1k',
  1932. 'swsl_resnet18': 'resnet18.fb_swsl_ig1b_ft_in1k',
  1933. 'swsl_resnet50': 'resnet50.fb_swsl_ig1b_ft_in1k',
  1934. 'swsl_resnext50_32x4d': 'resnext50_32x4d.fb_swsl_ig1b_ft_in1k',
  1935. 'swsl_resnext101_32x4d': 'resnext101_32x4d.fb_swsl_ig1b_ft_in1k',
  1936. 'swsl_resnext101_32x8d': 'resnext101_32x8d.fb_swsl_ig1b_ft_in1k',
  1937. 'swsl_resnext101_32x16d': 'resnext101_32x16d.fb_swsl_ig1b_ft_in1k',
  1938. 'gluon_resnet18_v1b': 'resnet18.gluon_in1k',
  1939. 'gluon_resnet34_v1b': 'resnet34.gluon_in1k',
  1940. 'gluon_resnet50_v1b': 'resnet50.gluon_in1k',
  1941. 'gluon_resnet101_v1b': 'resnet101.gluon_in1k',
  1942. 'gluon_resnet152_v1b': 'resnet152.gluon_in1k',
  1943. 'gluon_resnet50_v1c': 'resnet50c.gluon_in1k',
  1944. 'gluon_resnet101_v1c': 'resnet101c.gluon_in1k',
  1945. 'gluon_resnet152_v1c': 'resnet152c.gluon_in1k',
  1946. 'gluon_resnet50_v1d': 'resnet50d.gluon_in1k',
  1947. 'gluon_resnet101_v1d': 'resnet101d.gluon_in1k',
  1948. 'gluon_resnet152_v1d': 'resnet152d.gluon_in1k',
  1949. 'gluon_resnet50_v1s': 'resnet50s.gluon_in1k',
  1950. 'gluon_resnet101_v1s': 'resnet101s.gluon_in1k',
  1951. 'gluon_resnet152_v1s': 'resnet152s.gluon_in1k',
  1952. 'gluon_resnext50_32x4d': 'resnext50_32x4d.gluon_in1k',
  1953. 'gluon_resnext101_32x4d': 'resnext101_32x4d.gluon_in1k',
  1954. 'gluon_resnext101_64x4d': 'resnext101_64x4d.gluon_in1k',
  1955. 'gluon_seresnext50_32x4d': 'seresnext50_32x4d.gluon_in1k',
  1956. 'gluon_seresnext101_32x4d': 'seresnext101_32x4d.gluon_in1k',
  1957. 'gluon_seresnext101_64x4d': 'seresnext101_64x4d.gluon_in1k',
  1958. 'gluon_senet154': 'senet154.gluon_in1k',
  1959. 'seresnext26tn_32x4d': 'seresnext26t_32x4d',
  1960. })