| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- """ Attention Factory
- Hacked together by / Copyright 2021 Ross Wightman
- """
- import torch
- from functools import partial
- from .bottleneck_attn import BottleneckAttn
- from .cbam import CbamModule, LightCbamModule
- from .coord_attn import CoordAttn, EfficientLocalAttn, StripAttn, SimpleCoordAttn
- from .eca import EcaModule, CecaModule
- from .gather_excite import GatherExcite
- from .global_context import GlobalContext
- from .halo_attn import HaloAttn
- from .lambda_layer import LambdaLayer
- from .non_local_attn import NonLocalAttn, BatNonLocalAttn
- from .selective_kernel import SelectiveKernel
- from .split_attn import SplitAttn
- from .squeeze_excite import SEModule, EffectiveSEModule
- def get_attn(attn_type):
- if isinstance(attn_type, torch.nn.Module):
- return attn_type
- module_cls = None
- if attn_type:
- if isinstance(attn_type, str):
- attn_type = attn_type.lower()
- # Lightweight attention modules (channel and/or coarse spatial).
- # Typically added to existing network architecture blocks in addition to existing convolutions.
- if attn_type == 'se':
- module_cls = SEModule
- elif attn_type == 'ese':
- module_cls = EffectiveSEModule
- elif attn_type == 'eca':
- module_cls = EcaModule
- elif attn_type == 'ecam':
- module_cls = partial(EcaModule, use_mlp=True)
- elif attn_type == 'ceca':
- module_cls = CecaModule
- elif attn_type == 'ge':
- module_cls = GatherExcite
- elif attn_type == 'gc':
- module_cls = GlobalContext
- elif attn_type == 'gca':
- module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)
- elif attn_type == 'cbam':
- module_cls = CbamModule
- elif attn_type == 'lcbam':
- module_cls = LightCbamModule
- elif attn_type == 'coord':
- module_cls = CoordAttn
- elif attn_type == 'scoord':
- module_cls = SimpleCoordAttn
- elif attn_type == 'ela':
- module_cls = EfficientLocalAttn
- elif attn_type == 'strip':
- module_cls = StripAttn
- # Attention / attention-like modules w/ significant params
- # Typically replace some of the existing workhorse convs in a network architecture.
- # All of these accept a stride argument and can spatially downsample the input.
- elif attn_type == 'sk':
- module_cls = SelectiveKernel
- elif attn_type == 'splat':
- module_cls = SplitAttn
- # Self-attention / attention-like modules w/ significant compute and/or params
- # Typically replace some of the existing workhorse convs in a network architecture.
- # All of these accept a stride argument and can spatially downsample the input.
- elif attn_type == 'lambda':
- return LambdaLayer
- elif attn_type == 'bottleneck':
- return BottleneckAttn
- elif attn_type == 'halo':
- return HaloAttn
- elif attn_type == 'nl':
- module_cls = NonLocalAttn
- elif attn_type == 'bat':
- module_cls = BatNonLocalAttn
- # Woops!
- else:
- assert False, "Invalid attn module (%s)" % attn_type
- elif isinstance(attn_type, bool):
- if attn_type:
- module_cls = SEModule
- else:
- module_cls = attn_type
- return module_cls
- def create_attn(attn_type, channels, **kwargs):
- module_cls = get_attn(attn_type)
- if module_cls is not None:
- # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
- return module_cls(channels, **kwargs)
- return None
|