norm.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. """ Normalization layers and wrappers
  2. Norm layer definitions that support fast norm and consistent channel arg order (always first arg).
  3. Hacked together by / Copyright 2022 Ross Wightman
  4. """
  5. import numbers
  6. from typing import Tuple
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from .fast_norm import (
  11. is_fast_norm,
  12. fast_group_norm,
  13. fast_layer_norm,
  14. fast_rms_norm,
  15. rms_norm2d,
  16. fast_rms_norm2d,
  17. fast_simple_norm,
  18. simple_norm,
  19. )
  20. try:
  21. from torch.nn.functional import rms_norm
  22. except ImportError:
  23. from .fast_norm import rms_norm
  24. class GroupNorm(nn.GroupNorm):
  25. _fast_norm: torch.jit.Final[bool]
  26. def __init__(
  27. self,
  28. num_channels: int,
  29. num_groups: int = 32,
  30. eps: float = 1e-5,
  31. affine: bool = True,
  32. **kwargs,
  33. ):
  34. # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
  35. super().__init__(num_groups, num_channels, eps=eps, affine=affine, **kwargs)
  36. self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
  37. def forward(self, x):
  38. if self._fast_norm:
  39. return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
  40. else:
  41. return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
  42. class GroupNorm1(nn.GroupNorm):
  43. """ Group Normalization with 1 group.
  44. Input: tensor in shape [B, C, *]
  45. """
  46. _fast_norm: torch.jit.Final[bool]
  47. def __init__(self, num_channels: int, **kwargs):
  48. super().__init__(1, num_channels, **kwargs)
  49. self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
  50. def forward(self, x: torch.Tensor) -> torch.Tensor:
  51. if self._fast_norm:
  52. return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
  53. else:
  54. return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
  55. class LayerNorm(nn.LayerNorm):
  56. """ LayerNorm w/ fast norm option
  57. """
  58. _fast_norm: torch.jit.Final[bool]
  59. def __init__(
  60. self,
  61. num_channels: int,
  62. eps: float = 1e-6,
  63. affine: bool = True,
  64. **kwargs,
  65. ):
  66. super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
  67. self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
  68. def forward(self, x: torch.Tensor) -> torch.Tensor:
  69. if self._fast_norm:
  70. x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  71. else:
  72. x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  73. return x
  74. class LayerNormFp32(nn.LayerNorm):
  75. """ LayerNorm
  76. """
  77. def __init__(
  78. self,
  79. num_channels: int,
  80. eps: float = 1e-6,
  81. affine: bool = True,
  82. **kwargs,
  83. ):
  84. super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
  85. def forward(self, x: torch.Tensor) -> torch.Tensor:
  86. weight = self.weight.float() if self.weight is not None else None
  87. bias = self.bias.float() if self.bias is not None else None
  88. x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype)
  89. return x
  90. class LayerNorm2d(nn.LayerNorm):
  91. """ LayerNorm for channels of '2D' spatial NCHW tensors """
  92. _fast_norm: torch.jit.Final[bool]
  93. def __init__(
  94. self,
  95. num_channels: int,
  96. eps: float = 1e-6,
  97. affine: bool = True,
  98. **kwargs,
  99. ):
  100. super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
  101. self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
  102. def forward(self, x: torch.Tensor) -> torch.Tensor:
  103. x = x.permute(0, 2, 3, 1)
  104. if self._fast_norm:
  105. x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  106. else:
  107. x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  108. x = x.permute(0, 3, 1, 2)
  109. return x
  110. class LayerNorm2dFp32(nn.LayerNorm):
  111. """ LayerNorm for channels of '2D' spatial NCHW tensors """
  112. def __init__(
  113. self,
  114. num_channels: int,
  115. eps: float = 1e-6,
  116. affine: bool = True,
  117. **kwargs,
  118. ):
  119. super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
  120. def forward(self, x: torch.Tensor) -> torch.Tensor:
  121. x = x.permute(0, 2, 3, 1)
  122. weight = self.weight.float() if self.weight is not None else None
  123. bias = self.bias.float() if self.bias is not None else None
  124. x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype)
  125. x = x.permute(0, 3, 1, 2)
  126. return x
  127. def _is_contiguous(tensor: torch.Tensor) -> bool:
  128. # jit is oh so lovely :/
  129. if torch.jit.is_scripting():
  130. return tensor.is_contiguous()
  131. else:
  132. return tensor.is_contiguous(memory_format=torch.contiguous_format)
  133. def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
  134. s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
  135. x = (x - u) * torch.rsqrt(s + eps)
  136. x = x * weight[:, None, None] + bias[:, None, None]
  137. return x
  138. def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
  139. u = x.mean(dim=1, keepdim=True)
  140. s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0)
  141. x = (x - u) * torch.rsqrt(s + eps)
  142. x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1)
  143. return x
  144. class LayerNormExp2d(nn.LayerNorm):
  145. """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
  146. Experimental implementation w/ manual norm for tensors non-contiguous tensors.
  147. This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last
  148. layout. However, benefits are not always clear and can perform worse on other GPUs.
  149. """
  150. def __init__(self, num_channels: int, eps: float = 1e-6):
  151. super().__init__(num_channels, eps=eps)
  152. def forward(self, x) -> torch.Tensor:
  153. if _is_contiguous(x):
  154. x = F.layer_norm(
  155. x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
  156. else:
  157. x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
  158. return x
  159. class RmsNorm(nn.Module):
  160. """ RmsNorm w/ fast (apex) norm if available
  161. """
  162. __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
  163. normalized_shape: Tuple[int, ...]
  164. eps: float
  165. elementwise_affine: bool
  166. _fast_norm: bool
  167. def __init__(
  168. self,
  169. channels: int,
  170. eps: float = 1e-6,
  171. affine: bool = True,
  172. device=None,
  173. dtype=None,
  174. ) -> None:
  175. dd = {'device': device, 'dtype': dtype}
  176. super().__init__()
  177. normalized_shape = channels
  178. if isinstance(normalized_shape, numbers.Integral):
  179. # mypy error: incompatible types in assignment
  180. normalized_shape = (normalized_shape,) # type: ignore[assignment]
  181. self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
  182. self.eps = eps
  183. self.elementwise_affine = affine
  184. self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
  185. if self.elementwise_affine:
  186. self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
  187. else:
  188. self.register_parameter('weight', None)
  189. self.reset_parameters()
  190. def reset_parameters(self) -> None:
  191. if self.elementwise_affine:
  192. nn.init.ones_(self.weight)
  193. def forward(self, x: torch.Tensor) -> torch.Tensor:
  194. # NOTE fast norm fallback needs our rms norm impl, so both paths through here.
  195. # Since there is no built-in PyTorch impl, always uses APEX RmsNorm if installed.
  196. if self._fast_norm:
  197. x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
  198. else:
  199. x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
  200. return x
  201. class RmsNormFp32(nn.Module):
  202. """ RmsNorm w/ fast (apex) norm if available
  203. """
  204. __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
  205. normalized_shape: Tuple[int, ...]
  206. eps: float
  207. elementwise_affine: bool
  208. def __init__(
  209. self,
  210. channels: int,
  211. eps: float = 1e-6,
  212. affine: bool = True,
  213. device=None,
  214. dtype=None,
  215. ) -> None:
  216. dd = {'device': device, 'dtype': dtype}
  217. super().__init__()
  218. normalized_shape = channels
  219. if isinstance(normalized_shape, numbers.Integral):
  220. # mypy error: incompatible types in assignment
  221. normalized_shape = (normalized_shape,) # type: ignore[assignment]
  222. self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
  223. self.eps = eps
  224. self.elementwise_affine = affine
  225. if self.elementwise_affine:
  226. self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
  227. else:
  228. self.register_parameter('weight', None)
  229. self.reset_parameters()
  230. def reset_parameters(self) -> None:
  231. if self.elementwise_affine:
  232. nn.init.ones_(self.weight)
  233. def forward(self, x: torch.Tensor) -> torch.Tensor:
  234. weight = self.weight.float() if self.weight is not None else None
  235. x = rms_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
  236. return x
  237. class RmsNorm2d(nn.Module):
  238. """ RmsNorm2D for NCHW tensors, w/ fast apex or cast norm if available
  239. NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
  240. on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
  241. like https://github.com/pytorch/pytorch/pull/150576 lands.
  242. """
  243. __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
  244. normalized_shape: Tuple[int, ...]
  245. eps: float
  246. elementwise_affine: bool
  247. _fast_norm: bool
  248. def __init__(
  249. self,
  250. channels: int,
  251. eps: float = 1e-6,
  252. affine: bool = True,
  253. device=None,
  254. dtype=None,
  255. ) -> None:
  256. dd = {'device': device, 'dtype': dtype}
  257. super().__init__()
  258. normalized_shape = channels
  259. if isinstance(normalized_shape, numbers.Integral):
  260. # mypy error: incompatible types in assignment
  261. normalized_shape = (normalized_shape,) # type: ignore[assignment]
  262. self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
  263. self.eps = eps
  264. self.elementwise_affine = affine
  265. self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
  266. if self.elementwise_affine:
  267. self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
  268. else:
  269. self.register_parameter('weight', None)
  270. self.reset_parameters()
  271. def reset_parameters(self) -> None:
  272. if self.elementwise_affine:
  273. nn.init.ones_(self.weight)
  274. def forward(self, x: torch.Tensor) -> torch.Tensor:
  275. # NOTE fast norm fallback needs our rms norm impl, so both paths through here.
  276. # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
  277. if self._fast_norm:
  278. x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
  279. else:
  280. x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
  281. return x
  282. class RmsNorm2dFp32(nn.Module):
  283. """ RmsNorm2D for NCHW tensors, w/ fast apex or cast norm if available
  284. NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
  285. on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
  286. like https://github.com/pytorch/pytorch/pull/150576 lands.
  287. """
  288. __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
  289. normalized_shape: Tuple[int, ...]
  290. eps: float
  291. elementwise_affine: bool
  292. def __init__(
  293. self,
  294. channels: int,
  295. eps: float = 1e-6,
  296. affine: bool = True,
  297. device=None,
  298. dtype=None,
  299. ) -> None:
  300. dd = {'device': device, 'dtype': dtype}
  301. super().__init__()
  302. normalized_shape = channels
  303. if isinstance(normalized_shape, numbers.Integral):
  304. # mypy error: incompatible types in assignment
  305. normalized_shape = (normalized_shape,) # type: ignore[assignment]
  306. self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
  307. self.eps = eps
  308. self.elementwise_affine = affine
  309. if self.elementwise_affine:
  310. self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
  311. else:
  312. self.register_parameter('weight', None)
  313. self.reset_parameters()
  314. def reset_parameters(self) -> None:
  315. if self.elementwise_affine:
  316. nn.init.ones_(self.weight)
  317. def forward(self, x: torch.Tensor) -> torch.Tensor:
  318. weight = self.weight.float() if self.weight is not None else None
  319. x = rms_norm2d(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
  320. return x
  321. class SimpleNorm(nn.Module):
  322. """ SimpleNorm (x / std(x))
  323. """
  324. __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
  325. normalized_shape: Tuple[int, ...]
  326. eps: float
  327. elementwise_affine: bool
  328. _fast_norm: bool
  329. def __init__(
  330. self,
  331. channels: int,
  332. eps: float = 1e-6,
  333. affine: bool = True,
  334. device=None,
  335. dtype=None,
  336. ) -> None:
  337. dd = {'device': device, 'dtype': dtype}
  338. super().__init__()
  339. normalized_shape = channels
  340. if isinstance(normalized_shape, numbers.Integral):
  341. # mypy error: incompatible types in assignment
  342. normalized_shape = (normalized_shape,) # type: ignore[assignment]
  343. self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
  344. self.eps = eps
  345. self.elementwise_affine = affine
  346. self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
  347. if self.elementwise_affine:
  348. self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
  349. else:
  350. self.register_parameter('weight', None)
  351. self.reset_parameters()
  352. def reset_parameters(self) -> None:
  353. if self.elementwise_affine:
  354. nn.init.ones_(self.weight)
  355. def forward(self, x: torch.Tensor) -> torch.Tensor:
  356. if self._fast_norm:
  357. x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
  358. else:
  359. x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
  360. return x
  361. class SimpleNormFp32(nn.Module):
  362. """ SimpleNorm (x / std(x))
  363. """
  364. __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
  365. normalized_shape: Tuple[int, ...]
  366. eps: float
  367. elementwise_affine: bool
  368. def __init__(
  369. self,
  370. channels: int,
  371. eps: float = 1e-6,
  372. affine: bool = True,
  373. device=None,
  374. dtype=None,
  375. ) -> None:
  376. dd = {'device': device, 'dtype': dtype}
  377. super().__init__()
  378. normalized_shape = channels
  379. if isinstance(normalized_shape, numbers.Integral):
  380. # mypy error: incompatible types in assignment
  381. normalized_shape = (normalized_shape,) # type: ignore[assignment]
  382. self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
  383. self.eps = eps
  384. self.elementwise_affine = affine
  385. if self.elementwise_affine:
  386. self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
  387. else:
  388. self.register_parameter('weight', None)
  389. self.reset_parameters()
  390. def reset_parameters(self) -> None:
  391. if self.elementwise_affine:
  392. nn.init.ones_(self.weight)
  393. def forward(self, x: torch.Tensor) -> torch.Tensor:
  394. weight = self.weight.float() if self.weight is not None else None
  395. x = simple_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
  396. return x
  397. class SimpleNorm2d(nn.Module):
  398. """ SimpleNorm for NCHW tensors
  399. """
  400. __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
  401. normalized_shape: Tuple[int, ...]
  402. eps: float
  403. elementwise_affine: bool
  404. _fast_norm: bool
  405. def __init__(
  406. self,
  407. channels: int,
  408. eps: float = 1e-6,
  409. affine: bool = True,
  410. device=None,
  411. dtype=None,
  412. ) -> None:
  413. dd = {'device': device, 'dtype': dtype}
  414. super().__init__()
  415. normalized_shape = channels
  416. if isinstance(normalized_shape, numbers.Integral):
  417. # mypy error: incompatible types in assignment
  418. normalized_shape = (normalized_shape,) # type: ignore[assignment]
  419. self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
  420. self.eps = eps
  421. self.elementwise_affine = affine
  422. self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
  423. if self.elementwise_affine:
  424. self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
  425. else:
  426. self.register_parameter('weight', None)
  427. self.reset_parameters()
  428. def reset_parameters(self) -> None:
  429. if self.elementwise_affine:
  430. nn.init.ones_(self.weight)
  431. def forward(self, x: torch.Tensor) -> torch.Tensor:
  432. x = x.permute(0, 2, 3, 1)
  433. if self._fast_norm:
  434. x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
  435. else:
  436. x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
  437. x = x.permute(0, 3, 1, 2)
  438. return x
  439. class SimpleNorm2dFp32(nn.Module):
  440. """ SimpleNorm for NCHW tensors
  441. """
  442. __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
  443. normalized_shape: Tuple[int, ...]
  444. eps: float
  445. elementwise_affine: bool
  446. def __init__(
  447. self,
  448. channels: int,
  449. eps: float = 1e-6,
  450. affine: bool = True,
  451. device=None,
  452. dtype=None,
  453. ) -> None:
  454. dd = {'device': device, 'dtype': dtype}
  455. super().__init__()
  456. normalized_shape = channels
  457. if isinstance(normalized_shape, numbers.Integral):
  458. # mypy error: incompatible types in assignment
  459. normalized_shape = (normalized_shape,) # type: ignore[assignment]
  460. self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
  461. self.eps = eps
  462. self.elementwise_affine = affine
  463. if self.elementwise_affine:
  464. self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
  465. else:
  466. self.register_parameter('weight', None)
  467. self.reset_parameters()
  468. def reset_parameters(self) -> None:
  469. if self.elementwise_affine:
  470. nn.init.ones_(self.weight)
  471. def forward(self, x: torch.Tensor) -> torch.Tensor:
  472. x = x.permute(0, 2, 3, 1)
  473. weight = self.weight.float() if self.weight is not None else None
  474. x = simple_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
  475. x = x.permute(0, 3, 1, 2)
  476. return x