agc.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. """ Adaptive Gradient Clipping
  2. An impl of AGC, as per (https://arxiv.org/abs/2102.06171):
  3. @article{brock2021high,
  4. author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
  5. title={High-Performance Large-Scale Image Recognition Without Normalization},
  6. journal={arXiv preprint arXiv:},
  7. year={2021}
  8. }
  9. Code references:
  10. * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets
  11. * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c
  12. Hacked together by / Copyright 2021 Ross Wightman
  13. """
  14. import torch
  15. def unitwise_norm(x, norm_type=2.0):
  16. if x.ndim <= 1:
  17. return x.norm(norm_type)
  18. else:
  19. # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
  20. # might need special cases for other weights (possibly MHA) where this may not be true
  21. return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
  22. def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
  23. if isinstance(parameters, torch.Tensor):
  24. parameters = [parameters]
  25. for p in parameters:
  26. if p.grad is None:
  27. continue
  28. p_data = p.detach()
  29. g_data = p.grad.detach()
  30. max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
  31. grad_norm = unitwise_norm(g_data, norm_type=norm_type)
  32. clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
  33. new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
  34. p.grad.detach().copy_(new_grads)