| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740 |
- """ FocalNet
- As described in `Focal Modulation Networks` - https://arxiv.org/abs/2203.11926
- Significant modifications and refactoring from the original impl at https://github.com/microsoft/FocalNet
- This impl is/has:
- * fully convolutional, NCHW tensor layout throughout, seemed to have minimal performance impact but more flexible
- * re-ordered downsample / layer so that striding always at beginning of layer (stage)
- * no input size constraints or input resolution/H/W tracking through the model
- * torchscript fixed and a number of quirks cleaned up
- * feature extraction support via `features_only=True`
- """
- # --------------------------------------------------------
- # FocalNets -- Focal Modulation Networks
- # Copyright (c) 2022 Microsoft
- # Licensed under The MIT License [see LICENSE for details]
- # Written by Jianwei Yang (jianwyan@microsoft.com)
- # --------------------------------------------------------
- from functools import partial
- from typing import Callable, List, Optional, Tuple, Type, Union
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import (
- Mlp,
- DropPath,
- LayerNorm2d,
- LayerScale2d,
- trunc_normal_,
- ClassifierHead,
- NormMlpClassifierHead,
- calculate_drop_path_rates,
- )
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._manipulate import named_apply, checkpoint
- from ._registry import generate_default_cfgs, register_model
- __all__ = ['FocalNet']
- class FocalModulation(nn.Module):
- def __init__(
- self,
- dim: int,
- focal_window: int,
- focal_level: int,
- focal_factor: int = 2,
- bias: bool = True,
- use_post_norm: bool = False,
- normalize_modulator: bool = False,
- proj_drop: float = 0.,
- norm_layer: Type[nn.Module] = LayerNorm2d,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.dim = dim
- self.focal_window = focal_window
- self.focal_level = focal_level
- self.focal_factor = focal_factor
- self.use_post_norm = use_post_norm
- self.normalize_modulator = normalize_modulator
- self.input_split = [dim, dim, self.focal_level + 1]
- self.f = nn.Conv2d(dim, 2 * dim + (self.focal_level + 1), kernel_size=1, bias=bias, **dd)
- self.h = nn.Conv2d(dim, dim, kernel_size=1, bias=bias, **dd)
- self.act = nn.GELU()
- self.proj = nn.Conv2d(dim, dim, kernel_size=1, **dd)
- self.proj_drop = nn.Dropout(proj_drop)
- self.focal_layers = nn.ModuleList()
- self.kernel_sizes = []
- for k in range(self.focal_level):
- kernel_size = self.focal_factor * k + self.focal_window
- self.focal_layers.append(nn.Sequential(
- nn.Conv2d(dim, dim, kernel_size=kernel_size, groups=dim, padding=kernel_size // 2, bias=False, **dd),
- nn.GELU(),
- ))
- self.kernel_sizes.append(kernel_size)
- self.norm = norm_layer(dim, **dd) if self.use_post_norm else nn.Identity()
- def forward(self, x):
- # pre linear projection
- x = self.f(x)
- q, ctx, gates = torch.split(x, self.input_split, 1)
- # context aggregation
- ctx_all = 0
- for l, focal_layer in enumerate(self.focal_layers):
- ctx = focal_layer(ctx)
- ctx_all = ctx_all + ctx * gates[:, l:l + 1]
- ctx_global = self.act(ctx.mean((2, 3), keepdim=True))
- ctx_all = ctx_all + ctx_global * gates[:, self.focal_level:]
- # normalize context
- if self.normalize_modulator:
- ctx_all = ctx_all / (self.focal_level + 1)
- # focal modulation
- x_out = q * self.h(ctx_all)
- x_out = self.norm(x_out)
- # post linear projection
- x_out = self.proj(x_out)
- x_out = self.proj_drop(x_out)
- return x_out
- class FocalNetBlock(nn.Module):
- """ Focal Modulation Network Block.
- """
- def __init__(
- self,
- dim: int,
- mlp_ratio: float = 4.,
- focal_level: int = 1,
- focal_window: int = 3,
- use_post_norm: bool = False,
- use_post_norm_in_modulation: bool = False,
- normalize_modulator: bool = False,
- layerscale_value: Optional[float] = 1e-4,
- proj_drop: float = 0.,
- drop_path: float = 0.,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Type[nn.Module] = LayerNorm2d,
- device=None,
- dtype=None,
- ):
- """
- Args:
- dim: Number of input channels.
- mlp_ratio: Ratio of mlp hidden dim to embedding dim.
- focal_level: Number of focal levels.
- focal_window: Focal window size at first focal level.
- use_post_norm: Whether to use layer norm after modulation.
- use_post_norm_in_modulation: Whether to use layer norm in modulation.
- layerscale_value: Initial layerscale value.
- proj_drop: Dropout rate.
- drop_path: Stochastic depth rate.
- act_layer: Activation layer.
- norm_layer: Normalization layer.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.dim = dim
- self.mlp_ratio = mlp_ratio
- self.focal_window = focal_window
- self.focal_level = focal_level
- self.use_post_norm = use_post_norm
- self.norm1 = norm_layer(dim, **dd) if not use_post_norm else nn.Identity()
- self.modulation = FocalModulation(
- dim,
- focal_window=focal_window,
- focal_level=self.focal_level,
- use_post_norm=use_post_norm_in_modulation,
- normalize_modulator=normalize_modulator,
- proj_drop=proj_drop,
- norm_layer=norm_layer,
- **dd,
- )
- self.norm1_post = norm_layer(dim, **dd) if use_post_norm else nn.Identity()
- self.ls1 = LayerScale2d(dim, layerscale_value, **dd) if layerscale_value is not None else nn.Identity()
- self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim, **dd) if not use_post_norm else nn.Identity()
- self.mlp = Mlp(
- in_features=dim,
- hidden_features=int(dim * mlp_ratio),
- act_layer=act_layer,
- drop=proj_drop,
- use_conv=True,
- **dd,
- )
- self.norm2_post = norm_layer(dim, **dd) if use_post_norm else nn.Identity()
- self.ls2 = LayerScale2d(dim, layerscale_value, **dd) if layerscale_value is not None else nn.Identity()
- self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- def forward(self, x):
- shortcut = x
- # Focal Modulation
- x = self.norm1(x)
- x = self.modulation(x)
- x = self.norm1_post(x)
- x = shortcut + self.drop_path1(self.ls1(x))
- # FFN
- x = x + self.drop_path2(self.ls2(self.norm2_post(self.mlp(self.norm2(x)))))
- return x
- class FocalNetStage(nn.Module):
- """ A basic Focal Transformer layer for one stage.
- """
- def __init__(
- self,
- dim: int,
- out_dim: int,
- depth: int,
- mlp_ratio: float = 4.,
- downsample: bool = True,
- focal_level: int = 1,
- focal_window: int = 1,
- use_overlap_down: bool = False,
- use_post_norm: bool = False,
- use_post_norm_in_modulation: bool = False,
- normalize_modulator: bool = False,
- layerscale_value: Optional[float] = 1e-4,
- proj_drop: float = 0.,
- drop_path: Union[float, List[float]] = 0.,
- norm_layer: Type[nn.Module] = LayerNorm2d,
- device=None,
- dtype=None,
- ):
- """
- Args:
- dim: Number of input channels.
- out_dim: Number of output channels.
- depth: Number of blocks.
- mlp_ratio: Ratio of mlp hidden dim to embedding dim.
- downsample: Downsample layer at start of the layer.
- focal_level: Number of focal levels
- focal_window: Focal window size at first focal level
- use_overlap_down: User overlapped convolution in downsample layer.
- use_post_norm: Whether to use layer norm after modulation.
- use_post_norm_in_modulation: Whether to use layer norm in modulation.
- layerscale_value: Initial layerscale value
- proj_drop: Dropout rate for projections.
- drop_path: Stochastic depth rate.
- norm_layer: Normalization layer.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.dim = dim
- self.depth = depth
- self.grad_checkpointing = False
- if downsample:
- self.downsample = Downsample(
- in_chs=dim,
- out_chs=out_dim,
- stride=2,
- overlap=use_overlap_down,
- norm_layer=norm_layer,
- **dd,
- )
- else:
- self.downsample = nn.Identity()
- # build blocks
- self.blocks = nn.ModuleList([
- FocalNetBlock(
- dim=out_dim,
- mlp_ratio=mlp_ratio,
- focal_level=focal_level,
- focal_window=focal_window,
- use_post_norm=use_post_norm,
- use_post_norm_in_modulation=use_post_norm_in_modulation,
- normalize_modulator=normalize_modulator,
- layerscale_value=layerscale_value,
- proj_drop=proj_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer,
- **dd,
- )
- for i in range(depth)])
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- self.grad_checkpointing = enable
- def forward(self, x):
- x = self.downsample(x)
- for blk in self.blocks:
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint(blk, x)
- else:
- x = blk(x)
- return x
- class Downsample(nn.Module):
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- stride: int = 4,
- overlap: bool = False,
- norm_layer: Optional[Type[nn.Module]] = None,
- device=None,
- dtype=None,
- ):
- """
- Args:
- in_chs: Number of input image channels.
- out_chs: Number of linear projection output channels.
- stride: Downsample stride.
- overlap: Use overlapping convolutions if True.
- norm_layer: Normalization layer.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.stride = stride
- padding = 0
- kernel_size = stride
- if overlap:
- assert stride in (2, 4)
- if stride == 4:
- kernel_size, padding = 7, 2
- elif stride == 2:
- kernel_size, padding = 3, 1
- self.proj = nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, padding=padding, **dd)
- self.norm = norm_layer(out_chs, **dd) if norm_layer is not None else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
- class FocalNet(nn.Module):
- """" Focal Modulation Networks (FocalNets)
- """
- def __init__(
- self,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- embed_dim: int = 96,
- depths: Tuple[int, ...] = (2, 2, 6, 2),
- mlp_ratio: float = 4.,
- focal_levels: Tuple[int, ...] = (2, 2, 2, 2),
- focal_windows: Tuple[int, ...] = (3, 3, 3, 3),
- use_overlap_down: bool = False,
- use_post_norm: bool = False,
- use_post_norm_in_modulation: bool = False,
- normalize_modulator: bool = False,
- head_hidden_size: Optional[int] = None,
- head_init_scale: float = 1.0,
- layerscale_value: Optional[float] = None,
- drop_rate: float = 0.,
- proj_drop_rate: float = 0.,
- drop_path_rate: float = 0.1,
- norm_layer: Type[nn.Module] = partial(LayerNorm2d, eps=1e-5),
- device=None,
- dtype=None,
- ):
- """
- Args:
- in_chans: Number of input image channels.
- num_classes: Number of classes for classification head.
- embed_dim: Patch embedding dimension.
- depths: Depth of each Focal Transformer layer.
- mlp_ratio: Ratio of mlp hidden dim to embedding dim.
- focal_levels: How many focal levels at all stages. Note that this excludes the finest-grain level.
- focal_windows: The focal window size at all stages.
- use_overlap_down: Whether to use convolutional embedding.
- use_post_norm: Whether to use layernorm after modulation (it helps stabilize training of large models)
- layerscale_value: Value for layer scale.
- drop_rate: Dropout rate.
- drop_path_rate: Stochastic depth rate.
- norm_layer: Normalization layer.
- """
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.num_layers = len(depths)
- embed_dim = [embed_dim * (2 ** i) for i in range(self.num_layers)]
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.num_features = self.head_hidden_size = embed_dim[-1]
- self.feature_info = []
- self.stem = Downsample(
- in_chs=in_chans,
- out_chs=embed_dim[0],
- overlap=use_overlap_down,
- norm_layer=norm_layer,
- **dd,
- )
- in_dim = embed_dim[0]
- dpr = calculate_drop_path_rates(drop_path_rate, sum(depths)) # stochastic depth decay rule
- layers = []
- for i_layer in range(self.num_layers):
- out_dim = embed_dim[i_layer]
- layer = FocalNetStage(
- dim=in_dim,
- out_dim=out_dim,
- depth=depths[i_layer],
- mlp_ratio=mlp_ratio,
- downsample=i_layer > 0,
- focal_level=focal_levels[i_layer],
- focal_window=focal_windows[i_layer],
- use_overlap_down=use_overlap_down,
- use_post_norm=use_post_norm,
- use_post_norm_in_modulation=use_post_norm_in_modulation,
- normalize_modulator=normalize_modulator,
- layerscale_value=layerscale_value,
- proj_drop=proj_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
- norm_layer=norm_layer,
- **dd,
- )
- in_dim = out_dim
- layers += [layer]
- self.feature_info += [dict(num_chs=out_dim, reduction=4 * 2 ** i_layer, module=f'layers.{i_layer}')]
- self.layers = nn.Sequential(*layers)
- if head_hidden_size:
- self.norm = nn.Identity()
- self.head_hidden_size = head_hidden_size
- self.head = NormMlpClassifierHead(
- self.num_features,
- num_classes,
- hidden_size=head_hidden_size,
- pool_type=global_pool,
- drop_rate=drop_rate,
- norm_layer=norm_layer,
- **dd,
- )
- else:
- self.norm = norm_layer(self.num_features, **dd)
- self.head = ClassifierHead(
- self.num_features,
- num_classes,
- pool_type=global_pool,
- drop_rate=drop_rate,
- **dd,
- )
- named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
- @torch.jit.ignore
- def no_weight_decay(self):
- return {''}
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- return dict(
- stem=r'^stem',
- blocks=[
- (r'^layers\.(\d+)', None),
- (r'^norm', (99999,))
- ] if coarse else [
- (r'^layers\.(\d+).downsample', (0,)),
- (r'^layers\.(\d+)\.\w+\.(\d+)', None),
- (r'^norm', (99999,)),
- ]
- )
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- self.grad_checkpointing = enable
- for l in self.layers:
- l.set_grad_checkpointing(enable=enable)
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head.fc
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- self.head.reset(num_classes, pool_type=global_pool)
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- norm: Apply norm layer to compatible intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- Returns:
- """
- assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
- intermediates = []
- take_indices, max_index = feature_take_indices(len(self.layers), indices)
- # forward pass
- x = self.stem(x)
- if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
- stages = self.layers
- else:
- stages = self.layers[:max_index + 1]
- last_idx = len(self.layers) - 1
- for feat_idx, stage in enumerate(stages):
- x = stage(x)
- if feat_idx in take_indices:
- if norm and feat_idx == last_idx:
- x_inter = self.norm(x) # applying final norm to last intermediate
- else:
- x_inter = x
- intermediates.append(x_inter)
- if intermediates_only:
- return intermediates
- if feat_idx == last_idx:
- x = self.norm(x)
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ):
- """ Prune layers not required for specified intermediates.
- """
- take_indices, max_index = feature_take_indices(len(self.layers), indices)
- self.layers = self.layers[:max_index + 1] # truncate blocks w/ stem as idx 0
- if prune_norm:
- self.norm = nn.Identity()
- if prune_head:
- self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x):
- x = self.stem(x)
- x = self.layers(x)
- x = self.norm(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def _init_weights(module, name=None, head_init_scale=1.0):
- if isinstance(module, nn.Conv2d):
- trunc_normal_(module.weight, std=.02)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif isinstance(module, nn.Linear):
- trunc_normal_(module.weight, std=.02)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- if name and 'head.fc' in name:
- module.weight.data.mul_(head_init_scale)
- module.bias.data.mul_(head_init_scale)
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
- 'crop_pct': .9, 'interpolation': 'bicubic',
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'stem.proj', 'classifier': 'head.fc',
- 'license': 'mit', **kwargs
- }
- default_cfgs = generate_default_cfgs({
- "focalnet_tiny_srf.ms_in1k": _cfg(
- hf_hub_id='timm/'),
- "focalnet_small_srf.ms_in1k": _cfg(
- hf_hub_id='timm/'),
- "focalnet_base_srf.ms_in1k": _cfg(
- hf_hub_id='timm/'),
- "focalnet_tiny_lrf.ms_in1k": _cfg(
- hf_hub_id='timm/'),
- "focalnet_small_lrf.ms_in1k": _cfg(
- hf_hub_id='timm/'),
- "focalnet_base_lrf.ms_in1k": _cfg(
- hf_hub_id='timm/'),
- "focalnet_large_fl3.ms_in22k": _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
- "focalnet_large_fl4.ms_in22k": _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
- "focalnet_xlarge_fl3.ms_in22k": _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
- "focalnet_xlarge_fl4.ms_in22k": _cfg(
- hf_hub_id='timm/',
- input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
- "focalnet_huge_fl3.ms_in22k": _cfg(
- hf_hub_id='timm/',
- num_classes=21842),
- "focalnet_huge_fl4.ms_in22k": _cfg(
- hf_hub_id='timm/',
- num_classes=0),
- })
- def checkpoint_filter_fn(state_dict, model: FocalNet):
- state_dict = state_dict.get('model', state_dict)
- if 'stem.proj.weight' in state_dict:
- return state_dict
- import re
- out_dict = {}
- dest_dict = model.state_dict()
- for k, v in state_dict.items():
- k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
- k = k.replace('patch_embed', 'stem')
- k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
- if 'norm' in k and k not in dest_dict:
- k = re.sub(r'norm([0-9])', r'norm\1_post', k)
- k = k.replace('ln.', 'norm.')
- k = k.replace('head', 'head.fc')
- if k in dest_dict and dest_dict[k].numel() == v.numel() and dest_dict[k].shape != v.shape:
- v = v.reshape(dest_dict[k].shape)
- out_dict[k] = v
- return out_dict
- def _create_focalnet(variant, pretrained=False, **kwargs):
- default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
- out_indices = kwargs.pop('out_indices', default_out_indices)
- model = build_model_with_cfg(
- FocalNet, variant, pretrained,
- pretrained_filter_fn=checkpoint_filter_fn,
- feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
- **kwargs)
- return model
- @register_model
- def focalnet_tiny_srf(pretrained=False, **kwargs) -> FocalNet:
- model_kwargs = dict(depths=[2, 2, 6, 2], embed_dim=96, **kwargs)
- return _create_focalnet('focalnet_tiny_srf', pretrained=pretrained, **model_kwargs)
- @register_model
- def focalnet_small_srf(pretrained=False, **kwargs) -> FocalNet:
- model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=96, **kwargs)
- return _create_focalnet('focalnet_small_srf', pretrained=pretrained, **model_kwargs)
- @register_model
- def focalnet_base_srf(pretrained=False, **kwargs) -> FocalNet:
- model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=128, **kwargs)
- return _create_focalnet('focalnet_base_srf', pretrained=pretrained, **model_kwargs)
- @register_model
- def focalnet_tiny_lrf(pretrained=False, **kwargs) -> FocalNet:
- model_kwargs = dict(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs)
- return _create_focalnet('focalnet_tiny_lrf', pretrained=pretrained, **model_kwargs)
- @register_model
- def focalnet_small_lrf(pretrained=False, **kwargs) -> FocalNet:
- model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs)
- return _create_focalnet('focalnet_small_lrf', pretrained=pretrained, **model_kwargs)
- @register_model
- def focalnet_base_lrf(pretrained=False, **kwargs) -> FocalNet:
- model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], **kwargs)
- return _create_focalnet('focalnet_base_lrf', pretrained=pretrained, **model_kwargs)
- # FocalNet large+ models
- @register_model
- def focalnet_large_fl3(pretrained=False, **kwargs) -> FocalNet:
- model_kwargs = dict(
- depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[3, 3, 3, 3], focal_windows=[5] * 4,
- use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
- return _create_focalnet('focalnet_large_fl3', pretrained=pretrained, **model_kwargs)
- @register_model
- def focalnet_large_fl4(pretrained=False, **kwargs) -> FocalNet:
- model_kwargs = dict(
- depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[4, 4, 4, 4],
- use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
- return _create_focalnet('focalnet_large_fl4', pretrained=pretrained, **model_kwargs)
- @register_model
- def focalnet_xlarge_fl3(pretrained=False, **kwargs) -> FocalNet:
- model_kwargs = dict(
- depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[3, 3, 3, 3], focal_windows=[5] * 4,
- use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
- return _create_focalnet('focalnet_xlarge_fl3', pretrained=pretrained, **model_kwargs)
- @register_model
- def focalnet_xlarge_fl4(pretrained=False, **kwargs) -> FocalNet:
- model_kwargs = dict(
- depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[4, 4, 4, 4],
- use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
- return _create_focalnet('focalnet_xlarge_fl4', pretrained=pretrained, **model_kwargs)
- @register_model
- def focalnet_huge_fl3(pretrained=False, **kwargs) -> FocalNet:
- model_kwargs = dict(
- depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[3, 3, 3, 3], focal_windows=[3] * 4,
- use_post_norm=True, use_post_norm_in_modulation=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
- return _create_focalnet('focalnet_huge_fl3', pretrained=pretrained, **model_kwargs)
- @register_model
- def focalnet_huge_fl4(pretrained=False, **kwargs) -> FocalNet:
- model_kwargs = dict(
- depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[4, 4, 4, 4],
- use_post_norm=True, use_post_norm_in_modulation=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
- return _create_focalnet('focalnet_huge_fl4', pretrained=pretrained, **model_kwargs)
|