coord_attn.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. """ Coordinate Attention and Variants
  2. Coordinate Attention decomposes channel attention into two 1D feature encoding processes
  3. to capture long-range dependencies with precise positional information. This module includes
  4. the original implementation along with simplified and other variants.
  5. Papers / References:
  6. - Coordinate Attention: `Coordinate Attention for Efficient Mobile Network Design` - https://arxiv.org/abs/2103.02907
  7. - Efficient Local Attention: `Rethinking Local Perception in Lightweight Vision Transformer` - https://arxiv.org/abs/2403.01123
  8. Hacked together by / Copyright 2025 Ross Wightman
  9. """
  10. from typing import Optional, Type, Union
  11. import torch
  12. from torch import nn
  13. from .create_act import create_act_layer
  14. from .helpers import make_divisible
  15. from .norm import GroupNorm1
  16. class CoordAttn(nn.Module):
  17. def __init__(
  18. self,
  19. channels: int,
  20. rd_ratio: float = 1. / 16,
  21. rd_channels: Optional[int] = None,
  22. rd_divisor: int = 8,
  23. se_factor: float = 2/3,
  24. bias: bool = False,
  25. act_layer: Type[nn.Module] = nn.Hardswish,
  26. norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm2d,
  27. gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
  28. has_skip: bool = False,
  29. device=None,
  30. dtype=None,
  31. ):
  32. """Coordinate Attention module for spatial feature recalibration.
  33. Introduced in "Coordinate Attention for Efficient Mobile Network Design" (CVPR 2021).
  34. Decomposes channel attention into two 1D feature encoding processes along the height and
  35. width axes to capture long-range dependencies with precise positional information.
  36. Args:
  37. channels: Number of input channels.
  38. rd_ratio: Reduction ratio for bottleneck channel calculation.
  39. rd_channels: Explicit number of bottleneck channels, overrides rd_ratio if set.
  40. rd_divisor: Divisor for making bottleneck channels divisible.
  41. se_factor: Applied to rd_ratio for final channel count (keeps params similar to SE).
  42. bias: Whether to use bias in convolution layers.
  43. act_layer: Activation module class for bottleneck.
  44. norm_layer: Normalization module class, None for no normalization.
  45. gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class.
  46. has_skip: Whether to add residual skip connection to output.
  47. device: Device to place tensors on.
  48. dtype: Data type for tensors.
  49. """
  50. dd = {'device': device, 'dtype': dtype}
  51. super().__init__()
  52. self.has_skip = has_skip
  53. if not rd_channels:
  54. rd_channels = make_divisible(channels * rd_ratio * se_factor, rd_divisor, round_limit=0.)
  55. self.conv1 = nn.Conv2d(channels, rd_channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd)
  56. self.bn1 = norm_layer(rd_channels, **dd) if norm_layer is not None else nn.Identity()
  57. self.act = act_layer()
  58. self.conv_h = nn.Conv2d(rd_channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd)
  59. self.conv_w = nn.Conv2d(rd_channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd)
  60. self.gate = create_act_layer(gate_layer)
  61. def forward(self, x):
  62. identity = x
  63. N, C, H, W = x.size()
  64. # Strip pooling
  65. x_h = x.mean(3, keepdim=True)
  66. x_w = x.mean(2, keepdim=True)
  67. x_w = x_w.transpose(-1, -2)
  68. y = torch.cat([x_h, x_w], dim=2)
  69. y = self.conv1(y)
  70. y = self.bn1(y)
  71. y = self.act(y)
  72. x_h, x_w = torch.split(y, [H, W], dim=2)
  73. x_w = x_w.transpose(-1, -2)
  74. a_h = self.gate(self.conv_h(x_h))
  75. a_w = self.gate(self.conv_w(x_w))
  76. out = identity * a_w * a_h
  77. if self.has_skip:
  78. out = out + identity
  79. return out
  80. class SimpleCoordAttn(nn.Module):
  81. """Simplified Coordinate Attention variant.
  82. Uses
  83. * linear layers instead of convolutions
  84. * no norm
  85. * additive pre-gating re-combination
  86. for reduced complexity while maintaining the core coordinate attention mechanism
  87. of separate height and width attention.
  88. """
  89. def __init__(
  90. self,
  91. channels: int,
  92. rd_ratio: float = 0.25,
  93. rd_channels: Optional[int] = None,
  94. rd_divisor: int = 8,
  95. se_factor: float = 2 / 3,
  96. bias: bool = True,
  97. act_layer: Type[nn.Module] = nn.SiLU,
  98. gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
  99. has_skip: bool = False,
  100. device=None,
  101. dtype=None,
  102. ):
  103. """
  104. Args:
  105. channels: Number of input channels.
  106. rd_ratio: Reduction ratio for bottleneck channel calculation.
  107. rd_channels: Explicit number of bottleneck channels, overrides rd_ratio if set.
  108. rd_divisor: Divisor for making bottleneck channels divisible.
  109. se_factor: Applied to rd_ratio for final channel count (keeps param similar to SE)
  110. bias: Whether to use bias in linear layers.
  111. act_layer: Activation module class for bottleneck.
  112. gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class.
  113. has_skip: Whether to add residual skip connection to output.
  114. device: Device to place tensors on.
  115. dtype: Data type for tensors.
  116. """
  117. dd = {'device': device, 'dtype': dtype}
  118. super().__init__()
  119. self.has_skip = has_skip
  120. if not rd_channels:
  121. rd_channels = make_divisible(channels * rd_ratio * se_factor, rd_divisor, round_limit=0.)
  122. self.fc1 = nn.Linear(channels, rd_channels, bias=bias, **dd)
  123. self.act = act_layer()
  124. self.fc_h = nn.Linear(rd_channels, channels, bias=bias, **dd)
  125. self.fc_w = nn.Linear(rd_channels, channels, bias=bias, **dd)
  126. self.gate = create_act_layer(gate_layer)
  127. def forward(self, x):
  128. identity = x
  129. # Strip pooling
  130. x_h = x.mean(dim=3) # (N, C, H)
  131. x_w = x.mean(dim=2) # (N, C, W)
  132. # Shared bottleneck projection
  133. x_h = self.act(self.fc1(x_h.transpose(1, 2))) # (N, H, rd_c)
  134. x_w = self.act(self.fc1(x_w.transpose(1, 2))) # (N, W, rd_c)
  135. # Separate attention heads
  136. a_h = self.fc_h(x_h).transpose(1, 2).unsqueeze(-1) # (N, C, H, 1)
  137. a_w = self.fc_w(x_w).transpose(1, 2).unsqueeze(-2) # (N, C, 1, W)
  138. out = identity * self.gate(a_h + a_w)
  139. if self.has_skip:
  140. out = out + identity
  141. return out
  142. class EfficientLocalAttn(nn.Module):
  143. """Efficient Local Attention.
  144. Lightweight alternative to Coordinate Attention that preserves spatial
  145. information without channel reduction. Uses 1D depthwise convolutions
  146. and GroupNorm for better generalization.
  147. Paper: https://arxiv.org/abs/2403.01123
  148. """
  149. def __init__(
  150. self,
  151. channels: int,
  152. kernel_size: int = 7,
  153. bias: bool = False,
  154. act_layer: Type[nn.Module] = nn.SiLU,
  155. gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
  156. norm_layer: Optional[Type[nn.Module]] = GroupNorm1,
  157. has_skip: bool = False,
  158. device=None,
  159. dtype=None,
  160. ):
  161. """
  162. Args:
  163. channels: Number of input channels.
  164. kernel_size: Kernel size for 1D depthwise convolutions.
  165. bias: Whether to use bias in convolution layers.
  166. act_layer: Activation module class applied after normalization.
  167. gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class.
  168. norm_layer: Normalization module class, None for no normalization.
  169. has_skip: Whether to add residual skip connection to output.
  170. device: Device to place tensors on.
  171. dtype: Data type for tensors.
  172. """
  173. dd = {'device': device, 'dtype': dtype}
  174. super().__init__()
  175. self.has_skip = has_skip
  176. self.conv_h = nn.Conv2d(
  177. channels, channels,
  178. kernel_size=(kernel_size, 1),
  179. stride=1,
  180. padding=(kernel_size // 2, 0),
  181. groups=channels,
  182. bias=bias,
  183. **dd
  184. )
  185. self.conv_w = nn.Conv2d(
  186. channels, channels,
  187. kernel_size=(1, kernel_size),
  188. stride=1,
  189. padding=(0, kernel_size // 2),
  190. groups=channels,
  191. bias=bias,
  192. **dd
  193. )
  194. if norm_layer is not None:
  195. self.norm_h = norm_layer(channels, **dd)
  196. self.norm_w = norm_layer(channels, **dd)
  197. else:
  198. self.norm_h = nn.Identity()
  199. self.norm_w = nn.Identity()
  200. self.act = act_layer()
  201. self.gate = create_act_layer(gate_layer)
  202. def forward(self, x):
  203. identity = x
  204. # Strip pooling: (N, C, H, W) -> (N, C, H) and (N, C, W)
  205. x_h = x.mean(dim=3, keepdim=True)
  206. x_w = x.mean(dim=2, keepdim=True)
  207. # 1D conv + norm + act
  208. x_h = self.act(self.norm_h(self.conv_h(x_h))) # (N, C, H, 1)
  209. x_w = self.act(self.norm_w(self.conv_w(x_w))) # (N, C, 1, W)
  210. # Generate attention maps
  211. a_h = self.gate(x_h) # (N, C, H, 1)
  212. a_w = self.gate(x_w) # (N, C, 1, W)
  213. out = identity * a_h * a_w
  214. if self.has_skip:
  215. out = out + identity
  216. return out
  217. class StripAttn(nn.Module):
  218. """Minimal Strip Attention.
  219. Lightweight spatial attention using strip pooling with optional learned refinement.
  220. """
  221. def __init__(
  222. self,
  223. channels: int,
  224. use_conv: bool = True,
  225. kernel_size: int = 3,
  226. bias: bool = False,
  227. gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
  228. has_skip: bool = False,
  229. device=None,
  230. dtype=None,
  231. **_,
  232. ):
  233. """
  234. Args:
  235. channels: Number of input channels.
  236. use_conv: Whether to apply depthwise convolutions for learned spatial refinement.
  237. kernel_size: Kernel size for 1D depthwise convolutions when use_conv is True.
  238. bias: Whether to use bias in convolution layers.
  239. gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class.
  240. has_skip: Whether to add residual skip connection to output.
  241. device: Device to place tensors on.
  242. dtype: Data type for tensors.
  243. """
  244. dd = {'device': device, 'dtype': dtype}
  245. super().__init__()
  246. self.has_skip = has_skip
  247. self.use_conv = use_conv
  248. if use_conv:
  249. self.conv_h = nn.Conv2d(
  250. channels, channels,
  251. kernel_size=(kernel_size, 1),
  252. stride=1,
  253. padding=(kernel_size // 2, 0),
  254. groups=channels,
  255. bias=bias,
  256. **dd
  257. )
  258. self.conv_w = nn.Conv2d(
  259. channels, channels,
  260. kernel_size=(1, kernel_size),
  261. stride=1,
  262. padding=(0, kernel_size // 2),
  263. groups=channels,
  264. bias=bias,
  265. **dd
  266. )
  267. else:
  268. self.conv_h = nn.Identity()
  269. self.conv_w = nn.Identity()
  270. self.gate = create_act_layer(gate_layer)
  271. def forward(self, x):
  272. identity = x
  273. # Strip pooling
  274. x_h = x.mean(dim=3, keepdim=True) # (N, C, H, 1)
  275. x_w = x.mean(dim=2, keepdim=True) # (N, C, 1, W)
  276. # Optional learned refinement
  277. x_h = self.conv_h(x_h)
  278. x_w = self.conv_w(x_w)
  279. # Combine and gate
  280. a_hw = self.gate(x_h + x_w) # broadcasts to (N, C, H, W)
  281. out = identity * a_hw
  282. if self.has_skip:
  283. out = out + identity
  284. return out