| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448 |
- """
- TResNet: High Performance GPU-Dedicated Architecture
- https://arxiv.org/pdf/2003.13630.pdf
- Original model: https://github.com/mrT23/TResNet
- """
- from collections import OrderedDict
- from functools import partial
- from typing import List, Optional, Tuple, Union, Type
- import torch
- import torch.nn as nn
- from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath, calculate_drop_path_rates
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._manipulate import checkpoint, checkpoint_seq
- from ._registry import register_model, generate_default_cfgs, register_model_deprecations
- __all__ = ['TResNet'] # model_registry will add each entrypoint fn to this
- class BasicBlock(nn.Module):
- expansion = 1
- def __init__(
- self,
- inplanes: int,
- planes: int,
- stride: int = 1,
- downsample: Optional[nn.Module] = None,
- use_se: bool = True,
- aa_layer: Optional[Type[nn.Module]] = None,
- drop_path_rate: float = 0.,
- device=None,
- dtype=None,
- ) -> None:
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.downsample = downsample
- self.stride = stride
- act_layer = partial(nn.LeakyReLU, negative_slope=1e-3)
- self.conv1 = ConvNormAct(
- inplanes,
- planes,
- kernel_size=3,
- stride=stride,
- act_layer=act_layer,
- aa_layer=aa_layer,
- **dd,
- )
- self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False, **dd)
- self.act = nn.ReLU(inplace=True)
- rd_chs = max(planes * self.expansion // 4, 64)
- self.se = SEModule(planes * self.expansion, rd_channels=rd_chs, **dd) if use_se else None
- self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
- def forward(self, x):
- if self.downsample is not None:
- shortcut = self.downsample(x)
- else:
- shortcut = x
- out = self.conv1(x)
- out = self.conv2(out)
- if self.se is not None:
- out = self.se(out)
- out = self.drop_path(out) + shortcut
- out = self.act(out)
- return out
- class Bottleneck(nn.Module):
- expansion = 4
- def __init__(
- self,
- inplanes: int,
- planes: int,
- stride: int = 1,
- downsample: Optional[nn.Module] = None,
- use_se: bool = True,
- act_layer: Optional[Type[nn.Module]] = None,
- aa_layer: Optional[Type[nn.Module]] = None,
- drop_path_rate: float = 0.,
- device=None,
- dtype=None,
- ) -> None:
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.downsample = downsample
- self.stride = stride
- act_layer = act_layer or partial(nn.LeakyReLU, negative_slope=1e-3)
- self.conv1 = ConvNormAct(inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer, **dd)
- self.conv2 = ConvNormAct(
- planes,
- planes,
- kernel_size=3,
- stride=stride,
- act_layer=act_layer,
- aa_layer=aa_layer,
- **dd,
- )
- reduction_chs = max(planes * self.expansion // 8, 64)
- self.se = SEModule(planes, rd_channels=reduction_chs, **dd) if use_se else None
- self.conv3 = ConvNormAct(planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False, **dd)
- self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
- self.act = nn.ReLU(inplace=True)
- def forward(self, x):
- if self.downsample is not None:
- shortcut = self.downsample(x)
- else:
- shortcut = x
- out = self.conv1(x)
- out = self.conv2(out)
- if self.se is not None:
- out = self.se(out)
- out = self.conv3(out)
- out = self.drop_path(out) + shortcut
- out = self.act(out)
- return out
- class TResNet(nn.Module):
- def __init__(
- self,
- layers: List[int],
- in_chans: int = 3,
- num_classes: int = 1000,
- width_factor: float = 1.0,
- v2: bool = False,
- global_pool: str = 'fast',
- drop_rate: float = 0.,
- drop_path_rate: float = 0.,
- device=None,
- dtype=None,
- ) -> None:
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.drop_rate = drop_rate
- self.grad_checkpointing = False
- aa_layer = BlurPool2d
- act_layer = nn.LeakyReLU
- # TResnet stages
- self.inplanes = int(64 * width_factor)
- self.planes = int(64 * width_factor)
- if v2:
- self.inplanes = self.inplanes // 8 * 8
- self.planes = self.planes // 8 * 8
- dpr = calculate_drop_path_rates(drop_path_rate, layers, stagewise=True)
- conv1 = ConvNormAct(in_chans * 16, self.planes, stride=1, kernel_size=3, act_layer=act_layer, **dd)
- layer1 = self._make_layer(
- Bottleneck if v2 else BasicBlock,
- self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[0], **dd)
- layer2 = self._make_layer(
- Bottleneck if v2 else BasicBlock,
- self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[1], **dd)
- layer3 = self._make_layer(
- Bottleneck,
- self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[2], **dd)
- layer4 = self._make_layer(
- Bottleneck,
- self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer, drop_path_rate=dpr[3], **dd)
- # body
- self.body = nn.Sequential(OrderedDict([
- ('s2d', SpaceToDepth()),
- ('conv1', conv1),
- ('layer1', layer1),
- ('layer2', layer2),
- ('layer3', layer3),
- ('layer4', layer4),
- ]))
- self.feature_info = [
- dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D?
- dict(num_chs=self.planes * (Bottleneck.expansion if v2 else 1), reduction=4, module='body.layer1'),
- dict(num_chs=self.planes * 2 * (Bottleneck.expansion if v2 else 1), reduction=8, module='body.layer2'),
- dict(num_chs=self.planes * 4 * Bottleneck.expansion, reduction=16, module='body.layer3'),
- dict(num_chs=self.planes * 8 * Bottleneck.expansion, reduction=32, module='body.layer4'),
- ]
- # head
- self.num_features = self.head_hidden_size = (self.planes * 8) * Bottleneck.expansion
- self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate, **dd)
- # model initialization
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
- if isinstance(m, nn.Linear):
- m.weight.data.normal_(0, 0.01)
- # residual connections special initialization
- for m in self.modules():
- if isinstance(m, BasicBlock):
- nn.init.zeros_(m.conv2.bn.weight)
- if isinstance(m, Bottleneck):
- nn.init.zeros_(m.conv3.bn.weight)
- def _make_layer(
- self,
- block,
- planes,
- blocks,
- stride=1,
- use_se=True,
- aa_layer=None,
- drop_path_rate=0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- downsample = None
- if stride != 1 or self.inplanes != planes * block.expansion:
- layers = []
- if stride == 2:
- # avg pooling before 1x1 conv
- layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
- layers += [ConvNormAct(
- self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False, **dd)]
- downsample = nn.Sequential(*layers)
- layers = []
- for i in range(blocks):
- layers.append(block(
- self.inplanes,
- planes,
- stride=stride if i == 0 else 1,
- downsample=downsample if i == 0 else None,
- use_se=use_se,
- aa_layer=aa_layer,
- drop_path_rate=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate,
- **dd,
- ))
- self.inplanes = planes * block.expansion
- return nn.Sequential(*layers)
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- matcher = dict(stem=r'^body\.conv1', blocks=r'^body\.layer(\d+)' if coarse else r'^body\.layer(\d+)\.(\d+)')
- return matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- self.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head.fc
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- self.head.reset(num_classes, pool_type=global_pool)
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- norm: Apply norm layer to compatible intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- Returns:
- """
- assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
- intermediates = []
- stage_ends = [1, 2, 3, 4, 5]
- take_indices, max_index = feature_take_indices(len(stage_ends), indices)
- take_indices = [stage_ends[i] for i in take_indices]
- max_index = stage_ends[max_index]
- # forward pass
- if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
- stages = self.body
- else:
- stages = self.body[:max_index + 1]
- for feat_idx, stage in enumerate(stages):
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint(stage, x)
- else:
- x = stage(x)
- if feat_idx in take_indices:
- intermediates.append(x)
- if intermediates_only:
- return intermediates
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ):
- """ Prune layers not required for specified intermediates.
- """
- stage_ends = [1, 2, 3, 4, 5]
- take_indices, max_index = feature_take_indices(len(stage_ends), indices)
- max_index = stage_ends[max_index]
- self.body = self.body[:max_index + 1] # truncate blocks w/ stem as idx 0
- if prune_head:
- self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x):
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = self.body.s2d(x)
- x = self.body.conv1(x)
- x = checkpoint_seq([
- self.body.layer1,
- self.body.layer2,
- self.body.layer3,
- self.body.layer4],
- x, flatten=True)
- else:
- x = self.body(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def checkpoint_filter_fn(state_dict, model):
- if 'body.conv1.conv.weight' in state_dict:
- return state_dict
- import re
- state_dict = state_dict.get('model', state_dict)
- state_dict = state_dict.get('state_dict', state_dict)
- out_dict = {}
- for k, v in state_dict.items():
- k = re.sub(r'conv(\d+)\.0.0', lambda x: f'conv{int(x.group(1))}.conv', k)
- k = re.sub(r'conv(\d+)\.0.1', lambda x: f'conv{int(x.group(1))}.bn', k)
- k = re.sub(r'conv(\d+)\.0', lambda x: f'conv{int(x.group(1))}.conv', k)
- k = re.sub(r'conv(\d+)\.1', lambda x: f'conv{int(x.group(1))}.bn', k)
- k = re.sub(r'downsample\.(\d+)\.0', lambda x: f'downsample.{int(x.group(1))}.conv', k)
- k = re.sub(r'downsample\.(\d+)\.1', lambda x: f'downsample.{int(x.group(1))}.bn', k)
- if k.endswith('bn.weight'):
- # convert weight from inplace_abn to batchnorm
- v = v.abs().add(1e-5)
- out_dict[k] = v
- return out_dict
- def _create_tresnet(variant, pretrained=False, **kwargs):
- return build_model_with_cfg(
- TResNet,
- variant,
- pretrained,
- pretrained_filter_fn=checkpoint_filter_fn,
- feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True),
- **kwargs,
- )
- def _cfg(url='', **kwargs):
- return {
- 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
- 'crop_pct': 0.875, 'interpolation': 'bilinear',
- 'mean': (0., 0., 0.), 'std': (1., 1., 1.),
- 'first_conv': 'body.conv1.conv', 'classifier': 'head.fc',
- 'license': 'apache-2.0',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- 'tresnet_m.miil_in21k_ft_in1k': _cfg(hf_hub_id='timm/'),
- 'tresnet_m.miil_in21k': _cfg(hf_hub_id='timm/', num_classes=11221),
- 'tresnet_m.miil_in1k': _cfg(hf_hub_id='timm/'),
- 'tresnet_l.miil_in1k': _cfg(hf_hub_id='timm/'),
- 'tresnet_xl.miil_in1k': _cfg(hf_hub_id='timm/'),
- 'tresnet_m.miil_in1k_448': _cfg(
- input_size=(3, 448, 448), pool_size=(14, 14),
- hf_hub_id='timm/'),
- 'tresnet_l.miil_in1k_448': _cfg(
- input_size=(3, 448, 448), pool_size=(14, 14),
- hf_hub_id='timm/'),
- 'tresnet_xl.miil_in1k_448': _cfg(
- input_size=(3, 448, 448), pool_size=(14, 14),
- hf_hub_id='timm/'),
- 'tresnet_v2_l.miil_in21k_ft_in1k': _cfg(hf_hub_id='timm/'),
- 'tresnet_v2_l.miil_in21k': _cfg(hf_hub_id='timm/', num_classes=11221),
- })
- @register_model
- def tresnet_m(pretrained=False, **kwargs) -> TResNet:
- model_args = dict(layers=[3, 4, 11, 3])
- return _create_tresnet('tresnet_m', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def tresnet_l(pretrained=False, **kwargs) -> TResNet:
- model_args = dict(layers=[4, 5, 18, 3], width_factor=1.2)
- return _create_tresnet('tresnet_l', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def tresnet_xl(pretrained=False, **kwargs) -> TResNet:
- model_args = dict(layers=[4, 5, 24, 3], width_factor=1.3)
- return _create_tresnet('tresnet_xl', pretrained=pretrained, **dict(model_args, **kwargs))
- @register_model
- def tresnet_v2_l(pretrained=False, **kwargs) -> TResNet:
- model_args = dict(layers=[3, 4, 23, 3], width_factor=1.0, v2=True)
- return _create_tresnet('tresnet_v2_l', pretrained=pretrained, **dict(model_args, **kwargs))
- register_model_deprecations(__name__, {
- 'tresnet_m_miil_in21k': 'tresnet_m.miil_in21k',
- 'tresnet_m_448': 'tresnet_m.miil_in1k_448',
- 'tresnet_l_448': 'tresnet_l.miil_in1k_448',
- 'tresnet_xl_448': 'tresnet_xl.miil_in1k_448',
- })
|