std_conv.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. """ Convolution with Weight Standardization (StdConv and ScaledStdConv)
  2. StdConv:
  3. @article{weightstandardization,
  4. author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille},
  5. title = {Weight Standardization},
  6. journal = {arXiv preprint arXiv:1903.10520},
  7. year = {2019},
  8. }
  9. Code: https://github.com/joe-siyuan-qiao/WeightStandardization
  10. ScaledStdConv:
  11. Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
  12. - https://arxiv.org/abs/2101.08692
  13. Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
  14. Hacked together by / copyright Ross Wightman, 2021.
  15. """
  16. from typing import Optional, Tuple, Union
  17. import torch
  18. import torch.nn as nn
  19. import torch.nn.functional as F
  20. from ._fx import register_notrace_module
  21. from .padding import get_padding, get_padding_value, pad_same
  22. class StdConv2d(nn.Conv2d):
  23. """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
  24. Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
  25. https://arxiv.org/abs/1903.10520v2
  26. """
  27. def __init__(
  28. self,
  29. in_channel: int,
  30. out_channels: int,
  31. kernel_size: Union[int, Tuple[int, int]],
  32. stride: Union[int, Tuple[int, int]] = 1,
  33. padding: Optional[Union[int, Tuple[int, int]]] = None,
  34. dilation: Union[int, Tuple[int, int]] = 1,
  35. groups: int = 1,
  36. bias: bool = False,
  37. eps: float = 1e-6,
  38. device=None,
  39. dtype=None,
  40. ):
  41. if padding is None:
  42. padding = get_padding(kernel_size, stride, dilation)
  43. super().__init__(
  44. in_channel, out_channels, kernel_size, stride=stride,
  45. padding=padding, dilation=dilation, groups=groups, bias=bias, device=device, dtype=dtype)
  46. self.eps = eps
  47. def forward(self, x):
  48. weight = F.batch_norm(
  49. self.weight.reshape(1, self.out_channels, -1),
  50. None, # running_mean
  51. None, # running_var
  52. training=True,
  53. momentum=0.,
  54. eps=self.eps,
  55. ).reshape_as(self.weight)
  56. x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
  57. return x
  58. @register_notrace_module
  59. class StdConv2dSame(nn.Conv2d):
  60. """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model.
  61. Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
  62. https://arxiv.org/abs/1903.10520v2
  63. """
  64. def __init__(
  65. self,
  66. in_channel: int,
  67. out_channels: int,
  68. kernel_size: Union[int, Tuple[int, int]],
  69. stride: Union[int, Tuple[int, int]] = 1,
  70. padding: str = 'SAME',
  71. dilation: Union[int, Tuple[int, int]] = 1,
  72. groups: int = 1,
  73. bias: bool = False,
  74. eps: float = 1e-6,
  75. device=None,
  76. dtype=None,
  77. ):
  78. padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
  79. super().__init__(
  80. in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
  81. groups=groups, bias=bias, device=device, dtype=dtype)
  82. self.same_pad = is_dynamic
  83. self.eps = eps
  84. def forward(self, x):
  85. if self.same_pad:
  86. x = pad_same(x, self.kernel_size, self.stride, self.dilation)
  87. weight = F.batch_norm(
  88. self.weight.reshape(1, self.out_channels, -1),
  89. None, # running_mean
  90. None, # running_var
  91. training=True,
  92. momentum=0.,
  93. eps=self.eps,
  94. ).reshape_as(self.weight)
  95. x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
  96. return x
  97. class ScaledStdConv2d(nn.Conv2d):
  98. """Conv2d layer with Scaled Weight Standardization.
  99. Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
  100. https://arxiv.org/abs/2101.08692
  101. NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor.
  102. """
  103. def __init__(
  104. self,
  105. in_channels: int,
  106. out_channels: int,
  107. kernel_size: Union[int, Tuple[int, int]],
  108. stride: Union[int, Tuple[int, int]] = 1,
  109. padding: Optional[Union[int, Tuple[int, int], str]] = None,
  110. dilation: Union[int, Tuple[int, int]] = 1,
  111. groups: int = 1,
  112. bias: bool = True,
  113. gamma: float = 1.0,
  114. eps: float = 1e-6,
  115. gain_init: float = 1.0,
  116. device=None,
  117. dtype=None,
  118. ):
  119. dd = {'device': device, 'dtype': dtype}
  120. if padding is None:
  121. padding = get_padding(kernel_size, stride, dilation)
  122. super().__init__(
  123. in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
  124. groups=groups, bias=bias, **dd)
  125. self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
  126. self.eps = eps
  127. self.gain_init = gain_init
  128. self.gain = nn.Parameter(torch.empty((self.out_channels, 1, 1, 1), **dd))
  129. self.reset_parameters()
  130. def reset_parameters(self) -> None:
  131. # Only initialize gain if it exists (for the second call)
  132. if hasattr(self, 'gain'):
  133. torch.nn.init.constant_(self.gain, self.gain_init)
  134. # Also reset parent parameters if needed
  135. super().reset_parameters()
  136. # On first call (from super().__init__), do nothing
  137. def forward(self, x):
  138. weight = F.batch_norm(
  139. self.weight.reshape(1, self.out_channels, -1),
  140. None, # running_mean
  141. None, # running_var
  142. weight=(self.gain * self.scale).view(-1),
  143. training=True,
  144. momentum=0.,
  145. eps=self.eps,
  146. ).reshape_as(self.weight)
  147. return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
  148. @register_notrace_module
  149. class ScaledStdConv2dSame(nn.Conv2d):
  150. """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support
  151. Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
  152. https://arxiv.org/abs/2101.08692
  153. NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor.
  154. """
  155. def __init__(
  156. self,
  157. in_channels: int,
  158. out_channels: int,
  159. kernel_size: Union[int, Tuple[int, int]],
  160. stride: Union[int, Tuple[int, int]] = 1,
  161. padding: str = 'SAME',
  162. dilation: Union[int, Tuple[int, int]] = 1,
  163. groups: int = 1,
  164. bias: bool = True,
  165. gamma: float = 1.0,
  166. eps: float = 1e-6,
  167. gain_init: float = 1.0,
  168. device=None,
  169. dtype=None,
  170. ):
  171. dd = {'device': device, 'dtype': dtype}
  172. padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
  173. super().__init__(
  174. in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
  175. groups=groups, bias=bias, **dd)
  176. self.scale = gamma * self.weight[0].numel() ** -0.5
  177. self.same_pad = is_dynamic
  178. self.eps = eps
  179. self.gain_init = gain_init
  180. self.gain = nn.Parameter(torch.empty((self.out_channels, 1, 1, 1), **dd))
  181. self.reset_parameters()
  182. def reset_parameters(self) -> None:
  183. # Only initialize gain if it exists (for the second call)
  184. if hasattr(self, 'gain'):
  185. torch.nn.init.constant_(self.gain, self.gain_init)
  186. # Also reset parent parameters if needed
  187. super().reset_parameters()
  188. # On first call (from super().__init__), do nothing
  189. def forward(self, x):
  190. if self.same_pad:
  191. x = pad_same(x, self.kernel_size, self.stride, self.dilation)
  192. weight = F.batch_norm(
  193. self.weight.reshape(1, self.out_channels, -1),
  194. None, # running_mean
  195. None, # running_var
  196. weight=(self.gain * self.scale).view(-1),
  197. training=True,
  198. momentum=0.,
  199. eps=self.eps,
  200. ).reshape_as(self.weight)
  201. return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)