split_attn.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. """ Split Attention Conv2d (for ResNeSt Models)
  2. Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
  3. Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
  4. Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
  5. """
  6. from typing import Optional, Type, Union
  7. import torch
  8. import torch.nn.functional as F
  9. from torch import nn
  10. from .helpers import make_divisible
  11. class RadixSoftmax(nn.Module):
  12. def __init__(self, radix: int, cardinality: int):
  13. super().__init__()
  14. self.radix = radix
  15. self.cardinality = cardinality
  16. def forward(self, x):
  17. batch = x.size(0)
  18. if self.radix > 1:
  19. x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
  20. x = F.softmax(x, dim=1)
  21. x = x.reshape(batch, -1)
  22. else:
  23. x = torch.sigmoid(x)
  24. return x
  25. class SplitAttn(nn.Module):
  26. """Split-Attention (aka Splat)
  27. """
  28. def __init__(
  29. self,
  30. in_channels: int,
  31. out_channels: Optional[int] = None,
  32. kernel_size: int = 3,
  33. stride: int = 1,
  34. padding: Optional[int] = None,
  35. dilation: int = 1,
  36. groups: int = 1,
  37. bias: bool = False,
  38. radix: int = 2,
  39. rd_ratio: float = 0.25,
  40. rd_channels: Optional[int] = None,
  41. rd_divisor: int = 8,
  42. act_layer: Type[nn.Module] = nn.ReLU,
  43. norm_layer: Optional[Type[nn.Module]] = None,
  44. drop_layer: Optional[Type[nn.Module]] = None,
  45. **kwargs,
  46. ):
  47. dd = {'device': kwargs.pop('device', None), 'dtype': kwargs.pop('dtype', None)}
  48. super().__init__()
  49. out_channels = out_channels or in_channels
  50. self.radix = radix
  51. mid_chs = out_channels * radix
  52. if rd_channels is None:
  53. attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
  54. else:
  55. attn_chs = rd_channels * radix
  56. padding = kernel_size // 2 if padding is None else padding
  57. self.conv = nn.Conv2d(
  58. in_channels,
  59. mid_chs,
  60. kernel_size,
  61. stride,
  62. padding,
  63. dilation,
  64. groups=groups * radix,
  65. bias=bias,
  66. **kwargs,
  67. **dd,
  68. )
  69. self.bn0 = norm_layer(mid_chs, **dd) if norm_layer else nn.Identity()
  70. self.drop = drop_layer() if drop_layer is not None else nn.Identity()
  71. self.act0 = act_layer(inplace=True)
  72. self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups, **dd)
  73. self.bn1 = norm_layer(attn_chs, **dd) if norm_layer else nn.Identity()
  74. self.act1 = act_layer(inplace=True)
  75. self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups, **dd)
  76. self.rsoftmax = RadixSoftmax(radix, groups)
  77. def forward(self, x):
  78. x = self.conv(x)
  79. x = self.bn0(x)
  80. x = self.drop(x)
  81. x = self.act0(x)
  82. B, RC, H, W = x.shape
  83. if self.radix > 1:
  84. x = x.reshape((B, self.radix, RC // self.radix, H, W))
  85. x_gap = x.sum(dim=1)
  86. else:
  87. x_gap = x
  88. x_gap = x_gap.mean((2, 3), keepdim=True)
  89. x_gap = self.fc1(x_gap)
  90. x_gap = self.bn1(x_gap)
  91. x_gap = self.act1(x_gap)
  92. x_attn = self.fc2(x_gap)
  93. x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
  94. if self.radix > 1:
  95. out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
  96. else:
  97. out = x * x_attn
  98. return out.contiguous()