| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575 |
- """ Normalization layers and wrappers
- Norm layer definitions that support fast norm and consistent channel arg order (always first arg).
- Hacked together by / Copyright 2022 Ross Wightman
- """
- import numbers
- from typing import Tuple
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from .fast_norm import (
- is_fast_norm,
- fast_group_norm,
- fast_layer_norm,
- fast_rms_norm,
- rms_norm2d,
- fast_rms_norm2d,
- fast_simple_norm,
- simple_norm,
- )
- try:
- from torch.nn.functional import rms_norm
- except ImportError:
- from .fast_norm import rms_norm
- class GroupNorm(nn.GroupNorm):
- _fast_norm: torch.jit.Final[bool]
- def __init__(
- self,
- num_channels: int,
- num_groups: int = 32,
- eps: float = 1e-5,
- affine: bool = True,
- **kwargs,
- ):
- # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
- super().__init__(num_groups, num_channels, eps=eps, affine=affine, **kwargs)
- self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
- def forward(self, x):
- if self._fast_norm:
- return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
- else:
- return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
- class GroupNorm1(nn.GroupNorm):
- """ Group Normalization with 1 group.
- Input: tensor in shape [B, C, *]
- """
- _fast_norm: torch.jit.Final[bool]
- def __init__(self, num_channels: int, **kwargs):
- super().__init__(1, num_channels, **kwargs)
- self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self._fast_norm:
- return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
- else:
- return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
- class LayerNorm(nn.LayerNorm):
- """ LayerNorm w/ fast norm option
- """
- _fast_norm: torch.jit.Final[bool]
- def __init__(
- self,
- num_channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- **kwargs,
- ):
- super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
- self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self._fast_norm:
- x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
- else:
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
- return x
- class LayerNormFp32(nn.LayerNorm):
- """ LayerNorm
- """
- def __init__(
- self,
- num_channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- **kwargs,
- ):
- super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- weight = self.weight.float() if self.weight is not None else None
- bias = self.bias.float() if self.bias is not None else None
- x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype)
- return x
- class LayerNorm2d(nn.LayerNorm):
- """ LayerNorm for channels of '2D' spatial NCHW tensors """
- _fast_norm: torch.jit.Final[bool]
- def __init__(
- self,
- num_channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- **kwargs,
- ):
- super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
- self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = x.permute(0, 2, 3, 1)
- if self._fast_norm:
- x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
- else:
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
- x = x.permute(0, 3, 1, 2)
- return x
- class LayerNorm2dFp32(nn.LayerNorm):
- """ LayerNorm for channels of '2D' spatial NCHW tensors """
- def __init__(
- self,
- num_channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- **kwargs,
- ):
- super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = x.permute(0, 2, 3, 1)
- weight = self.weight.float() if self.weight is not None else None
- bias = self.bias.float() if self.bias is not None else None
- x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype)
- x = x.permute(0, 3, 1, 2)
- return x
- def _is_contiguous(tensor: torch.Tensor) -> bool:
- # jit is oh so lovely :/
- if torch.jit.is_scripting():
- return tensor.is_contiguous()
- else:
- return tensor.is_contiguous(memory_format=torch.contiguous_format)
- def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
- s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
- x = (x - u) * torch.rsqrt(s + eps)
- x = x * weight[:, None, None] + bias[:, None, None]
- return x
- def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
- u = x.mean(dim=1, keepdim=True)
- s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0)
- x = (x - u) * torch.rsqrt(s + eps)
- x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1)
- return x
- class LayerNormExp2d(nn.LayerNorm):
- """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
- Experimental implementation w/ manual norm for tensors non-contiguous tensors.
- This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last
- layout. However, benefits are not always clear and can perform worse on other GPUs.
- """
- def __init__(self, num_channels: int, eps: float = 1e-6):
- super().__init__(num_channels, eps=eps)
- def forward(self, x) -> torch.Tensor:
- if _is_contiguous(x):
- x = F.layer_norm(
- x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
- else:
- x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
- return x
- class RmsNorm(nn.Module):
- """ RmsNorm w/ fast (apex) norm if available
- """
- __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
- normalized_shape: Tuple[int, ...]
- eps: float
- elementwise_affine: bool
- _fast_norm: bool
- def __init__(
- self,
- channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- normalized_shape = channels
- if isinstance(normalized_shape, numbers.Integral):
- # mypy error: incompatible types in assignment
- normalized_shape = (normalized_shape,) # type: ignore[assignment]
- self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
- self.eps = eps
- self.elementwise_affine = affine
- self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
- if self.elementwise_affine:
- self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
- else:
- self.register_parameter('weight', None)
- self.reset_parameters()
- def reset_parameters(self) -> None:
- if self.elementwise_affine:
- nn.init.ones_(self.weight)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- # NOTE fast norm fallback needs our rms norm impl, so both paths through here.
- # Since there is no built-in PyTorch impl, always uses APEX RmsNorm if installed.
- if self._fast_norm:
- x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
- else:
- x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
- return x
- class RmsNormFp32(nn.Module):
- """ RmsNorm w/ fast (apex) norm if available
- """
- __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
- normalized_shape: Tuple[int, ...]
- eps: float
- elementwise_affine: bool
- def __init__(
- self,
- channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- normalized_shape = channels
- if isinstance(normalized_shape, numbers.Integral):
- # mypy error: incompatible types in assignment
- normalized_shape = (normalized_shape,) # type: ignore[assignment]
- self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
- self.eps = eps
- self.elementwise_affine = affine
- if self.elementwise_affine:
- self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
- else:
- self.register_parameter('weight', None)
- self.reset_parameters()
- def reset_parameters(self) -> None:
- if self.elementwise_affine:
- nn.init.ones_(self.weight)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- weight = self.weight.float() if self.weight is not None else None
- x = rms_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
- return x
- class RmsNorm2d(nn.Module):
- """ RmsNorm2D for NCHW tensors, w/ fast apex or cast norm if available
- NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
- on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
- like https://github.com/pytorch/pytorch/pull/150576 lands.
- """
- __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
- normalized_shape: Tuple[int, ...]
- eps: float
- elementwise_affine: bool
- _fast_norm: bool
- def __init__(
- self,
- channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- normalized_shape = channels
- if isinstance(normalized_shape, numbers.Integral):
- # mypy error: incompatible types in assignment
- normalized_shape = (normalized_shape,) # type: ignore[assignment]
- self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
- self.eps = eps
- self.elementwise_affine = affine
- self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
- if self.elementwise_affine:
- self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
- else:
- self.register_parameter('weight', None)
- self.reset_parameters()
- def reset_parameters(self) -> None:
- if self.elementwise_affine:
- nn.init.ones_(self.weight)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- # NOTE fast norm fallback needs our rms norm impl, so both paths through here.
- # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
- if self._fast_norm:
- x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
- else:
- x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
- return x
- class RmsNorm2dFp32(nn.Module):
- """ RmsNorm2D for NCHW tensors, w/ fast apex or cast norm if available
- NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
- on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
- like https://github.com/pytorch/pytorch/pull/150576 lands.
- """
- __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
- normalized_shape: Tuple[int, ...]
- eps: float
- elementwise_affine: bool
- def __init__(
- self,
- channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- normalized_shape = channels
- if isinstance(normalized_shape, numbers.Integral):
- # mypy error: incompatible types in assignment
- normalized_shape = (normalized_shape,) # type: ignore[assignment]
- self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
- self.eps = eps
- self.elementwise_affine = affine
- if self.elementwise_affine:
- self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
- else:
- self.register_parameter('weight', None)
- self.reset_parameters()
- def reset_parameters(self) -> None:
- if self.elementwise_affine:
- nn.init.ones_(self.weight)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- weight = self.weight.float() if self.weight is not None else None
- x = rms_norm2d(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
- return x
- class SimpleNorm(nn.Module):
- """ SimpleNorm (x / std(x))
- """
- __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
- normalized_shape: Tuple[int, ...]
- eps: float
- elementwise_affine: bool
- _fast_norm: bool
- def __init__(
- self,
- channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- normalized_shape = channels
- if isinstance(normalized_shape, numbers.Integral):
- # mypy error: incompatible types in assignment
- normalized_shape = (normalized_shape,) # type: ignore[assignment]
- self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
- self.eps = eps
- self.elementwise_affine = affine
- self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
- if self.elementwise_affine:
- self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
- else:
- self.register_parameter('weight', None)
- self.reset_parameters()
- def reset_parameters(self) -> None:
- if self.elementwise_affine:
- nn.init.ones_(self.weight)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self._fast_norm:
- x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
- else:
- x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
- return x
- class SimpleNormFp32(nn.Module):
- """ SimpleNorm (x / std(x))
- """
- __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
- normalized_shape: Tuple[int, ...]
- eps: float
- elementwise_affine: bool
- def __init__(
- self,
- channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- normalized_shape = channels
- if isinstance(normalized_shape, numbers.Integral):
- # mypy error: incompatible types in assignment
- normalized_shape = (normalized_shape,) # type: ignore[assignment]
- self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
- self.eps = eps
- self.elementwise_affine = affine
- if self.elementwise_affine:
- self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
- else:
- self.register_parameter('weight', None)
- self.reset_parameters()
- def reset_parameters(self) -> None:
- if self.elementwise_affine:
- nn.init.ones_(self.weight)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- weight = self.weight.float() if self.weight is not None else None
- x = simple_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
- return x
- class SimpleNorm2d(nn.Module):
- """ SimpleNorm for NCHW tensors
- """
- __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
- normalized_shape: Tuple[int, ...]
- eps: float
- elementwise_affine: bool
- _fast_norm: bool
- def __init__(
- self,
- channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- normalized_shape = channels
- if isinstance(normalized_shape, numbers.Integral):
- # mypy error: incompatible types in assignment
- normalized_shape = (normalized_shape,) # type: ignore[assignment]
- self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
- self.eps = eps
- self.elementwise_affine = affine
- self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
- if self.elementwise_affine:
- self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
- else:
- self.register_parameter('weight', None)
- self.reset_parameters()
- def reset_parameters(self) -> None:
- if self.elementwise_affine:
- nn.init.ones_(self.weight)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = x.permute(0, 2, 3, 1)
- if self._fast_norm:
- x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
- else:
- x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
- x = x.permute(0, 3, 1, 2)
- return x
- class SimpleNorm2dFp32(nn.Module):
- """ SimpleNorm for NCHW tensors
- """
- __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
- normalized_shape: Tuple[int, ...]
- eps: float
- elementwise_affine: bool
- def __init__(
- self,
- channels: int,
- eps: float = 1e-6,
- affine: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- normalized_shape = channels
- if isinstance(normalized_shape, numbers.Integral):
- # mypy error: incompatible types in assignment
- normalized_shape = (normalized_shape,) # type: ignore[assignment]
- self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
- self.eps = eps
- self.elementwise_affine = affine
- if self.elementwise_affine:
- self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd))
- else:
- self.register_parameter('weight', None)
- self.reset_parameters()
- def reset_parameters(self) -> None:
- if self.elementwise_affine:
- nn.init.ones_(self.weight)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = x.permute(0, 2, 3, 1)
- weight = self.weight.float() if self.weight is not None else None
- x = simple_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
- x = x.permute(0, 3, 1, 2)
- return x
|