| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341 |
- """ Coordinate Attention and Variants
- Coordinate Attention decomposes channel attention into two 1D feature encoding processes
- to capture long-range dependencies with precise positional information. This module includes
- the original implementation along with simplified and other variants.
- Papers / References:
- - Coordinate Attention: `Coordinate Attention for Efficient Mobile Network Design` - https://arxiv.org/abs/2103.02907
- - Efficient Local Attention: `Rethinking Local Perception in Lightweight Vision Transformer` - https://arxiv.org/abs/2403.01123
- Hacked together by / Copyright 2025 Ross Wightman
- """
- from typing import Optional, Type, Union
- import torch
- from torch import nn
- from .create_act import create_act_layer
- from .helpers import make_divisible
- from .norm import GroupNorm1
- class CoordAttn(nn.Module):
- def __init__(
- self,
- channels: int,
- rd_ratio: float = 1. / 16,
- rd_channels: Optional[int] = None,
- rd_divisor: int = 8,
- se_factor: float = 2/3,
- bias: bool = False,
- act_layer: Type[nn.Module] = nn.Hardswish,
- norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm2d,
- gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
- has_skip: bool = False,
- device=None,
- dtype=None,
- ):
- """Coordinate Attention module for spatial feature recalibration.
- Introduced in "Coordinate Attention for Efficient Mobile Network Design" (CVPR 2021).
- Decomposes channel attention into two 1D feature encoding processes along the height and
- width axes to capture long-range dependencies with precise positional information.
- Args:
- channels: Number of input channels.
- rd_ratio: Reduction ratio for bottleneck channel calculation.
- rd_channels: Explicit number of bottleneck channels, overrides rd_ratio if set.
- rd_divisor: Divisor for making bottleneck channels divisible.
- se_factor: Applied to rd_ratio for final channel count (keeps params similar to SE).
- bias: Whether to use bias in convolution layers.
- act_layer: Activation module class for bottleneck.
- norm_layer: Normalization module class, None for no normalization.
- gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class.
- has_skip: Whether to add residual skip connection to output.
- device: Device to place tensors on.
- dtype: Data type for tensors.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.has_skip = has_skip
- if not rd_channels:
- rd_channels = make_divisible(channels * rd_ratio * se_factor, rd_divisor, round_limit=0.)
- self.conv1 = nn.Conv2d(channels, rd_channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd)
- self.bn1 = norm_layer(rd_channels, **dd) if norm_layer is not None else nn.Identity()
- self.act = act_layer()
- self.conv_h = nn.Conv2d(rd_channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd)
- self.conv_w = nn.Conv2d(rd_channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd)
- self.gate = create_act_layer(gate_layer)
- def forward(self, x):
- identity = x
- N, C, H, W = x.size()
- # Strip pooling
- x_h = x.mean(3, keepdim=True)
- x_w = x.mean(2, keepdim=True)
- x_w = x_w.transpose(-1, -2)
- y = torch.cat([x_h, x_w], dim=2)
- y = self.conv1(y)
- y = self.bn1(y)
- y = self.act(y)
- x_h, x_w = torch.split(y, [H, W], dim=2)
- x_w = x_w.transpose(-1, -2)
- a_h = self.gate(self.conv_h(x_h))
- a_w = self.gate(self.conv_w(x_w))
- out = identity * a_w * a_h
- if self.has_skip:
- out = out + identity
- return out
- class SimpleCoordAttn(nn.Module):
- """Simplified Coordinate Attention variant.
- Uses
- * linear layers instead of convolutions
- * no norm
- * additive pre-gating re-combination
- for reduced complexity while maintaining the core coordinate attention mechanism
- of separate height and width attention.
- """
- def __init__(
- self,
- channels: int,
- rd_ratio: float = 0.25,
- rd_channels: Optional[int] = None,
- rd_divisor: int = 8,
- se_factor: float = 2 / 3,
- bias: bool = True,
- act_layer: Type[nn.Module] = nn.SiLU,
- gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
- has_skip: bool = False,
- device=None,
- dtype=None,
- ):
- """
- Args:
- channels: Number of input channels.
- rd_ratio: Reduction ratio for bottleneck channel calculation.
- rd_channels: Explicit number of bottleneck channels, overrides rd_ratio if set.
- rd_divisor: Divisor for making bottleneck channels divisible.
- se_factor: Applied to rd_ratio for final channel count (keeps param similar to SE)
- bias: Whether to use bias in linear layers.
- act_layer: Activation module class for bottleneck.
- gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class.
- has_skip: Whether to add residual skip connection to output.
- device: Device to place tensors on.
- dtype: Data type for tensors.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.has_skip = has_skip
- if not rd_channels:
- rd_channels = make_divisible(channels * rd_ratio * se_factor, rd_divisor, round_limit=0.)
- self.fc1 = nn.Linear(channels, rd_channels, bias=bias, **dd)
- self.act = act_layer()
- self.fc_h = nn.Linear(rd_channels, channels, bias=bias, **dd)
- self.fc_w = nn.Linear(rd_channels, channels, bias=bias, **dd)
- self.gate = create_act_layer(gate_layer)
- def forward(self, x):
- identity = x
- # Strip pooling
- x_h = x.mean(dim=3) # (N, C, H)
- x_w = x.mean(dim=2) # (N, C, W)
- # Shared bottleneck projection
- x_h = self.act(self.fc1(x_h.transpose(1, 2))) # (N, H, rd_c)
- x_w = self.act(self.fc1(x_w.transpose(1, 2))) # (N, W, rd_c)
- # Separate attention heads
- a_h = self.fc_h(x_h).transpose(1, 2).unsqueeze(-1) # (N, C, H, 1)
- a_w = self.fc_w(x_w).transpose(1, 2).unsqueeze(-2) # (N, C, 1, W)
- out = identity * self.gate(a_h + a_w)
- if self.has_skip:
- out = out + identity
- return out
- class EfficientLocalAttn(nn.Module):
- """Efficient Local Attention.
- Lightweight alternative to Coordinate Attention that preserves spatial
- information without channel reduction. Uses 1D depthwise convolutions
- and GroupNorm for better generalization.
- Paper: https://arxiv.org/abs/2403.01123
- """
- def __init__(
- self,
- channels: int,
- kernel_size: int = 7,
- bias: bool = False,
- act_layer: Type[nn.Module] = nn.SiLU,
- gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
- norm_layer: Optional[Type[nn.Module]] = GroupNorm1,
- has_skip: bool = False,
- device=None,
- dtype=None,
- ):
- """
- Args:
- channels: Number of input channels.
- kernel_size: Kernel size for 1D depthwise convolutions.
- bias: Whether to use bias in convolution layers.
- act_layer: Activation module class applied after normalization.
- gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class.
- norm_layer: Normalization module class, None for no normalization.
- has_skip: Whether to add residual skip connection to output.
- device: Device to place tensors on.
- dtype: Data type for tensors.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.has_skip = has_skip
- self.conv_h = nn.Conv2d(
- channels, channels,
- kernel_size=(kernel_size, 1),
- stride=1,
- padding=(kernel_size // 2, 0),
- groups=channels,
- bias=bias,
- **dd
- )
- self.conv_w = nn.Conv2d(
- channels, channels,
- kernel_size=(1, kernel_size),
- stride=1,
- padding=(0, kernel_size // 2),
- groups=channels,
- bias=bias,
- **dd
- )
- if norm_layer is not None:
- self.norm_h = norm_layer(channels, **dd)
- self.norm_w = norm_layer(channels, **dd)
- else:
- self.norm_h = nn.Identity()
- self.norm_w = nn.Identity()
- self.act = act_layer()
- self.gate = create_act_layer(gate_layer)
- def forward(self, x):
- identity = x
- # Strip pooling: (N, C, H, W) -> (N, C, H) and (N, C, W)
- x_h = x.mean(dim=3, keepdim=True)
- x_w = x.mean(dim=2, keepdim=True)
- # 1D conv + norm + act
- x_h = self.act(self.norm_h(self.conv_h(x_h))) # (N, C, H, 1)
- x_w = self.act(self.norm_w(self.conv_w(x_w))) # (N, C, 1, W)
- # Generate attention maps
- a_h = self.gate(x_h) # (N, C, H, 1)
- a_w = self.gate(x_w) # (N, C, 1, W)
- out = identity * a_h * a_w
- if self.has_skip:
- out = out + identity
- return out
- class StripAttn(nn.Module):
- """Minimal Strip Attention.
- Lightweight spatial attention using strip pooling with optional learned refinement.
- """
- def __init__(
- self,
- channels: int,
- use_conv: bool = True,
- kernel_size: int = 3,
- bias: bool = False,
- gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
- has_skip: bool = False,
- device=None,
- dtype=None,
- **_,
- ):
- """
- Args:
- channels: Number of input channels.
- use_conv: Whether to apply depthwise convolutions for learned spatial refinement.
- kernel_size: Kernel size for 1D depthwise convolutions when use_conv is True.
- bias: Whether to use bias in convolution layers.
- gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class.
- has_skip: Whether to add residual skip connection to output.
- device: Device to place tensors on.
- dtype: Data type for tensors.
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.has_skip = has_skip
- self.use_conv = use_conv
- if use_conv:
- self.conv_h = nn.Conv2d(
- channels, channels,
- kernel_size=(kernel_size, 1),
- stride=1,
- padding=(kernel_size // 2, 0),
- groups=channels,
- bias=bias,
- **dd
- )
- self.conv_w = nn.Conv2d(
- channels, channels,
- kernel_size=(1, kernel_size),
- stride=1,
- padding=(0, kernel_size // 2),
- groups=channels,
- bias=bias,
- **dd
- )
- else:
- self.conv_h = nn.Identity()
- self.conv_w = nn.Identity()
- self.gate = create_act_layer(gate_layer)
- def forward(self, x):
- identity = x
- # Strip pooling
- x_h = x.mean(dim=3, keepdim=True) # (N, C, H, 1)
- x_w = x.mean(dim=2, keepdim=True) # (N, C, 1, W)
- # Optional learned refinement
- x_h = self.conv_h(x_h)
- x_w = self.conv_w(x_w)
- # Combine and gate
- a_hw = self.gate(x_h + x_w) # broadcasts to (N, C, H, W)
- out = identity * a_hw
- if self.has_skip:
- out = out + identity
- return out
|