| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690 |
- """ Normalization + Activation Layers
- Provides Norm+Act fns for standard PyTorch norm layers such as
- * BatchNorm
- * GroupNorm
- * LayerNorm
- This allows swapping with alternative layers that are natively both norm + act such as
- * EvoNorm (evo_norm.py)
- * FilterResponseNorm (filter_response_norm.py)
- * InplaceABN (inplace_abn.py)
- Hacked together by / Copyright 2022 Ross Wightman
- """
- from typing import Any, Dict, List, Optional, Type, Union
- import torch
- from torch import nn as nn
- from torch.nn import functional as F
- from torchvision.ops.misc import FrozenBatchNorm2d
- from ._fx import register_notrace_module
- from .create_act import create_act_layer
- from .fast_norm import (
- is_fast_norm,
- fast_group_norm,
- fast_layer_norm,
- fast_rms_norm,
- rms_norm2d,
- fast_rms_norm2d,
- )
- from .norm import RmsNorm, RmsNorm2d
- from .trace_utils import _assert
- from .typing import LayerType
- try:
- from torch.nn.functional import rms_norm
- except ImportError:
- from .fast_norm import rms_norm
- def _create_act(
- act_layer: LayerType,
- act_kwargs: Dict[str, Any] = None,
- inplace: Optional[bool] = False,
- apply_act: bool = True,
- ) -> nn.Module:
- act_kwargs = act_kwargs or {}
- act_kwargs.setdefault('inplace', inplace)
- act = None
- if apply_act:
- act = create_act_layer(act_layer, **act_kwargs)
- return nn.Identity() if act is None else act
- @register_notrace_module
- class BatchNormAct2d(nn.BatchNorm2d):
- """BatchNorm + Activation
- This module performs BatchNorm + Activation in a manner that will remain backwards
- compatible with weights trained with separate bn, act. This is why we inherit from BN
- instead of composing it as a .bn member.
- """
- def __init__(
- self,
- num_features: int,
- eps: float = 1e-5,
- momentum: float = 0.1,
- affine: bool = True,
- track_running_stats: bool = True,
- apply_act: bool = True,
- act_layer: LayerType = nn.ReLU,
- act_kwargs: Dict[str, Any] = None,
- inplace: bool = True,
- drop_layer: Optional[Type[nn.Module]] = None,
- device=None,
- dtype=None,
- ):
- try:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(
- num_features,
- eps=eps,
- momentum=momentum,
- affine=affine,
- track_running_stats=track_running_stats,
- **factory_kwargs,
- )
- except TypeError:
- # NOTE for backwards compat with old PyTorch w/o factory device/dtype support
- super().__init__(
- num_features,
- eps=eps,
- momentum=momentum,
- affine=affine,
- track_running_stats=track_running_stats,
- )
- self.drop = drop_layer() if drop_layer is not None else nn.Identity()
- self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
- def forward(self, x):
- # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing
- _assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)')
- # exponential_average_factor is set to self.momentum
- # (when it is available) only so that it gets updated
- # in ONNX graph when this node is exported to ONNX.
- if self.momentum is None:
- exponential_average_factor = 0.0
- else:
- exponential_average_factor = self.momentum
- if self.training and self.track_running_stats:
- # TODO: if statement only here to tell the jit to skip emitting this when it is None
- if self.num_batches_tracked is not None: # type: ignore[has-type]
- self.num_batches_tracked.add_(1) # type: ignore[has-type]
- if self.momentum is None: # use cumulative moving average
- exponential_average_factor = 1.0 / float(self.num_batches_tracked)
- else: # use exponential moving average
- exponential_average_factor = self.momentum
- r"""
- Decide whether the mini-batch stats should be used for normalization rather than the buffers.
- Mini-batch stats are used in training mode, and in eval mode when buffers are None.
- """
- if self.training:
- bn_training = True
- else:
- bn_training = (self.running_mean is None) and (self.running_var is None)
- r"""
- Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
- passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
- used for normalization (i.e. in eval mode when buffers are not None).
- """
- x = F.batch_norm(
- x,
- # If buffers are not to be tracked, ensure that they won't be updated
- self.running_mean if not self.training or self.track_running_stats else None,
- self.running_var if not self.training or self.track_running_stats else None,
- self.weight,
- self.bias,
- bn_training,
- exponential_average_factor,
- self.eps,
- )
- x = self.drop(x)
- x = self.act(x)
- return x
- @register_notrace_module
- class SyncBatchNormAct(nn.SyncBatchNorm):
- # Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
- # This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
- # but ONLY when used in conjunction with the timm conversion function below.
- # Do not create this module directly or use the PyTorch conversion function.
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine
- if hasattr(self, "drop"):
- x = self.drop(x)
- if hasattr(self, "act"):
- x = self.act(x)
- return x
- def convert_sync_batchnorm(module, process_group=None):
- # convert both BatchNorm and BatchNormAct layers to Synchronized variants
- module_output = module
- if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
- if isinstance(module, BatchNormAct2d):
- # convert timm norm + act layer
- module_output = SyncBatchNormAct(
- module.num_features,
- module.eps,
- module.momentum,
- module.affine,
- module.track_running_stats,
- process_group=process_group,
- )
- # set act and drop attr from the original module
- module_output.act = module.act
- module_output.drop = module.drop
- else:
- # convert standard BatchNorm layers
- module_output = torch.nn.SyncBatchNorm(
- module.num_features,
- module.eps,
- module.momentum,
- module.affine,
- module.track_running_stats,
- process_group,
- )
- if module.affine:
- with torch.no_grad():
- module_output.weight = module.weight
- module_output.bias = module.bias
- module_output.running_mean = module.running_mean
- module_output.running_var = module.running_var
- module_output.num_batches_tracked = module.num_batches_tracked
- module_output.training = module.training
- if hasattr(module, "qconfig"):
- module_output.qconfig = module.qconfig
- for name, child in module.named_children():
- module_output.add_module(name, convert_sync_batchnorm(child, process_group))
- del module
- return module_output
- @register_notrace_module
- class FrozenBatchNormAct2d(torch.nn.Module):
- """
- BatchNormAct2d where the batch statistics and the affine parameters are fixed
- Args:
- num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
- eps (float): a value added to the denominator for numerical stability. Default: 1e-5
- """
- def __init__(
- self,
- num_features: int,
- eps: float = 1e-5,
- apply_act: bool = True,
- act_layer: LayerType = nn.ReLU,
- act_kwargs: Dict[str, Any] = None,
- inplace: bool = True,
- drop_layer: Optional[Type[nn.Module]] = None,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.eps = eps
- self.register_buffer("weight", torch.ones(num_features, **dd))
- self.register_buffer("bias", torch.zeros(num_features, **dd))
- self.register_buffer("running_mean", torch.zeros(num_features, **dd))
- self.register_buffer("running_var", torch.ones(num_features, **dd))
- self.drop = drop_layer() if drop_layer is not None else nn.Identity()
- self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
- def _load_from_state_dict(
- self,
- state_dict: dict,
- prefix: str,
- local_metadata: dict,
- strict: bool,
- missing_keys: List[str],
- unexpected_keys: List[str],
- error_msgs: List[str],
- ):
- num_batches_tracked_key = prefix + "num_batches_tracked"
- if num_batches_tracked_key in state_dict:
- del state_dict[num_batches_tracked_key]
- super()._load_from_state_dict(
- state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- # move reshapes to the beginning
- # to make it fuser-friendly
- w = self.weight.reshape(1, -1, 1, 1)
- b = self.bias.reshape(1, -1, 1, 1)
- rv = self.running_var.reshape(1, -1, 1, 1)
- rm = self.running_mean.reshape(1, -1, 1, 1)
- scale = w * (rv + self.eps).rsqrt()
- bias = b - rm * scale
- x = x * scale + bias
- x = self.act(self.drop(x))
- return x
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps}, act={self.act})"
- def freeze_batch_norm_2d(module):
- """
- Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers
- of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively.
- Args:
- module (torch.nn.Module): Any PyTorch module.
- Returns:
- torch.nn.Module: Resulting module
- Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
- """
- res = module
- if isinstance(module, (BatchNormAct2d, SyncBatchNormAct)):
- res = FrozenBatchNormAct2d(module.num_features)
- res.num_features = module.num_features
- res.affine = module.affine
- if module.affine:
- res.weight.data = module.weight.data.clone().detach()
- res.bias.data = module.bias.data.clone().detach()
- res.running_mean.data = module.running_mean.data
- res.running_var.data = module.running_var.data
- res.eps = module.eps
- res.drop = module.drop
- res.act = module.act
- elif isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
- res = FrozenBatchNorm2d(module.num_features)
- res.num_features = module.num_features
- res.affine = module.affine
- if module.affine:
- res.weight.data = module.weight.data.clone().detach()
- res.bias.data = module.bias.data.clone().detach()
- res.running_mean.data = module.running_mean.data
- res.running_var.data = module.running_var.data
- res.eps = module.eps
- else:
- for name, child in module.named_children():
- new_child = freeze_batch_norm_2d(child)
- if new_child is not child:
- res.add_module(name, new_child)
- return res
- def unfreeze_batch_norm_2d(module):
- """
- Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
- of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
- recursively and submodules are converted in place.
- Args:
- module (torch.nn.Module): Any PyTorch module.
- Returns:
- torch.nn.Module: Resulting module
- Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
- """
- res = module
- if isinstance(module, FrozenBatchNormAct2d):
- res = BatchNormAct2d(module.num_features)
- if module.affine:
- res.weight.data = module.weight.data.clone().detach()
- res.bias.data = module.bias.data.clone().detach()
- res.running_mean.data = module.running_mean.data
- res.running_var.data = module.running_var.data
- res.eps = module.eps
- res.drop = module.drop
- res.act = module.act
- elif isinstance(module, FrozenBatchNorm2d):
- res = torch.nn.BatchNorm2d(module.num_features)
- if module.affine:
- res.weight.data = module.weight.data.clone().detach()
- res.bias.data = module.bias.data.clone().detach()
- res.running_mean.data = module.running_mean.data
- res.running_var.data = module.running_var.data
- res.eps = module.eps
- else:
- for name, child in module.named_children():
- new_child = unfreeze_batch_norm_2d(child)
- if new_child is not child:
- res.add_module(name, new_child)
- return res
- def _num_groups(num_channels: int, num_groups: int, group_size: int):
- if group_size:
- assert num_channels % group_size == 0
- return num_channels // group_size
- return num_groups
- class GroupNormAct(nn.GroupNorm):
- _fast_norm: torch.jit.Final[bool]
- # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
- def __init__(
- self,
- num_channels: int,
- num_groups: int = 32,
- eps: float = 1e-5,
- affine: bool = True,
- group_size: Optional[int] = None,
- apply_act: bool = True,
- act_layer: LayerType = nn.ReLU,
- act_kwargs: Dict[str, Any] = None,
- inplace: bool = True,
- drop_layer: Optional[Type[nn.Module]] = None,
- device=None,
- dtype=None,
- ):
- super().__init__(
- _num_groups(num_channels, num_groups, group_size),
- num_channels,
- eps=eps,
- affine=affine,
- device=device,
- dtype=dtype,
- )
- self.drop = drop_layer() if drop_layer is not None else nn.Identity()
- self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
- self._fast_norm = is_fast_norm()
- def forward(self, x):
- if self._fast_norm:
- x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
- else:
- x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
- x = self.drop(x)
- x = self.act(x)
- return x
- class GroupNorm1Act(nn.GroupNorm):
- _fast_norm: torch.jit.Final[bool]
- def __init__(
- self,
- num_channels: int,
- eps: float = 1e-5,
- affine: bool = True,
- apply_act: bool = True,
- act_layer: LayerType = nn.ReLU,
- act_kwargs: Dict[str, Any] = None,
- inplace: bool = True,
- drop_layer: Optional[Type[nn.Module]] = None,
- device=None,
- dtype=None,
- ):
- super().__init__(1, num_channels, eps=eps, affine=affine, device=device, dtype=dtype)
- self.drop = drop_layer() if drop_layer is not None else nn.Identity()
- self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
- self._fast_norm = is_fast_norm()
- def forward(self, x):
- if self._fast_norm:
- x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
- else:
- x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
- x = self.drop(x)
- x = self.act(x)
- return x
- class LayerNormAct(nn.LayerNorm):
- _fast_norm: torch.jit.Final[bool]
- def __init__(
- self,
- normalization_shape: Union[int, List[int], torch.Size],
- eps: float = 1e-5,
- affine: bool = True,
- apply_act: bool = True,
- act_layer: LayerType = nn.ReLU,
- act_kwargs: Dict[str, Any] = None,
- inplace: bool = True,
- drop_layer: Optional[Type[nn.Module]] = None,
- **kwargs,
- ):
- super().__init__(normalization_shape, eps=eps, elementwise_affine=affine, **kwargs)
- self.drop = drop_layer() if drop_layer is not None else nn.Identity()
- self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
- self._fast_norm = is_fast_norm()
- def forward(self, x):
- if self._fast_norm:
- x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
- else:
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
- x = self.drop(x)
- x = self.act(x)
- return x
- class LayerNormActFp32(nn.LayerNorm):
- def __init__(
- self,
- normalization_shape: Union[int, List[int], torch.Size],
- eps: float = 1e-5,
- affine: bool = True,
- apply_act: bool = True,
- act_layer: LayerType = nn.ReLU,
- act_kwargs: Dict[str, Any] = None,
- inplace: bool = True,
- drop_layer: Optional[Type[nn.Module]] = None,
- **kwargs,
- ):
- super().__init__(normalization_shape, eps=eps, elementwise_affine=affine, **kwargs)
- self.drop = drop_layer() if drop_layer is not None else nn.Identity()
- self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
- def forward(self, x):
- weight = self.weight.float() if self.weight is not None else None
- bias = self.bias.float() if self.bias is not None else None
- x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype)
- x = self.drop(x)
- x = self.act(x)
- return x
- class LayerNormAct2d(nn.LayerNorm):
- _fast_norm: torch.jit.Final[bool]
- def __init__(
- self,
- num_channels: int,
- eps: float = 1e-5,
- affine: bool = True,
- apply_act: bool = True,
- act_layer: LayerType = nn.ReLU,
- act_kwargs: Dict[str, Any] = None,
- inplace: bool = True,
- drop_layer: Optional[Type[nn.Module]] = None,
- **kwargs,
- ):
- super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
- self.drop = drop_layer() if drop_layer is not None else nn.Identity()
- self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
- self._fast_norm = is_fast_norm()
- def forward(self, x):
- x = x.permute(0, 2, 3, 1)
- if self._fast_norm:
- x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
- else:
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
- x = x.permute(0, 3, 1, 2)
- x = self.drop(x)
- x = self.act(x)
- return x
- class LayerNormAct2dFp32(nn.LayerNorm):
- def __init__(
- self,
- num_channels: int,
- eps: float = 1e-5,
- affine: bool = True,
- apply_act: bool = True,
- act_layer: LayerType = nn.ReLU,
- act_kwargs: Dict[str, Any] = None,
- inplace: bool = True,
- drop_layer: Optional[Type[nn.Module]] = None,
- **kwargs,
- ):
- super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
- self.drop = drop_layer() if drop_layer is not None else nn.Identity()
- self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
- def forward(self, x):
- x = x.permute(0, 2, 3, 1)
- weight = self.weight.float() if self.weight is not None else None
- bias = self.bias.float() if self.bias is not None else None
- x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype)
- x = x.permute(0, 3, 1, 2)
- x = self.drop(x)
- x = self.act(x)
- return x
- class RmsNormAct(RmsNorm):
- """ RMSNorm + Activation for '2D' NCHW tensors
- NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
- on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
- like https://github.com/pytorch/pytorch/pull/150576 lands.
- """
- def __init__(
- self,
- num_channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- apply_act: bool = True,
- act_layer: LayerType = nn.ReLU,
- act_kwargs: Dict[str, Any] = None,
- inplace: bool = True,
- drop_layer: Optional[Type[nn.Module]] = None,
- **kwargs,
- ):
- super().__init__(channels=num_channels, eps=eps, affine=affine, **kwargs)
- self.drop = drop_layer() if drop_layer is not None else nn.Identity()
- self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
- self._fast_norm = is_fast_norm()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self._fast_norm:
- x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
- else:
- x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
- x = self.drop(x)
- x = self.act(x)
- return x
- class RmsNormActFp32(RmsNorm):
- """ RMSNorm + Activation for '2D' NCHW tensors
- NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
- on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
- like https://github.com/pytorch/pytorch/pull/150576 lands.
- """
- def __init__(
- self,
- num_channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- apply_act: bool = True,
- act_layer: LayerType = nn.ReLU,
- act_kwargs: Dict[str, Any] = None,
- inplace: bool = True,
- drop_layer: Optional[Type[nn.Module]] = None,
- **kwargs,
- ):
- super().__init__(channels=num_channels, eps=eps, affine=affine, **kwargs)
- self.drop = drop_layer() if drop_layer is not None else nn.Identity()
- self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- weight = self.weight.float() if self.weight is not None else None
- x = rms_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
- x = self.drop(x)
- x = self.act(x)
- return x
- class RmsNormAct2d(RmsNorm2d):
- """ RMSNorm + Activation for '2D' NCHW tensors
- NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
- on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
- like https://github.com/pytorch/pytorch/pull/150576 lands.
- """
- def __init__(
- self,
- num_channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- apply_act: bool = True,
- act_layer: LayerType = nn.ReLU,
- act_kwargs: Dict[str, Any] = None,
- inplace: bool = True,
- drop_layer: Optional[Type[nn.Module]] = None,
- **kwargs,
- ):
- super().__init__(channels=num_channels, eps=eps, affine=affine, **kwargs)
- self.drop = drop_layer() if drop_layer is not None else nn.Identity()
- self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
- self._fast_norm = is_fast_norm()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self._fast_norm:
- x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
- else:
- x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
- x = self.drop(x)
- x = self.act(x)
- return x
- class RmsNormAct2dFp32(RmsNorm2d):
- """ RMSNorm + Activation for '2D' NCHW tensors
- NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
- on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
- like https://github.com/pytorch/pytorch/pull/150576 lands.
- """
- def __init__(
- self,
- num_channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- apply_act: bool = True,
- act_layer: LayerType = nn.ReLU,
- act_kwargs: Dict[str, Any] = None,
- inplace: bool = True,
- drop_layer: Optional[Type[nn.Module]] = None,
- **kwargs,
- ):
- super().__init__(channels=num_channels, eps=eps, affine=affine, **kwargs)
- self.drop = drop_layer() if drop_layer is not None else nn.Identity()
- self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- weight = self.weight.float() if self.weight is not None else None
- x = rms_norm2d(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
- x = self.drop(x)
- x = self.act(x)
- return x
|