| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508 |
- """ Inception-V3
- Originally from torchvision Inception3 model
- Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE
- """
- from functools import partial
- from typing import Optional, Type
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
- from timm.layers import trunc_normal_, create_classifier, Linear, ConvNormAct
- from ._builder import build_model_with_cfg
- from ._builder import resolve_pretrained_cfg
- from ._manipulate import flatten_modules
- from ._registry import register_model, generate_default_cfgs, register_model_deprecations
- __all__ = ['InceptionV3'] # model_registry will add each entrypoint fn to this
- class InceptionA(nn.Module):
- def __init__(
- self,
- in_channels: int,
- pool_features: int,
- conv_block: Optional[Type[nn.Module]] = None,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- conv_block = conv_block or ConvNormAct
- self.branch1x1 = conv_block(in_channels, 64, kernel_size=1, **dd)
- self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1, **dd)
- self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2, **dd)
- self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1, **dd)
- self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1, **dd)
- self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1, **dd)
- self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1, **dd)
- def _forward(self, x):
- branch1x1 = self.branch1x1(x)
- branch5x5 = self.branch5x5_1(x)
- branch5x5 = self.branch5x5_2(branch5x5)
- branch3x3dbl = self.branch3x3dbl_1(x)
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
- branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
- branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
- branch_pool = self.branch_pool(branch_pool)
- outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
- return outputs
- def forward(self, x):
- outputs = self._forward(x)
- return torch.cat(outputs, 1)
- class InceptionB(nn.Module):
- def __init__(
- self,
- in_channels: int,
- conv_block: Optional[Type[nn.Module]] = None,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- conv_block = conv_block or ConvNormAct
- self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2, **dd)
- self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1, **dd)
- self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1, **dd)
- self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2, **dd)
- def _forward(self, x):
- branch3x3 = self.branch3x3(x)
- branch3x3dbl = self.branch3x3dbl_1(x)
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
- branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
- branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
- outputs = [branch3x3, branch3x3dbl, branch_pool]
- return outputs
- def forward(self, x):
- outputs = self._forward(x)
- return torch.cat(outputs, 1)
- class InceptionC(nn.Module):
- def __init__(
- self,
- in_channels: int,
- channels_7x7: int,
- conv_block: Optional[Type[nn.Module]] = None,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- conv_block = conv_block or ConvNormAct
- self.branch1x1 = conv_block(in_channels, 192, kernel_size=1, **dd)
- c7 = channels_7x7
- self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1, **dd)
- self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3), **dd)
- self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0), **dd)
- self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1, **dd)
- self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0), **dd)
- self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3), **dd)
- self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0), **dd)
- self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3), **dd)
- self.branch_pool = conv_block(in_channels, 192, kernel_size=1, **dd)
- def _forward(self, x):
- branch1x1 = self.branch1x1(x)
- branch7x7 = self.branch7x7_1(x)
- branch7x7 = self.branch7x7_2(branch7x7)
- branch7x7 = self.branch7x7_3(branch7x7)
- branch7x7dbl = self.branch7x7dbl_1(x)
- branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
- branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
- branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
- branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
- branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
- branch_pool = self.branch_pool(branch_pool)
- outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
- return outputs
- def forward(self, x):
- outputs = self._forward(x)
- return torch.cat(outputs, 1)
- class InceptionD(nn.Module):
- def __init__(
- self,
- in_channels: int,
- conv_block: Optional[Type[nn.Module]] = None,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- conv_block = conv_block or ConvNormAct
- self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1, **dd)
- self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2, **dd)
- self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1, **dd)
- self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3), **dd)
- self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0), **dd)
- self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2, **dd)
- def _forward(self, x):
- branch3x3 = self.branch3x3_1(x)
- branch3x3 = self.branch3x3_2(branch3x3)
- branch7x7x3 = self.branch7x7x3_1(x)
- branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
- branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
- branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
- branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
- outputs = [branch3x3, branch7x7x3, branch_pool]
- return outputs
- def forward(self, x):
- outputs = self._forward(x)
- return torch.cat(outputs, 1)
- class InceptionE(nn.Module):
- def __init__(
- self,
- in_channels: int,
- conv_block: Optional[Type[nn.Module]] = None,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- conv_block = conv_block or ConvNormAct
- self.branch1x1 = conv_block(in_channels, 320, kernel_size=1, **dd)
- self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1, **dd)
- self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1), **dd)
- self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0), **dd)
- self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1, **dd)
- self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1, **dd)
- self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1), **dd)
- self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0), **dd)
- self.branch_pool = conv_block(in_channels, 192, kernel_size=1, **dd)
- def _forward(self, x):
- branch1x1 = self.branch1x1(x)
- branch3x3 = self.branch3x3_1(x)
- branch3x3 = [
- self.branch3x3_2a(branch3x3),
- self.branch3x3_2b(branch3x3),
- ]
- branch3x3 = torch.cat(branch3x3, 1)
- branch3x3dbl = self.branch3x3dbl_1(x)
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
- branch3x3dbl = [
- self.branch3x3dbl_3a(branch3x3dbl),
- self.branch3x3dbl_3b(branch3x3dbl),
- ]
- branch3x3dbl = torch.cat(branch3x3dbl, 1)
- branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
- branch_pool = self.branch_pool(branch_pool)
- outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
- return outputs
- def forward(self, x):
- outputs = self._forward(x)
- return torch.cat(outputs, 1)
- class InceptionAux(nn.Module):
- def __init__(
- self,
- in_channels: int,
- num_classes: int,
- conv_block: Optional[Type[nn.Module]] = None,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- conv_block = conv_block or ConvNormAct
- self.conv0 = conv_block(in_channels, 128, kernel_size=1, **dd)
- self.conv1 = conv_block(128, 768, kernel_size=5, **dd)
- self.conv1.stddev = 0.01
- self.fc = Linear(768, num_classes, **dd)
- self.fc.stddev = 0.001
- def forward(self, x):
- # N x 768 x 17 x 17
- x = F.avg_pool2d(x, kernel_size=5, stride=3)
- # N x 768 x 5 x 5
- x = self.conv0(x)
- # N x 128 x 5 x 5
- x = self.conv1(x)
- # N x 768 x 1 x 1
- # Adaptive average pooling
- x = F.adaptive_avg_pool2d(x, (1, 1))
- # N x 768 x 1 x 1
- x = torch.flatten(x, 1)
- # N x 768
- x = self.fc(x)
- # N x 1000
- return x
- class InceptionV3(nn.Module):
- """Inception-V3
- """
- aux_logits: torch.jit.Final[bool]
- def __init__(
- self,
- num_classes: int = 1000,
- in_chans: int = 3,
- drop_rate: float = 0.,
- global_pool: str = 'avg',
- aux_logits: bool = False,
- norm_layer: str = 'batchnorm2d',
- norm_eps: float = 1e-3,
- act_layer: str = 'relu',
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.aux_logits = aux_logits
- conv_block = partial(
- ConvNormAct,
- padding=0,
- norm_layer=norm_layer,
- act_layer=act_layer,
- norm_kwargs=dict(eps=norm_eps),
- act_kwargs=dict(inplace=True),
- )
- self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2, **dd)
- self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3, **dd)
- self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1, **dd)
- self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
- self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1, **dd)
- self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3, **dd)
- self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
- self.Mixed_5b = InceptionA(192, pool_features=32, conv_block=conv_block, **dd)
- self.Mixed_5c = InceptionA(256, pool_features=64, conv_block=conv_block, **dd)
- self.Mixed_5d = InceptionA(288, pool_features=64, conv_block=conv_block, **dd)
- self.Mixed_6a = InceptionB(288, conv_block=conv_block, **dd)
- self.Mixed_6b = InceptionC(768, channels_7x7=128, conv_block=conv_block, **dd)
- self.Mixed_6c = InceptionC(768, channels_7x7=160, conv_block=conv_block, **dd)
- self.Mixed_6d = InceptionC(768, channels_7x7=160, conv_block=conv_block, **dd)
- self.Mixed_6e = InceptionC(768, channels_7x7=192, conv_block=conv_block, **dd)
- if aux_logits:
- self.AuxLogits = InceptionAux(768, num_classes, conv_block=conv_block, **dd)
- else:
- self.AuxLogits = None
- self.Mixed_7a = InceptionD(768, conv_block=conv_block, **dd)
- self.Mixed_7b = InceptionE(1280, conv_block=conv_block, **dd)
- self.Mixed_7c = InceptionE(2048, conv_block=conv_block, **dd)
- self.feature_info = [
- dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'),
- dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'),
- dict(num_chs=288, reduction=8, module='Mixed_5d'),
- dict(num_chs=768, reduction=16, module='Mixed_6e'),
- dict(num_chs=2048, reduction=32, module='Mixed_7c'),
- ]
- self.num_features = self.head_hidden_size = 2048
- self.global_pool, self.head_drop, self.fc = create_classifier(
- self.num_features,
- self.num_classes,
- pool_type=global_pool,
- drop_rate=drop_rate,
- **dd,
- )
- for m in self.modules():
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
- stddev = m.stddev if hasattr(m, 'stddev') else 0.1
- trunc_normal_(m.weight, std=stddev)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- module_map = {k: i for i, (k, _) in enumerate(flatten_modules(self.named_children(), prefix=()))}
- module_map.pop(('fc',))
- def _matcher(name):
- if any([name.startswith(n) for n in ('Conv2d_1', 'Conv2d_2')]):
- return 0
- elif any([name.startswith(n) for n in ('Conv2d_3', 'Conv2d_4')]):
- return 1
- else:
- for k in module_map.keys():
- if k == tuple(name.split('.')[:len(k)]):
- return module_map[k]
- return float('inf')
- return _matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- assert not enable, 'gradient checkpointing not supported'
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.fc
- def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
- self.num_classes = num_classes
- self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
- def forward_preaux(self, x):
- x = self.Conv2d_1a_3x3(x) # N x 32 x 149 x 149
- x = self.Conv2d_2a_3x3(x) # N x 32 x 147 x 147
- x = self.Conv2d_2b_3x3(x) # N x 64 x 147 x 147
- x = self.Pool1(x) # N x 64 x 73 x 73
- x = self.Conv2d_3b_1x1(x) # N x 80 x 73 x 73
- x = self.Conv2d_4a_3x3(x) # N x 192 x 71 x 71
- x = self.Pool2(x) # N x 192 x 35 x 35
- x = self.Mixed_5b(x) # N x 256 x 35 x 35
- x = self.Mixed_5c(x) # N x 288 x 35 x 35
- x = self.Mixed_5d(x) # N x 288 x 35 x 35
- x = self.Mixed_6a(x) # N x 768 x 17 x 17
- x = self.Mixed_6b(x) # N x 768 x 17 x 17
- x = self.Mixed_6c(x) # N x 768 x 17 x 17
- x = self.Mixed_6d(x) # N x 768 x 17 x 17
- x = self.Mixed_6e(x) # N x 768 x 17 x 17
- return x
- def forward_postaux(self, x):
- x = self.Mixed_7a(x) # N x 1280 x 8 x 8
- x = self.Mixed_7b(x) # N x 2048 x 8 x 8
- x = self.Mixed_7c(x) # N x 2048 x 8 x 8
- return x
- def forward_features(self, x):
- x = self.forward_preaux(x)
- if self.aux_logits:
- aux = self.AuxLogits(x)
- x = self.forward_postaux(x)
- return x, aux
- x = self.forward_postaux(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- x = self.global_pool(x)
- x = self.head_drop(x)
- if pre_logits:
- return x
- x = self.fc(x)
- return x
- def forward(self, x):
- if self.aux_logits:
- x, aux = self.forward_features(x)
- x = self.forward_head(x)
- return x, aux
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def _create_inception_v3(variant, pretrained=False, **kwargs):
- pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
- aux_logits = kwargs.get('aux_logits', False)
- has_aux_logits = False
- if pretrained_cfg:
- # only torchvision pretrained weights have aux logits
- has_aux_logits = pretrained_cfg.tag == 'tv_in1k'
- if aux_logits:
- assert not kwargs.pop('features_only', False)
- load_strict = has_aux_logits
- else:
- load_strict = not has_aux_logits
- return build_model_with_cfg(
- InceptionV3,
- variant,
- pretrained,
- pretrained_cfg=pretrained_cfg,
- pretrained_strict=load_strict,
- **kwargs,
- )
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
- 'crop_pct': 0.875, 'interpolation': 'bicubic',
- 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
- 'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc', 'license': 'apache-2.0',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- # original PyTorch weights, ported from Tensorflow but modified
- 'inception_v3.tv_in1k': _cfg(
- # NOTE checkpoint has aux logit layer weights
- hf_hub_id='timm/',
- url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'),
- # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
- 'inception_v3.tf_in1k': _cfg(hf_hub_id='timm/'),
- # my port of Tensorflow adversarially trained Inception V3 from
- # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
- 'inception_v3.tf_adv_in1k': _cfg(hf_hub_id='timm/'),
- # from gluon pretrained models, best performing in terms of accuracy/loss metrics
- # https://gluon-cv.mxnet.io/model_zoo/classification.html
- 'inception_v3.gluon_in1k': _cfg(
- hf_hub_id='timm/',
- mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
- std=IMAGENET_DEFAULT_STD, # also works well with inception defaults
- )
- })
- @register_model
- def inception_v3(pretrained=False, **kwargs) -> InceptionV3:
- model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs)
- return model
- register_model_deprecations(__name__, {
- 'tf_inception_v3': 'inception_v3.tf_in1k',
- 'adv_inception_v3': 'inception_v3.tf_adv_in1k',
- 'gluon_inception_v3': 'inception_v3.gluon_in1k',
- })
|