gather_excite.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. """ Gather-Excite Attention Block
  2. Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348
  3. Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet
  4. I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another
  5. impl that covers all of the cases.
  6. NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation
  7. Hacked together by / Copyright 2021 Ross Wightman
  8. """
  9. from typing import Optional, Tuple, Type, Union
  10. import math
  11. from torch import nn as nn
  12. import torch.nn.functional as F
  13. from .create_act import create_act_layer, get_act_layer
  14. from .create_conv2d import create_conv2d
  15. from .helpers import make_divisible
  16. from .mlp import ConvMlp
  17. class GatherExcite(nn.Module):
  18. """ Gather-Excite Attention Module
  19. """
  20. def __init__(
  21. self,
  22. channels: int,
  23. feat_size: Optional[Tuple[int, int]] = None,
  24. extra_params: bool = False,
  25. extent: int = 0,
  26. use_mlp: bool = True,
  27. rd_ratio: float = 1./16,
  28. rd_channels: Optional[int] = None,
  29. rd_divisor: int = 1,
  30. add_maxpool: bool = False,
  31. act_layer: Type[nn.Module] = nn.ReLU,
  32. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  33. gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
  34. device=None,
  35. dtype=None,
  36. ):
  37. dd = {'device': device, 'dtype': dtype}
  38. super().__init__()
  39. self.add_maxpool = add_maxpool
  40. act_layer = get_act_layer(act_layer)
  41. self.extent = extent
  42. if extra_params:
  43. self.gather = nn.Sequential()
  44. if extent == 0:
  45. assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'
  46. self.gather.add_module(
  47. 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True, *dd))
  48. if norm_layer:
  49. self.gather.add_module(f'norm1', nn.BatchNorm2d(channels, *dd))
  50. else:
  51. assert extent % 2 == 0
  52. num_conv = int(math.log2(extent))
  53. for i in range(num_conv):
  54. self.gather.add_module(
  55. f'conv{i + 1}',
  56. create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True, *dd))
  57. if norm_layer:
  58. self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels, *dd))
  59. if i != num_conv - 1:
  60. self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))
  61. else:
  62. self.gather = None
  63. if self.extent == 0:
  64. self.gk = 0
  65. self.gs = 0
  66. else:
  67. assert extent % 2 == 0
  68. self.gk = self.extent * 2 - 1
  69. self.gs = self.extent
  70. if not rd_channels:
  71. rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
  72. self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer, *dd) if use_mlp else nn.Identity()
  73. self.gate = create_act_layer(gate_layer)
  74. def forward(self, x):
  75. size = x.shape[-2:]
  76. if self.gather is not None:
  77. x_ge = self.gather(x)
  78. else:
  79. if self.extent == 0:
  80. # global extent
  81. x_ge = x.mean(dim=(2, 3), keepdims=True)
  82. if self.add_maxpool:
  83. # experimental codepath, may remove or change
  84. x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)
  85. else:
  86. x_ge = F.avg_pool2d(
  87. x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)
  88. if self.add_maxpool:
  89. # experimental codepath, may remove or change
  90. x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)
  91. x_ge = self.mlp(x_ge)
  92. if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:
  93. x_ge = F.interpolate(x_ge, size=size)
  94. return x * self.gate(x_ge)