hardcorenas.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from functools import partial
  2. import torch.nn as nn
  3. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  4. from ._builder import build_model_with_cfg
  5. from ._builder import pretrained_cfg_for_features
  6. from ._efficientnet_blocks import SqueezeExcite
  7. from ._efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels
  8. from ._registry import register_model, generate_default_cfgs
  9. from .mobilenetv3 import MobileNetV3, MobileNetV3Features
  10. __all__ = [] # model_registry will add each entrypoint fn to this
  11. def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
  12. """Creates a hardcorenas model
  13. Ref impl: https://github.com/Alibaba-MIIL/HardCoReNAS
  14. Paper: https://arxiv.org/abs/2102.11646
  15. """
  16. num_features = 1280
  17. se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
  18. model_kwargs = dict(
  19. block_args=decode_arch_def(arch_def),
  20. num_features=num_features,
  21. stem_size=32,
  22. norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  23. act_layer=resolve_act_layer(kwargs, 'hard_swish'),
  24. se_layer=se_layer,
  25. **kwargs,
  26. )
  27. features_only = False
  28. model_cls = MobileNetV3
  29. kwargs_filter = None
  30. if model_kwargs.pop('features_only', False):
  31. features_only = True
  32. kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool')
  33. model_cls = MobileNetV3Features
  34. model = build_model_with_cfg(
  35. model_cls,
  36. variant,
  37. pretrained,
  38. pretrained_strict=not features_only,
  39. kwargs_filter=kwargs_filter,
  40. **model_kwargs,
  41. )
  42. if features_only:
  43. model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
  44. return model
  45. def _cfg(url='', **kwargs):
  46. return {
  47. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  48. 'crop_pct': 0.875, 'interpolation': 'bilinear',
  49. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  50. 'first_conv': 'conv_stem', 'classifier': 'classifier',
  51. 'license': 'apache-2.0',
  52. **kwargs
  53. }
  54. default_cfgs = generate_default_cfgs({
  55. 'hardcorenas_a.miil_green_in1k': _cfg(hf_hub_id='timm/'),
  56. 'hardcorenas_b.miil_green_in1k': _cfg(hf_hub_id='timm/'),
  57. 'hardcorenas_c.miil_green_in1k': _cfg(hf_hub_id='timm/'),
  58. 'hardcorenas_d.miil_green_in1k': _cfg(hf_hub_id='timm/'),
  59. 'hardcorenas_e.miil_green_in1k': _cfg(hf_hub_id='timm/'),
  60. 'hardcorenas_f.miil_green_in1k': _cfg(hf_hub_id='timm/'),
  61. })
  62. @register_model
  63. def hardcorenas_a(pretrained=False, **kwargs) -> MobileNetV3:
  64. """ hardcorenas_A """
  65. arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
  66. ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e6_c40_nre_se0.25'],
  67. ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25'],
  68. ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25'],
  69. ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
  70. model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_a', arch_def=arch_def, **kwargs)
  71. return model
  72. @register_model
  73. def hardcorenas_b(pretrained=False, **kwargs) -> MobileNetV3:
  74. """ hardcorenas_B """
  75. arch_def = [['ds_r1_k3_s1_e1_c16_nre'],
  76. ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25', 'ir_r1_k3_s1_e3_c24_nre'],
  77. ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre'],
  78. ['ir_r1_k5_s2_e3_c80', 'ir_r1_k5_s1_e3_c80', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'],
  79. ['ir_r1_k5_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'],
  80. ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'],
  81. ['cn_r1_k1_s1_c960']]
  82. model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_b', arch_def=arch_def, **kwargs)
  83. return model
  84. @register_model
  85. def hardcorenas_c(pretrained=False, **kwargs) -> MobileNetV3:
  86. """ hardcorenas_C """
  87. arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
  88. ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre',
  89. 'ir_r1_k5_s1_e3_c40_nre'],
  90. ['ir_r1_k5_s2_e4_c80', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'],
  91. ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'],
  92. ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'],
  93. ['cn_r1_k1_s1_c960']]
  94. model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_c', arch_def=arch_def, **kwargs)
  95. return model
  96. @register_model
  97. def hardcorenas_d(pretrained=False, **kwargs) -> MobileNetV3:
  98. """ hardcorenas_D """
  99. arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
  100. ['ir_r1_k5_s2_e3_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k3_s1_e3_c40_nre_se0.25'],
  101. ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25',
  102. 'ir_r1_k3_s1_e3_c80_se0.25'],
  103. ['ir_r1_k3_s1_e4_c112_se0.25', 'ir_r1_k5_s1_e4_c112_se0.25', 'ir_r1_k3_s1_e3_c112_se0.25',
  104. 'ir_r1_k5_s1_e3_c112_se0.25'],
  105. ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
  106. 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
  107. model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_d', arch_def=arch_def, **kwargs)
  108. return model
  109. @register_model
  110. def hardcorenas_e(pretrained=False, **kwargs) -> MobileNetV3:
  111. """ hardcorenas_E """
  112. arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
  113. ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25',
  114. 'ir_r1_k3_s1_e3_c40_nre_se0.25'], ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e6_c80_se0.25'],
  115. ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25',
  116. 'ir_r1_k5_s1_e3_c112_se0.25'],
  117. ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
  118. 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
  119. model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_e', arch_def=arch_def, **kwargs)
  120. return model
  121. @register_model
  122. def hardcorenas_f(pretrained=False, **kwargs) -> MobileNetV3:
  123. """ hardcorenas_F """
  124. arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
  125. ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e6_c40_nre_se0.25'],
  126. ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25',
  127. 'ir_r1_k3_s1_e3_c80_se0.25'],
  128. ['ir_r1_k3_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25',
  129. 'ir_r1_k3_s1_e3_c112_se0.25'],
  130. ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25',
  131. 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
  132. model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_f', arch_def=arch_def, **kwargs)
  133. return model