norm_act.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690
  1. """ Normalization + Activation Layers
  2. Provides Norm+Act fns for standard PyTorch norm layers such as
  3. * BatchNorm
  4. * GroupNorm
  5. * LayerNorm
  6. This allows swapping with alternative layers that are natively both norm + act such as
  7. * EvoNorm (evo_norm.py)
  8. * FilterResponseNorm (filter_response_norm.py)
  9. * InplaceABN (inplace_abn.py)
  10. Hacked together by / Copyright 2022 Ross Wightman
  11. """
  12. from typing import Any, Dict, List, Optional, Type, Union
  13. import torch
  14. from torch import nn as nn
  15. from torch.nn import functional as F
  16. from torchvision.ops.misc import FrozenBatchNorm2d
  17. from ._fx import register_notrace_module
  18. from .create_act import create_act_layer
  19. from .fast_norm import (
  20. is_fast_norm,
  21. fast_group_norm,
  22. fast_layer_norm,
  23. fast_rms_norm,
  24. rms_norm2d,
  25. fast_rms_norm2d,
  26. )
  27. from .norm import RmsNorm, RmsNorm2d
  28. from .trace_utils import _assert
  29. from .typing import LayerType
  30. try:
  31. from torch.nn.functional import rms_norm
  32. except ImportError:
  33. from .fast_norm import rms_norm
  34. def _create_act(
  35. act_layer: LayerType,
  36. act_kwargs: Dict[str, Any] = None,
  37. inplace: Optional[bool] = False,
  38. apply_act: bool = True,
  39. ) -> nn.Module:
  40. act_kwargs = act_kwargs or {}
  41. act_kwargs.setdefault('inplace', inplace)
  42. act = None
  43. if apply_act:
  44. act = create_act_layer(act_layer, **act_kwargs)
  45. return nn.Identity() if act is None else act
  46. @register_notrace_module
  47. class BatchNormAct2d(nn.BatchNorm2d):
  48. """BatchNorm + Activation
  49. This module performs BatchNorm + Activation in a manner that will remain backwards
  50. compatible with weights trained with separate bn, act. This is why we inherit from BN
  51. instead of composing it as a .bn member.
  52. """
  53. def __init__(
  54. self,
  55. num_features: int,
  56. eps: float = 1e-5,
  57. momentum: float = 0.1,
  58. affine: bool = True,
  59. track_running_stats: bool = True,
  60. apply_act: bool = True,
  61. act_layer: LayerType = nn.ReLU,
  62. act_kwargs: Dict[str, Any] = None,
  63. inplace: bool = True,
  64. drop_layer: Optional[Type[nn.Module]] = None,
  65. device=None,
  66. dtype=None,
  67. ):
  68. try:
  69. factory_kwargs = {'device': device, 'dtype': dtype}
  70. super().__init__(
  71. num_features,
  72. eps=eps,
  73. momentum=momentum,
  74. affine=affine,
  75. track_running_stats=track_running_stats,
  76. **factory_kwargs,
  77. )
  78. except TypeError:
  79. # NOTE for backwards compat with old PyTorch w/o factory device/dtype support
  80. super().__init__(
  81. num_features,
  82. eps=eps,
  83. momentum=momentum,
  84. affine=affine,
  85. track_running_stats=track_running_stats,
  86. )
  87. self.drop = drop_layer() if drop_layer is not None else nn.Identity()
  88. self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
  89. def forward(self, x):
  90. # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing
  91. _assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)')
  92. # exponential_average_factor is set to self.momentum
  93. # (when it is available) only so that it gets updated
  94. # in ONNX graph when this node is exported to ONNX.
  95. if self.momentum is None:
  96. exponential_average_factor = 0.0
  97. else:
  98. exponential_average_factor = self.momentum
  99. if self.training and self.track_running_stats:
  100. # TODO: if statement only here to tell the jit to skip emitting this when it is None
  101. if self.num_batches_tracked is not None: # type: ignore[has-type]
  102. self.num_batches_tracked.add_(1) # type: ignore[has-type]
  103. if self.momentum is None: # use cumulative moving average
  104. exponential_average_factor = 1.0 / float(self.num_batches_tracked)
  105. else: # use exponential moving average
  106. exponential_average_factor = self.momentum
  107. r"""
  108. Decide whether the mini-batch stats should be used for normalization rather than the buffers.
  109. Mini-batch stats are used in training mode, and in eval mode when buffers are None.
  110. """
  111. if self.training:
  112. bn_training = True
  113. else:
  114. bn_training = (self.running_mean is None) and (self.running_var is None)
  115. r"""
  116. Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
  117. passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
  118. used for normalization (i.e. in eval mode when buffers are not None).
  119. """
  120. x = F.batch_norm(
  121. x,
  122. # If buffers are not to be tracked, ensure that they won't be updated
  123. self.running_mean if not self.training or self.track_running_stats else None,
  124. self.running_var if not self.training or self.track_running_stats else None,
  125. self.weight,
  126. self.bias,
  127. bn_training,
  128. exponential_average_factor,
  129. self.eps,
  130. )
  131. x = self.drop(x)
  132. x = self.act(x)
  133. return x
  134. @register_notrace_module
  135. class SyncBatchNormAct(nn.SyncBatchNorm):
  136. # Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
  137. # This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
  138. # but ONLY when used in conjunction with the timm conversion function below.
  139. # Do not create this module directly or use the PyTorch conversion function.
  140. def forward(self, x: torch.Tensor) -> torch.Tensor:
  141. x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine
  142. if hasattr(self, "drop"):
  143. x = self.drop(x)
  144. if hasattr(self, "act"):
  145. x = self.act(x)
  146. return x
  147. def convert_sync_batchnorm(module, process_group=None):
  148. # convert both BatchNorm and BatchNormAct layers to Synchronized variants
  149. module_output = module
  150. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  151. if isinstance(module, BatchNormAct2d):
  152. # convert timm norm + act layer
  153. module_output = SyncBatchNormAct(
  154. module.num_features,
  155. module.eps,
  156. module.momentum,
  157. module.affine,
  158. module.track_running_stats,
  159. process_group=process_group,
  160. )
  161. # set act and drop attr from the original module
  162. module_output.act = module.act
  163. module_output.drop = module.drop
  164. else:
  165. # convert standard BatchNorm layers
  166. module_output = torch.nn.SyncBatchNorm(
  167. module.num_features,
  168. module.eps,
  169. module.momentum,
  170. module.affine,
  171. module.track_running_stats,
  172. process_group,
  173. )
  174. if module.affine:
  175. with torch.no_grad():
  176. module_output.weight = module.weight
  177. module_output.bias = module.bias
  178. module_output.running_mean = module.running_mean
  179. module_output.running_var = module.running_var
  180. module_output.num_batches_tracked = module.num_batches_tracked
  181. module_output.training = module.training
  182. if hasattr(module, "qconfig"):
  183. module_output.qconfig = module.qconfig
  184. for name, child in module.named_children():
  185. module_output.add_module(name, convert_sync_batchnorm(child, process_group))
  186. del module
  187. return module_output
  188. @register_notrace_module
  189. class FrozenBatchNormAct2d(torch.nn.Module):
  190. """
  191. BatchNormAct2d where the batch statistics and the affine parameters are fixed
  192. Args:
  193. num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
  194. eps (float): a value added to the denominator for numerical stability. Default: 1e-5
  195. """
  196. def __init__(
  197. self,
  198. num_features: int,
  199. eps: float = 1e-5,
  200. apply_act: bool = True,
  201. act_layer: LayerType = nn.ReLU,
  202. act_kwargs: Dict[str, Any] = None,
  203. inplace: bool = True,
  204. drop_layer: Optional[Type[nn.Module]] = None,
  205. device=None,
  206. dtype=None,
  207. ):
  208. dd = {'device': device, 'dtype': dtype}
  209. super().__init__()
  210. self.eps = eps
  211. self.register_buffer("weight", torch.ones(num_features, **dd))
  212. self.register_buffer("bias", torch.zeros(num_features, **dd))
  213. self.register_buffer("running_mean", torch.zeros(num_features, **dd))
  214. self.register_buffer("running_var", torch.ones(num_features, **dd))
  215. self.drop = drop_layer() if drop_layer is not None else nn.Identity()
  216. self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
  217. def _load_from_state_dict(
  218. self,
  219. state_dict: dict,
  220. prefix: str,
  221. local_metadata: dict,
  222. strict: bool,
  223. missing_keys: List[str],
  224. unexpected_keys: List[str],
  225. error_msgs: List[str],
  226. ):
  227. num_batches_tracked_key = prefix + "num_batches_tracked"
  228. if num_batches_tracked_key in state_dict:
  229. del state_dict[num_batches_tracked_key]
  230. super()._load_from_state_dict(
  231. state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
  232. )
  233. def forward(self, x: torch.Tensor) -> torch.Tensor:
  234. # move reshapes to the beginning
  235. # to make it fuser-friendly
  236. w = self.weight.reshape(1, -1, 1, 1)
  237. b = self.bias.reshape(1, -1, 1, 1)
  238. rv = self.running_var.reshape(1, -1, 1, 1)
  239. rm = self.running_mean.reshape(1, -1, 1, 1)
  240. scale = w * (rv + self.eps).rsqrt()
  241. bias = b - rm * scale
  242. x = x * scale + bias
  243. x = self.act(self.drop(x))
  244. return x
  245. def __repr__(self) -> str:
  246. return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps}, act={self.act})"
  247. def freeze_batch_norm_2d(module):
  248. """
  249. Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers
  250. of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively.
  251. Args:
  252. module (torch.nn.Module): Any PyTorch module.
  253. Returns:
  254. torch.nn.Module: Resulting module
  255. Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
  256. """
  257. res = module
  258. if isinstance(module, (BatchNormAct2d, SyncBatchNormAct)):
  259. res = FrozenBatchNormAct2d(module.num_features)
  260. res.num_features = module.num_features
  261. res.affine = module.affine
  262. if module.affine:
  263. res.weight.data = module.weight.data.clone().detach()
  264. res.bias.data = module.bias.data.clone().detach()
  265. res.running_mean.data = module.running_mean.data
  266. res.running_var.data = module.running_var.data
  267. res.eps = module.eps
  268. res.drop = module.drop
  269. res.act = module.act
  270. elif isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
  271. res = FrozenBatchNorm2d(module.num_features)
  272. res.num_features = module.num_features
  273. res.affine = module.affine
  274. if module.affine:
  275. res.weight.data = module.weight.data.clone().detach()
  276. res.bias.data = module.bias.data.clone().detach()
  277. res.running_mean.data = module.running_mean.data
  278. res.running_var.data = module.running_var.data
  279. res.eps = module.eps
  280. else:
  281. for name, child in module.named_children():
  282. new_child = freeze_batch_norm_2d(child)
  283. if new_child is not child:
  284. res.add_module(name, new_child)
  285. return res
  286. def unfreeze_batch_norm_2d(module):
  287. """
  288. Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
  289. of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
  290. recursively and submodules are converted in place.
  291. Args:
  292. module (torch.nn.Module): Any PyTorch module.
  293. Returns:
  294. torch.nn.Module: Resulting module
  295. Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
  296. """
  297. res = module
  298. if isinstance(module, FrozenBatchNormAct2d):
  299. res = BatchNormAct2d(module.num_features)
  300. if module.affine:
  301. res.weight.data = module.weight.data.clone().detach()
  302. res.bias.data = module.bias.data.clone().detach()
  303. res.running_mean.data = module.running_mean.data
  304. res.running_var.data = module.running_var.data
  305. res.eps = module.eps
  306. res.drop = module.drop
  307. res.act = module.act
  308. elif isinstance(module, FrozenBatchNorm2d):
  309. res = torch.nn.BatchNorm2d(module.num_features)
  310. if module.affine:
  311. res.weight.data = module.weight.data.clone().detach()
  312. res.bias.data = module.bias.data.clone().detach()
  313. res.running_mean.data = module.running_mean.data
  314. res.running_var.data = module.running_var.data
  315. res.eps = module.eps
  316. else:
  317. for name, child in module.named_children():
  318. new_child = unfreeze_batch_norm_2d(child)
  319. if new_child is not child:
  320. res.add_module(name, new_child)
  321. return res
  322. def _num_groups(num_channels: int, num_groups: int, group_size: int):
  323. if group_size:
  324. assert num_channels % group_size == 0
  325. return num_channels // group_size
  326. return num_groups
  327. class GroupNormAct(nn.GroupNorm):
  328. _fast_norm: torch.jit.Final[bool]
  329. # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
  330. def __init__(
  331. self,
  332. num_channels: int,
  333. num_groups: int = 32,
  334. eps: float = 1e-5,
  335. affine: bool = True,
  336. group_size: Optional[int] = None,
  337. apply_act: bool = True,
  338. act_layer: LayerType = nn.ReLU,
  339. act_kwargs: Dict[str, Any] = None,
  340. inplace: bool = True,
  341. drop_layer: Optional[Type[nn.Module]] = None,
  342. device=None,
  343. dtype=None,
  344. ):
  345. super().__init__(
  346. _num_groups(num_channels, num_groups, group_size),
  347. num_channels,
  348. eps=eps,
  349. affine=affine,
  350. device=device,
  351. dtype=dtype,
  352. )
  353. self.drop = drop_layer() if drop_layer is not None else nn.Identity()
  354. self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
  355. self._fast_norm = is_fast_norm()
  356. def forward(self, x):
  357. if self._fast_norm:
  358. x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
  359. else:
  360. x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
  361. x = self.drop(x)
  362. x = self.act(x)
  363. return x
  364. class GroupNorm1Act(nn.GroupNorm):
  365. _fast_norm: torch.jit.Final[bool]
  366. def __init__(
  367. self,
  368. num_channels: int,
  369. eps: float = 1e-5,
  370. affine: bool = True,
  371. apply_act: bool = True,
  372. act_layer: LayerType = nn.ReLU,
  373. act_kwargs: Dict[str, Any] = None,
  374. inplace: bool = True,
  375. drop_layer: Optional[Type[nn.Module]] = None,
  376. device=None,
  377. dtype=None,
  378. ):
  379. super().__init__(1, num_channels, eps=eps, affine=affine, device=device, dtype=dtype)
  380. self.drop = drop_layer() if drop_layer is not None else nn.Identity()
  381. self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
  382. self._fast_norm = is_fast_norm()
  383. def forward(self, x):
  384. if self._fast_norm:
  385. x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
  386. else:
  387. x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
  388. x = self.drop(x)
  389. x = self.act(x)
  390. return x
  391. class LayerNormAct(nn.LayerNorm):
  392. _fast_norm: torch.jit.Final[bool]
  393. def __init__(
  394. self,
  395. normalization_shape: Union[int, List[int], torch.Size],
  396. eps: float = 1e-5,
  397. affine: bool = True,
  398. apply_act: bool = True,
  399. act_layer: LayerType = nn.ReLU,
  400. act_kwargs: Dict[str, Any] = None,
  401. inplace: bool = True,
  402. drop_layer: Optional[Type[nn.Module]] = None,
  403. **kwargs,
  404. ):
  405. super().__init__(normalization_shape, eps=eps, elementwise_affine=affine, **kwargs)
  406. self.drop = drop_layer() if drop_layer is not None else nn.Identity()
  407. self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
  408. self._fast_norm = is_fast_norm()
  409. def forward(self, x):
  410. if self._fast_norm:
  411. x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  412. else:
  413. x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  414. x = self.drop(x)
  415. x = self.act(x)
  416. return x
  417. class LayerNormActFp32(nn.LayerNorm):
  418. def __init__(
  419. self,
  420. normalization_shape: Union[int, List[int], torch.Size],
  421. eps: float = 1e-5,
  422. affine: bool = True,
  423. apply_act: bool = True,
  424. act_layer: LayerType = nn.ReLU,
  425. act_kwargs: Dict[str, Any] = None,
  426. inplace: bool = True,
  427. drop_layer: Optional[Type[nn.Module]] = None,
  428. **kwargs,
  429. ):
  430. super().__init__(normalization_shape, eps=eps, elementwise_affine=affine, **kwargs)
  431. self.drop = drop_layer() if drop_layer is not None else nn.Identity()
  432. self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
  433. def forward(self, x):
  434. weight = self.weight.float() if self.weight is not None else None
  435. bias = self.bias.float() if self.bias is not None else None
  436. x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype)
  437. x = self.drop(x)
  438. x = self.act(x)
  439. return x
  440. class LayerNormAct2d(nn.LayerNorm):
  441. _fast_norm: torch.jit.Final[bool]
  442. def __init__(
  443. self,
  444. num_channels: int,
  445. eps: float = 1e-5,
  446. affine: bool = True,
  447. apply_act: bool = True,
  448. act_layer: LayerType = nn.ReLU,
  449. act_kwargs: Dict[str, Any] = None,
  450. inplace: bool = True,
  451. drop_layer: Optional[Type[nn.Module]] = None,
  452. **kwargs,
  453. ):
  454. super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
  455. self.drop = drop_layer() if drop_layer is not None else nn.Identity()
  456. self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
  457. self._fast_norm = is_fast_norm()
  458. def forward(self, x):
  459. x = x.permute(0, 2, 3, 1)
  460. if self._fast_norm:
  461. x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  462. else:
  463. x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  464. x = x.permute(0, 3, 1, 2)
  465. x = self.drop(x)
  466. x = self.act(x)
  467. return x
  468. class LayerNormAct2dFp32(nn.LayerNorm):
  469. def __init__(
  470. self,
  471. num_channels: int,
  472. eps: float = 1e-5,
  473. affine: bool = True,
  474. apply_act: bool = True,
  475. act_layer: LayerType = nn.ReLU,
  476. act_kwargs: Dict[str, Any] = None,
  477. inplace: bool = True,
  478. drop_layer: Optional[Type[nn.Module]] = None,
  479. **kwargs,
  480. ):
  481. super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
  482. self.drop = drop_layer() if drop_layer is not None else nn.Identity()
  483. self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
  484. def forward(self, x):
  485. x = x.permute(0, 2, 3, 1)
  486. weight = self.weight.float() if self.weight is not None else None
  487. bias = self.bias.float() if self.bias is not None else None
  488. x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype)
  489. x = x.permute(0, 3, 1, 2)
  490. x = self.drop(x)
  491. x = self.act(x)
  492. return x
  493. class RmsNormAct(RmsNorm):
  494. """ RMSNorm + Activation for '2D' NCHW tensors
  495. NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
  496. on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
  497. like https://github.com/pytorch/pytorch/pull/150576 lands.
  498. """
  499. def __init__(
  500. self,
  501. num_channels: int,
  502. eps: float = 1e-6,
  503. affine: bool = True,
  504. apply_act: bool = True,
  505. act_layer: LayerType = nn.ReLU,
  506. act_kwargs: Dict[str, Any] = None,
  507. inplace: bool = True,
  508. drop_layer: Optional[Type[nn.Module]] = None,
  509. **kwargs,
  510. ):
  511. super().__init__(channels=num_channels, eps=eps, affine=affine, **kwargs)
  512. self.drop = drop_layer() if drop_layer is not None else nn.Identity()
  513. self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
  514. self._fast_norm = is_fast_norm()
  515. def forward(self, x: torch.Tensor) -> torch.Tensor:
  516. if self._fast_norm:
  517. x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
  518. else:
  519. x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
  520. x = self.drop(x)
  521. x = self.act(x)
  522. return x
  523. class RmsNormActFp32(RmsNorm):
  524. """ RMSNorm + Activation for '2D' NCHW tensors
  525. NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
  526. on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
  527. like https://github.com/pytorch/pytorch/pull/150576 lands.
  528. """
  529. def __init__(
  530. self,
  531. num_channels: int,
  532. eps: float = 1e-6,
  533. affine: bool = True,
  534. apply_act: bool = True,
  535. act_layer: LayerType = nn.ReLU,
  536. act_kwargs: Dict[str, Any] = None,
  537. inplace: bool = True,
  538. drop_layer: Optional[Type[nn.Module]] = None,
  539. **kwargs,
  540. ):
  541. super().__init__(channels=num_channels, eps=eps, affine=affine, **kwargs)
  542. self.drop = drop_layer() if drop_layer is not None else nn.Identity()
  543. self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
  544. def forward(self, x: torch.Tensor) -> torch.Tensor:
  545. weight = self.weight.float() if self.weight is not None else None
  546. x = rms_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
  547. x = self.drop(x)
  548. x = self.act(x)
  549. return x
  550. class RmsNormAct2d(RmsNorm2d):
  551. """ RMSNorm + Activation for '2D' NCHW tensors
  552. NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
  553. on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
  554. like https://github.com/pytorch/pytorch/pull/150576 lands.
  555. """
  556. def __init__(
  557. self,
  558. num_channels: int,
  559. eps: float = 1e-6,
  560. affine: bool = True,
  561. apply_act: bool = True,
  562. act_layer: LayerType = nn.ReLU,
  563. act_kwargs: Dict[str, Any] = None,
  564. inplace: bool = True,
  565. drop_layer: Optional[Type[nn.Module]] = None,
  566. **kwargs,
  567. ):
  568. super().__init__(channels=num_channels, eps=eps, affine=affine, **kwargs)
  569. self.drop = drop_layer() if drop_layer is not None else nn.Identity()
  570. self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
  571. self._fast_norm = is_fast_norm()
  572. def forward(self, x: torch.Tensor) -> torch.Tensor:
  573. if self._fast_norm:
  574. x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
  575. else:
  576. x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
  577. x = self.drop(x)
  578. x = self.act(x)
  579. return x
  580. class RmsNormAct2dFp32(RmsNorm2d):
  581. """ RMSNorm + Activation for '2D' NCHW tensors
  582. NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
  583. on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
  584. like https://github.com/pytorch/pytorch/pull/150576 lands.
  585. """
  586. def __init__(
  587. self,
  588. num_channels: int,
  589. eps: float = 1e-6,
  590. affine: bool = True,
  591. apply_act: bool = True,
  592. act_layer: LayerType = nn.ReLU,
  593. act_kwargs: Dict[str, Any] = None,
  594. inplace: bool = True,
  595. drop_layer: Optional[Type[nn.Module]] = None,
  596. **kwargs,
  597. ):
  598. super().__init__(channels=num_channels, eps=eps, affine=affine, **kwargs)
  599. self.drop = drop_layer() if drop_layer is not None else nn.Identity()
  600. self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
  601. def forward(self, x: torch.Tensor) -> torch.Tensor:
  602. weight = self.weight.float() if self.weight is not None else None
  603. x = rms_norm2d(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
  604. x = self.drop(x)
  605. x = self.act(x)
  606. return x