| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- import math
- import torch
- from torch.optim.optimizer import Optimizer
- class NAdamLegacy(Optimizer):
- """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
- NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
- It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
- Arguments:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 2e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
- schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
- __ http://cs229.stanford.edu/proj2015/054_report.pdf
- __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
- Originally taken from: https://github.com/pytorch/pytorch/pull/1408
- NOTE: Has potential issues but does work well on some problems.
- """
- def __init__(
- self,
- params,
- lr=2e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=0,
- schedule_decay=4e-3,
- ):
- if not 0.0 <= lr:
- raise ValueError("Invalid learning rate: {}".format(lr))
- defaults = dict(
- lr=lr,
- betas=betas,
- eps=eps,
- weight_decay=weight_decay,
- schedule_decay=schedule_decay,
- )
- super(NAdamLegacy, self).__init__(params, defaults)
- @torch.no_grad()
- def step(self, closure=None):
- """Performs a single optimization step.
- Arguments:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
- for group in self.param_groups:
- for p in group['params']:
- if p.grad is None:
- continue
- grad = p.grad
- state = self.state[p]
- # State initialization
- if len(state) == 0:
- state['step'] = 0
- state['m_schedule'] = 1.
- state['exp_avg'] = torch.zeros_like(p)
- state['exp_avg_sq'] = torch.zeros_like(p)
- # Warming momentum schedule
- m_schedule = state['m_schedule']
- schedule_decay = group['schedule_decay']
- exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
- beta1, beta2 = group['betas']
- eps = group['eps']
- state['step'] += 1
- t = state['step']
- bias_correction2 = 1 - beta2 ** t
- if group['weight_decay'] != 0:
- grad = grad.add(p, alpha=group['weight_decay'])
- momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay)))
- momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
- m_schedule_new = m_schedule * momentum_cache_t
- m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
- state['m_schedule'] = m_schedule_new
- # Decay the first and second moment running average coefficient
- exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1)
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2)
- denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
- p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
- p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
- return loss
|