| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464 |
- """
- pnasnet5large implementation grabbed from Cadene's pretrained models
- Additional credit to https://github.com/creafz
- https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py
- """
- from collections import OrderedDict
- from functools import partial
- from typing import Type
- import torch
- import torch.nn as nn
- from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier
- from ._builder import build_model_with_cfg
- from ._registry import register_model, generate_default_cfgs
- __all__ = ['PNASNet5Large']
- class SeparableConv2d(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: int,
- padding: str = '',
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.depthwise_conv2d = create_conv2d(
- in_channels,
- in_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- groups=in_channels,
- **dd,
- )
- self.pointwise_conv2d = create_conv2d(
- in_channels,
- out_channels,
- kernel_size=1,
- padding=padding,
- **dd,
- )
- def forward(self, x):
- x = self.depthwise_conv2d(x)
- x = self.pointwise_conv2d(x)
- return x
- class BranchSeparables(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: int = 1,
- stem_cell: bool = False,
- padding: str = '',
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- middle_channels = out_channels if stem_cell else in_channels
- self.act_1 = nn.ReLU()
- self.separable_1 = SeparableConv2d(
- in_channels,
- middle_channels,
- kernel_size,
- stride=stride,
- padding=padding,
- **dd,
- )
- self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, **dd)
- self.act_2 = nn.ReLU()
- self.separable_2 = SeparableConv2d(
- middle_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=padding,
- **dd,
- )
- self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, **dd)
- def forward(self, x):
- x = self.act_1(x)
- x = self.separable_1(x)
- x = self.bn_sep_1(x)
- x = self.act_2(x)
- x = self.separable_2(x)
- x = self.bn_sep_2(x)
- return x
- class ActConvBn(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: int = 1,
- padding: str = '',
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.act = nn.ReLU()
- self.conv = create_conv2d(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- **dd,
- )
- self.bn = nn.BatchNorm2d(out_channels, eps=0.001, **dd)
- def forward(self, x):
- x = self.act(x)
- x = self.conv(x)
- x = self.bn(x)
- return x
- class FactorizedReduction(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- padding: str = '',
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.act = nn.ReLU()
- self.path_1 = nn.Sequential(OrderedDict([
- ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
- ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding, **dd)),
- ]))
- self.path_2 = nn.Sequential(OrderedDict([
- ('pad', nn.ZeroPad2d((-1, 1, -1, 1))), # shift
- ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
- ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding, **dd)),
- ]))
- self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001, **dd)
- def forward(self, x):
- x = self.act(x)
- x_path1 = self.path_1(x)
- x_path2 = self.path_2(x)
- out = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
- return out
- class CellBase(nn.Module):
- def cell_forward(self, x_left, x_right):
- x_comb_iter_0_left = self.comb_iter_0_left(x_left)
- x_comb_iter_0_right = self.comb_iter_0_right(x_left)
- x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
- x_comb_iter_1_left = self.comb_iter_1_left(x_right)
- x_comb_iter_1_right = self.comb_iter_1_right(x_right)
- x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
- x_comb_iter_2_left = self.comb_iter_2_left(x_right)
- x_comb_iter_2_right = self.comb_iter_2_right(x_right)
- x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
- x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2)
- x_comb_iter_3_right = self.comb_iter_3_right(x_right)
- x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
- x_comb_iter_4_left = self.comb_iter_4_left(x_left)
- if self.comb_iter_4_right is not None:
- x_comb_iter_4_right = self.comb_iter_4_right(x_right)
- else:
- x_comb_iter_4_right = x_right
- x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
- x_out = torch.cat([x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
- return x_out
- class CellStem0(CellBase):
- def __init__(
- self,
- in_chs_left: int,
- out_chs_left: int,
- in_chs_right: int,
- out_chs_right: int,
- pad_type: str = '',
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type, **dd)
- self.comb_iter_0_left = BranchSeparables(
- in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=pad_type, **dd)
- self.comb_iter_0_right = nn.Sequential(OrderedDict([
- ('max_pool', create_pool2d('max', 3, stride=2, padding=pad_type)),
- ('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type, **dd)),
- ('bn', nn.BatchNorm2d(out_chs_left, eps=0.001, **dd)),
- ]))
- self.comb_iter_1_left = BranchSeparables(
- out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=pad_type, **dd)
- self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=pad_type)
- self.comb_iter_2_left = BranchSeparables(
- out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=pad_type, **dd)
- self.comb_iter_2_right = BranchSeparables(
- out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=pad_type, **dd)
- self.comb_iter_3_left = BranchSeparables(
- out_chs_right, out_chs_right, kernel_size=3, padding=pad_type, **dd)
- self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=pad_type)
- self.comb_iter_4_left = BranchSeparables(
- in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=pad_type, **dd)
- self.comb_iter_4_right = ActConvBn(
- out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=pad_type, **dd)
- def forward(self, x_left):
- x_right = self.conv_1x1(x_left)
- x_out = self.cell_forward(x_left, x_right)
- return x_out
- class Cell(CellBase):
- def __init__(
- self,
- in_chs_left: int,
- out_chs_left: int,
- in_chs_right: int,
- out_chs_right: int,
- pad_type: str = '',
- is_reduction: bool = False,
- match_prev_layer_dims: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- # If `is_reduction` is set to `True` stride 2 is used for
- # convolution and pooling layers to reduce the spatial size of
- # the output of a cell approximately by a factor of 2.
- stride = 2 if is_reduction else 1
- # If `match_prev_layer_dimensions` is set to `True`
- # `FactorizedReduction` is used to reduce the spatial size
- # of the left input of a cell approximately by a factor of 2.
- self.match_prev_layer_dimensions = match_prev_layer_dims
- if match_prev_layer_dims:
- self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=pad_type, **dd)
- else:
- self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type, **dd)
- self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type, **dd)
- self.comb_iter_0_left = BranchSeparables(
- out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=pad_type, **dd)
- self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
- self.comb_iter_1_left = BranchSeparables(
- out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=pad_type, **dd)
- self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
- self.comb_iter_2_left = BranchSeparables(
- out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=pad_type, **dd)
- self.comb_iter_2_right = BranchSeparables(
- out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=pad_type, **dd)
- self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3, **dd)
- self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
- self.comb_iter_4_left = BranchSeparables(
- out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=pad_type, **dd)
- if is_reduction:
- self.comb_iter_4_right = ActConvBn(
- out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=pad_type, **dd)
- else:
- self.comb_iter_4_right = None
- def forward(self, x_left, x_right):
- x_left = self.conv_prev_1x1(x_left)
- x_right = self.conv_1x1(x_right)
- x_out = self.cell_forward(x_left, x_right)
- return x_out
- class PNASNet5Large(nn.Module):
- def __init__(
- self,
- num_classes: int = 1000,
- in_chans: int = 3,
- output_stride: int = 32,
- drop_rate: float = 0.,
- global_pool: str = 'avg',
- pad_type: str = '',
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.num_features = self.head_hidden_size = 4320
- assert output_stride == 32
- self.conv_0 = ConvNormAct(
- in_chans, 96, kernel_size=3, stride=2, padding=0,
- norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False, **dd)
- self.cell_stem_0 = CellStem0(
- in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type, **dd)
- self.cell_stem_1 = Cell(
- in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, pad_type=pad_type,
- match_prev_layer_dims=True, is_reduction=True, **dd)
- self.cell_0 = Cell(
- in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, pad_type=pad_type,
- match_prev_layer_dims=True, **dd)
- self.cell_1 = Cell(
- in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type, **dd)
- self.cell_2 = Cell(
- in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type, **dd)
- self.cell_3 = Cell(
- in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type, **dd)
- self.cell_4 = Cell(
- in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, pad_type=pad_type,
- is_reduction=True, **dd)
- self.cell_5 = Cell(
- in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type,
- match_prev_layer_dims=True, **dd)
- self.cell_6 = Cell(
- in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type, **dd)
- self.cell_7 = Cell(
- in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type, **dd)
- self.cell_8 = Cell(
- in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, pad_type=pad_type,
- is_reduction=True, **dd)
- self.cell_9 = Cell(
- in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type,
- match_prev_layer_dims=True, **dd)
- self.cell_10 = Cell(
- in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type, **dd)
- self.cell_11 = Cell(
- in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type, **dd)
- self.act = nn.ReLU()
- self.feature_info = [
- dict(num_chs=96, reduction=2, module='conv_0'),
- dict(num_chs=270, reduction=4, module='cell_stem_1.conv_1x1.act'),
- dict(num_chs=1080, reduction=8, module='cell_4.conv_1x1.act'),
- dict(num_chs=2160, reduction=16, module='cell_8.conv_1x1.act'),
- dict(num_chs=4320, reduction=32, module='act'),
- ]
- self.global_pool, self.head_drop, self.last_linear = create_classifier(
- self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate, **dd)
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- return dict(stem=r'^conv_0|cell_stem_[01]', blocks=r'^cell_(\d+)')
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- assert not enable, 'gradient checkpointing not supported'
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.last_linear
- def reset_classifier(self, num_classes: int, global_pool: str = 'avg', device=None, dtype=None):
- dd = {'device': device, 'dtype': dtype}
- self.num_classes = num_classes
- self.global_pool, self.last_linear = create_classifier(
- self.num_features, self.num_classes, pool_type=global_pool, **dd)
- def forward_features(self, x):
- x_conv_0 = self.conv_0(x)
- x_stem_0 = self.cell_stem_0(x_conv_0)
- x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0)
- x_cell_0 = self.cell_0(x_stem_0, x_stem_1)
- x_cell_1 = self.cell_1(x_stem_1, x_cell_0)
- x_cell_2 = self.cell_2(x_cell_0, x_cell_1)
- x_cell_3 = self.cell_3(x_cell_1, x_cell_2)
- x_cell_4 = self.cell_4(x_cell_2, x_cell_3)
- x_cell_5 = self.cell_5(x_cell_3, x_cell_4)
- x_cell_6 = self.cell_6(x_cell_4, x_cell_5)
- x_cell_7 = self.cell_7(x_cell_5, x_cell_6)
- x_cell_8 = self.cell_8(x_cell_6, x_cell_7)
- x_cell_9 = self.cell_9(x_cell_7, x_cell_8)
- x_cell_10 = self.cell_10(x_cell_8, x_cell_9)
- x_cell_11 = self.cell_11(x_cell_9, x_cell_10)
- x = self.act(x_cell_11)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- x = self.global_pool(x)
- x = self.head_drop(x)
- return x if pre_logits else self.last_linear(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def _create_pnasnet(variant, pretrained=False, **kwargs):
- return build_model_with_cfg(
- PNASNet5Large,
- variant,
- pretrained,
- feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
- **kwargs,
- )
- default_cfgs = generate_default_cfgs({
- 'pnasnet5large.tf_in1k': {
- 'hf_hub_id': 'timm/',
- 'input_size': (3, 331, 331),
- 'pool_size': (11, 11),
- 'crop_pct': 0.911,
- 'interpolation': 'bicubic',
- 'mean': (0.5, 0.5, 0.5),
- 'std': (0.5, 0.5, 0.5),
- 'num_classes': 1000,
- 'first_conv': 'conv_0.conv',
- 'classifier': 'last_linear',
- 'license': 'apache-2.0',
- },
- })
- @register_model
- def pnasnet5large(pretrained=False, **kwargs) -> PNASNet5Large:
- r"""PNASNet-5 model architecture from the
- `"Progressive Neural Architecture Search"
- <https://arxiv.org/abs/1712.00559>`_ paper.
- """
- model_kwargs = dict(pad_type='same', **kwargs)
- return _create_pnasnet('pnasnet5large', pretrained, **model_kwargs)
|