| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208 |
- """PyTorch CspNet
- A PyTorch implementation of Cross Stage Partial Networks including:
- * CSPResNet50
- * CSPResNeXt50
- * CSPDarkNet53
- * and DarkNet53 for good measure
- Based on paper `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
- Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStagePartialNetworks
- Hacked together by / Copyright 2020 Ross Wightman
- """
- from dataclasses import dataclass, asdict, replace
- from functools import partial
- from typing import Any, Dict, List, Optional, Tuple, Type, Union
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import ClassifierHead, ConvNormAct, DropPath, calculate_drop_path_rates, get_attn, create_act_layer, make_divisible
- from ._builder import build_model_with_cfg
- from ._manipulate import named_apply, MATCH_PREV_GROUP
- from ._registry import register_model, generate_default_cfgs
- __all__ = ['CspNet'] # model_registry will add each entrypoint fn to this
- @dataclass
- class CspStemCfg:
- out_chs: Union[int, Tuple[int, ...]] = 32
- stride: Union[int, Tuple[int, ...]] = 2
- kernel_size: int = 3
- padding: Union[int, str] = ''
- pool: Optional[str] = ''
- def _pad_arg(x, n):
- # pads an argument tuple to specified n by padding with last value
- if not isinstance(x, (tuple, list)):
- x = (x,)
- curr_n = len(x)
- pad_n = n - curr_n
- if pad_n <= 0:
- return x[:n]
- return tuple(x + (x[-1],) * pad_n)
- @dataclass
- class CspStagesCfg:
- depth: Tuple[int, ...] = (3, 3, 5, 2) # block depth (number of block repeats in stages)
- out_chs: Tuple[int, ...] = (128, 256, 512, 1024) # number of output channels for blocks in stage
- stride: Union[int, Tuple[int, ...]] = 2 # stride of stage
- groups: Union[int, Tuple[int, ...]] = 1 # num kxk conv groups
- block_ratio: Union[float, Tuple[float, ...]] = 1.0
- bottle_ratio: Union[float, Tuple[float, ...]] = 1. # bottleneck-ratio of blocks in stage
- avg_down: Union[bool, Tuple[bool, ...]] = False
- attn_layer: Optional[Union[str, Tuple[str, ...]]] = None
- attn_kwargs: Optional[Union[Dict, Tuple[Dict]]] = None
- stage_type: Union[str, Tuple[str]] = 'csp' # stage type ('csp', 'cs2', 'dark')
- block_type: Union[str, Tuple[str]] = 'bottle' # blocks type for stages ('bottle', 'dark')
- # cross-stage only
- expand_ratio: Union[float, Tuple[float, ...]] = 1.0
- cross_linear: Union[bool, Tuple[bool, ...]] = False
- down_growth: Union[bool, Tuple[bool, ...]] = False
- def __post_init__(self):
- n = len(self.depth)
- assert len(self.out_chs) == n
- self.stride = _pad_arg(self.stride, n)
- self.groups = _pad_arg(self.groups, n)
- self.block_ratio = _pad_arg(self.block_ratio, n)
- self.bottle_ratio = _pad_arg(self.bottle_ratio, n)
- self.avg_down = _pad_arg(self.avg_down, n)
- self.attn_layer = _pad_arg(self.attn_layer, n)
- self.attn_kwargs = _pad_arg(self.attn_kwargs, n)
- self.stage_type = _pad_arg(self.stage_type, n)
- self.block_type = _pad_arg(self.block_type, n)
- self.expand_ratio = _pad_arg(self.expand_ratio, n)
- self.cross_linear = _pad_arg(self.cross_linear, n)
- self.down_growth = _pad_arg(self.down_growth, n)
- @dataclass
- class CspModelCfg:
- stem: CspStemCfg
- stages: CspStagesCfg
- zero_init_last: bool = True # zero init last weight (usually bn) in residual path
- act_layer: str = 'leaky_relu'
- norm_layer: str = 'batchnorm'
- aa_layer: Optional[str] = None # FIXME support string factory for this
- def _cs3_cfg(
- width_multiplier=1.0,
- depth_multiplier=1.0,
- avg_down=False,
- act_layer='silu',
- focus=False,
- attn_layer=None,
- attn_kwargs=None,
- bottle_ratio=1.0,
- block_type='dark',
- ):
- if focus:
- stem_cfg = CspStemCfg(
- out_chs=make_divisible(64 * width_multiplier),
- kernel_size=6, stride=2, padding=2, pool='')
- else:
- stem_cfg = CspStemCfg(
- out_chs=tuple([make_divisible(c * width_multiplier) for c in (32, 64)]),
- kernel_size=3, stride=2, pool='')
- return CspModelCfg(
- stem=stem_cfg,
- stages=CspStagesCfg(
- out_chs=tuple([make_divisible(c * width_multiplier) for c in (128, 256, 512, 1024)]),
- depth=tuple([int(d * depth_multiplier) for d in (3, 6, 9, 3)]),
- stride=2,
- bottle_ratio=bottle_ratio,
- block_ratio=0.5,
- avg_down=avg_down,
- attn_layer=attn_layer,
- attn_kwargs=attn_kwargs,
- stage_type='cs3',
- block_type=block_type,
- ),
- act_layer=act_layer,
- )
- class BottleneckBlock(nn.Module):
- """ ResNe(X)t Bottleneck Block
- """
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- dilation: int = 1,
- bottle_ratio: float = 0.25,
- groups: int = 1,
- act_layer: Type[nn.Module] = nn.ReLU,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- attn_last: bool = False,
- attn_layer: Optional[Type[nn.Module]] = None,
- drop_block: Optional[Type[nn.Module]] = None,
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- mid_chs = int(round(out_chs * bottle_ratio))
- ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
- attn_last = attn_layer is not None and attn_last
- attn_first = attn_layer is not None and not attn_last
- self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs, **dd)
- self.conv2 = ConvNormAct(
- mid_chs,
- mid_chs,
- kernel_size=3,
- dilation=dilation,
- groups=groups,
- drop_layer=drop_block,
- **ckwargs,
- **dd,
- )
- self.attn2 = attn_layer(mid_chs, act_layer=act_layer, **dd) if attn_first else nn.Identity()
- self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs, **dd)
- self.attn3 = attn_layer(out_chs, act_layer=act_layer, **dd) if attn_last else nn.Identity()
- self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
- self.act3 = create_act_layer(act_layer)
- def zero_init_last(self):
- nn.init.zeros_(self.conv3.bn.weight)
- def forward(self, x):
- shortcut = x
- x = self.conv1(x)
- x = self.conv2(x)
- x = self.attn2(x)
- x = self.conv3(x)
- x = self.attn3(x)
- x = self.drop_path(x) + shortcut
- # FIXME partial shortcut needed if first block handled as per original, not used for my current impl
- #x[:, :shortcut.size(1)] += shortcut
- x = self.act3(x)
- return x
- class DarkBlock(nn.Module):
- """ DarkNet Block
- """
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- dilation: int = 1,
- bottle_ratio: float = 0.5,
- groups: int = 1,
- act_layer: Type[nn.Module] = nn.ReLU,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- attn_layer: Optional[Type[nn.Module]] = None,
- drop_block: Optional[Type[nn.Module]] = None,
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- mid_chs = int(round(out_chs * bottle_ratio))
- ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
- self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs, **dd)
- self.attn = attn_layer(mid_chs, act_layer=act_layer, **dd) if attn_layer is not None else nn.Identity()
- self.conv2 = ConvNormAct(
- mid_chs,
- out_chs,
- kernel_size=3,
- dilation=dilation,
- groups=groups,
- drop_layer=drop_block,
- **ckwargs,
- **dd,
- )
- self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
- def zero_init_last(self):
- nn.init.zeros_(self.conv2.bn.weight)
- def forward(self, x):
- shortcut = x
- x = self.conv1(x)
- x = self.attn(x)
- x = self.conv2(x)
- x = self.drop_path(x) + shortcut
- return x
- class EdgeBlock(nn.Module):
- """ EdgeResidual / Fused-MBConv / MobileNetV1-like 3x3 + 1x1 block (w/ activated output)
- """
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- dilation: int = 1,
- bottle_ratio: float = 0.5,
- groups: int = 1,
- act_layer: Type[nn.Module] = nn.ReLU,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- attn_layer: Optional[Type[nn.Module]] = None,
- drop_block: Optional[Type[nn.Module]] = None,
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- mid_chs = int(round(out_chs * bottle_ratio))
- ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
- self.conv1 = ConvNormAct(
- in_chs,
- mid_chs,
- kernel_size=3,
- dilation=dilation,
- groups=groups,
- drop_layer=drop_block,
- **ckwargs,
- **dd,
- )
- self.attn = attn_layer(mid_chs, act_layer=act_layer, **dd) if attn_layer is not None else nn.Identity()
- self.conv2 = ConvNormAct(mid_chs, out_chs, kernel_size=1, **ckwargs, **dd)
- self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
- def zero_init_last(self):
- nn.init.zeros_(self.conv2.bn.weight)
- def forward(self, x):
- shortcut = x
- x = self.conv1(x)
- x = self.attn(x)
- x = self.conv2(x)
- x = self.drop_path(x) + shortcut
- return x
- class CrossStage(nn.Module):
- """Cross Stage."""
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- stride: int,
- dilation: int,
- depth: int,
- block_ratio: float = 1.,
- bottle_ratio: float = 1.,
- expand_ratio: float = 1.,
- groups: int = 1,
- first_dilation: Optional[int] = None,
- avg_down: bool = False,
- down_growth: bool = False,
- cross_linear: bool = False,
- block_dpr: Optional[List[float]] = None,
- block_fn: Type[nn.Module] = BottleneckBlock,
- device=None,
- dtype=None,
- **block_kwargs,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- first_dilation = first_dilation or dilation
- down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
- self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
- block_out_chs = int(round(out_chs * block_ratio))
- conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
- aa_layer = block_kwargs.pop('aa_layer', None)
- if stride != 1 or first_dilation != dilation:
- if avg_down:
- self.conv_down = nn.Sequential(
- nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
- ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs, **dd)
- )
- else:
- self.conv_down = ConvNormAct(
- in_chs,
- down_chs,
- kernel_size=3,
- stride=stride,
- dilation=first_dilation,
- groups=groups,
- aa_layer=aa_layer,
- **conv_kwargs,
- **dd,
- )
- prev_chs = down_chs
- else:
- self.conv_down = nn.Identity()
- prev_chs = in_chs
- # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also,
- # there is also special case for the first stage for some of the model that results in uneven split
- # across the two paths. I did it this way for simplicity for now.
- self.conv_exp = ConvNormAct(
- prev_chs,
- exp_chs,
- kernel_size=1,
- apply_act=not cross_linear,
- **conv_kwargs,
- **dd,
- )
- prev_chs = exp_chs // 2 # output of conv_exp is always split in two
- self.blocks = nn.Sequential()
- for i in range(depth):
- self.blocks.add_module(str(i), block_fn(
- in_chs=prev_chs,
- out_chs=block_out_chs,
- dilation=dilation,
- bottle_ratio=bottle_ratio,
- groups=groups,
- drop_path=block_dpr[i] if block_dpr is not None else 0.,
- **block_kwargs,
- **dd,
- ))
- prev_chs = block_out_chs
- # transition convs
- self.conv_transition_b = ConvNormAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs, **dd)
- self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs, **dd)
- def forward(self, x):
- x = self.conv_down(x)
- x = self.conv_exp(x)
- xs, xb = x.split(self.expand_chs // 2, dim=1)
- xb = self.blocks(xb)
- xb = self.conv_transition_b(xb).contiguous()
- out = self.conv_transition(torch.cat([xs, xb], dim=1))
- return out
- class CrossStage3(nn.Module):
- """Cross Stage 3.
- Similar to CrossStage, but with only one transition conv for the output.
- """
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- stride: int,
- dilation: int,
- depth: int,
- block_ratio: float = 1.,
- bottle_ratio: float = 1.,
- expand_ratio: float = 1.,
- groups: int = 1,
- first_dilation: Optional[int] = None,
- avg_down: bool = False,
- down_growth: bool = False,
- cross_linear: bool = False,
- block_dpr: Optional[List[float]] = None,
- block_fn: Type[nn.Module] = BottleneckBlock,
- device=None,
- dtype=None,
- **block_kwargs,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- first_dilation = first_dilation or dilation
- down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
- self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
- block_out_chs = int(round(out_chs * block_ratio))
- conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
- aa_layer = block_kwargs.pop('aa_layer', None)
- if stride != 1 or first_dilation != dilation:
- if avg_down:
- self.conv_down = nn.Sequential(
- nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
- ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs, **dd)
- )
- else:
- self.conv_down = ConvNormAct(
- in_chs,
- down_chs,
- kernel_size=3,
- stride=stride,
- dilation=first_dilation,
- groups=groups,
- aa_layer=aa_layer,
- **conv_kwargs,
- **dd,
- )
- prev_chs = down_chs
- else:
- self.conv_down = None
- prev_chs = in_chs
- # expansion conv
- self.conv_exp = ConvNormAct(
- prev_chs,
- exp_chs,
- kernel_size=1,
- apply_act=not cross_linear,
- **conv_kwargs,
- **dd,
- )
- prev_chs = exp_chs // 2 # expanded output is split in 2 for blocks and cross stage
- self.blocks = nn.Sequential()
- for i in range(depth):
- self.blocks.add_module(str(i), block_fn(
- in_chs=prev_chs,
- out_chs=block_out_chs,
- dilation=dilation,
- bottle_ratio=bottle_ratio,
- groups=groups,
- drop_path=block_dpr[i] if block_dpr is not None else 0.,
- **block_kwargs,
- **dd,
- ))
- prev_chs = block_out_chs
- # transition convs
- self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs, **dd)
- def forward(self, x):
- x = self.conv_down(x)
- x = self.conv_exp(x)
- x1, x2 = x.split(self.expand_chs // 2, dim=1)
- x1 = self.blocks(x1)
- out = self.conv_transition(torch.cat([x1, x2], dim=1))
- return out
- class DarkStage(nn.Module):
- """DarkNet stage."""
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- stride: int,
- dilation: int,
- depth: int,
- block_ratio: float = 1.,
- bottle_ratio: float = 1.,
- groups: int = 1,
- first_dilation: Optional[int] = None,
- avg_down: bool = False,
- block_fn: Type[nn.Module] = BottleneckBlock,
- block_dpr: Optional[List[float]] = None,
- device=None,
- dtype=None,
- **block_kwargs,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- first_dilation = first_dilation or dilation
- conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
- aa_layer = block_kwargs.pop('aa_layer', None)
- if avg_down:
- self.conv_down = nn.Sequential(
- nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
- ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs, **dd)
- )
- else:
- self.conv_down = ConvNormAct(
- in_chs,
- out_chs,
- kernel_size=3,
- stride=stride,
- dilation=first_dilation,
- groups=groups,
- aa_layer=aa_layer,
- **conv_kwargs,
- **dd,
- )
- prev_chs = out_chs
- block_out_chs = int(round(out_chs * block_ratio))
- self.blocks = nn.Sequential()
- for i in range(depth):
- self.blocks.add_module(str(i), block_fn(
- in_chs=prev_chs,
- out_chs=block_out_chs,
- dilation=dilation,
- bottle_ratio=bottle_ratio,
- groups=groups,
- drop_path=block_dpr[i] if block_dpr is not None else 0.,
- **block_kwargs,
- **dd,
- ))
- prev_chs = block_out_chs
- def forward(self, x):
- x = self.conv_down(x)
- x = self.blocks(x)
- return x
- def create_csp_stem(
- in_chans: int = 3,
- out_chs: int = 32,
- kernel_size: int = 3,
- stride: int = 2,
- pool: str = '',
- padding: str = '',
- act_layer: Type[nn.Module] = nn.ReLU,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- aa_layer: Optional[Type[nn.Module]] = None,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- stem = nn.Sequential()
- feature_info = []
- if not isinstance(out_chs, (tuple, list)):
- out_chs = [out_chs]
- stem_depth = len(out_chs)
- assert stem_depth
- assert stride in (1, 2, 4)
- prev_feat = None
- prev_chs = in_chans
- last_idx = stem_depth - 1
- stem_stride = 1
- for i, chs in enumerate(out_chs):
- conv_name = f'conv{i + 1}'
- conv_stride = 2 if (i == 0 and stride > 1) or (i == last_idx and stride > 2 and not pool) else 1
- if conv_stride > 1 and prev_feat is not None:
- feature_info.append(prev_feat)
- stem.add_module(conv_name, ConvNormAct(
- prev_chs, chs, kernel_size,
- stride=conv_stride,
- padding=padding if i == 0 else '',
- act_layer=act_layer,
- norm_layer=norm_layer,
- **dd,
- ))
- stem_stride *= conv_stride
- prev_chs = chs
- prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', conv_name]))
- if pool:
- assert stride > 2
- if prev_feat is not None:
- feature_info.append(prev_feat)
- if aa_layer is not None:
- stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
- stem.add_module('aa', aa_layer(channels=prev_chs, stride=2, **dd))
- pool_name = 'aa'
- else:
- stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
- pool_name = 'pool'
- stem_stride *= 2
- prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', pool_name]))
- feature_info.append(prev_feat)
- return stem, feature_info
- def _get_stage_fn(stage_args):
- stage_type = stage_args.pop('stage_type')
- assert stage_type in ('dark', 'csp', 'cs3')
- if stage_type == 'dark':
- stage_args.pop('expand_ratio', None)
- stage_args.pop('cross_linear', None)
- stage_args.pop('down_growth', None)
- stage_fn = DarkStage
- elif stage_type == 'csp':
- stage_fn = CrossStage
- else:
- stage_fn = CrossStage3
- return stage_fn, stage_args
- def _get_block_fn(stage_args):
- block_type = stage_args.pop('block_type')
- assert block_type in ('dark', 'edge', 'bottle')
- if block_type == 'dark':
- return DarkBlock, stage_args
- elif block_type == 'edge':
- return EdgeBlock, stage_args
- else:
- return BottleneckBlock, stage_args
- def _get_attn_fn(stage_args):
- attn_layer = stage_args.pop('attn_layer')
- attn_kwargs = stage_args.pop('attn_kwargs', None) or {}
- if attn_layer is not None:
- attn_layer = get_attn(attn_layer)
- if attn_kwargs:
- attn_layer = partial(attn_layer, **attn_kwargs)
- return attn_layer, stage_args
- def create_csp_stages(
- cfg: CspModelCfg,
- drop_path_rate: float,
- output_stride: int,
- stem_feat: Dict[str, Any],
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- cfg_dict = asdict(cfg.stages)
- num_stages = len(cfg.stages.depth)
- cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \
- calculate_drop_path_rates(drop_path_rate, cfg.stages.depth, stagewise=True)
- stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())]
- block_kwargs = dict(
- act_layer=cfg.act_layer,
- norm_layer=cfg.norm_layer,
- )
- dilation = 1
- net_stride = stem_feat['reduction']
- prev_chs = stem_feat['num_chs']
- prev_feat = stem_feat
- feature_info = []
- stages = []
- for stage_idx, stage_args in enumerate(stage_args):
- stage_fn, stage_args = _get_stage_fn(stage_args)
- block_fn, stage_args = _get_block_fn(stage_args)
- attn_fn, stage_args = _get_attn_fn(stage_args)
- stride = stage_args.pop('stride')
- if stride != 1 and prev_feat:
- feature_info.append(prev_feat)
- if net_stride >= output_stride and stride > 1:
- dilation *= stride
- stride = 1
- net_stride *= stride
- first_dilation = 1 if dilation in (1, 2) else 2
- stages += [stage_fn(
- prev_chs,
- **stage_args,
- stride=stride,
- first_dilation=first_dilation,
- dilation=dilation,
- block_fn=block_fn,
- aa_layer=cfg.aa_layer,
- attn_layer=attn_fn, # will be passed through stage as block_kwargs
- **block_kwargs,
- **dd,
- )]
- prev_chs = stage_args['out_chs']
- prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
- feature_info.append(prev_feat)
- return nn.Sequential(*stages), feature_info
- class CspNet(nn.Module):
- """Cross Stage Partial base model.
- Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
- Ref Impl: https://github.com/WongKinYiu/CrossStagePartialNetworks
- NOTE: There are differences in the way I handle the 1x1 'expansion' conv in this impl vs the
- darknet impl. I did it this way for simplicity and less special cases.
- """
- def __init__(
- self,
- cfg: CspModelCfg,
- in_chans: int = 3,
- num_classes: int = 1000,
- output_stride: int = 32,
- global_pool: str = 'avg',
- drop_rate: float = 0.,
- drop_path_rate: float = 0.,
- zero_init_last: bool = True,
- device=None,
- dtype=None,
- **kwargs,
- ):
- """
- Args:
- cfg (CspModelCfg): Model architecture configuration
- in_chans (int): Number of input channels (default: 3)
- num_classes (int): Number of classifier classes (default: 1000)
- output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
- global_pool (str): Global pooling type (default: 'avg')
- drop_rate (float): Dropout rate (default: 0.)
- drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
- zero_init_last (bool): Zero-init last weight of residual path
- kwargs (dict): Extra kwargs overlayed onto cfg
- """
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.drop_rate = drop_rate
- assert output_stride in (8, 16, 32)
- cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
- layer_args = dict(
- act_layer=cfg.act_layer,
- norm_layer=cfg.norm_layer,
- aa_layer=cfg.aa_layer
- )
- self.feature_info = []
- # Construct the stem
- self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args, **dd)
- self.feature_info.extend(stem_feat_info[:-1])
- # Construct the stages
- self.stages, stage_feat_info = create_csp_stages(
- cfg,
- drop_path_rate=drop_path_rate,
- output_stride=output_stride,
- stem_feat=stem_feat_info[-1],
- **dd,
- )
- prev_chs = stage_feat_info[-1]['num_chs']
- self.feature_info.extend(stage_feat_info)
- # Construct the head
- self.num_features = self.head_hidden_size = prev_chs
- self.head = ClassifierHead(
- in_features=prev_chs,
- num_classes=num_classes,
- pool_type=global_pool,
- drop_rate=drop_rate,
- **dd,
- )
- named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- matcher = dict(
- stem=r'^stem',
- blocks=r'^stages\.(\d+)' if coarse else [
- (r'^stages\.(\d+)\.blocks\.(\d+)', None),
- (r'^stages\.(\d+)\..*transition', MATCH_PREV_GROUP), # map to last block in stage
- (r'^stages\.(\d+)', (0,)),
- ]
- )
- return matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- assert not enable, 'gradient checkpointing not supported'
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head.fc
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- self.head.reset(num_classes, global_pool)
- def forward_features(self, x):
- x = self.stem(x)
- x = self.stages(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def _init_weights(module, name, zero_init_last=False):
- if isinstance(module, nn.Conv2d):
- nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif isinstance(module, nn.Linear):
- nn.init.normal_(module.weight, mean=0.0, std=0.01)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif zero_init_last and hasattr(module, 'zero_init_last'):
- module.zero_init_last()
- model_cfgs = dict(
- cspresnet50=CspModelCfg(
- stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
- stages=CspStagesCfg(
- depth=(3, 3, 5, 2),
- out_chs=(128, 256, 512, 1024),
- stride=(1, 2),
- expand_ratio=2.,
- bottle_ratio=0.5,
- cross_linear=True,
- ),
- ),
- cspresnet50d=CspModelCfg(
- stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
- stages=CspStagesCfg(
- depth=(3, 3, 5, 2),
- out_chs=(128, 256, 512, 1024),
- stride=(1,) + (2,),
- expand_ratio=2.,
- bottle_ratio=0.5,
- block_ratio=1.,
- cross_linear=True,
- ),
- ),
- cspresnet50w=CspModelCfg(
- stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
- stages=CspStagesCfg(
- depth=(3, 3, 5, 2),
- out_chs=(256, 512, 1024, 2048),
- stride=(1,) + (2,),
- expand_ratio=1.,
- bottle_ratio=0.25,
- block_ratio=0.5,
- cross_linear=True,
- ),
- ),
- cspresnext50=CspModelCfg(
- stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
- stages=CspStagesCfg(
- depth=(3, 3, 5, 2),
- out_chs=(256, 512, 1024, 2048),
- stride=(1,) + (2,),
- groups=32,
- expand_ratio=1.,
- bottle_ratio=1.,
- block_ratio=0.5,
- cross_linear=True,
- ),
- ),
- cspdarknet53=CspModelCfg(
- stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
- stages=CspStagesCfg(
- depth=(1, 2, 8, 8, 4),
- out_chs=(64, 128, 256, 512, 1024),
- stride=2,
- expand_ratio=(2.,) + (1.,),
- bottle_ratio=(0.5,) + (1.,),
- block_ratio=(1.,) + (0.5,),
- down_growth=True,
- block_type='dark',
- ),
- ),
- darknet17=CspModelCfg(
- stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
- stages=CspStagesCfg(
- depth=(1,) * 5,
- out_chs=(64, 128, 256, 512, 1024),
- stride=(2,),
- bottle_ratio=(0.5,),
- block_ratio=(1.,),
- stage_type='dark',
- block_type='dark',
- ),
- ),
- darknet21=CspModelCfg(
- stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
- stages=CspStagesCfg(
- depth=(1, 1, 1, 2, 2),
- out_chs=(64, 128, 256, 512, 1024),
- stride=(2,),
- bottle_ratio=(0.5,),
- block_ratio=(1.,),
- stage_type='dark',
- block_type='dark',
- ),
- ),
- sedarknet21=CspModelCfg(
- stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
- stages=CspStagesCfg(
- depth=(1, 1, 1, 2, 2),
- out_chs=(64, 128, 256, 512, 1024),
- stride=2,
- bottle_ratio=0.5,
- block_ratio=1.,
- attn_layer='se',
- stage_type='dark',
- block_type='dark',
- ),
- ),
- darknet53=CspModelCfg(
- stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
- stages=CspStagesCfg(
- depth=(1, 2, 8, 8, 4),
- out_chs=(64, 128, 256, 512, 1024),
- stride=2,
- bottle_ratio=0.5,
- block_ratio=1.,
- stage_type='dark',
- block_type='dark',
- ),
- ),
- darknetaa53=CspModelCfg(
- stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
- stages=CspStagesCfg(
- depth=(1, 2, 8, 8, 4),
- out_chs=(64, 128, 256, 512, 1024),
- stride=2,
- bottle_ratio=0.5,
- block_ratio=1.,
- avg_down=True,
- stage_type='dark',
- block_type='dark',
- ),
- ),
- cs3darknet_s=_cs3_cfg(width_multiplier=0.5, depth_multiplier=0.5),
- cs3darknet_m=_cs3_cfg(width_multiplier=0.75, depth_multiplier=0.67),
- cs3darknet_l=_cs3_cfg(),
- cs3darknet_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33),
- cs3darknet_focus_s=_cs3_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True),
- cs3darknet_focus_m=_cs3_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True),
- cs3darknet_focus_l=_cs3_cfg(focus=True),
- cs3darknet_focus_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True),
- cs3sedarknet_l=_cs3_cfg(attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
- cs3sedarknet_x=_cs3_cfg(attn_layer='se', width_multiplier=1.25, depth_multiplier=1.33),
- cs3sedarknet_xdw=CspModelCfg(
- stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''),
- stages=CspStagesCfg(
- depth=(3, 6, 12, 4),
- out_chs=(256, 512, 1024, 2048),
- stride=2,
- groups=(1, 1, 256, 512),
- bottle_ratio=0.5,
- block_ratio=0.5,
- attn_layer='se',
- ),
- act_layer='silu',
- ),
- cs3edgenet_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33, bottle_ratio=1.5, block_type='edge'),
- cs3se_edgenet_x=_cs3_cfg(
- width_multiplier=1.25, depth_multiplier=1.33, bottle_ratio=1.5, block_type='edge',
- attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
- )
- def _create_cspnet(variant, pretrained=False, **kwargs):
- if variant.startswith('darknet') or variant.startswith('cspdarknet'):
- # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5]
- default_out_indices = (0, 1, 2, 3, 4, 5)
- else:
- default_out_indices = (0, 1, 2, 3, 4)
- out_indices = kwargs.pop('out_indices', default_out_indices)
- return build_model_with_cfg(
- CspNet, variant, pretrained,
- model_cfg=model_cfgs[variant],
- feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
- **kwargs)
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
- 'crop_pct': 0.887, 'interpolation': 'bilinear',
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc', 'license': 'apache-2.0',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- 'cspresnet50.ra_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnet50_ra-d3e8d487.pth'),
- 'cspresnet50d.untrained': _cfg(),
- 'cspresnet50w.untrained': _cfg(),
- 'cspresnext50.ra_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth',
- ),
- 'cspdarknet53.ra_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'),
- 'darknet17.untrained': _cfg(),
- 'darknet21.untrained': _cfg(),
- 'sedarknet21.untrained': _cfg(),
- 'darknet53.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'darknetaa53.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknetaa53_c2ns-5c28ec8a.pth',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'cs3darknet_s.untrained': _cfg(interpolation='bicubic'),
- 'cs3darknet_m.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_m_c2ns-43f06604.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95,
- ),
- 'cs3darknet_l.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_l_c2ns-16220c5d.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'cs3darknet_x.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_x_c2ns-4e4490aa.pth',
- interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'cs3darknet_focus_s.ra4_e3600_r256_in1k': _cfg(
- hf_hub_id='timm/',
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
- interpolation='bicubic', test_input_size=(3, 320, 320), test_crop_pct=1.0),
- 'cs3darknet_focus_m.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_m_c2ns-e23bed41.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'cs3darknet_focus_l.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_l_c2ns-65ef8888.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'cs3darknet_focus_x.untrained': _cfg(interpolation='bicubic'),
- 'cs3sedarknet_l.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_l_c2ns-e8d1dc13.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'cs3sedarknet_x.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_x_c2ns-b4d0abc0.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'cs3sedarknet_xdw.untrained': _cfg(interpolation='bicubic'),
- 'cs3edgenet_x.c2_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3edgenet_x_c2-2e1610a9.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'cs3se_edgenet_x.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3se_edgenet_x_c2ns-76f8e3ac.pth',
- interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
- })
- @register_model
- def cspresnet50(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cspresnet50', pretrained=pretrained, **kwargs)
- @register_model
- def cspresnet50d(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cspresnet50d', pretrained=pretrained, **kwargs)
- @register_model
- def cspresnet50w(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cspresnet50w', pretrained=pretrained, **kwargs)
- @register_model
- def cspresnext50(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cspresnext50', pretrained=pretrained, **kwargs)
- @register_model
- def cspdarknet53(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cspdarknet53', pretrained=pretrained, **kwargs)
- @register_model
- def darknet17(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('darknet17', pretrained=pretrained, **kwargs)
- @register_model
- def darknet21(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('darknet21', pretrained=pretrained, **kwargs)
- @register_model
- def sedarknet21(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('sedarknet21', pretrained=pretrained, **kwargs)
- @register_model
- def darknet53(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('darknet53', pretrained=pretrained, **kwargs)
- @register_model
- def darknetaa53(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('darknetaa53', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_s(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_s', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_m(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_m', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_l(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_l', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_x(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_x', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_focus_s(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_focus_s', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_focus_m(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_focus_m', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_focus_l(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_focus_l', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_focus_x(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_focus_x', pretrained=pretrained, **kwargs)
- @register_model
- def cs3sedarknet_l(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3sedarknet_l', pretrained=pretrained, **kwargs)
- @register_model
- def cs3sedarknet_x(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3sedarknet_x', pretrained=pretrained, **kwargs)
- @register_model
- def cs3sedarknet_xdw(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs)
- @register_model
- def cs3edgenet_x(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3edgenet_x', pretrained=pretrained, **kwargs)
- @register_model
- def cs3se_edgenet_x(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3se_edgenet_x', pretrained=pretrained, **kwargs)
|