laprop.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. """ PyTorch impl of LaProp optimizer
  2. Code simplified from https://github.com/Z-T-WANG/LaProp-Optimizer, MIT License
  3. Paper: LaProp: Separating Momentum and Adaptivity in Adam, https://arxiv.org/abs/2002.04839
  4. @article{ziyin2020laprop,
  5. title={LaProp: a Better Way to Combine Momentum with Adaptive Gradient},
  6. author={Ziyin, Liu and Wang, Zhikang T and Ueda, Masahito},
  7. journal={arXiv preprint arXiv:2002.04839},
  8. year={2020}
  9. }
  10. References for added functionality:
  11. Cautious Optimizers: https://arxiv.org/abs/2411.16085
  12. Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285
  13. """
  14. from typing import Tuple
  15. from torch.optim import Optimizer
  16. import torch
  17. from ._types import ParamsT
  18. class LaProp(Optimizer):
  19. """ LaProp Optimizer
  20. Paper: LaProp: Separating Momentum and Adaptivity in Adam, https://arxiv.org/abs/2002.04839
  21. """
  22. def __init__(
  23. self,
  24. params: ParamsT,
  25. lr: float = 4e-4,
  26. betas: Tuple[float, float] = (0.9, 0.999),
  27. eps: float = 1e-15,
  28. weight_decay: float = 0.,
  29. caution: bool = False,
  30. corrected_weight_decay: bool = False,
  31. ):
  32. if not 0.0 <= lr:
  33. raise ValueError("Invalid learning rate: {}".format(lr))
  34. if not 0.0 <= eps:
  35. raise ValueError("Invalid epsilon value: {}".format(eps))
  36. if not 0.0 <= betas[0] < 1.0:
  37. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  38. if not 0.0 <= betas[1] < 1.0:
  39. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  40. defaults = dict(
  41. lr=lr,
  42. betas=betas,
  43. eps=eps,
  44. weight_decay=weight_decay,
  45. caution=caution,
  46. corrected_weight_decay=corrected_weight_decay,
  47. )
  48. super(LaProp, self).__init__(params, defaults)
  49. def __setstate__(self, state):
  50. super().__setstate__(state)
  51. for group in self.param_groups:
  52. group.setdefault('caution', False)
  53. group.setdefault('corrected_weight_decay', False)
  54. @torch.no_grad()
  55. def step(self, closure=None):
  56. """Performs a single optimization step.
  57. Arguments:
  58. closure (callable, optional): A closure that reevaluates the model
  59. and returns the loss.
  60. """
  61. loss = None
  62. if closure is not None:
  63. with torch.enable_grad():
  64. loss = closure()
  65. for group in self.param_groups:
  66. for p in group['params']:
  67. if p.grad is None:
  68. continue
  69. grad = p.grad
  70. if grad.is_sparse:
  71. raise RuntimeError('LaProp does not support sparse gradients')
  72. state = self.state[p]
  73. # State initialization
  74. if len(state) == 0:
  75. state['step'] = 0
  76. # Exponential moving average of gradient values
  77. state['exp_avg'] = torch.zeros_like(p)
  78. # Exponential moving average of learning rates
  79. state['exp_avg_lr_1'] = 0.
  80. state['exp_avg_lr_2'] = 0.
  81. # Exponential moving average of squared gradient values
  82. state['exp_avg_sq'] = torch.zeros_like(p)
  83. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  84. beta1, beta2 = group['betas']
  85. state['step'] += 1
  86. one_minus_beta2 = 1 - beta2
  87. one_minus_beta1 = 1 - beta1
  88. # Decay the first and second moment running average coefficient
  89. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=one_minus_beta2)
  90. state['exp_avg_lr_1'] = state['exp_avg_lr_1'] * beta1 + one_minus_beta1 * group['lr']
  91. state['exp_avg_lr_2'] = state['exp_avg_lr_2'] * beta2 + one_minus_beta2
  92. # 1 - beta1 ** state['step']
  93. bias_correction1 = state['exp_avg_lr_1'] / group['lr'] if group['lr'] != 0. else 1.
  94. bias_correction2 = state['exp_avg_lr_2']
  95. step_size = 1 / bias_correction1
  96. denom = exp_avg_sq.div(bias_correction2).sqrt_().add_(group['eps'])
  97. step_of_this_grad = grad / denom
  98. exp_avg.mul_(beta1).add_(step_of_this_grad, alpha=group['lr'] * one_minus_beta1)
  99. if group['caution']:
  100. # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
  101. mask = (exp_avg * grad > 0).to(grad.dtype)
  102. mask.div_(mask.mean().clamp_(min=1e-3))
  103. exp_avg = exp_avg * mask
  104. p.add_(exp_avg, alpha=-step_size)
  105. if group['weight_decay'] != 0:
  106. if group['corrected_weight_decay']:
  107. wd_scale = group['lr'] ** 2 / self.defaults['lr']
  108. else:
  109. wd_scale = group['lr']
  110. p.add_(p, alpha=-wd_scale * group['weight_decay'])
  111. return loss