clip_grad.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. import functools
  4. import types
  5. import typing
  6. import warnings
  7. from collections.abc import Callable
  8. from typing import cast, TypeAlias, TypeVar
  9. from typing_extensions import deprecated, ParamSpec
  10. import torch
  11. from torch import Tensor
  12. from torch.utils._foreach_utils import (
  13. _device_has_foreach_support,
  14. _group_tensors_by_device_and_dtype,
  15. _has_foreach_support,
  16. )
  17. __all__: list[str] = [
  18. "clip_grad_norm",
  19. "clip_grad_norm_",
  20. "clip_grad_value_",
  21. ]
  22. _tensor_or_tensors: TypeAlias = torch.Tensor | typing.Iterable[torch.Tensor] # noqa: PYI042
  23. _P = ParamSpec("_P")
  24. _R = TypeVar("_R")
  25. def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]:
  26. """
  27. This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
  28. clip_grad_norm_ and clip_grad_value_ themselves.
  29. """
  30. def _no_grad_wrapper(*args, **kwargs):
  31. with torch.no_grad():
  32. # pyrefly: ignore [invalid-param-spec]
  33. return func(*args, **kwargs)
  34. functools.update_wrapper(_no_grad_wrapper, func)
  35. # pyrefly: ignore [bad-return]
  36. return _no_grad_wrapper
  37. @_no_grad
  38. def _get_total_norm(
  39. tensors: _tensor_or_tensors,
  40. norm_type: float = 2.0,
  41. error_if_nonfinite: bool = False,
  42. foreach: bool | None = None,
  43. ) -> torch.Tensor:
  44. r"""Compute the norm of an iterable of tensors.
  45. The norm is computed over the norms of the individual tensors, as if the norms of
  46. the individual tensors were concatenated into a single vector.
  47. Args:
  48. tensors (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  49. single Tensor that will be normalized
  50. norm_type (float): type of the used p-norm. Can be ``'inf'`` for
  51. infinity norm.
  52. error_if_nonfinite (bool): if True, an error is thrown if the total
  53. norm of :attr:`tensors` is ``nan``, ``inf``, or ``-inf``.
  54. Default: ``False``
  55. foreach (bool): use the faster foreach-based implementation.
  56. If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
  57. fall back to the slow implementation for other device types.
  58. Default: ``None``
  59. Returns:
  60. Total norm of the tensors (viewed as a single vector).
  61. """
  62. if isinstance(tensors, torch.Tensor):
  63. tensors = [tensors]
  64. else:
  65. tensors = list(tensors)
  66. norm_type = float(norm_type)
  67. if len(tensors) == 0:
  68. return torch.tensor(0.0)
  69. first_device = tensors[0].device
  70. grouped_tensors: dict[
  71. tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]]
  72. ] = _group_tensors_by_device_and_dtype( # pyrefly: ignore [bad-assignment]
  73. [tensors] # type: ignore[list-item]
  74. ) # type: ignore[assignment]
  75. norms: list[Tensor] = []
  76. for (device, _), ([device_tensors], _) in grouped_tensors.items():
  77. if (foreach is None and _has_foreach_support(device_tensors, device)) or (
  78. foreach and _device_has_foreach_support(device)
  79. ):
  80. norms.extend(torch._foreach_norm(device_tensors, norm_type))
  81. elif foreach:
  82. raise RuntimeError(
  83. f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
  84. )
  85. else:
  86. norms.extend(
  87. [torch.linalg.vector_norm(g, norm_type) for g in device_tensors]
  88. )
  89. total_norm = torch.linalg.vector_norm(
  90. torch.stack([norm.to(first_device) for norm in norms]), norm_type
  91. )
  92. if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
  93. raise RuntimeError(
  94. f"The total norm of order {norm_type} for gradients from "
  95. "`parameters` is non-finite, so it cannot be clipped. To disable "
  96. "this error and scale the gradients by the non-finite norm anyway, "
  97. "set `error_if_nonfinite=False`"
  98. )
  99. return total_norm
  100. @_no_grad
  101. def _clip_grads_with_norm_(
  102. parameters: _tensor_or_tensors,
  103. max_norm: float,
  104. total_norm: torch.Tensor,
  105. foreach: bool | None = None,
  106. ) -> None:
  107. r"""Scale the gradients of an iterable of parameters given a pre-calculated total norm and desired max norm.
  108. The gradients will be scaled by the following calculation
  109. .. math::
  110. grad = grad * \min(\frac{max\_norm}{total\_norm + 1e-6}, 1)
  111. Gradients are modified in-place.
  112. Note: The scale coefficient is clamped to a maximum of 1.0 to prevent gradient amplification.
  113. This ensures that gradients are only scaled down when the total norm exceeds max_norm.
  114. This function is equivalent to :func:`torch.nn.utils.clip_grad_norm_` with a pre-calculated
  115. total norm.
  116. Args:
  117. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  118. single Tensor that will have gradients normalized
  119. max_norm (float): max norm of the gradients
  120. total_norm (Tensor): total norm of the gradients to use for clipping
  121. foreach (bool): use the faster foreach-based implementation.
  122. If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
  123. fall back to the slow implementation for other device types.
  124. Default: ``None``
  125. Returns:
  126. None
  127. """
  128. if isinstance(parameters, torch.Tensor):
  129. parameters = [parameters]
  130. grads = [p.grad for p in parameters if p.grad is not None]
  131. max_norm = float(max_norm)
  132. if len(grads) == 0:
  133. return
  134. grouped_grads: dict[
  135. tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]]
  136. ] = _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment]
  137. clip_coef = max_norm / (total_norm + 1e-6)
  138. # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
  139. # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
  140. # when the gradients do not reside in CPU memory.
  141. clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
  142. for (device, _), ([device_grads], _) in grouped_grads.items():
  143. if (foreach is None and _has_foreach_support(device_grads, device)) or (
  144. foreach and _device_has_foreach_support(device)
  145. ):
  146. torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
  147. elif foreach:
  148. raise RuntimeError(
  149. f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
  150. )
  151. else:
  152. clip_coef_clamped_device = clip_coef_clamped.to(device)
  153. for g in device_grads:
  154. g.mul_(clip_coef_clamped_device)
  155. @_no_grad
  156. def clip_grad_norm_(
  157. parameters: _tensor_or_tensors,
  158. max_norm: float,
  159. norm_type: float = 2.0,
  160. error_if_nonfinite: bool = False,
  161. foreach: bool | None = None,
  162. ) -> torch.Tensor:
  163. r"""Clip the gradient norm of an iterable of parameters.
  164. The norm is computed over the norms of the individual gradients of all parameters,
  165. as if the norms of the individual gradients were concatenated into a single vector.
  166. Gradients are modified in-place.
  167. This function is equivalent to :func:`torch.nn.utils.get_total_norm` followed by
  168. :func:`torch.nn.utils.clip_grads_with_norm_` with the ``total_norm`` returned by ``get_total_norm``.
  169. Args:
  170. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  171. single Tensor that will have gradients normalized
  172. max_norm (float): max norm of the gradients
  173. norm_type (float, optional): type of the used p-norm. Can be ``'inf'`` for
  174. infinity norm. Default: 2.0
  175. error_if_nonfinite (bool, optional): if True, an error is thrown if the total
  176. norm of the gradients from :attr:`parameters` is ``nan``,
  177. ``inf``, or ``-inf``. Default: False
  178. foreach (bool, optional): use the faster foreach-based implementation.
  179. If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
  180. fall back to the slow implementation for other device types.
  181. Default: ``None``
  182. Returns:
  183. Total norm of the parameter gradients (viewed as a single vector).
  184. """
  185. if isinstance(parameters, torch.Tensor):
  186. parameters = [parameters]
  187. else:
  188. is_generator = isinstance(parameters, types.GeneratorType)
  189. # prevent generators from being exhausted
  190. parameters = list(parameters)
  191. if is_generator and len(parameters) == 0:
  192. warnings.warn(
  193. "`parameters` is an empty generator, no gradient clipping will occur.",
  194. stacklevel=3,
  195. )
  196. grads = [p.grad for p in parameters if p.grad is not None]
  197. total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
  198. _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
  199. return total_norm
  200. @deprecated(
  201. "`torch.nn.utils.clip_grad_norm` is now deprecated "
  202. "in favor of `torch.nn.utils.clip_grad_norm_`.",
  203. category=FutureWarning,
  204. )
  205. def clip_grad_norm(
  206. parameters: _tensor_or_tensors,
  207. max_norm: float,
  208. norm_type: float = 2.0,
  209. error_if_nonfinite: bool = False,
  210. foreach: bool | None = None,
  211. ) -> torch.Tensor:
  212. r"""Clip the gradient norm of an iterable of parameters.
  213. .. warning::
  214. This method is now deprecated in favor of
  215. :func:`torch.nn.utils.clip_grad_norm_`.
  216. """
  217. return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach)
  218. @_no_grad
  219. def clip_grad_value_(
  220. parameters: _tensor_or_tensors,
  221. clip_value: float,
  222. foreach: bool | None = None,
  223. ) -> None:
  224. r"""Clip the gradients of an iterable of parameters at specified value.
  225. Gradients are modified in-place.
  226. Args:
  227. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  228. single Tensor that will have gradients normalized
  229. clip_value (float): maximum allowed value of the gradients.
  230. The gradients are clipped in the range
  231. :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
  232. foreach (bool, optional): use the faster foreach-based implementation
  233. If ``None``, use the foreach implementation for CUDA and CPU native tensors and
  234. silently fall back to the slow implementation for other device types.
  235. Default: ``None``
  236. """
  237. if isinstance(parameters, torch.Tensor):
  238. parameters = [parameters]
  239. clip_value = float(clip_value)
  240. grads = [p.grad for p in parameters if p.grad is not None]
  241. # pyrefly: ignore [bad-argument-type]
  242. grouped_grads = _group_tensors_by_device_and_dtype([grads])
  243. for (device, _), ([grads], _) in grouped_grads.items():
  244. if (
  245. foreach is None
  246. and _has_foreach_support(cast(list[Tensor], grads), device=device)
  247. ) or (foreach and _device_has_foreach_support(device)):
  248. torch._foreach_clamp_min_(cast(list[Tensor], grads), -clip_value)
  249. torch._foreach_clamp_max_(cast(list[Tensor], grads), clip_value)
  250. elif foreach:
  251. raise RuntimeError(
  252. f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
  253. )
  254. else:
  255. for grad in grads:
  256. cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value)