adamp.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. """
  2. AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py
  3. Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
  4. Code: https://github.com/clovaai/AdamP
  5. References for added functionality:
  6. Cautious Optimizers: https://arxiv.org/abs/2411.16085
  7. Spherical Cautious Optimizers: https://openreview.net/forum?id=OyT2CJ4fh7
  8. Copyright (c) 2020-present NAVER Corp.
  9. MIT license
  10. """
  11. import torch
  12. import torch.nn.functional as F
  13. from torch.optim.optimizer import Optimizer
  14. import math
  15. def _channel_view(x) -> torch.Tensor:
  16. return x.reshape(x.size(0), -1)
  17. def _layer_view(x) -> torch.Tensor:
  18. return x.reshape(1, -1)
  19. def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float, caution: bool = False):
  20. wd = 1.
  21. expand_size = (-1,) + (1,) * (len(p.shape) - 1)
  22. for view_func in [_channel_view, _layer_view]:
  23. param_view = view_func(p)
  24. grad_view = view_func(grad)
  25. cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()
  26. # FIXME this is a problem for PyTorch XLA
  27. if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
  28. p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
  29. perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
  30. if caution:
  31. # Spherical Cautious Optimizer Logic
  32. grad_radial = p_n * view_func(p_n * grad).sum(dim=1).reshape(expand_size)
  33. grad_perp = grad - grad_radial
  34. mask = (perturb * grad_perp > 0).to(grad.dtype)
  35. mask.div_(mask.mean().clamp_(min=1e-3))
  36. perturb.mul_(mask)
  37. # Enhance the numerical stability of the Cautious Optimizer
  38. perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
  39. wd = wd_ratio
  40. return perturb, wd
  41. if caution:
  42. # Standard Cautious Optimizer Logic for non-projected parameters
  43. mask = (perturb * grad > 0).to(grad.dtype)
  44. mask.div_(mask.mean().clamp_(min=1e-3))
  45. perturb.mul_(mask)
  46. return perturb, wd
  47. class AdamP(Optimizer):
  48. def __init__(
  49. self,
  50. params,
  51. lr=1e-3,
  52. betas=(0.9, 0.999),
  53. eps=1e-8,
  54. weight_decay=0,
  55. delta=0.1,
  56. wd_ratio=0.1,
  57. nesterov=False,
  58. caution=False,
  59. ):
  60. defaults = dict(
  61. lr=lr,
  62. betas=betas,
  63. eps=eps,
  64. weight_decay=weight_decay,
  65. delta=delta,
  66. wd_ratio=wd_ratio,
  67. nesterov=nesterov,
  68. caution=caution,
  69. )
  70. super(AdamP, self).__init__(params, defaults)
  71. @torch.no_grad()
  72. def step(self, closure=None):
  73. loss = None
  74. if closure is not None:
  75. with torch.enable_grad():
  76. loss = closure()
  77. for group in self.param_groups:
  78. for p in group['params']:
  79. if p.grad is None:
  80. continue
  81. grad = p.grad
  82. beta1, beta2 = group['betas']
  83. nesterov = group['nesterov']
  84. caution = group.get('caution', False)
  85. state = self.state[p]
  86. # State initialization
  87. if len(state) == 0:
  88. state['step'] = 0
  89. state['exp_avg'] = torch.zeros_like(p)
  90. state['exp_avg_sq'] = torch.zeros_like(p)
  91. # Adam
  92. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  93. state['step'] += 1
  94. bias_correction1 = 1 - beta1 ** state['step']
  95. bias_correction2 = 1 - beta2 ** state['step']
  96. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
  97. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
  98. denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
  99. step_size = group['lr'] / bias_correction1
  100. if nesterov:
  101. perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
  102. else:
  103. perturb = exp_avg / denom
  104. # Projection
  105. wd_ratio = 1.
  106. if len(p.shape) > 1:
  107. perturb, wd_ratio = projection(
  108. p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'], caution
  109. )
  110. elif caution:
  111. # Apply standard caution for scalars/1D tensors if needed
  112. mask = (perturb * grad > 0).to(grad.dtype)
  113. mask.div_(mask.mean().clamp_(min=1e-3))
  114. perturb.mul_(mask)
  115. # Weight decay
  116. if group['weight_decay'] > 0:
  117. p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
  118. # Step
  119. p.add_(perturb, alpha=-step_size)
  120. return loss