| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189 |
- """ Normalization Free Nets. NFNet, NF-RegNet, NF-ResNet (pre-activation) Models
- Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- - https://arxiv.org/abs/2101.08692
- Paper: `High-Performance Large-Scale Image Recognition Without Normalization`
- - https://arxiv.org/abs/2102.06171
- Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
- Status:
- * These models are a work in progress, experiments ongoing.
- * Pretrained weights for two models so far, more to come.
- * Model details updated to closer match official JAX code now that it's released
- * NF-ResNet, NF-RegNet-B, and NFNet-F models supported
- Hacked together by / copyright Ross Wightman, 2021.
- """
- from collections import OrderedDict
- from dataclasses import dataclass, replace
- from functools import partial
- from typing import Any, Callable, Dict, Optional, Tuple
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import ClassifierHead, DropPath, calculate_drop_path_rates, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame, \
- get_act_layer, get_act_fn, get_attn, make_divisible
- from ._builder import build_model_with_cfg
- from ._features_fx import register_notrace_module
- from ._manipulate import checkpoint_seq
- from ._registry import generate_default_cfgs, register_model
- __all__ = ['NormFreeNet', 'NfCfg'] # model_registry will add each entrypoint fn to this
- @dataclass
- class NfCfg:
- """Configuration for Normalization-Free Networks."""
- depths: Tuple[int, int, int, int]
- channels: Tuple[int, int, int, int]
- alpha: float = 0.2
- stem_type: str = '3x3'
- stem_chs: Optional[int] = None
- group_size: Optional[int] = None
- attn_layer: Optional[str] = None
- attn_kwargs: Optional[Dict[str, Any]] = None
- attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used
- width_factor: float = 1.0
- bottle_ratio: float = 0.5
- num_features: int = 0 # num out_channels for final conv, no final_conv if 0
- ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal
- reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle
- extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models
- gamma_in_act: bool = False
- same_padding: bool = False
- std_conv_eps: float = 1e-5
- skipinit: bool = False # disabled by default, non-trivial performance impact
- zero_init_fc: bool = False
- act_layer: str = 'silu'
- class GammaAct(nn.Module):
- """Activation function with gamma scaling factor."""
- def __init__(self, act_type: str = 'relu', gamma: float = 1.0, inplace: bool = False):
- """Initialize GammaAct.
- Args:
- act_type: Type of activation function.
- gamma: Scaling factor for activation output.
- inplace: Whether to perform activation in-place.
- """
- super().__init__()
- self.act_fn = get_act_fn(act_type)
- self.gamma = gamma
- self.inplace = inplace
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward pass.
- Args:
- x: Input tensor.
- Returns:
- Scaled activation output.
- """
- return self.act_fn(x, inplace=self.inplace).mul_(self.gamma)
- def act_with_gamma(act_type: str, gamma: float = 1.) -> Callable:
- """Create activation function factory with gamma scaling.
- Args:
- act_type: Type of activation function.
- gamma: Scaling factor for activation output.
- Returns:
- Activation function factory.
- """
- def _create(inplace: bool = False) -> GammaAct:
- return GammaAct(act_type, gamma=gamma, inplace=inplace)
- return _create
- class DownsampleAvg(nn.Module):
- """AvgPool downsampling as in 'D' ResNet variants with dilation support."""
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- stride: int = 1,
- dilation: int = 1,
- first_dilation: Optional[int] = None,
- conv_layer: Callable = ScaledStdConv2d,
- device=None,
- dtype=None,
- ):
- """Initialize DownsampleAvg.
- Args:
- in_chs: Input channels.
- out_chs: Output channels.
- stride: Stride for downsampling.
- dilation: Dilation rate.
- first_dilation: First dilation rate (unused).
- conv_layer: Convolution layer type.
- """
- super().__init__()
- avg_stride = stride if dilation == 1 else 1
- if stride > 1 or dilation > 1:
- avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
- self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
- else:
- self.pool = nn.Identity()
- self.conv = conv_layer(in_chs, out_chs, 1, stride=1, device=device, dtype=dtype)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward pass.
- Args:
- x: Input tensor.
- Returns:
- Downsampled tensor.
- """
- return self.conv(self.pool(x))
- @register_notrace_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301
- class NormFreeBlock(nn.Module):
- """Normalization-Free pre-activation block.
- """
- def __init__(
- self,
- in_chs: int,
- out_chs: Optional[int] = None,
- stride: int = 1,
- dilation: int = 1,
- first_dilation: Optional[int] = None,
- alpha: float = 1.0,
- beta: float = 1.0,
- bottle_ratio: float = 0.25,
- group_size: Optional[int] = None,
- ch_div: int = 1,
- reg: bool = True,
- extra_conv: bool = False,
- skipinit: bool = False,
- attn_layer: Optional[Callable] = None,
- attn_gain: float = 2.0,
- act_layer: Optional[Callable] = None,
- conv_layer: Callable = ScaledStdConv2d,
- drop_path_rate: float = 0.,
- device=None,
- dtype=None,
- ):
- """Initialize NormFreeBlock.
- Args:
- in_chs: Input channels.
- out_chs: Output channels.
- stride: Stride for convolution.
- dilation: Dilation rate.
- first_dilation: First dilation rate.
- alpha: Alpha scaling factor for residual.
- beta: Beta scaling factor for pre-activation.
- bottle_ratio: Bottleneck ratio.
- group_size: Group convolution size.
- ch_div: Channel divisor for rounding.
- reg: Use RegNet-style configuration.
- extra_conv: Add extra 3x3 convolution.
- skipinit: Use skipinit initialization.
- attn_layer: Attention layer type.
- attn_gain: Attention gain factor.
- act_layer: Activation layer type.
- conv_layer: Convolution layer type.
- drop_path_rate: Stochastic depth drop rate.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- first_dilation = first_dilation or dilation
- out_chs = out_chs or in_chs
- # RegNet variants scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
- mid_chs = make_divisible(in_chs * bottle_ratio if reg else out_chs * bottle_ratio, ch_div)
- groups = 1 if not group_size else mid_chs // group_size
- if group_size and group_size % ch_div == 0:
- mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error
- self.alpha = alpha
- self.beta = beta
- self.attn_gain = attn_gain
- if in_chs != out_chs or stride != 1 or dilation != first_dilation:
- self.downsample = DownsampleAvg(
- in_chs,
- out_chs,
- stride=stride,
- dilation=dilation,
- first_dilation=first_dilation,
- conv_layer=conv_layer,
- **dd,
- )
- else:
- self.downsample = None
- self.act1 = act_layer()
- self.conv1 = conv_layer(in_chs, mid_chs, 1, **dd)
- self.act2 = act_layer(inplace=True)
- self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups, **dd)
- if extra_conv:
- self.act2b = act_layer(inplace=True)
- self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups, **dd)
- else:
- self.act2b = None
- self.conv2b = None
- if reg and attn_layer is not None:
- self.attn = attn_layer(mid_chs, **dd) # RegNet blocks apply attn btw conv2 & 3
- else:
- self.attn = None
- self.act3 = act_layer()
- self.conv3 = conv_layer(mid_chs, out_chs, 1, gain_init=1. if skipinit else 0., **dd)
- if not reg and attn_layer is not None:
- self.attn_last = attn_layer(out_chs, **dd) # ResNet blocks apply attn after conv3
- else:
- self.attn_last = None
- self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
- self.skipinit_gain = nn.Parameter(torch.tensor(0., **dd)) if skipinit else None
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward pass.
- Args:
- x: Input tensor.
- Returns:
- Output tensor.
- """
- out = self.act1(x) * self.beta
- # shortcut branch
- shortcut = x
- if self.downsample is not None:
- shortcut = self.downsample(out)
- # residual branch
- out = self.conv1(out)
- out = self.conv2(self.act2(out))
- if self.conv2b is not None:
- out = self.conv2b(self.act2b(out))
- if self.attn is not None:
- out = self.attn_gain * self.attn(out)
- out = self.conv3(self.act3(out))
- if self.attn_last is not None:
- out = self.attn_gain * self.attn_last(out)
- out = self.drop_path(out)
- if self.skipinit_gain is not None:
- out.mul_(self.skipinit_gain)
- out = out * self.alpha + shortcut
- return out
- def create_stem(
- in_chs: int,
- out_chs: int,
- stem_type: str = '',
- conv_layer: Optional[Callable] = None,
- act_layer: Optional[Callable] = None,
- preact_feature: bool = True,
- device=None,
- dtype=None,
- ) -> Tuple[nn.Sequential, int, Dict[str, Any]]:
- """Create stem module for NFNet models.
- Args:
- in_chs: Input channels.
- out_chs: Output channels.
- stem_type: Type of stem ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', etc.).
- conv_layer: Convolution layer type.
- act_layer: Activation layer type.
- preact_feature: Use pre-activation feature.
- Returns:
- Tuple of (stem_module, stem_stride, stem_feature_info).
- """
- dd = {'device': device, 'dtype': dtype}
- stem_stride = 2
- stem_feature = dict(num_chs=out_chs, reduction=2, module='stem.conv')
- stem = OrderedDict()
- assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
- if 'deep' in stem_type:
- if 'quad' in stem_type:
- # 4 deep conv stack as in NFNet-F models
- assert 'pool' not in stem_type
- stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs)
- strides = (2, 1, 1, 2)
- stem_stride = 4
- stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv3')
- else:
- if 'tiered' in stem_type:
- stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # 'T' resnets in resnet.py
- else:
- stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets
- strides = (2, 1, 1)
- stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv2')
- last_idx = len(stem_chs) - 1
- for i, (c, s) in enumerate(zip(stem_chs, strides)):
- stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s, **dd)
- if i != last_idx:
- stem[f'act{i + 2}'] = act_layer(inplace=True)
- in_chs = c
- elif '3x3' in stem_type:
- # 3x3 stem conv as in RegNet
- stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2, **dd)
- else:
- # 7x7 stem conv as in ResNet
- stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2, **dd)
- if 'pool' in stem_type:
- stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1)
- stem_stride = 4
- return nn.Sequential(stem), stem_stride, stem_feature
- # from https://github.com/deepmind/deepmind-research/tree/master/nfnets
- _nonlin_gamma = dict(
- identity=1.0,
- celu=1.270926833152771,
- elu=1.2716004848480225,
- gelu=1.7015043497085571,
- leaky_relu=1.70590341091156,
- log_sigmoid=1.9193484783172607,
- log_softmax=1.0002083778381348,
- relu=1.7139588594436646,
- relu6=1.7131484746932983,
- selu=1.0008515119552612,
- sigmoid=4.803835391998291,
- silu=1.7881293296813965,
- softsign=2.338853120803833,
- softplus=1.9203323125839233,
- tanh=1.5939117670059204,
- )
- class NormFreeNet(nn.Module):
- """ Normalization-Free Network
- As described in :
- `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- - https://arxiv.org/abs/2101.08692
- and
- `High-Performance Large-Scale Image Recognition Without Normalization` - https://arxiv.org/abs/2102.06171
- This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and
- the (preact) ResNet models described earlier in the paper.
- There are a few differences:
- * channels are rounded to be divisible by 8 by default (keep tensor core kernels happy),
- this changes channel dim and param counts slightly from the paper models
- * activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
- impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
- * a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but
- apply it in each activation. This is slightly slower, numerically different, but matches official impl.
- * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
- for what it is/does. Approx 8-10% throughput loss.
- """
- def __init__(
- self,
- cfg: NfCfg,
- num_classes: int = 1000,
- in_chans: int = 3,
- global_pool: str = 'avg',
- output_stride: int = 32,
- drop_rate: float = 0.,
- drop_path_rate: float = 0.,
- device=None,
- dtype=None,
- **kwargs: Any,
- ):
- """
- Args:
- cfg: Model architecture configuration.
- num_classes: Number of classifier classes.
- in_chans: Number of input channels.
- global_pool: Global pooling type.
- output_stride: Output stride of network, one of (8, 16, 32).
- drop_rate: Dropout rate.
- drop_path_rate: Stochastic depth drop-path rate.
- **kwargs: Extra kwargs overlayed onto cfg.
- """
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.drop_rate = drop_rate
- self.grad_checkpointing = False
- cfg = replace(cfg, **kwargs)
- assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
- conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
- if cfg.gamma_in_act:
- act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
- conv_layer = partial(conv_layer, eps=cfg.std_conv_eps)
- else:
- act_layer = get_act_layer(cfg.act_layer)
- conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps)
- attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
- stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
- self.stem, stem_stride, stem_feat = create_stem(
- in_chans,
- stem_chs,
- cfg.stem_type,
- conv_layer=conv_layer,
- act_layer=act_layer,
- **dd,
- )
- self.feature_info = [stem_feat]
- drop_path_rates = calculate_drop_path_rates(drop_path_rate, cfg.depths, stagewise=True)
- prev_chs = stem_chs
- net_stride = stem_stride
- dilation = 1
- expected_var = 1.0
- stages = []
- for stage_idx, stage_depth in enumerate(cfg.depths):
- stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
- if net_stride >= output_stride and stride > 1:
- dilation *= stride
- stride = 1
- net_stride *= stride
- first_dilation = 1 if dilation in (1, 2) else 2
- blocks = []
- for block_idx in range(cfg.depths[stage_idx]):
- first_block = block_idx == 0 and stage_idx == 0
- out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div)
- blocks += [NormFreeBlock(
- in_chs=prev_chs, out_chs=out_chs,
- alpha=cfg.alpha,
- beta=1. / expected_var ** 0.5,
- stride=stride if block_idx == 0 else 1,
- dilation=dilation,
- first_dilation=first_dilation,
- group_size=cfg.group_size,
- bottle_ratio=1. if cfg.reg and first_block else cfg.bottle_ratio,
- ch_div=cfg.ch_div,
- reg=cfg.reg,
- extra_conv=cfg.extra_conv,
- skipinit=cfg.skipinit,
- attn_layer=attn_layer,
- attn_gain=cfg.attn_gain,
- act_layer=act_layer,
- conv_layer=conv_layer,
- drop_path_rate=drop_path_rates[stage_idx][block_idx],
- **dd,
- )]
- if block_idx == 0:
- expected_var = 1. # expected var is reset after first block of each stage
- expected_var += cfg.alpha ** 2 # Even if reset occurs, increment expected variance
- first_dilation = dilation
- prev_chs = out_chs
- self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')]
- stages += [nn.Sequential(*blocks)]
- self.stages = nn.Sequential(*stages)
- if cfg.num_features:
- # The paper NFRegNet models have an EfficientNet-like final head convolution.
- self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
- self.final_conv = conv_layer(prev_chs, self.num_features, 1, **dd)
- self.feature_info[-1] = dict(num_chs=self.num_features, reduction=net_stride, module=f'final_conv')
- else:
- self.num_features = prev_chs
- self.final_conv = nn.Identity()
- self.final_act = act_layer(inplace=cfg.num_features > 0)
- self.head_hidden_size = self.num_features
- self.head = ClassifierHead(
- self.num_features,
- num_classes,
- pool_type=global_pool,
- drop_rate=self.drop_rate,
- **dd,
- )
- for n, m in self.named_modules():
- if 'fc' in n and isinstance(m, nn.Linear):
- if cfg.zero_init_fc:
- nn.init.zeros_(m.weight)
- else:
- nn.init.normal_(m.weight, 0., .01)
- if m.bias is not None:
- nn.init.zeros_(m.bias)
- elif isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear')
- if m.bias is not None:
- nn.init.zeros_(m.bias)
- @torch.jit.ignore
- def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
- """Group parameters for optimization."""
- matcher = dict(
- stem=r'^stem',
- blocks=[
- (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
- (r'^final_conv', (99999,))
- ]
- )
- return matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable: bool = True) -> None:
- """Enable or disable gradient checkpointing."""
- self.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- """Get the classifier head."""
- return self.head.fc
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
- """Reset the classifier head.
- Args:
- num_classes: Number of classes for new classifier.
- global_pool: Global pooling type.
- """
- self.num_classes = num_classes
- self.head.reset(num_classes, global_pool)
- def forward_features(self, x: torch.Tensor) -> torch.Tensor:
- """Forward pass through feature extraction layers.
- Args:
- x: Input tensor.
- Returns:
- Feature tensor.
- """
- x = self.stem(x)
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.stages, x)
- else:
- x = self.stages(x)
- x = self.final_conv(x)
- x = self.final_act(x)
- return x
- def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
- """Forward pass through classifier head.
- Args:
- x: Input features.
- pre_logits: Return features before final linear layer.
- Returns:
- Classification logits or features.
- """
- return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward pass.
- Args:
- x: Input tensor.
- Returns:
- Output logits.
- """
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def _nfres_cfg(
- depths: Tuple[int, ...],
- channels: Tuple[int, ...] = (256, 512, 1024, 2048),
- group_size: Optional[int] = None,
- act_layer: str = 'relu',
- attn_layer: Optional[str] = None,
- attn_kwargs: Optional[Dict[str, Any]] = None,
- ) -> NfCfg:
- """Create NFNet ResNet configuration.
- Args:
- depths: Number of blocks in each stage.
- channels: Channel dimensions for each stage.
- group_size: Group convolution size.
- act_layer: Activation layer type.
- attn_layer: Attention layer type.
- attn_kwargs: Attention layer arguments.
- Returns:
- NFNet configuration.
- """
- attn_kwargs = attn_kwargs or {}
- cfg = NfCfg(
- depths=depths,
- channels=channels,
- stem_type='7x7_pool',
- stem_chs=64,
- bottle_ratio=0.25,
- group_size=group_size,
- act_layer=act_layer,
- attn_layer=attn_layer,
- attn_kwargs=attn_kwargs,
- )
- return cfg
- def _nfreg_cfg(depths: Tuple[int, ...], channels: Tuple[int, ...] = (48, 104, 208, 440)) -> NfCfg:
- """Create NFNet RegNet configuration.
- Args:
- depths: Number of blocks in each stage.
- channels: Channel dimensions for each stage.
- Returns:
- NFNet configuration.
- """
- num_features = 1280 * channels[-1] // 440
- attn_kwargs = dict(rd_ratio=0.5)
- cfg = NfCfg(
- depths=depths,
- channels=channels,
- stem_type='3x3',
- group_size=8,
- width_factor=0.75,
- bottle_ratio=2.25,
- num_features=num_features,
- reg=True,
- attn_layer='se',
- attn_kwargs=attn_kwargs,
- )
- return cfg
- def _nfnet_cfg(
- depths: Tuple[int, ...],
- channels: Tuple[int, ...] = (256, 512, 1536, 1536),
- group_size: int = 128,
- bottle_ratio: float = 0.5,
- feat_mult: float = 2.,
- act_layer: str = 'gelu',
- attn_layer: str = 'se',
- attn_kwargs: Optional[Dict[str, Any]] = None,
- ) -> NfCfg:
- """Create NFNet configuration.
- Args:
- depths: Number of blocks in each stage.
- channels: Channel dimensions for each stage.
- group_size: Group convolution size.
- bottle_ratio: Bottleneck ratio.
- feat_mult: Feature multiplier for final layer.
- act_layer: Activation layer type.
- attn_layer: Attention layer type.
- attn_kwargs: Attention layer arguments.
- Returns:
- NFNet configuration.
- """
- num_features = int(channels[-1] * feat_mult)
- attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5)
- cfg = NfCfg(
- depths=depths,
- channels=channels,
- stem_type='deep_quad',
- stem_chs=128,
- group_size=group_size,
- bottle_ratio=bottle_ratio,
- extra_conv=True,
- num_features=num_features,
- act_layer=act_layer,
- attn_layer=attn_layer,
- attn_kwargs=attn_kwargs,
- )
- return cfg
- def _dm_nfnet_cfg(
- depths: Tuple[int, ...],
- channels: Tuple[int, ...] = (256, 512, 1536, 1536),
- act_layer: str = 'gelu',
- skipinit: bool = True,
- ) -> NfCfg:
- """Create DeepMind NFNet configuration.
- Args:
- depths: Number of blocks in each stage.
- channels: Channel dimensions for each stage.
- act_layer: Activation layer type.
- skipinit: Use skipinit initialization.
- Returns:
- NFNet configuration.
- """
- cfg = NfCfg(
- depths=depths,
- channels=channels,
- stem_type='deep_quad',
- stem_chs=128,
- group_size=128,
- bottle_ratio=0.5,
- extra_conv=True,
- gamma_in_act=True,
- same_padding=True,
- skipinit=skipinit,
- num_features=int(channels[-1] * 2.0),
- act_layer=act_layer,
- attn_layer='se',
- attn_kwargs=dict(rd_ratio=0.5),
- )
- return cfg
- model_cfgs = dict(
- # NFNet-F models w/ GELU compatible with DeepMind weights
- dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)),
- dm_nfnet_f1=_dm_nfnet_cfg(depths=(2, 4, 12, 6)),
- dm_nfnet_f2=_dm_nfnet_cfg(depths=(3, 6, 18, 9)),
- dm_nfnet_f3=_dm_nfnet_cfg(depths=(4, 8, 24, 12)),
- dm_nfnet_f4=_dm_nfnet_cfg(depths=(5, 10, 30, 15)),
- dm_nfnet_f5=_dm_nfnet_cfg(depths=(6, 12, 36, 18)),
- dm_nfnet_f6=_dm_nfnet_cfg(depths=(7, 14, 42, 21)),
- # NFNet-F models w/ GELU
- nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)),
- nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)),
- nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)),
- nfnet_f3=_nfnet_cfg(depths=(4, 8, 24, 12)),
- nfnet_f4=_nfnet_cfg(depths=(5, 10, 30, 15)),
- nfnet_f5=_nfnet_cfg(depths=(6, 12, 36, 18)),
- nfnet_f6=_nfnet_cfg(depths=(7, 14, 42, 21)),
- nfnet_f7=_nfnet_cfg(depths=(8, 16, 48, 24)),
- # Experimental 'light' versions of NFNet-F that are little leaner, w/ SiLU act
- nfnet_l0=_nfnet_cfg(
- depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
- attn_kwargs=dict(rd_ratio=0.25, rd_divisor=8), act_layer='silu'),
- eca_nfnet_l0=_nfnet_cfg(
- depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
- attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
- eca_nfnet_l1=_nfnet_cfg(
- depths=(2, 4, 12, 6), feat_mult=2, group_size=64, bottle_ratio=0.25,
- attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
- eca_nfnet_l2=_nfnet_cfg(
- depths=(3, 6, 18, 9), feat_mult=2, group_size=64, bottle_ratio=0.25,
- attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
- eca_nfnet_l3=_nfnet_cfg(
- depths=(4, 8, 24, 12), feat_mult=2, group_size=64, bottle_ratio=0.25,
- attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
- # EffNet influenced RegNet defs.
- # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8.
- nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)),
- nf_regnet_b1=_nfreg_cfg(depths=(2, 4, 7, 7)),
- nf_regnet_b2=_nfreg_cfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488)),
- nf_regnet_b3=_nfreg_cfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528)),
- nf_regnet_b4=_nfreg_cfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616)),
- nf_regnet_b5=_nfreg_cfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704)),
- # ResNet (preact, D style deep stem/avg down) defs
- nf_resnet26=_nfres_cfg(depths=(2, 2, 2, 2)),
- nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)),
- nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)),
- nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
- nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
- nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
- nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()),
- nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()),
- nf_ecaresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='eca', attn_kwargs=dict()),
- test_nfnet=_nfnet_cfg(
- depths=(1, 1, 1, 1), channels=(32, 64, 96, 128), feat_mult=1.5, group_size=8, bottle_ratio=0.25,
- attn_kwargs=dict(rd_ratio=0.25, rd_divisor=8), act_layer='silu'),
- )
- def _create_normfreenet(variant: str, pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Create a NormFreeNet model.
- Args:
- variant: Model variant name.
- pretrained: Load pretrained weights.
- **kwargs: Additional model arguments.
- Returns:
- NormFreeNet model instance.
- """
- model_cfg = model_cfgs[variant]
- feature_cfg = dict(flatten_sequential=True)
- return build_model_with_cfg(
- NormFreeNet,
- variant,
- pretrained,
- model_cfg=model_cfg,
- feature_cfg=feature_cfg,
- **kwargs,
- )
- def _dcfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
- """Create default configuration dictionary.
- Args:
- url: Model weight URL.
- **kwargs: Additional configuration options.
- Returns:
- Configuration dictionary.
- """
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
- 'crop_pct': 0.9, 'interpolation': 'bicubic',
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'stem.conv1', 'classifier': 'head.fc', 'license': 'apache-2.0',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- 'dm_nfnet_f0.dm_in1k': _dcfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f0-604f9c3a.pth',
- pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), crop_pct=.9, crop_mode='squash'),
- 'dm_nfnet_f1.dm_in1k': _dcfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f1-fc540f82.pth',
- pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), crop_pct=0.91, crop_mode='squash'),
- 'dm_nfnet_f2.dm_in1k': _dcfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f2-89875923.pth',
- pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), crop_pct=0.92, crop_mode='squash'),
- 'dm_nfnet_f3.dm_in1k': _dcfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f3-d74ab3aa.pth',
- pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), crop_pct=0.94, crop_mode='squash'),
- 'dm_nfnet_f4.dm_in1k': _dcfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f4-0ac5b10b.pth',
- pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), crop_pct=0.951, crop_mode='squash'),
- 'dm_nfnet_f5.dm_in1k': _dcfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f5-ecb20ab1.pth',
- pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), crop_pct=0.954, crop_mode='squash'),
- 'dm_nfnet_f6.dm_in1k': _dcfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f6-e0f12116.pth',
- pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), crop_pct=0.956, crop_mode='squash'),
- 'nfnet_f0': _dcfg(
- url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
- 'nfnet_f1': _dcfg(
- url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)),
- 'nfnet_f2': _dcfg(
- url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)),
- 'nfnet_f3': _dcfg(
- url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)),
- 'nfnet_f4': _dcfg(
- url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)),
- 'nfnet_f5': _dcfg(
- url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)),
- 'nfnet_f6': _dcfg(
- url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)),
- 'nfnet_f7': _dcfg(
- url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
- 'nfnet_l0.ra2_in1k': _dcfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0_ra2-45c6688d.pth',
- pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'eca_nfnet_l0.ra2_in1k': _dcfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l0_ra2-e3e9ac50.pth',
- pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'eca_nfnet_l1.ra2_in1k': _dcfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth',
- pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), test_crop_pct=1.0),
- 'eca_nfnet_l2.ra3_in1k': _dcfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l2_ra3-da781a61.pth',
- pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), test_crop_pct=1.0),
- 'eca_nfnet_l3': _dcfg(
- url='',
- pool_size=(11, 11), input_size=(3, 352, 352), test_input_size=(3, 448, 448), test_crop_pct=1.0),
- 'nf_regnet_b0': _dcfg(
- url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
- 'nf_regnet_b1.ra2_in1k': _dcfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth',
- pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), first_conv='stem.conv'), # NOT to paper spec
- 'nf_regnet_b2': _dcfg(
- url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272), first_conv='stem.conv'),
- 'nf_regnet_b3': _dcfg(
- url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320), first_conv='stem.conv'),
- 'nf_regnet_b4': _dcfg(
- url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), first_conv='stem.conv'),
- 'nf_regnet_b5': _dcfg(
- url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456), first_conv='stem.conv'),
- 'nf_resnet26': _dcfg(url='', first_conv='stem.conv'),
- 'nf_resnet50.ra2_in1k': _dcfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth',
- pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94, first_conv='stem.conv'),
- 'nf_resnet101': _dcfg(url='', first_conv='stem.conv'),
- 'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'),
- 'nf_seresnet50': _dcfg(url='', first_conv='stem.conv'),
- 'nf_seresnet101': _dcfg(url='', first_conv='stem.conv'),
- 'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'),
- 'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'),
- 'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'),
- 'test_nfnet.r160_in1k': _dcfg(
- hf_hub_id='timm/',
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
- crop_pct=0.95, input_size=(3, 160, 160), pool_size=(5, 5)),
- })
- @register_model
- def dm_nfnet_f0(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F0 (DeepMind weight compatible)."""
- return _create_normfreenet('dm_nfnet_f0', pretrained=pretrained, **kwargs)
- @register_model
- def dm_nfnet_f1(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F1 (DeepMind weight compatible)."""
- return _create_normfreenet('dm_nfnet_f1', pretrained=pretrained, **kwargs)
- @register_model
- def dm_nfnet_f2(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F2 (DeepMind weight compatible)."""
- return _create_normfreenet('dm_nfnet_f2', pretrained=pretrained, **kwargs)
- @register_model
- def dm_nfnet_f3(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F3 (DeepMind weight compatible)."""
- return _create_normfreenet('dm_nfnet_f3', pretrained=pretrained, **kwargs)
- @register_model
- def dm_nfnet_f4(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F4 (DeepMind weight compatible)."""
- return _create_normfreenet('dm_nfnet_f4', pretrained=pretrained, **kwargs)
- @register_model
- def dm_nfnet_f5(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F5 (DeepMind weight compatible)."""
- return _create_normfreenet('dm_nfnet_f5', pretrained=pretrained, **kwargs)
- @register_model
- def dm_nfnet_f6(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F6 (DeepMind weight compatible)."""
- return _create_normfreenet('dm_nfnet_f6', pretrained=pretrained, **kwargs)
- @register_model
- def nfnet_f0(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F0."""
- return _create_normfreenet('nfnet_f0', pretrained=pretrained, **kwargs)
- @register_model
- def nfnet_f1(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F1."""
- return _create_normfreenet('nfnet_f1', pretrained=pretrained, **kwargs)
- @register_model
- def nfnet_f2(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F2."""
- return _create_normfreenet('nfnet_f2', pretrained=pretrained, **kwargs)
- @register_model
- def nfnet_f3(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F3."""
- return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs)
- @register_model
- def nfnet_f4(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F4."""
- return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs)
- @register_model
- def nfnet_f5(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F5."""
- return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs)
- @register_model
- def nfnet_f6(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F6."""
- return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs)
- @register_model
- def nfnet_f7(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-F7."""
- return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs)
- @register_model
- def nfnet_l0(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """NFNet-L0b w/ SiLU.
- My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio
- """
- return _create_normfreenet('nfnet_l0', pretrained=pretrained, **kwargs)
- @register_model
- def eca_nfnet_l0(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """ECA-NFNet-L0 w/ SiLU.
- My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
- """
- return _create_normfreenet('eca_nfnet_l0', pretrained=pretrained, **kwargs)
- @register_model
- def eca_nfnet_l1(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """ECA-NFNet-L1 w/ SiLU.
- My experimental 'light' model w/ F1 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
- """
- return _create_normfreenet('eca_nfnet_l1', pretrained=pretrained, **kwargs)
- @register_model
- def eca_nfnet_l2(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """ECA-NFNet-L2 w/ SiLU.
- My experimental 'light' model w/ F2 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
- """
- return _create_normfreenet('eca_nfnet_l2', pretrained=pretrained, **kwargs)
- @register_model
- def eca_nfnet_l3(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """ECA-NFNet-L3 w/ SiLU.
- My experimental 'light' model w/ F3 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
- """
- return _create_normfreenet('eca_nfnet_l3', pretrained=pretrained, **kwargs)
- @register_model
- def nf_regnet_b0(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free RegNet-B0.
- """
- return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs)
- @register_model
- def nf_regnet_b1(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free RegNet-B1.
- """
- return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs)
- @register_model
- def nf_regnet_b2(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free RegNet-B2.
- """
- return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs)
- @register_model
- def nf_regnet_b3(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free RegNet-B3.
- """
- return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs)
- @register_model
- def nf_regnet_b4(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free RegNet-B4.
- """
- return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs)
- @register_model
- def nf_regnet_b5(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free RegNet-B5.
- """
- return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs)
- @register_model
- def nf_resnet26(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free ResNet-26.
- """
- return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs)
- @register_model
- def nf_resnet50(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free ResNet-50.
- """
- return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs)
- @register_model
- def nf_resnet101(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free ResNet-101.
- """
- return _create_normfreenet('nf_resnet101', pretrained=pretrained, **kwargs)
- @register_model
- def nf_seresnet26(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free SE-ResNet26."""
- return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs)
- @register_model
- def nf_seresnet50(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free SE-ResNet50."""
- return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs)
- @register_model
- def nf_seresnet101(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free SE-ResNet101."""
- return _create_normfreenet('nf_seresnet101', pretrained=pretrained, **kwargs)
- @register_model
- def nf_ecaresnet26(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free ECA-ResNet26."""
- return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs)
- @register_model
- def nf_ecaresnet50(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free ECA-ResNet50."""
- return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs)
- @register_model
- def nf_ecaresnet101(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Normalization-Free ECA-ResNet101."""
- return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs)
- @register_model
- def test_nfnet(pretrained: bool = False, **kwargs: Any) -> NormFreeNet:
- """Test NFNet model for experimentation."""
- return _create_normfreenet('test_nfnet', pretrained=pretrained, **kwargs)
|