nadam.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import math
  2. import torch
  3. from torch.optim.optimizer import Optimizer
  4. class NAdamLegacy(Optimizer):
  5. """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
  6. NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
  7. It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
  8. Arguments:
  9. params (iterable): iterable of parameters to optimize or dicts defining
  10. parameter groups
  11. lr (float, optional): learning rate (default: 2e-3)
  12. betas (Tuple[float, float], optional): coefficients used for computing
  13. running averages of gradient and its square
  14. eps (float, optional): term added to the denominator to improve
  15. numerical stability (default: 1e-8)
  16. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  17. schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
  18. __ http://cs229.stanford.edu/proj2015/054_report.pdf
  19. __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
  20. Originally taken from: https://github.com/pytorch/pytorch/pull/1408
  21. NOTE: Has potential issues but does work well on some problems.
  22. """
  23. def __init__(
  24. self,
  25. params,
  26. lr=2e-3,
  27. betas=(0.9, 0.999),
  28. eps=1e-8,
  29. weight_decay=0,
  30. schedule_decay=4e-3,
  31. ):
  32. if not 0.0 <= lr:
  33. raise ValueError("Invalid learning rate: {}".format(lr))
  34. defaults = dict(
  35. lr=lr,
  36. betas=betas,
  37. eps=eps,
  38. weight_decay=weight_decay,
  39. schedule_decay=schedule_decay,
  40. )
  41. super(NAdamLegacy, self).__init__(params, defaults)
  42. @torch.no_grad()
  43. def step(self, closure=None):
  44. """Performs a single optimization step.
  45. Arguments:
  46. closure (callable, optional): A closure that reevaluates the model
  47. and returns the loss.
  48. """
  49. loss = None
  50. if closure is not None:
  51. with torch.enable_grad():
  52. loss = closure()
  53. for group in self.param_groups:
  54. for p in group['params']:
  55. if p.grad is None:
  56. continue
  57. grad = p.grad
  58. state = self.state[p]
  59. # State initialization
  60. if len(state) == 0:
  61. state['step'] = 0
  62. state['m_schedule'] = 1.
  63. state['exp_avg'] = torch.zeros_like(p)
  64. state['exp_avg_sq'] = torch.zeros_like(p)
  65. # Warming momentum schedule
  66. m_schedule = state['m_schedule']
  67. schedule_decay = group['schedule_decay']
  68. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  69. beta1, beta2 = group['betas']
  70. eps = group['eps']
  71. state['step'] += 1
  72. t = state['step']
  73. bias_correction2 = 1 - beta2 ** t
  74. if group['weight_decay'] != 0:
  75. grad = grad.add(p, alpha=group['weight_decay'])
  76. momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay)))
  77. momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
  78. m_schedule_new = m_schedule * momentum_cache_t
  79. m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
  80. state['m_schedule'] = m_schedule_new
  81. # Decay the first and second moment running average coefficient
  82. exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1)
  83. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2)
  84. denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
  85. p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
  86. p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
  87. return loss