| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- """ Split Attention Conv2d (for ResNeSt Models)
- Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
- Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
- Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
- """
- from typing import Optional, Type, Union
- import torch
- import torch.nn.functional as F
- from torch import nn
- from .helpers import make_divisible
- class RadixSoftmax(nn.Module):
- def __init__(self, radix: int, cardinality: int):
- super().__init__()
- self.radix = radix
- self.cardinality = cardinality
- def forward(self, x):
- batch = x.size(0)
- if self.radix > 1:
- x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
- x = F.softmax(x, dim=1)
- x = x.reshape(batch, -1)
- else:
- x = torch.sigmoid(x)
- return x
- class SplitAttn(nn.Module):
- """Split-Attention (aka Splat)
- """
- def __init__(
- self,
- in_channels: int,
- out_channels: Optional[int] = None,
- kernel_size: int = 3,
- stride: int = 1,
- padding: Optional[int] = None,
- dilation: int = 1,
- groups: int = 1,
- bias: bool = False,
- radix: int = 2,
- rd_ratio: float = 0.25,
- rd_channels: Optional[int] = None,
- rd_divisor: int = 8,
- act_layer: Type[nn.Module] = nn.ReLU,
- norm_layer: Optional[Type[nn.Module]] = None,
- drop_layer: Optional[Type[nn.Module]] = None,
- **kwargs,
- ):
- dd = {'device': kwargs.pop('device', None), 'dtype': kwargs.pop('dtype', None)}
- super().__init__()
- out_channels = out_channels or in_channels
- self.radix = radix
- mid_chs = out_channels * radix
- if rd_channels is None:
- attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
- else:
- attn_chs = rd_channels * radix
- padding = kernel_size // 2 if padding is None else padding
- self.conv = nn.Conv2d(
- in_channels,
- mid_chs,
- kernel_size,
- stride,
- padding,
- dilation,
- groups=groups * radix,
- bias=bias,
- **kwargs,
- **dd,
- )
- self.bn0 = norm_layer(mid_chs, **dd) if norm_layer else nn.Identity()
- self.drop = drop_layer() if drop_layer is not None else nn.Identity()
- self.act0 = act_layer(inplace=True)
- self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups, **dd)
- self.bn1 = norm_layer(attn_chs, **dd) if norm_layer else nn.Identity()
- self.act1 = act_layer(inplace=True)
- self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups, **dd)
- self.rsoftmax = RadixSoftmax(radix, groups)
- def forward(self, x):
- x = self.conv(x)
- x = self.bn0(x)
- x = self.drop(x)
- x = self.act0(x)
- B, RC, H, W = x.shape
- if self.radix > 1:
- x = x.reshape((B, self.radix, RC // self.radix, H, W))
- x_gap = x.sum(dim=1)
- else:
- x_gap = x
- x_gap = x_gap.mean((2, 3), keepdim=True)
- x_gap = self.fc1(x_gap)
- x_gap = self.bn1(x_gap)
- x_gap = self.act1(x_gap)
- x_attn = self.fc2(x_gap)
- x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
- if self.radix > 1:
- out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
- else:
- out = x * x_attn
- return out.contiguous()
|