create_attn.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. """ Attention Factory
  2. Hacked together by / Copyright 2021 Ross Wightman
  3. """
  4. import torch
  5. from functools import partial
  6. from .bottleneck_attn import BottleneckAttn
  7. from .cbam import CbamModule, LightCbamModule
  8. from .coord_attn import CoordAttn, EfficientLocalAttn, StripAttn, SimpleCoordAttn
  9. from .eca import EcaModule, CecaModule
  10. from .gather_excite import GatherExcite
  11. from .global_context import GlobalContext
  12. from .halo_attn import HaloAttn
  13. from .lambda_layer import LambdaLayer
  14. from .non_local_attn import NonLocalAttn, BatNonLocalAttn
  15. from .selective_kernel import SelectiveKernel
  16. from .split_attn import SplitAttn
  17. from .squeeze_excite import SEModule, EffectiveSEModule
  18. def get_attn(attn_type):
  19. if isinstance(attn_type, torch.nn.Module):
  20. return attn_type
  21. module_cls = None
  22. if attn_type:
  23. if isinstance(attn_type, str):
  24. attn_type = attn_type.lower()
  25. # Lightweight attention modules (channel and/or coarse spatial).
  26. # Typically added to existing network architecture blocks in addition to existing convolutions.
  27. if attn_type == 'se':
  28. module_cls = SEModule
  29. elif attn_type == 'ese':
  30. module_cls = EffectiveSEModule
  31. elif attn_type == 'eca':
  32. module_cls = EcaModule
  33. elif attn_type == 'ecam':
  34. module_cls = partial(EcaModule, use_mlp=True)
  35. elif attn_type == 'ceca':
  36. module_cls = CecaModule
  37. elif attn_type == 'ge':
  38. module_cls = GatherExcite
  39. elif attn_type == 'gc':
  40. module_cls = GlobalContext
  41. elif attn_type == 'gca':
  42. module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)
  43. elif attn_type == 'cbam':
  44. module_cls = CbamModule
  45. elif attn_type == 'lcbam':
  46. module_cls = LightCbamModule
  47. elif attn_type == 'coord':
  48. module_cls = CoordAttn
  49. elif attn_type == 'scoord':
  50. module_cls = SimpleCoordAttn
  51. elif attn_type == 'ela':
  52. module_cls = EfficientLocalAttn
  53. elif attn_type == 'strip':
  54. module_cls = StripAttn
  55. # Attention / attention-like modules w/ significant params
  56. # Typically replace some of the existing workhorse convs in a network architecture.
  57. # All of these accept a stride argument and can spatially downsample the input.
  58. elif attn_type == 'sk':
  59. module_cls = SelectiveKernel
  60. elif attn_type == 'splat':
  61. module_cls = SplitAttn
  62. # Self-attention / attention-like modules w/ significant compute and/or params
  63. # Typically replace some of the existing workhorse convs in a network architecture.
  64. # All of these accept a stride argument and can spatially downsample the input.
  65. elif attn_type == 'lambda':
  66. return LambdaLayer
  67. elif attn_type == 'bottleneck':
  68. return BottleneckAttn
  69. elif attn_type == 'halo':
  70. return HaloAttn
  71. elif attn_type == 'nl':
  72. module_cls = NonLocalAttn
  73. elif attn_type == 'bat':
  74. module_cls = BatNonLocalAttn
  75. # Woops!
  76. else:
  77. assert False, "Invalid attn module (%s)" % attn_type
  78. elif isinstance(attn_type, bool):
  79. if attn_type:
  80. module_cls = SEModule
  81. else:
  82. module_cls = attn_type
  83. return module_cls
  84. def create_attn(attn_type, channels, **kwargs):
  85. module_cls = get_attn(attn_type)
  86. if module_cls is not None:
  87. # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
  88. return module_cls(channels, **kwargs)
  89. return None