| 1234567891011121314151617181920212223 |
- import torch
- from timm.utils.agc import adaptive_clip_grad
- def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
- """ Dispatch to gradient clipping method
- Args:
- parameters (Iterable): model parameters to clip
- value (float): clipping value/factor/norm, mode dependant
- mode (str): clipping mode, one of 'norm', 'value', 'agc'
- norm_type (float): p-norm, default 2.0
- """
- if mode == 'norm':
- torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
- elif mode == 'value':
- torch.nn.utils.clip_grad_value_(parameters, value)
- elif mode == 'agc':
- adaptive_clip_grad(parameters, value, norm_type=norm_type)
- else:
- assert False, f"Unknown clip mode ({mode})."
|