sgdp.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. """
  2. SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py
  3. Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
  4. Code: https://github.com/clovaai/AdamP
  5. References for added functionality:
  6. Cautious Optimizers: https://arxiv.org/abs/2411.16085
  7. Spherical Cautious Optimizers: https://openreview.net/forum?id=OyT2CJ4fh7
  8. Copyright (c) 2020-present NAVER Corp.
  9. MIT license
  10. """
  11. import torch
  12. import torch.nn.functional as F
  13. from torch.optim.optimizer import Optimizer, required
  14. import math
  15. from .adamp import projection
  16. class SGDP(Optimizer):
  17. def __init__(
  18. self,
  19. params,
  20. lr=required,
  21. momentum=0,
  22. dampening=0,
  23. weight_decay=0,
  24. nesterov=False,
  25. eps=1e-8,
  26. delta=0.1,
  27. wd_ratio=0.1,
  28. caution=False
  29. ):
  30. defaults = dict(
  31. lr=lr,
  32. momentum=momentum,
  33. dampening=dampening,
  34. weight_decay=weight_decay,
  35. nesterov=nesterov,
  36. eps=eps,
  37. delta=delta,
  38. wd_ratio=wd_ratio,
  39. caution=caution,
  40. )
  41. super(SGDP, self).__init__(params, defaults)
  42. @torch.no_grad()
  43. def step(self, closure=None):
  44. loss = None
  45. if closure is not None:
  46. with torch.enable_grad():
  47. loss = closure()
  48. for group in self.param_groups:
  49. weight_decay = group['weight_decay']
  50. momentum = group['momentum']
  51. dampening = group['dampening']
  52. nesterov = group['nesterov']
  53. caution = group.get('caution', False)
  54. for p in group['params']:
  55. if p.grad is None:
  56. continue
  57. grad = p.grad
  58. state = self.state[p]
  59. # State initialization
  60. if len(state) == 0:
  61. state['momentum'] = torch.zeros_like(p)
  62. # SGD
  63. buf = state['momentum']
  64. buf.mul_(momentum).add_(grad, alpha=1. - dampening)
  65. if nesterov:
  66. d_p = grad + momentum * buf
  67. else:
  68. d_p = buf.clone()
  69. # Projection
  70. wd_ratio = 1.
  71. if len(p.shape) > 1:
  72. d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'], caution)
  73. elif caution:
  74. mask = (d_p * grad > 0).to(grad.dtype)
  75. mask.div_(mask.mean().clamp_(min=1e-3))
  76. d_p.mul_(mask)
  77. # Weight decay
  78. if weight_decay != 0:
  79. p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
  80. # Step
  81. p.add_(d_p, alpha=-group['lr'])
  82. return loss