grn.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. """ Global Response Normalization Module
  2. Based on the GRN layer presented in
  3. `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
  4. This implementation
  5. * works for both NCHW and NHWC tensor layouts
  6. * uses affine param names matching existing torch norm layers
  7. * slightly improves eager mode performance via fused addcmul
  8. Hacked together by / Copyright 2023 Ross Wightman
  9. """
  10. import torch
  11. from torch import nn as nn
  12. class GlobalResponseNorm(nn.Module):
  13. """ Global Response Normalization layer
  14. """
  15. def __init__(
  16. self,
  17. dim: int,
  18. eps: float = 1e-6,
  19. channels_last: bool = True,
  20. device=None,
  21. dtype=None,
  22. ):
  23. dd = {'device': device, 'dtype': dtype}
  24. super().__init__()
  25. self.eps = eps
  26. if channels_last:
  27. self.spatial_dim = (1, 2)
  28. self.channel_dim = -1
  29. self.wb_shape = (1, 1, 1, -1)
  30. else:
  31. self.spatial_dim = (2, 3)
  32. self.channel_dim = 1
  33. self.wb_shape = (1, -1, 1, 1)
  34. self.weight = nn.Parameter(torch.zeros(dim, **dd))
  35. self.bias = nn.Parameter(torch.zeros(dim, **dd))
  36. def forward(self, x):
  37. x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True)
  38. x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps)
  39. return x + torch.addcmul(self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n)