clip_grad.py 796 B

1234567891011121314151617181920212223
  1. import torch
  2. from timm.utils.agc import adaptive_clip_grad
  3. def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
  4. """ Dispatch to gradient clipping method
  5. Args:
  6. parameters (Iterable): model parameters to clip
  7. value (float): clipping value/factor/norm, mode dependant
  8. mode (str): clipping mode, one of 'norm', 'value', 'agc'
  9. norm_type (float): p-norm, default 2.0
  10. """
  11. if mode == 'norm':
  12. torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
  13. elif mode == 'value':
  14. torch.nn.utils.clip_grad_value_(parameters, value)
  15. elif mode == 'agc':
  16. adaptive_clip_grad(parameters, value, norm_type=norm_type)
  17. else:
  18. assert False, f"Unknown clip mode ({mode})."