non_local_attn.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. """ Bilinear-Attention-Transform and Non-Local Attention
  2. Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms`
  3. - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html
  4. Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification
  5. """
  6. from typing import Optional, Type
  7. import torch
  8. from torch import nn
  9. from torch.nn import functional as F
  10. from ._fx import register_notrace_module
  11. from .conv_bn_act import ConvNormAct
  12. from .helpers import make_divisible
  13. from .trace_utils import _assert
  14. class NonLocalAttn(nn.Module):
  15. """Spatial NL block for image classification.
  16. This was adapted from https://github.com/BA-Transform/BAT-Image-Classification
  17. Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net.
  18. """
  19. def __init__(
  20. self,
  21. in_channels,
  22. use_scale=True,
  23. rd_ratio=1/8,
  24. rd_channels=None,
  25. rd_divisor=8,
  26. device=None,
  27. dtype=None,
  28. **_,
  29. ):
  30. dd = {'device': device, 'dtype': dtype}
  31. super().__init__()
  32. if rd_channels is None:
  33. rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
  34. self.scale = in_channels ** -0.5 if use_scale else 1.0
  35. self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True, **dd)
  36. self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True, **dd)
  37. self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True, **dd)
  38. self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True, **dd)
  39. self.norm = nn.BatchNorm2d(in_channels, **dd)
  40. self.reset_parameters()
  41. def forward(self, x):
  42. shortcut = x
  43. t = self.t(x)
  44. p = self.p(x)
  45. g = self.g(x)
  46. B, C, H, W = t.size()
  47. t = t.view(B, C, -1).permute(0, 2, 1)
  48. p = p.view(B, C, -1)
  49. g = g.view(B, C, -1).permute(0, 2, 1)
  50. att = torch.bmm(t, p) * self.scale
  51. att = F.softmax(att, dim=2)
  52. x = torch.bmm(att, g)
  53. x = x.permute(0, 2, 1).reshape(B, C, H, W)
  54. x = self.z(x)
  55. x = self.norm(x) + shortcut
  56. return x
  57. def reset_parameters(self):
  58. for name, m in self.named_modules():
  59. if isinstance(m, nn.Conv2d):
  60. nn.init.kaiming_normal_(
  61. m.weight, mode='fan_out', nonlinearity='relu')
  62. if len(list(m.parameters())) > 1:
  63. nn.init.constant_(m.bias, 0.0)
  64. elif isinstance(m, nn.BatchNorm2d):
  65. nn.init.constant_(m.weight, 0)
  66. nn.init.constant_(m.bias, 0)
  67. elif isinstance(m, nn.GroupNorm):
  68. nn.init.constant_(m.weight, 0)
  69. nn.init.constant_(m.bias, 0)
  70. @register_notrace_module
  71. class BilinearAttnTransform(nn.Module):
  72. def __init__(
  73. self,
  74. in_channels: int,
  75. block_size: int,
  76. groups: int,
  77. act_layer: Type[nn.Module] = nn.ReLU,
  78. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  79. device=None,
  80. dtype=None,
  81. ):
  82. dd = {'device': device, 'dtype': dtype}
  83. super().__init__()
  84. self.conv1 = ConvNormAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer, **dd)
  85. self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1), **dd)
  86. self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size), **dd)
  87. self.conv2 = ConvNormAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer, **dd)
  88. self.block_size = block_size
  89. self.groups = groups
  90. self.in_channels = in_channels
  91. def resize_mat(self, x, t: int):
  92. B, C, block_size, block_size1 = x.shape
  93. _assert(block_size == block_size1, '')
  94. if t <= 1:
  95. return x
  96. x = x.view(B * C, -1, 1, 1)
  97. x = x * torch.eye(t, t, dtype=x.dtype, device=x.device)
  98. x = x.view(B * C, block_size, block_size, t, t)
  99. x = torch.cat(torch.split(x, 1, dim=1), dim=3)
  100. x = torch.cat(torch.split(x, 1, dim=2), dim=4)
  101. x = x.view(B, C, block_size * t, block_size * t)
  102. return x
  103. def forward(self, x):
  104. _assert(x.shape[-1] % self.block_size == 0, '')
  105. _assert(x.shape[-2] % self.block_size == 0, '')
  106. B, C, H, W = x.shape
  107. out = self.conv1(x)
  108. rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
  109. cp = F.adaptive_max_pool2d(out, (1, self.block_size))
  110. p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
  111. q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
  112. p = p / p.sum(dim=3, keepdim=True)
  113. q = q / q.sum(dim=2, keepdim=True)
  114. p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
  115. 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
  116. p = p.view(B, C, self.block_size, self.block_size)
  117. q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
  118. 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
  119. q = q.view(B, C, self.block_size, self.block_size)
  120. p = self.resize_mat(p, H // self.block_size)
  121. q = self.resize_mat(q, W // self.block_size)
  122. y = p.matmul(x)
  123. y = y.matmul(q)
  124. y = self.conv2(y)
  125. return y
  126. class BatNonLocalAttn(nn.Module):
  127. """ BAT
  128. Adapted from: https://github.com/BA-Transform/BAT-Image-Classification
  129. """
  130. def __init__(
  131. self,
  132. in_channels: int,
  133. block_size: int = 7,
  134. groups: int = 2,
  135. rd_ratio: float = 0.25,
  136. rd_channels: Optional[int] = None,
  137. rd_divisor: int = 8,
  138. drop_rate: float = 0.2,
  139. act_layer: Type[nn.Module] = nn.ReLU,
  140. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  141. device=None,
  142. dtype=None,
  143. **_,
  144. ):
  145. dd = {'device': device, 'dtype': dtype}
  146. super().__init__()
  147. if rd_channels is None:
  148. rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
  149. self.conv1 = ConvNormAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer, **dd)
  150. self.ba = BilinearAttnTransform(
  151. rd_channels,
  152. block_size,
  153. groups,
  154. act_layer=act_layer,
  155. norm_layer=norm_layer,
  156. **dd,
  157. )
  158. self.conv2 = ConvNormAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer, **dd)
  159. self.dropout = nn.Dropout2d(p=drop_rate)
  160. def forward(self, x):
  161. xl = self.conv1(x)
  162. y = self.ba(xl)
  163. y = self.conv2(y)
  164. y = self.dropout(y)
  165. return y + x