| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559 |
- """ VoVNet (V1 & V2)
- Papers:
- * `An Energy and GPU-Computation Efficient Backbone Network` - https://arxiv.org/abs/1904.09730
- * `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
- Looked at https://github.com/youngwanLEE/vovnet-detectron2 &
- https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
- for some reference, rewrote most of the code.
- Hacked together by / Copyright 2020 Ross Wightman
- """
- from typing import List, Optional, Tuple, Union, Type
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \
- create_attn, create_norm_act_layer, calculate_drop_path_rates
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._manipulate import checkpoint_seq
- from ._registry import register_model, generate_default_cfgs
- __all__ = ['VovNet'] # model_registry will add each entrypoint fn to this
- class SequentialAppendList(nn.Sequential):
- def __init__(self, *args, **kwargs):
- super().__init__(*args)
- def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor:
- for i, module in enumerate(self):
- if i == 0:
- concat_list.append(module(x))
- else:
- concat_list.append(module(concat_list[-1]))
- x = torch.cat(concat_list, dim=1)
- return x
- class OsaBlock(nn.Module):
- def __init__(
- self,
- in_chs: int,
- mid_chs: int,
- out_chs: int,
- layer_per_block: int,
- residual: bool = False,
- depthwise: bool = False,
- attn: str = '',
- norm_layer: Type[nn.Module] = BatchNormAct2d,
- act_layer: Type[nn.Module] = nn.ReLU,
- drop_path: Optional[nn.Module] = None,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.residual = residual
- self.depthwise = depthwise
- conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer, **dd)
- next_in_chs = in_chs
- if self.depthwise and next_in_chs != mid_chs:
- assert not residual
- self.conv_reduction = ConvNormAct(next_in_chs, mid_chs, 1, **conv_kwargs)
- else:
- self.conv_reduction = None
- mid_convs = []
- for i in range(layer_per_block):
- if self.depthwise:
- conv = SeparableConvNormAct(mid_chs, mid_chs, **conv_kwargs)
- else:
- conv = ConvNormAct(next_in_chs, mid_chs, 3, **conv_kwargs)
- next_in_chs = mid_chs
- mid_convs.append(conv)
- self.conv_mid = SequentialAppendList(*mid_convs)
- # feature aggregation
- next_in_chs = in_chs + layer_per_block * mid_chs
- self.conv_concat = ConvNormAct(next_in_chs, out_chs, **conv_kwargs)
- self.attn = create_attn(attn, out_chs, **dd) if attn else None
- self.drop_path = drop_path
- def forward(self, x):
- output = [x]
- if self.conv_reduction is not None:
- x = self.conv_reduction(x)
- x = self.conv_mid(x, output)
- x = self.conv_concat(x)
- if self.attn is not None:
- x = self.attn(x)
- if self.drop_path is not None:
- x = self.drop_path(x)
- if self.residual:
- x = x + output[0]
- return x
- class OsaStage(nn.Module):
- def __init__(
- self,
- in_chs: int,
- mid_chs: int,
- out_chs: int,
- block_per_stage: int,
- layer_per_block: int,
- downsample: bool = True,
- residual: bool = True,
- depthwise: bool = False,
- attn: str = 'ese',
- norm_layer: Type[nn.Module] = BatchNormAct2d,
- act_layer: Type[nn.Module] = nn.ReLU,
- drop_path_rates: Optional[List[float]] = None,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.grad_checkpointing = False
- if downsample:
- self.pool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
- else:
- self.pool = None
- blocks = []
- for i in range(block_per_stage):
- last_block = i == block_per_stage - 1
- if drop_path_rates is not None and drop_path_rates[i] > 0.:
- drop_path = DropPath(drop_path_rates[i])
- else:
- drop_path = None
- blocks += [OsaBlock(
- in_chs,
- mid_chs,
- out_chs,
- layer_per_block,
- residual=residual and i > 0,
- depthwise=depthwise,
- attn=attn if last_block else '',
- norm_layer=norm_layer,
- act_layer=act_layer,
- drop_path=drop_path,
- **dd,
- )]
- in_chs = out_chs
- self.blocks = nn.Sequential(*blocks)
- def forward(self, x):
- if self.pool is not None:
- x = self.pool(x)
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.blocks, x)
- else:
- x = self.blocks(x)
- return x
- class VovNet(nn.Module):
- def __init__(
- self,
- cfg: dict,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- output_stride: int = 32,
- norm_layer: Type[nn.Module] = BatchNormAct2d,
- act_layer: Type[nn.Module] = nn.ReLU,
- drop_rate: float = 0.,
- drop_path_rate: float = 0.,
- device=None,
- dtype=None,
- **kwargs,
- ):
- """
- Args:
- cfg (dict): Model architecture configuration
- in_chans (int): Number of input channels (default: 3)
- num_classes (int): Number of classifier classes (default: 1000)
- global_pool (str): Global pooling type (default: 'avg')
- output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
- norm_layer (Union[str, nn.Module]): normalization layer
- act_layer (Union[str, nn.Module]): activation layer
- drop_rate (float): Dropout rate (default: 0.)
- drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
- kwargs (dict): Extra kwargs overlayed onto cfg
- """
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.drop_rate = drop_rate
- assert output_stride == 32 # FIXME support dilation
- cfg = dict(cfg, **kwargs)
- stem_stride = cfg.get("stem_stride", 4)
- stem_chs = cfg["stem_chs"]
- stage_conv_chs = cfg["stage_conv_chs"]
- stage_out_chs = cfg["stage_out_chs"]
- block_per_stage = cfg["block_per_stage"]
- layer_per_block = cfg["layer_per_block"]
- conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer, **dd)
- # Stem module
- last_stem_stride = stem_stride // 2
- conv_type = SeparableConvNormAct if cfg["depthwise"] else ConvNormAct
- self.stem = nn.Sequential(*[
- ConvNormAct(in_chans, stem_chs[0], 3, stride=2, **conv_kwargs),
- conv_type(stem_chs[0], stem_chs[1], 3, stride=1, **conv_kwargs),
- conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, **conv_kwargs),
- ])
- self.feature_info = [dict(
- num_chs=stem_chs[1], reduction=2, module=f'stem.{1 if stem_stride == 4 else 2}')]
- current_stride = stem_stride
- # OSA stages
- stage_dpr = calculate_drop_path_rates(drop_path_rate, block_per_stage, stagewise=True)
- in_ch_list = stem_chs[-1:] + stage_out_chs[:-1]
- stage_args = dict(residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], **conv_kwargs)
- stages = []
- for i in range(4): # num_stages
- downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
- stages += [OsaStage(
- in_ch_list[i],
- stage_conv_chs[i],
- stage_out_chs[i],
- block_per_stage[i],
- layer_per_block,
- downsample=downsample,
- drop_path_rates=stage_dpr[i],
- **stage_args,
- )]
- self.num_features = stage_out_chs[i]
- current_stride *= 2 if downsample else 1
- self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
- self.stages = nn.Sequential(*stages)
- self.head_hidden_size = self.num_features
- self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate, **dd)
- for n, m in self.named_modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- elif isinstance(m, nn.Linear):
- nn.init.zeros_(m.bias)
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- return dict(
- stem=r'^stem',
- blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).blocks\.(\d+)',
- )
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- for s in self.stages:
- s.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head.fc
- def reset_classifier(self, num_classes, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- self.head.reset(num_classes, global_pool)
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- norm: Apply norm layer to compatible intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- Returns:
- """
- assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
- intermediates = []
- take_indices, max_index = feature_take_indices(5, indices)
- # forward pass
- feat_idx = 0
- x = self.stem[:-1](x)
- if feat_idx in take_indices:
- intermediates.append(x)
- x = self.stem[-1](x)
- if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
- stages = self.stages
- else:
- stages = self.stages[:max_index]
- for feat_idx, stage in enumerate(stages, start=1):
- x = stage(x)
- if feat_idx in take_indices:
- intermediates.append(x)
- if intermediates_only:
- return intermediates
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ):
- """ Prune layers not required for specified intermediates.
- """
- take_indices, max_index = feature_take_indices(5, indices)
- self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
- if prune_head:
- self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x):
- x = self.stem(x)
- return self.stages(x)
- def forward_head(self, x, pre_logits: bool = False):
- return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- # model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 &
- # https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
- model_cfgs = dict(
- vovnet39a=dict(
- stem_chs=[64, 64, 128],
- stage_conv_chs=[128, 160, 192, 224],
- stage_out_chs=[256, 512, 768, 1024],
- layer_per_block=5,
- block_per_stage=[1, 1, 2, 2],
- residual=False,
- depthwise=False,
- attn='',
- ),
- vovnet57a=dict(
- stem_chs=[64, 64, 128],
- stage_conv_chs=[128, 160, 192, 224],
- stage_out_chs=[256, 512, 768, 1024],
- layer_per_block=5,
- block_per_stage=[1, 1, 4, 3],
- residual=False,
- depthwise=False,
- attn='',
- ),
- ese_vovnet19b_slim_dw=dict(
- stem_chs=[64, 64, 64],
- stage_conv_chs=[64, 80, 96, 112],
- stage_out_chs=[112, 256, 384, 512],
- layer_per_block=3,
- block_per_stage=[1, 1, 1, 1],
- residual=True,
- depthwise=True,
- attn='ese',
- ),
- ese_vovnet19b_dw=dict(
- stem_chs=[64, 64, 64],
- stage_conv_chs=[128, 160, 192, 224],
- stage_out_chs=[256, 512, 768, 1024],
- layer_per_block=3,
- block_per_stage=[1, 1, 1, 1],
- residual=True,
- depthwise=True,
- attn='ese',
- ),
- ese_vovnet19b_slim=dict(
- stem_chs=[64, 64, 128],
- stage_conv_chs=[64, 80, 96, 112],
- stage_out_chs=[112, 256, 384, 512],
- layer_per_block=3,
- block_per_stage=[1, 1, 1, 1],
- residual=True,
- depthwise=False,
- attn='ese',
- ),
- ese_vovnet19b=dict(
- stem_chs=[64, 64, 128],
- stage_conv_chs=[128, 160, 192, 224],
- stage_out_chs=[256, 512, 768, 1024],
- layer_per_block=3,
- block_per_stage=[1, 1, 1, 1],
- residual=True,
- depthwise=False,
- attn='ese',
- ),
- ese_vovnet39b=dict(
- stem_chs=[64, 64, 128],
- stage_conv_chs=[128, 160, 192, 224],
- stage_out_chs=[256, 512, 768, 1024],
- layer_per_block=5,
- block_per_stage=[1, 1, 2, 2],
- residual=True,
- depthwise=False,
- attn='ese',
- ),
- ese_vovnet57b=dict(
- stem_chs=[64, 64, 128],
- stage_conv_chs=[128, 160, 192, 224],
- stage_out_chs=[256, 512, 768, 1024],
- layer_per_block=5,
- block_per_stage=[1, 1, 4, 3],
- residual=True,
- depthwise=False,
- attn='ese',
- ),
- ese_vovnet99b=dict(
- stem_chs=[64, 64, 128],
- stage_conv_chs=[128, 160, 192, 224],
- stage_out_chs=[256, 512, 768, 1024],
- layer_per_block=5,
- block_per_stage=[1, 3, 9, 3],
- residual=True,
- depthwise=False,
- attn='ese',
- ),
- eca_vovnet39b=dict(
- stem_chs=[64, 64, 128],
- stage_conv_chs=[128, 160, 192, 224],
- stage_out_chs=[256, 512, 768, 1024],
- layer_per_block=5,
- block_per_stage=[1, 1, 2, 2],
- residual=True,
- depthwise=False,
- attn='eca',
- ),
- )
- model_cfgs['ese_vovnet39b_evos'] = model_cfgs['ese_vovnet39b']
- def _create_vovnet(variant, pretrained=False, **kwargs):
- return build_model_with_cfg(
- VovNet,
- variant,
- pretrained,
- model_cfg=model_cfgs[variant],
- feature_cfg=dict(flatten_sequential=True),
- **kwargs,
- )
- def _cfg(url='', **kwargs):
- return {
- 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
- 'crop_pct': 0.875, 'interpolation': 'bicubic',
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
- 'license': 'apache-2.0', **kwargs,
- }
- default_cfgs = generate_default_cfgs({
- 'vovnet39a.untrained': _cfg(url=''),
- 'vovnet57a.untrained': _cfg(url=''),
- 'ese_vovnet19b_slim_dw.untrained': _cfg(url=''),
- 'ese_vovnet19b_dw.ra_in1k': _cfg(
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'ese_vovnet19b_slim.untrained': _cfg(url=''),
- 'ese_vovnet39b.ra_in1k': _cfg(
- hf_hub_id='timm/',
- test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'ese_vovnet57b.ra4_e3600_r256_in1k': _cfg(
- hf_hub_id='timm/',
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
- crop_pct=0.95, input_size=(3, 256, 256), pool_size=(8, 8),
- test_input_size=(3, 320, 320), test_crop_pct=1.0
- ),
- 'ese_vovnet99b.untrained': _cfg(url=''),
- 'eca_vovnet39b.untrained': _cfg(url=''),
- 'ese_vovnet39b_evos.untrained': _cfg(url=''),
- })
- @register_model
- def vovnet39a(pretrained=False, **kwargs) -> VovNet:
- return _create_vovnet('vovnet39a', pretrained=pretrained, **kwargs)
- @register_model
- def vovnet57a(pretrained=False, **kwargs) -> VovNet:
- return _create_vovnet('vovnet57a', pretrained=pretrained, **kwargs)
- @register_model
- def ese_vovnet19b_slim_dw(pretrained=False, **kwargs) -> VovNet:
- return _create_vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs)
- @register_model
- def ese_vovnet19b_dw(pretrained=False, **kwargs) -> VovNet:
- return _create_vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs)
- @register_model
- def ese_vovnet19b_slim(pretrained=False, **kwargs) -> VovNet:
- return _create_vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs)
- @register_model
- def ese_vovnet39b(pretrained=False, **kwargs) -> VovNet:
- return _create_vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs)
- @register_model
- def ese_vovnet57b(pretrained=False, **kwargs) -> VovNet:
- return _create_vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs)
- @register_model
- def ese_vovnet99b(pretrained=False, **kwargs) -> VovNet:
- return _create_vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs)
- @register_model
- def eca_vovnet39b(pretrained=False, **kwargs) -> VovNet:
- return _create_vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs)
- # Experimental Models
- @register_model
- def ese_vovnet39b_evos(pretrained=False, **kwargs) -> VovNet:
- def norm_act_fn(num_features, **nkwargs):
- return create_norm_act_layer('evonorms0', num_features, jit=False, **nkwargs)
- return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)
|