cbam.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. """ CBAM (sort-of) Attention
  2. Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521
  3. WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on
  4. some tasks, especially fine-grained it seems. I may end up removing this impl.
  5. Hacked together by / Copyright 2020 Ross Wightman
  6. """
  7. from typing import Optional, Tuple, Type, Union
  8. import torch
  9. from torch import nn as nn
  10. import torch.nn.functional as F
  11. from .conv_bn_act import ConvNormAct
  12. from .create_act import create_act_layer, get_act_layer
  13. from .helpers import make_divisible
  14. class ChannelAttn(nn.Module):
  15. """ Original CBAM channel attention module, currently avg + max pool variant only.
  16. """
  17. def __init__(
  18. self,
  19. channels: int,
  20. rd_ratio: float = 1. / 16,
  21. rd_channels: Optional[int] = None,
  22. rd_divisor: int = 1,
  23. act_layer: Type[nn.Module] = nn.ReLU,
  24. gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
  25. mlp_bias=False,
  26. device=None,
  27. dtype=None,
  28. ):
  29. dd = {'device': device, 'dtype': dtype}
  30. super().__init__()
  31. if not rd_channels:
  32. rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
  33. self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias, **dd)
  34. self.act = act_layer(inplace=True)
  35. self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias, **dd)
  36. self.gate = create_act_layer(gate_layer)
  37. def forward(self, x):
  38. x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True))))
  39. x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True))))
  40. return x * self.gate(x_avg + x_max)
  41. class LightChannelAttn(ChannelAttn):
  42. """An experimental 'lightweight' that sums avg + max pool first
  43. """
  44. def __init__(
  45. self,
  46. channels: int,
  47. rd_ratio: float = 1./16,
  48. rd_channels: Optional[int] = None,
  49. rd_divisor: int = 1,
  50. act_layer: Type[nn.Module] = nn.ReLU,
  51. gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
  52. mlp_bias: bool = False,
  53. device=None,
  54. dtype=None
  55. ):
  56. super().__init__(
  57. channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias, device=device, dtype=dtype)
  58. def forward(self, x):
  59. x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True)
  60. x_attn = self.fc2(self.act(self.fc1(x_pool)))
  61. return x * F.sigmoid(x_attn)
  62. class SpatialAttn(nn.Module):
  63. """ Original CBAM spatial attention module
  64. """
  65. def __init__(
  66. self,
  67. kernel_size: int = 7,
  68. gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
  69. device=None,
  70. dtype=None,
  71. ):
  72. super().__init__()
  73. self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False, device=device, dtype=dtype)
  74. self.gate = create_act_layer(gate_layer)
  75. def forward(self, x):
  76. x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1)
  77. x_attn = self.conv(x_attn)
  78. return x * self.gate(x_attn)
  79. class LightSpatialAttn(nn.Module):
  80. """An experimental 'lightweight' variant that sums avg_pool and max_pool results.
  81. """
  82. def __init__(
  83. self,
  84. kernel_size: int = 7,
  85. gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
  86. device=None,
  87. dtype=None,
  88. ):
  89. super().__init__()
  90. self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False, device=device, dtype=dtype)
  91. self.gate = create_act_layer(gate_layer)
  92. def forward(self, x):
  93. x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True)
  94. x_attn = self.conv(x_attn)
  95. return x * self.gate(x_attn)
  96. class CbamModule(nn.Module):
  97. def __init__(
  98. self,
  99. channels: int,
  100. rd_ratio: float = 1./16,
  101. rd_channels: Optional[int] = None,
  102. rd_divisor: int = 1,
  103. spatial_kernel_size: int = 7,
  104. act_layer: Type[nn.Module] = nn.ReLU,
  105. gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
  106. mlp_bias: bool = False,
  107. device=None,
  108. dtype=None,
  109. ):
  110. dd = {'device': device, 'dtype': dtype}
  111. super().__init__()
  112. self.channel = ChannelAttn(
  113. channels,
  114. rd_ratio=rd_ratio,
  115. rd_channels=rd_channels,
  116. rd_divisor=rd_divisor,
  117. act_layer=act_layer,
  118. gate_layer=gate_layer,
  119. mlp_bias=mlp_bias,
  120. **dd,
  121. )
  122. self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer, **dd)
  123. def forward(self, x):
  124. x = self.channel(x)
  125. x = self.spatial(x)
  126. return x
  127. class LightCbamModule(nn.Module):
  128. def __init__(
  129. self,
  130. channels: int,
  131. rd_ratio: float = 1./16,
  132. rd_channels: Optional[int] = None,
  133. rd_divisor: int = 1,
  134. spatial_kernel_size: int = 7,
  135. act_layer: Type[nn.Module] = nn.ReLU,
  136. gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
  137. mlp_bias: bool = False,
  138. device=None,
  139. dtype=None,
  140. ):
  141. dd = {'device': device, 'dtype': dtype}
  142. super().__init__()
  143. self.channel = LightChannelAttn(
  144. channels,
  145. rd_ratio=rd_ratio,
  146. rd_channels=rd_channels,
  147. rd_divisor=rd_divisor,
  148. act_layer=act_layer,
  149. gate_layer=gate_layer,
  150. mlp_bias=mlp_bias,
  151. **dd,
  152. )
  153. self.spatial = LightSpatialAttn(spatial_kernel_size, **dd)
  154. def forward(self, x):
  155. x = self.channel(x)
  156. x = self.spatial(x)
  157. return x