evo_norm.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. """ EvoNorm in PyTorch
  2. Based on `Evolving Normalization-Activation Layers` - https://arxiv.org/abs/2004.02967
  3. @inproceedings{NEURIPS2020,
  4. author = {Liu, Hanxiao and Brock, Andy and Simonyan, Karen and Le, Quoc},
  5. booktitle = {Advances in Neural Information Processing Systems},
  6. editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
  7. pages = {13539--13550},
  8. publisher = {Curran Associates, Inc.},
  9. title = {Evolving Normalization-Activation Layers},
  10. url = {https://proceedings.neurips.cc/paper/2020/file/9d4c03631b8b0c85ae08bf05eda37d0f-Paper.pdf},
  11. volume = {33},
  12. year = {2020}
  13. }
  14. An attempt at getting decent performing EvoNorms running in PyTorch.
  15. While faster than other PyTorch impl, still quite a ways off the built-in BatchNorm
  16. in terms of memory usage and throughput on GPUs.
  17. I'm testing these modules on TPU w/ PyTorch XLA. Promising start but
  18. currently working around some issues with builtin torch/tensor.var/std. Unlike
  19. GPU, similar train speeds for EvoNormS variants and BatchNorm.
  20. Hacked together by / Copyright 2020 Ross Wightman
  21. """
  22. from typing import Optional, Sequence, Type, Union
  23. import torch
  24. import torch.nn as nn
  25. import torch.nn.functional as F
  26. from .create_act import create_act_layer
  27. from .trace_utils import _assert
  28. def instance_std(x, eps: float = 1e-5):
  29. std = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype)
  30. return std.expand(x.shape)
  31. def instance_std_tpu(x, eps: float = 1e-5):
  32. std = manual_var(x, dim=(2, 3)).add(eps).sqrt()
  33. return std.expand(x.shape)
  34. # instance_std = instance_std_tpu
  35. def instance_rms(x, eps: float = 1e-5):
  36. rms = x.float().square().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(x.dtype)
  37. return rms.expand(x.shape)
  38. def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False):
  39. xm = x.mean(dim=dim, keepdim=True)
  40. if diff_sqm:
  41. # difference of squared mean and mean squared, faster on TPU can be less stable
  42. var = ((x * x).mean(dim=dim, keepdim=True) - (xm * xm)).clamp(0)
  43. else:
  44. var = ((x - xm) * (x - xm)).mean(dim=dim, keepdim=True)
  45. return var
  46. def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False):
  47. B, C, H, W = x.shape
  48. x_dtype = x.dtype
  49. _assert(C % groups == 0, '')
  50. if flatten:
  51. x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
  52. std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
  53. else:
  54. x = x.reshape(B, groups, C // groups, H, W)
  55. std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
  56. return std.expand(x.shape).reshape(B, C, H, W)
  57. def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False):
  58. # This is a workaround for some stability / odd behaviour of .var and .std
  59. # running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results
  60. B, C, H, W = x.shape
  61. _assert(C % groups == 0, '')
  62. if flatten:
  63. x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
  64. var = manual_var(x, dim=-1, diff_sqm=diff_sqm)
  65. else:
  66. x = x.reshape(B, groups, C // groups, H, W)
  67. var = manual_var(x, dim=(2, 3, 4), diff_sqm=diff_sqm)
  68. return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W)
  69. #group_std = group_std_tpu # FIXME TPU temporary
  70. def group_rms(x, groups: int = 32, eps: float = 1e-5):
  71. B, C, H, W = x.shape
  72. _assert(C % groups == 0, '')
  73. x_dtype = x.dtype
  74. x = x.reshape(B, groups, C // groups, H, W)
  75. rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(x_dtype)
  76. return rms.expand(x.shape).reshape(B, C, H, W)
  77. class EvoNorm2dB0(nn.Module):
  78. def __init__(
  79. self,
  80. num_features: int,
  81. apply_act: bool = True,
  82. momentum: float = 0.1,
  83. eps: float = 1e-3,
  84. device=None,
  85. dtype=None,
  86. **_
  87. ):
  88. dd = {'device': device, 'dtype': dtype}
  89. super().__init__()
  90. self.apply_act = apply_act # apply activation (non-linearity)
  91. self.momentum = momentum
  92. self.eps = eps
  93. self.weight = nn.Parameter(torch.empty(num_features, **dd))
  94. self.bias = nn.Parameter(torch.empty(num_features, **dd))
  95. self.v = nn.Parameter(torch.empty(num_features, **dd)) if apply_act else None
  96. self.register_buffer('running_var', torch.ones(num_features, **dd))
  97. self.reset_parameters()
  98. def reset_parameters(self):
  99. nn.init.ones_(self.weight)
  100. nn.init.zeros_(self.bias)
  101. if self.v is not None:
  102. nn.init.ones_(self.v)
  103. def forward(self, x):
  104. _assert(x.dim() == 4, 'expected 4D input')
  105. x_dtype = x.dtype
  106. v_shape = (1, -1, 1, 1)
  107. if self.v is not None:
  108. if self.training:
  109. var = x.float().var(dim=(0, 2, 3), unbiased=False)
  110. # var = manual_var(x, dim=(0, 2, 3)).squeeze()
  111. n = x.numel() / x.shape[1]
  112. self.running_var.copy_(
  113. self.running_var * (1 - self.momentum) +
  114. var.detach() * self.momentum * (n / (n - 1)))
  115. else:
  116. var = self.running_var
  117. left = var.add(self.eps).sqrt_().to(x_dtype).view(v_shape).expand_as(x)
  118. v = self.v.to(x_dtype).view(v_shape)
  119. right = x * v + instance_std(x, self.eps)
  120. x = x / left.max(right)
  121. return x * self.weight.to(x_dtype).view(v_shape) + self.bias.to(x_dtype).view(v_shape)
  122. class EvoNorm2dB1(nn.Module):
  123. def __init__(
  124. self,
  125. num_features: int,
  126. apply_act: bool = True,
  127. momentum: float = 0.1,
  128. eps: float = 1e-5,
  129. device=None,
  130. dtype=None,
  131. **_
  132. ):
  133. dd = {'device': device, 'dtype': dtype}
  134. super().__init__()
  135. self.apply_act = apply_act # apply activation (non-linearity)
  136. self.momentum = momentum
  137. self.eps = eps
  138. self.weight = nn.Parameter(torch.empty(num_features, **dd))
  139. self.bias = nn.Parameter(torch.empty(num_features, **dd))
  140. self.register_buffer('running_var', torch.ones(num_features, **dd))
  141. self.reset_parameters()
  142. def reset_parameters(self):
  143. nn.init.ones_(self.weight)
  144. nn.init.zeros_(self.bias)
  145. def forward(self, x):
  146. _assert(x.dim() == 4, 'expected 4D input')
  147. x_dtype = x.dtype
  148. v_shape = (1, -1, 1, 1)
  149. if self.apply_act:
  150. if self.training:
  151. var = x.float().var(dim=(0, 2, 3), unbiased=False)
  152. n = x.numel() / x.shape[1]
  153. self.running_var.copy_(
  154. self.running_var * (1 - self.momentum) +
  155. var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)))
  156. else:
  157. var = self.running_var
  158. var = var.to(x_dtype).view(v_shape)
  159. left = var.add(self.eps).sqrt_()
  160. right = (x + 1) * instance_rms(x, self.eps)
  161. x = x / left.max(right)
  162. return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
  163. class EvoNorm2dB2(nn.Module):
  164. def __init__(
  165. self,
  166. num_features: int,
  167. apply_act: bool = True,
  168. momentum: float = 0.1,
  169. eps: float = 1e-5,
  170. device=None,
  171. dtype=None,
  172. **_
  173. ):
  174. dd = {'device': device, 'dtype': dtype}
  175. super().__init__()
  176. self.apply_act = apply_act # apply activation (non-linearity)
  177. self.momentum = momentum
  178. self.eps = eps
  179. self.weight = nn.Parameter(torch.empty(num_features, **dd))
  180. self.bias = nn.Parameter(torch.empty(num_features, **dd))
  181. self.register_buffer('running_var', torch.ones(num_features, **dd))
  182. self.reset_parameters()
  183. def reset_parameters(self):
  184. nn.init.ones_(self.weight)
  185. nn.init.zeros_(self.bias)
  186. def forward(self, x):
  187. _assert(x.dim() == 4, 'expected 4D input')
  188. x_dtype = x.dtype
  189. v_shape = (1, -1, 1, 1)
  190. if self.apply_act:
  191. if self.training:
  192. var = x.float().var(dim=(0, 2, 3), unbiased=False)
  193. n = x.numel() / x.shape[1]
  194. self.running_var.copy_(
  195. self.running_var * (1 - self.momentum) +
  196. var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)))
  197. else:
  198. var = self.running_var
  199. var = var.to(x_dtype).view(v_shape)
  200. left = var.add(self.eps).sqrt_()
  201. right = instance_rms(x, self.eps) - x
  202. x = x / left.max(right)
  203. return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
  204. class EvoNorm2dS0(nn.Module):
  205. def __init__(
  206. self,
  207. num_features: int,
  208. groups: int = 32,
  209. group_size: Optional[int] = None,
  210. apply_act: bool = True,
  211. eps: float = 1e-5,
  212. device=None,
  213. dtype=None,
  214. **_
  215. ):
  216. dd = {'device': device, 'dtype': dtype}
  217. super().__init__()
  218. self.apply_act = apply_act # apply activation (non-linearity)
  219. if group_size:
  220. assert num_features % group_size == 0
  221. self.groups = num_features // group_size
  222. else:
  223. self.groups = groups
  224. self.eps = eps
  225. self.weight = nn.Parameter(torch.empty(num_features, **dd))
  226. self.bias = nn.Parameter(torch.empty(num_features, **dd))
  227. self.v = nn.Parameter(torch.empty(num_features, **dd)) if apply_act else None
  228. self.reset_parameters()
  229. def reset_parameters(self):
  230. nn.init.ones_(self.weight)
  231. nn.init.zeros_(self.bias)
  232. if self.v is not None:
  233. nn.init.ones_(self.v)
  234. def forward(self, x):
  235. _assert(x.dim() == 4, 'expected 4D input')
  236. x_dtype = x.dtype
  237. v_shape = (1, -1, 1, 1)
  238. if self.v is not None:
  239. v = self.v.view(v_shape).to(x_dtype)
  240. x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps)
  241. return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
  242. class EvoNorm2dS0a(EvoNorm2dS0):
  243. def __init__(
  244. self,
  245. num_features: int,
  246. groups: int = 32,
  247. group_size: Optional[int] = None,
  248. apply_act: bool = True,
  249. eps: float = 1e-3,
  250. device=None,
  251. dtype=None,
  252. **_
  253. ):
  254. super().__init__(
  255. num_features,
  256. groups=groups,
  257. group_size=group_size,
  258. apply_act=apply_act,
  259. eps=eps,
  260. device=device,
  261. dtype=dtype,
  262. )
  263. def forward(self, x):
  264. _assert(x.dim() == 4, 'expected 4D input')
  265. x_dtype = x.dtype
  266. v_shape = (1, -1, 1, 1)
  267. d = group_std(x, self.groups, self.eps)
  268. if self.v is not None:
  269. v = self.v.view(v_shape).to(x_dtype)
  270. x = x * (x * v).sigmoid()
  271. x = x / d
  272. return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
  273. class EvoNorm2dS1(nn.Module):
  274. def __init__(
  275. self,
  276. num_features: int,
  277. groups: int = 32,
  278. group_size: Optional[int] = None,
  279. apply_act: bool = True,
  280. act_layer: Optional[Type[nn.Module]] = None,
  281. eps: float = 1e-5,
  282. device=None,
  283. dtype=None,
  284. **_
  285. ):
  286. dd = {'device': device, 'dtype': dtype}
  287. super().__init__()
  288. act_layer = act_layer or nn.SiLU
  289. self.apply_act = apply_act # apply activation (non-linearity)
  290. if act_layer is not None and apply_act:
  291. self.act = create_act_layer(act_layer)
  292. else:
  293. self.act = nn.Identity()
  294. if group_size:
  295. assert num_features % group_size == 0
  296. self.groups = num_features // group_size
  297. else:
  298. self.groups = groups
  299. self.eps = eps
  300. self.pre_act_norm = False
  301. self.weight = nn.Parameter(torch.empty(num_features, **dd))
  302. self.bias = nn.Parameter(torch.empty(num_features, **dd))
  303. self.reset_parameters()
  304. def reset_parameters(self):
  305. nn.init.ones_(self.weight)
  306. nn.init.zeros_(self.bias)
  307. def forward(self, x):
  308. _assert(x.dim() == 4, 'expected 4D input')
  309. x_dtype = x.dtype
  310. v_shape = (1, -1, 1, 1)
  311. if self.apply_act:
  312. x = self.act(x) / group_std(x, self.groups, self.eps)
  313. return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
  314. class EvoNorm2dS1a(EvoNorm2dS1):
  315. def __init__(
  316. self,
  317. num_features: int,
  318. groups: int = 32,
  319. group_size: Optional[int] = None,
  320. apply_act: bool = True,
  321. act_layer: Optional[Type[nn.Module]] = None,
  322. eps: float = 1e-3,
  323. device=None,
  324. dtype=None,
  325. **_
  326. ):
  327. super().__init__(
  328. num_features,
  329. groups=groups,
  330. group_size=group_size,
  331. apply_act=apply_act,
  332. act_layer=act_layer,
  333. eps=eps,
  334. device=device,
  335. dtype=dtype,
  336. )
  337. def forward(self, x):
  338. _assert(x.dim() == 4, 'expected 4D input')
  339. x_dtype = x.dtype
  340. v_shape = (1, -1, 1, 1)
  341. x = self.act(x) / group_std(x, self.groups, self.eps)
  342. return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
  343. class EvoNorm2dS2(nn.Module):
  344. def __init__(
  345. self,
  346. num_features: int,
  347. groups: int = 32,
  348. group_size: Optional[int] = None,
  349. apply_act: bool = True,
  350. act_layer: Optional[Type[nn.Module]] = None,
  351. eps: float = 1e-5,
  352. device=None,
  353. dtype=None,
  354. **_
  355. ):
  356. dd = {'device': device, 'dtype': dtype}
  357. super().__init__()
  358. act_layer = act_layer or nn.SiLU
  359. self.apply_act = apply_act # apply activation (non-linearity)
  360. if act_layer is not None and apply_act:
  361. self.act = create_act_layer(act_layer)
  362. else:
  363. self.act = nn.Identity()
  364. if group_size:
  365. assert num_features % group_size == 0
  366. self.groups = num_features // group_size
  367. else:
  368. self.groups = groups
  369. self.eps = eps
  370. self.weight = nn.Parameter(torch.empty(num_features, **dd))
  371. self.bias = nn.Parameter(torch.empty(num_features, **dd))
  372. self.reset_parameters()
  373. def reset_parameters(self):
  374. nn.init.ones_(self.weight)
  375. nn.init.zeros_(self.bias)
  376. def forward(self, x):
  377. _assert(x.dim() == 4, 'expected 4D input')
  378. x_dtype = x.dtype
  379. v_shape = (1, -1, 1, 1)
  380. if self.apply_act:
  381. x = self.act(x) / group_rms(x, self.groups, self.eps)
  382. return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
  383. class EvoNorm2dS2a(EvoNorm2dS2):
  384. def __init__(
  385. self,
  386. num_features: int,
  387. groups: int = 32,
  388. group_size: Optional[int] = None,
  389. apply_act: bool = True,
  390. act_layer: Optional[Type[nn.Module]] = None,
  391. eps: float = 1e-3,
  392. device=None,
  393. dtype=None,
  394. **_
  395. ):
  396. super().__init__(
  397. num_features,
  398. groups=groups,
  399. group_size=group_size,
  400. apply_act=apply_act,
  401. act_layer=act_layer,
  402. eps=eps,
  403. device=device,
  404. dtype=dtype,
  405. )
  406. def forward(self, x):
  407. _assert(x.dim() == 4, 'expected 4D input')
  408. x_dtype = x.dtype
  409. v_shape = (1, -1, 1, 1)
  410. x = self.act(x) / group_rms(x, self.groups, self.eps)
  411. return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)