| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572 |
- """ Sequencer
- Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2205.01972.pdf
- """
- # Copyright (c) 2022. Yuki Tatsunami
- # Licensed under the Apache License, Version 2.0 (the "License");
- import math
- from functools import partial
- from itertools import accumulate
- from typing import List, Optional, Tuple, Type, Union
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
- from timm.layers import lecun_normal_, DropPath, Mlp, PatchEmbed, ClassifierHead
- from ._builder import build_model_with_cfg
- from ._manipulate import named_apply
- from ._registry import register_model, generate_default_cfgs
- __all__ = ['Sequencer2d'] # model_registry will add each entrypoint fn to this
- def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
- if isinstance(module, nn.Linear):
- if name.startswith('head'):
- nn.init.zeros_(module.weight)
- nn.init.constant_(module.bias, head_bias)
- else:
- if flax:
- # Flax defaults
- lecun_normal_(module.weight)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- else:
- nn.init.xavier_uniform_(module.weight)
- if module.bias is not None:
- if 'mlp' in name:
- nn.init.normal_(module.bias, std=1e-6)
- else:
- nn.init.zeros_(module.bias)
- elif isinstance(module, nn.Conv2d):
- lecun_normal_(module.weight)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
- nn.init.ones_(module.weight)
- nn.init.zeros_(module.bias)
- elif isinstance(module, (nn.RNN, nn.GRU, nn.LSTM)):
- stdv = 1.0 / math.sqrt(module.hidden_size)
- for weight in module.parameters():
- nn.init.uniform_(weight, -stdv, stdv)
- elif hasattr(module, 'init_weights'):
- module.init_weights()
- class RNNIdentity(nn.Module):
- def __init__(self, *args, **kwargs):
- super().__init__()
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]:
- return x, None
- class RNN2dBase(nn.Module):
- def __init__(
- self,
- input_size: int,
- hidden_size: int,
- num_layers: int = 1,
- bias: bool = True,
- bidirectional: bool = True,
- union: str = "cat",
- with_fc: bool = True,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.output_size = 2 * hidden_size if bidirectional else hidden_size
- self.union = union
- self.with_vertical = True
- self.with_horizontal = True
- self.with_fc = with_fc
- self.fc = None
- if with_fc:
- if union == "cat":
- self.fc = nn.Linear(2 * self.output_size, input_size, **dd)
- elif union == "add":
- self.fc = nn.Linear(self.output_size, input_size, **dd)
- elif union == "vertical":
- self.fc = nn.Linear(self.output_size, input_size, **dd)
- self.with_horizontal = False
- elif union == "horizontal":
- self.fc = nn.Linear(self.output_size, input_size, **dd)
- self.with_vertical = False
- else:
- raise ValueError("Unrecognized union: " + union)
- elif union == "cat":
- pass
- if 2 * self.output_size != input_size:
- raise ValueError(f"The output channel {2 * self.output_size} is different from the input channel {input_size}.")
- elif union == "add":
- pass
- if self.output_size != input_size:
- raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
- elif union == "vertical":
- if self.output_size != input_size:
- raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
- self.with_horizontal = False
- elif union == "horizontal":
- if self.output_size != input_size:
- raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
- self.with_vertical = False
- else:
- raise ValueError("Unrecognized union: " + union)
- self.rnn_v = RNNIdentity()
- self.rnn_h = RNNIdentity()
- def forward(self, x):
- B, H, W, C = x.shape
- if self.with_vertical:
- v = x.permute(0, 2, 1, 3)
- v = v.reshape(-1, H, C)
- v, _ = self.rnn_v(v)
- v = v.reshape(B, W, H, -1)
- v = v.permute(0, 2, 1, 3)
- else:
- v = None
- if self.with_horizontal:
- h = x.reshape(-1, W, C)
- h, _ = self.rnn_h(h)
- h = h.reshape(B, H, W, -1)
- else:
- h = None
- if v is not None and h is not None:
- if self.union == "cat":
- x = torch.cat([v, h], dim=-1)
- else:
- x = v + h
- elif v is not None:
- x = v
- elif h is not None:
- x = h
- if self.fc is not None:
- x = self.fc(x)
- return x
- class LSTM2d(RNN2dBase):
- def __init__(
- self,
- input_size: int,
- hidden_size: int,
- num_layers: int = 1,
- bias: bool = True,
- bidirectional: bool = True,
- union: str = "cat",
- with_fc: bool = True,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc, device, dtype)
- if self.with_vertical:
- self.rnn_v = nn.LSTM(
- input_size,
- hidden_size,
- num_layers,
- batch_first=True,
- bias=bias,
- bidirectional=bidirectional,
- **dd,
- )
- if self.with_horizontal:
- self.rnn_h = nn.LSTM(
- input_size,
- hidden_size,
- num_layers,
- batch_first=True,
- bias=bias,
- bidirectional=bidirectional,
- **dd,
- )
- class Sequencer2dBlock(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_size: int,
- mlp_ratio: float = 3.0,
- rnn_layer: Type[nn.Module] = LSTM2d,
- mlp_layer: Type[nn.Module] = Mlp,
- norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
- act_layer: Type[nn.Module] = nn.GELU,
- num_layers: int = 1,
- bidirectional: bool = True,
- union: str = "cat",
- with_fc: bool = True,
- drop: float = 0.,
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- channels_dim = int(mlp_ratio * dim)
- self.norm1 = norm_layer(dim, **dd)
- self.rnn_tokens = rnn_layer(
- dim,
- hidden_size,
- num_layers=num_layers,
- bidirectional=bidirectional,
- union=union,
- with_fc=with_fc,
- **dd,
- )
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim, **dd)
- self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop, **dd)
- def forward(self, x):
- x = x + self.drop_path(self.rnn_tokens(self.norm1(x)))
- x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
- return x
- class Shuffle(nn.Module):
- def __init__(self):
- super().__init__()
- def forward(self, x):
- if self.training:
- B, H, W, C = x.shape
- r = torch.randperm(H * W)
- x = x.reshape(B, -1, C)
- x = x[:, r, :].reshape(B, H, W, -1)
- return x
- class Downsample2d(nn.Module):
- def __init__(
- self,
- input_dim: int,
- output_dim: int,
- patch_size: int,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.down = nn.Conv2d(input_dim, output_dim, kernel_size=patch_size, stride=patch_size, **dd)
- def forward(self, x):
- x = x.permute(0, 3, 1, 2)
- x = self.down(x)
- x = x.permute(0, 2, 3, 1)
- return x
- class Sequencer2dStage(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: int,
- depth: int,
- patch_size: int,
- hidden_size: int,
- mlp_ratio: float,
- downsample: bool = False,
- block_layer: Type[nn.Module] = Sequencer2dBlock,
- rnn_layer: Type[nn.Module] = LSTM2d,
- mlp_layer: Type[nn.Module] = Mlp,
- norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
- act_layer: Type[nn.Module] = nn.GELU,
- num_layers: int = 1,
- bidirectional: bool = True,
- union: str = "cat",
- with_fc: bool = True,
- drop: float = 0.,
- drop_path: Union[float, List[float]] = 0.,
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- if downsample:
- self.downsample = Downsample2d(dim, dim_out, patch_size, **dd)
- else:
- assert dim == dim_out
- self.downsample = nn.Identity()
- blocks = []
- for block_idx in range(depth):
- blocks.append(block_layer(
- dim_out,
- hidden_size,
- mlp_ratio=mlp_ratio,
- rnn_layer=rnn_layer,
- mlp_layer=mlp_layer,
- norm_layer=norm_layer,
- act_layer=act_layer,
- num_layers=num_layers,
- bidirectional=bidirectional,
- union=union,
- with_fc=with_fc,
- drop=drop,
- drop_path=drop_path[block_idx] if isinstance(drop_path, (list, tuple)) else drop_path,
- **dd,
- ))
- self.blocks = nn.Sequential(*blocks)
- def forward(self, x):
- x = self.downsample(x)
- x = self.blocks(x)
- return x
- class Sequencer2d(nn.Module):
- def __init__(
- self,
- num_classes: int = 1000,
- img_size: int = 224,
- in_chans: int = 3,
- global_pool: str = 'avg',
- layers: Tuple[int, ...] = (4, 3, 8, 3),
- patch_sizes: Tuple[int, ...] = (7, 2, 2, 1),
- embed_dims: Tuple[int, ...] = (192, 384, 384, 384),
- hidden_sizes: Tuple[int, ...] = (48, 96, 96, 96),
- mlp_ratios: Tuple[float, ...] = (3.0, 3.0, 3.0, 3.0),
- block_layer: Type[nn.Module] = Sequencer2dBlock,
- rnn_layer: Type[nn.Module] = LSTM2d,
- mlp_layer: Type[nn.Module] = Mlp,
- norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
- act_layer: Type[nn.Module] = nn.GELU,
- num_rnn_layers: int = 1,
- bidirectional: bool = True,
- union: str = "cat",
- with_fc: bool = True,
- drop_rate: float = 0.,
- drop_path_rate: float = 0.,
- nlhb: bool = False,
- stem_norm: bool = False,
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- assert global_pool in ('', 'avg')
- self.num_classes = num_classes
- self.in_chans = in_chans
- self.global_pool = global_pool
- self.num_features = self.head_hidden_size = embed_dims[-1] # for consistency with other models
- self.feature_dim = -1 # channel dim index for feature outputs (rank 4, NHWC)
- self.output_fmt = 'NHWC'
- self.feature_info = []
- self.stem = PatchEmbed(
- img_size=None,
- patch_size=patch_sizes[0],
- in_chans=in_chans,
- embed_dim=embed_dims[0],
- norm_layer=norm_layer if stem_norm else None,
- flatten=False,
- output_fmt='NHWC',
- **dd,
- )
- assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios)
- reductions = list(accumulate(patch_sizes, lambda x, y: x * y))
- stages = []
- prev_dim = embed_dims[0]
- for i, _ in enumerate(embed_dims):
- stages += [Sequencer2dStage(
- prev_dim,
- embed_dims[i],
- depth=layers[i],
- downsample=i > 0,
- patch_size=patch_sizes[i],
- hidden_size=hidden_sizes[i],
- mlp_ratio=mlp_ratios[i],
- block_layer=block_layer,
- rnn_layer=rnn_layer,
- mlp_layer=mlp_layer,
- norm_layer=norm_layer,
- act_layer=act_layer,
- num_layers=num_rnn_layers,
- bidirectional=bidirectional,
- union=union,
- with_fc=with_fc,
- drop=drop_rate,
- drop_path=drop_path_rate,
- **dd,
- )]
- prev_dim = embed_dims[i]
- self.feature_info += [dict(num_chs=prev_dim, reduction=reductions[i], module=f'stages.{i}')]
- self.stages = nn.Sequential(*stages)
- self.norm = norm_layer(embed_dims[-1], **dd)
- self.head = ClassifierHead(
- self.num_features,
- num_classes,
- pool_type=global_pool,
- drop_rate=drop_rate,
- input_fmt=self.output_fmt,
- **dd,
- )
- self.init_weights(nlhb=nlhb)
- def init_weights(self, nlhb=False):
- head_bias = -math.log(self.num_classes) if nlhb else 0.
- named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- return dict(
- stem=r'^stem',
- blocks=[
- (r'^stages\.(\d+)', None),
- (r'^norm', (99999,))
- ] if coarse else [
- (r'^stages\.(\d+)\.blocks\.(\d+)', None),
- (r'^stages\.(\d+)\.downsample', (0,)),
- (r'^norm', (99999,))
- ]
- )
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- assert not enable, 'gradient checkpointing not supported'
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- self.head.reset(num_classes, pool_type=global_pool)
- def forward_features(self, x):
- x = self.stem(x)
- x = self.stages(x)
- x = self.norm(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- return self.head(x, pre_logits=True) if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def checkpoint_filter_fn(state_dict, model):
- """ Remap original checkpoints -> timm """
- if 'stages.0.blocks.0.norm1.weight' in state_dict:
- return state_dict # already translated checkpoint
- if 'model' in state_dict:
- state_dict = state_dict['model']
- import re
- out_dict = {}
- for k, v in state_dict.items():
- k = re.sub(r'blocks.([0-9]+).([0-9]+).down', lambda x: f'stages.{int(x.group(1)) + 1}.downsample.down', k)
- k = re.sub(r'blocks.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
- k = k.replace('head.', 'head.fc.')
- out_dict[k] = v
- return out_dict
- def _create_sequencer2d(variant, pretrained=False, **kwargs):
- default_out_indices = tuple(range(3))
- out_indices = kwargs.pop('out_indices', default_out_indices)
- model = build_model_with_cfg(
- Sequencer2d,
- variant,
- pretrained,
- pretrained_filter_fn=checkpoint_filter_fn,
- feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
- **kwargs,
- )
- return model
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
- 'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True,
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'stem.proj', 'classifier': 'head.fc',
- 'license': 'apache-2.0',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- 'sequencer2d_s.in1k': _cfg(hf_hub_id='timm/'),
- 'sequencer2d_m.in1k': _cfg(hf_hub_id='timm/'),
- 'sequencer2d_l.in1k': _cfg(hf_hub_id='timm/'),
- })
- @register_model
- def sequencer2d_s(pretrained=False, **kwargs) -> Sequencer2d:
- model_args = dict(
- layers=[4, 3, 8, 3],
- patch_sizes=[7, 2, 1, 1],
- embed_dims=[192, 384, 384, 384],
- hidden_sizes=[48, 96, 96, 96],
- mlp_ratios=[3.0, 3.0, 3.0, 3.0],
- rnn_layer=LSTM2d,
- bidirectional=True,
- union="cat",
- with_fc=True,
- )
- model = _create_sequencer2d('sequencer2d_s', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def sequencer2d_m(pretrained=False, **kwargs) -> Sequencer2d:
- model_args = dict(
- layers=[4, 3, 14, 3],
- patch_sizes=[7, 2, 1, 1],
- embed_dims=[192, 384, 384, 384],
- hidden_sizes=[48, 96, 96, 96],
- mlp_ratios=[3.0, 3.0, 3.0, 3.0],
- rnn_layer=LSTM2d,
- bidirectional=True,
- union="cat",
- with_fc=True,
- **kwargs)
- model = _create_sequencer2d('sequencer2d_m', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
- @register_model
- def sequencer2d_l(pretrained=False, **kwargs) -> Sequencer2d:
- model_args = dict(
- layers=[8, 8, 16, 4],
- patch_sizes=[7, 2, 1, 1],
- embed_dims=[192, 384, 384, 384],
- hidden_sizes=[48, 96, 96, 96],
- mlp_ratios=[3.0, 3.0, 3.0, 3.0],
- rnn_layer=LSTM2d,
- bidirectional=True,
- union="cat",
- with_fc=True,
- **kwargs)
- model = _create_sequencer2d('sequencer2d_l', pretrained=pretrained, **dict(model_args, **kwargs))
- return model
|