convmixer.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. """ ConvMixer
  2. """
  3. from typing import Optional, Type
  4. import torch
  5. import torch.nn as nn
  6. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  7. from timm.layers import SelectAdaptivePool2d
  8. from ._registry import register_model, generate_default_cfgs
  9. from ._builder import build_model_with_cfg
  10. from ._manipulate import checkpoint_seq
  11. __all__ = ['ConvMixer']
  12. class Residual(nn.Module):
  13. def __init__(self, fn: nn.Module):
  14. super().__init__()
  15. self.fn = fn
  16. def forward(self, x):
  17. return self.fn(x) + x
  18. class ConvMixer(nn.Module):
  19. def __init__(
  20. self,
  21. dim: int,
  22. depth: int,
  23. kernel_size: int = 9,
  24. patch_size: int = 7,
  25. in_chans: int = 3,
  26. num_classes: int = 1000,
  27. global_pool: str = 'avg',
  28. drop_rate: float = 0.,
  29. act_layer: Type[nn.Module] = nn.GELU,
  30. device=None,
  31. dtype=None,
  32. **kwargs,
  33. ):
  34. super().__init__()
  35. dd = {'device': device, 'dtype': dtype}
  36. self.num_classes = num_classes
  37. self.in_chans = in_chans
  38. self.num_features = self.head_hidden_size = dim
  39. self.grad_checkpointing = False
  40. self.stem = nn.Sequential(
  41. nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size, **dd),
  42. act_layer(),
  43. nn.BatchNorm2d(dim, **dd)
  44. )
  45. self.blocks = nn.Sequential(
  46. *[nn.Sequential(
  47. Residual(nn.Sequential(
  48. nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same", **dd),
  49. act_layer(),
  50. nn.BatchNorm2d(dim, **dd)
  51. )),
  52. nn.Conv2d(dim, dim, kernel_size=1, **dd),
  53. act_layer(),
  54. nn.BatchNorm2d(dim, **dd)
  55. ) for i in range(depth)]
  56. )
  57. self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
  58. self.head_drop = nn.Dropout(drop_rate)
  59. self.head = nn.Linear(dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  60. @torch.jit.ignore
  61. def group_matcher(self, coarse=False):
  62. matcher = dict(stem=r'^stem', blocks=r'^blocks\.(\d+)')
  63. return matcher
  64. @torch.jit.ignore
  65. def set_grad_checkpointing(self, enable=True):
  66. self.grad_checkpointing = enable
  67. @torch.jit.ignore
  68. def get_classifier(self) -> nn.Module:
  69. return self.head
  70. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  71. self.num_classes = num_classes
  72. if global_pool is not None:
  73. self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
  74. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  75. def forward_features(self, x):
  76. x = self.stem(x)
  77. if self.grad_checkpointing and not torch.jit.is_scripting():
  78. x = checkpoint_seq(self.blocks, x)
  79. else:
  80. x = self.blocks(x)
  81. return x
  82. def forward_head(self, x, pre_logits: bool = False):
  83. x = self.pooling(x)
  84. x = self.head_drop(x)
  85. return x if pre_logits else self.head(x)
  86. def forward(self, x):
  87. x = self.forward_features(x)
  88. x = self.forward_head(x)
  89. return x
  90. def _create_convmixer(variant, pretrained=False, **kwargs):
  91. if kwargs.get('features_only', None):
  92. raise RuntimeError('features_only not implemented for ConvMixer models.')
  93. return build_model_with_cfg(ConvMixer, variant, pretrained, **kwargs)
  94. def _cfg(url='', **kwargs):
  95. return {
  96. 'url': url,
  97. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  98. 'crop_pct': .96, 'interpolation': 'bicubic',
  99. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
  100. 'first_conv': 'stem.0', 'license': 'mit',
  101. **kwargs
  102. }
  103. default_cfgs = generate_default_cfgs({
  104. 'convmixer_1536_20.in1k': _cfg(hf_hub_id='timm/'),
  105. 'convmixer_768_32.in1k': _cfg(hf_hub_id='timm/'),
  106. 'convmixer_1024_20_ks9_p14.in1k': _cfg(hf_hub_id='timm/')
  107. })
  108. @register_model
  109. def convmixer_1536_20(pretrained=False, **kwargs) -> ConvMixer:
  110. model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs)
  111. return _create_convmixer('convmixer_1536_20', pretrained, **model_args)
  112. @register_model
  113. def convmixer_768_32(pretrained=False, **kwargs) -> ConvMixer:
  114. model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, act_layer=nn.ReLU, **kwargs)
  115. return _create_convmixer('convmixer_768_32', pretrained, **model_args)
  116. @register_model
  117. def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs) -> ConvMixer:
  118. model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs)
  119. return _create_convmixer('convmixer_1024_20_ks9_p14', pretrained, **model_args)