| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259 |
- """ 'Fast' Normalization Functions
- For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32.
- Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast)
- Hacked together by / Copyright 2022 Ross Wightman
- """
- from typing import List, Optional
- import torch
- from torch.nn import functional as F
- try:
- from apex.normalization.fused_layer_norm import fused_layer_norm_affine
- has_apex = True
- except ImportError:
- has_apex = False
- try:
- from apex.normalization.fused_layer_norm import fused_rms_norm_affine, fused_rms_norm
- has_apex_rmsnorm = True
- except ImportError:
- has_apex_rmsnorm = False
- has_torch_rms_norm = hasattr(F, 'rms_norm')
- # fast (ie lower precision LN) can be disabled with this flag if issues crop up
- _USE_FAST_NORM = False # defaulting to False for now
- def get_autocast_dtype(device: str = 'cuda'):
- try:
- return torch.get_autocast_dtype(device)
- except (AttributeError, TypeError):
- # dispatch to older device specific fns, only covering cuda/cpu devices here
- if device == 'cpu':
- return torch.get_autocast_cpu_dtype()
- else:
- assert device == 'cuda'
- return torch.get_autocast_gpu_dtype()
- def is_autocast_enabled(device: str = 'cuda'):
- try:
- return torch.is_autocast_enabled(device)
- except TypeError:
- # dispatch to older device specific fns, only covering cuda/cpu devices here
- if device == 'cpu':
- return torch.is_autocast_cpu_enabled()
- else:
- assert device == 'cuda'
- return torch.is_autocast_enabled() # defaults cuda (only cuda on older pytorch)
- def is_fast_norm():
- return _USE_FAST_NORM
- def set_fast_norm(enable=True):
- global _USE_FAST_NORM
- _USE_FAST_NORM = enable
- def fast_group_norm(
- x: torch.Tensor,
- num_groups: int,
- weight: Optional[torch.Tensor] = None,
- bias: Optional[torch.Tensor] = None,
- eps: float = 1e-5
- ) -> torch.Tensor:
- if torch.jit.is_scripting():
- # currently cannot use is_autocast_enabled within torchscript
- return F.group_norm(x, num_groups, weight, bias, eps)
- if is_autocast_enabled(x.device.type):
- # normally native AMP casts GN inputs to float32
- # here we use the low precision autocast dtype
- dt = get_autocast_dtype(x.device.type)
- x, weight, bias = (
- x.to(dt),
- weight.to(dt) if weight is not None else None,
- bias.to(dt) if bias is not None else None,
- )
- with torch.amp.autocast(device_type=x.device.type, enabled=False):
- return F.group_norm(x, num_groups, weight, bias, eps)
- def fast_layer_norm(
- x: torch.Tensor,
- normalized_shape: List[int],
- weight: Optional[torch.Tensor] = None,
- bias: Optional[torch.Tensor] = None,
- eps: float = 1e-5
- ) -> torch.Tensor:
- if torch.jit.is_scripting():
- # currently cannot use is_autocast_enabled within torchscript
- return F.layer_norm(x, normalized_shape, weight, bias, eps)
- if has_apex:
- return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
- if is_autocast_enabled(x.device.type):
- # normally native AMP casts LN inputs to float32
- # apex LN does not, this is behaving like Apex
- dt = get_autocast_dtype(x.device.type)
- x, weight, bias = (
- x.to(dt),
- weight.to(dt) if weight is not None else None,
- bias.to(dt) if bias is not None else None,
- )
- with torch.amp.autocast(device_type=x.device.type, enabled=False):
- return F.layer_norm(x, normalized_shape, weight, bias, eps)
- def rms_norm(
- x: torch.Tensor,
- normalized_shape: List[int],
- weight: Optional[torch.Tensor] = None,
- eps: float = 1e-5,
- ):
- norm_ndim = len(normalized_shape)
- v = x.pow(2)
- if torch.jit.is_scripting():
- # ndim = len(x.shape)
- # dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
- # NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
- assert norm_ndim == 1
- v = torch.mean(v, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
- else:
- dims = tuple(range(-1, -norm_ndim - 1, -1))
- v = torch.mean(v, dim=dims, keepdim=True)
- x = x * torch.rsqrt(v + eps)
- if weight is not None:
- x = x * weight
- return x
- def fast_rms_norm(
- x: torch.Tensor,
- normalized_shape: List[int],
- weight: Optional[torch.Tensor] = None,
- eps: float = 1e-5,
- ) -> torch.Tensor:
- if torch.jit.is_scripting():
- # this must be by itself, cannot merge with has_apex_rmsnorm
- return rms_norm(x, normalized_shape, weight, eps)
- if has_apex_rmsnorm:
- if weight is None:
- return fused_rms_norm(x, normalized_shape, eps)
- else:
- return fused_rms_norm_affine(x, weight, normalized_shape, eps)
- if is_autocast_enabled(x.device.type):
- # normally native AMP casts LN inputs to float32 and leaves the output as float32
- # apex LN does not, this is behaving like Apex
- dt = get_autocast_dtype(x.device.type)
- x, weight = x.to(dt), weight.to(dt) if weight is not None else None
- with torch.amp.autocast(device_type=x.device.type, enabled=False):
- if has_torch_rms_norm:
- x = F.rms_norm(x, normalized_shape, weight, eps)
- else:
- x = rms_norm(x, normalized_shape, weight, eps)
- return x
- def rms_norm2d(
- x: torch.Tensor,
- normalized_shape: List[int],
- weight: Optional[torch.Tensor] = None,
- eps: float = 1e-5,
- ):
- assert len(normalized_shape) == 1
- v = x.pow(2)
- v = torch.mean(v, dim=1, keepdim=True)
- x = x * torch.rsqrt(v + eps)
- if weight is not None:
- x = x * weight.reshape(1, -1, 1, 1)
- return x
- def fast_rms_norm2d(
- x: torch.Tensor,
- normalized_shape: List[int],
- weight: Optional[torch.Tensor] = None,
- eps: float = 1e-5,
- ) -> torch.Tensor:
- if torch.jit.is_scripting():
- # this must be by itself, cannot merge with has_apex_rmsnorm
- return rms_norm2d(x, normalized_shape, weight, eps)
- if has_apex_rmsnorm:
- x = x.permute(0, 2, 3, 1)
- if weight is None:
- x = fused_rms_norm(x, normalized_shape, eps)
- else:
- x = fused_rms_norm_affine(x, weight, normalized_shape, eps)
- x = x.permute(0, 3, 1, 2)
- if is_autocast_enabled(x.device.type):
- # normally native AMP casts norm inputs to float32 and leaves the output as float32
- # apex does not, this is behaving like Apex
- dt = get_autocast_dtype(x.device.type)
- x, weight = x.to(dt), weight.to(dt) if weight is not None else None
- with torch.amp.autocast(device_type=x.device.type, enabled=False):
- x = rms_norm2d(x, normalized_shape, weight, eps)
- return x
- def simple_norm(
- x: torch.Tensor,
- normalized_shape: List[int],
- weight: Optional[torch.Tensor] = None,
- eps: float = 1e-5,
- ):
- norm_ndim = len(normalized_shape)
- if torch.jit.is_scripting():
- # ndim = len(x.shape)
- # dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
- # NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
- assert norm_ndim == 1
- v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
- else:
- dims = tuple(range(-1, -norm_ndim - 1, -1))
- v = torch.var(x, dim=dims, keepdim=True)
- x = x * torch.rsqrt(v + eps)
- if weight is not None:
- x = x * weight
- return x
- def fast_simple_norm(
- x: torch.Tensor,
- normalized_shape: List[int],
- weight: Optional[torch.Tensor] = None,
- eps: float = 1e-5,
- ) -> torch.Tensor:
- if torch.jit.is_scripting():
- # this must be by itself, cannot merge with has_apex_rmsnorm
- return simple_norm(x, normalized_shape, weight, eps)
- if is_autocast_enabled(x.device.type):
- # normally native AMP casts LN inputs to float32
- # apex LN does not, this is behaving like Apex
- dt = get_autocast_dtype(x.device.type)
- x, weight = x.to(dt), weight.to(dt) if weight is not None else None
- with torch.amp.autocast(device_type=x.device.type, enabled=False):
- x = simple_norm(x, normalized_shape, weight, eps)
- return x
|