fast_norm.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. """ 'Fast' Normalization Functions
  2. For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32.
  3. Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast)
  4. Hacked together by / Copyright 2022 Ross Wightman
  5. """
  6. from typing import List, Optional
  7. import torch
  8. from torch.nn import functional as F
  9. try:
  10. from apex.normalization.fused_layer_norm import fused_layer_norm_affine
  11. has_apex = True
  12. except ImportError:
  13. has_apex = False
  14. try:
  15. from apex.normalization.fused_layer_norm import fused_rms_norm_affine, fused_rms_norm
  16. has_apex_rmsnorm = True
  17. except ImportError:
  18. has_apex_rmsnorm = False
  19. has_torch_rms_norm = hasattr(F, 'rms_norm')
  20. # fast (ie lower precision LN) can be disabled with this flag if issues crop up
  21. _USE_FAST_NORM = False # defaulting to False for now
  22. def get_autocast_dtype(device: str = 'cuda'):
  23. try:
  24. return torch.get_autocast_dtype(device)
  25. except (AttributeError, TypeError):
  26. # dispatch to older device specific fns, only covering cuda/cpu devices here
  27. if device == 'cpu':
  28. return torch.get_autocast_cpu_dtype()
  29. else:
  30. assert device == 'cuda'
  31. return torch.get_autocast_gpu_dtype()
  32. def is_autocast_enabled(device: str = 'cuda'):
  33. try:
  34. return torch.is_autocast_enabled(device)
  35. except TypeError:
  36. # dispatch to older device specific fns, only covering cuda/cpu devices here
  37. if device == 'cpu':
  38. return torch.is_autocast_cpu_enabled()
  39. else:
  40. assert device == 'cuda'
  41. return torch.is_autocast_enabled() # defaults cuda (only cuda on older pytorch)
  42. def is_fast_norm():
  43. return _USE_FAST_NORM
  44. def set_fast_norm(enable=True):
  45. global _USE_FAST_NORM
  46. _USE_FAST_NORM = enable
  47. def fast_group_norm(
  48. x: torch.Tensor,
  49. num_groups: int,
  50. weight: Optional[torch.Tensor] = None,
  51. bias: Optional[torch.Tensor] = None,
  52. eps: float = 1e-5
  53. ) -> torch.Tensor:
  54. if torch.jit.is_scripting():
  55. # currently cannot use is_autocast_enabled within torchscript
  56. return F.group_norm(x, num_groups, weight, bias, eps)
  57. if is_autocast_enabled(x.device.type):
  58. # normally native AMP casts GN inputs to float32
  59. # here we use the low precision autocast dtype
  60. dt = get_autocast_dtype(x.device.type)
  61. x, weight, bias = (
  62. x.to(dt),
  63. weight.to(dt) if weight is not None else None,
  64. bias.to(dt) if bias is not None else None,
  65. )
  66. with torch.amp.autocast(device_type=x.device.type, enabled=False):
  67. return F.group_norm(x, num_groups, weight, bias, eps)
  68. def fast_layer_norm(
  69. x: torch.Tensor,
  70. normalized_shape: List[int],
  71. weight: Optional[torch.Tensor] = None,
  72. bias: Optional[torch.Tensor] = None,
  73. eps: float = 1e-5
  74. ) -> torch.Tensor:
  75. if torch.jit.is_scripting():
  76. # currently cannot use is_autocast_enabled within torchscript
  77. return F.layer_norm(x, normalized_shape, weight, bias, eps)
  78. if has_apex:
  79. return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
  80. if is_autocast_enabled(x.device.type):
  81. # normally native AMP casts LN inputs to float32
  82. # apex LN does not, this is behaving like Apex
  83. dt = get_autocast_dtype(x.device.type)
  84. x, weight, bias = (
  85. x.to(dt),
  86. weight.to(dt) if weight is not None else None,
  87. bias.to(dt) if bias is not None else None,
  88. )
  89. with torch.amp.autocast(device_type=x.device.type, enabled=False):
  90. return F.layer_norm(x, normalized_shape, weight, bias, eps)
  91. def rms_norm(
  92. x: torch.Tensor,
  93. normalized_shape: List[int],
  94. weight: Optional[torch.Tensor] = None,
  95. eps: float = 1e-5,
  96. ):
  97. norm_ndim = len(normalized_shape)
  98. v = x.pow(2)
  99. if torch.jit.is_scripting():
  100. # ndim = len(x.shape)
  101. # dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
  102. # NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
  103. assert norm_ndim == 1
  104. v = torch.mean(v, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
  105. else:
  106. dims = tuple(range(-1, -norm_ndim - 1, -1))
  107. v = torch.mean(v, dim=dims, keepdim=True)
  108. x = x * torch.rsqrt(v + eps)
  109. if weight is not None:
  110. x = x * weight
  111. return x
  112. def fast_rms_norm(
  113. x: torch.Tensor,
  114. normalized_shape: List[int],
  115. weight: Optional[torch.Tensor] = None,
  116. eps: float = 1e-5,
  117. ) -> torch.Tensor:
  118. if torch.jit.is_scripting():
  119. # this must be by itself, cannot merge with has_apex_rmsnorm
  120. return rms_norm(x, normalized_shape, weight, eps)
  121. if has_apex_rmsnorm:
  122. if weight is None:
  123. return fused_rms_norm(x, normalized_shape, eps)
  124. else:
  125. return fused_rms_norm_affine(x, weight, normalized_shape, eps)
  126. if is_autocast_enabled(x.device.type):
  127. # normally native AMP casts LN inputs to float32 and leaves the output as float32
  128. # apex LN does not, this is behaving like Apex
  129. dt = get_autocast_dtype(x.device.type)
  130. x, weight = x.to(dt), weight.to(dt) if weight is not None else None
  131. with torch.amp.autocast(device_type=x.device.type, enabled=False):
  132. if has_torch_rms_norm:
  133. x = F.rms_norm(x, normalized_shape, weight, eps)
  134. else:
  135. x = rms_norm(x, normalized_shape, weight, eps)
  136. return x
  137. def rms_norm2d(
  138. x: torch.Tensor,
  139. normalized_shape: List[int],
  140. weight: Optional[torch.Tensor] = None,
  141. eps: float = 1e-5,
  142. ):
  143. assert len(normalized_shape) == 1
  144. v = x.pow(2)
  145. v = torch.mean(v, dim=1, keepdim=True)
  146. x = x * torch.rsqrt(v + eps)
  147. if weight is not None:
  148. x = x * weight.reshape(1, -1, 1, 1)
  149. return x
  150. def fast_rms_norm2d(
  151. x: torch.Tensor,
  152. normalized_shape: List[int],
  153. weight: Optional[torch.Tensor] = None,
  154. eps: float = 1e-5,
  155. ) -> torch.Tensor:
  156. if torch.jit.is_scripting():
  157. # this must be by itself, cannot merge with has_apex_rmsnorm
  158. return rms_norm2d(x, normalized_shape, weight, eps)
  159. if has_apex_rmsnorm:
  160. x = x.permute(0, 2, 3, 1)
  161. if weight is None:
  162. x = fused_rms_norm(x, normalized_shape, eps)
  163. else:
  164. x = fused_rms_norm_affine(x, weight, normalized_shape, eps)
  165. x = x.permute(0, 3, 1, 2)
  166. if is_autocast_enabled(x.device.type):
  167. # normally native AMP casts norm inputs to float32 and leaves the output as float32
  168. # apex does not, this is behaving like Apex
  169. dt = get_autocast_dtype(x.device.type)
  170. x, weight = x.to(dt), weight.to(dt) if weight is not None else None
  171. with torch.amp.autocast(device_type=x.device.type, enabled=False):
  172. x = rms_norm2d(x, normalized_shape, weight, eps)
  173. return x
  174. def simple_norm(
  175. x: torch.Tensor,
  176. normalized_shape: List[int],
  177. weight: Optional[torch.Tensor] = None,
  178. eps: float = 1e-5,
  179. ):
  180. norm_ndim = len(normalized_shape)
  181. if torch.jit.is_scripting():
  182. # ndim = len(x.shape)
  183. # dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
  184. # NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
  185. assert norm_ndim == 1
  186. v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
  187. else:
  188. dims = tuple(range(-1, -norm_ndim - 1, -1))
  189. v = torch.var(x, dim=dims, keepdim=True)
  190. x = x * torch.rsqrt(v + eps)
  191. if weight is not None:
  192. x = x * weight
  193. return x
  194. def fast_simple_norm(
  195. x: torch.Tensor,
  196. normalized_shape: List[int],
  197. weight: Optional[torch.Tensor] = None,
  198. eps: float = 1e-5,
  199. ) -> torch.Tensor:
  200. if torch.jit.is_scripting():
  201. # this must be by itself, cannot merge with has_apex_rmsnorm
  202. return simple_norm(x, normalized_shape, weight, eps)
  203. if is_autocast_enabled(x.device.type):
  204. # normally native AMP casts LN inputs to float32
  205. # apex LN does not, this is behaving like Apex
  206. dt = get_autocast_dtype(x.device.type)
  207. x, weight = x.to(dt), weight.to(dt) if weight is not None else None
  208. with torch.amp.autocast(device_type=x.device.type, enabled=False):
  209. x = simple_norm(x, normalized_shape, weight, eps)
  210. return x