| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- """ ConvMixer
- """
- from typing import Optional, Type
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import SelectAdaptivePool2d
- from ._registry import register_model, generate_default_cfgs
- from ._builder import build_model_with_cfg
- from ._manipulate import checkpoint_seq
- __all__ = ['ConvMixer']
- class Residual(nn.Module):
- def __init__(self, fn: nn.Module):
- super().__init__()
- self.fn = fn
- def forward(self, x):
- return self.fn(x) + x
- class ConvMixer(nn.Module):
- def __init__(
- self,
- dim: int,
- depth: int,
- kernel_size: int = 9,
- patch_size: int = 7,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- drop_rate: float = 0.,
- act_layer: Type[nn.Module] = nn.GELU,
- device=None,
- dtype=None,
- **kwargs,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.num_features = self.head_hidden_size = dim
- self.grad_checkpointing = False
- self.stem = nn.Sequential(
- nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size, **dd),
- act_layer(),
- nn.BatchNorm2d(dim, **dd)
- )
- self.blocks = nn.Sequential(
- *[nn.Sequential(
- Residual(nn.Sequential(
- nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same", **dd),
- act_layer(),
- nn.BatchNorm2d(dim, **dd)
- )),
- nn.Conv2d(dim, dim, kernel_size=1, **dd),
- act_layer(),
- nn.BatchNorm2d(dim, **dd)
- ) for i in range(depth)]
- )
- self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
- self.head_drop = nn.Dropout(drop_rate)
- self.head = nn.Linear(dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- matcher = dict(stem=r'^stem', blocks=r'^blocks\.(\d+)')
- return matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- self.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- if global_pool is not None:
- self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
- self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
- def forward_features(self, x):
- x = self.stem(x)
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint_seq(self.blocks, x)
- else:
- x = self.blocks(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- x = self.pooling(x)
- x = self.head_drop(x)
- return x if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def _create_convmixer(variant, pretrained=False, **kwargs):
- if kwargs.get('features_only', None):
- raise RuntimeError('features_only not implemented for ConvMixer models.')
- return build_model_with_cfg(ConvMixer, variant, pretrained, **kwargs)
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
- 'crop_pct': .96, 'interpolation': 'bicubic',
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
- 'first_conv': 'stem.0', 'license': 'mit',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- 'convmixer_1536_20.in1k': _cfg(hf_hub_id='timm/'),
- 'convmixer_768_32.in1k': _cfg(hf_hub_id='timm/'),
- 'convmixer_1024_20_ks9_p14.in1k': _cfg(hf_hub_id='timm/')
- })
- @register_model
- def convmixer_1536_20(pretrained=False, **kwargs) -> ConvMixer:
- model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs)
- return _create_convmixer('convmixer_1536_20', pretrained, **model_args)
- @register_model
- def convmixer_768_32(pretrained=False, **kwargs) -> ConvMixer:
- model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, act_layer=nn.ReLU, **kwargs)
- return _create_convmixer('convmixer_768_32', pretrained, **model_args)
- @register_model
- def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs) -> ConvMixer:
- model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs)
- return _create_convmixer('convmixer_1024_20_ks9_p14', pretrained, **model_args)
|