nvnovograd.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. """ Nvidia NovoGrad Optimizer.
  2. Original impl by Nvidia from Jasper example:
  3. - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
  4. Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
  5. - https://arxiv.org/abs/1905.11286
  6. """
  7. import torch
  8. from torch.optim.optimizer import Optimizer
  9. import math
  10. class NvNovoGrad(Optimizer):
  11. """
  12. Implements Novograd algorithm.
  13. Args:
  14. params (iterable): iterable of parameters to optimize or dicts defining
  15. parameter groups
  16. lr (float, optional): learning rate (default: 1e-3)
  17. betas (Tuple[float, float], optional): coefficients used for computing
  18. running averages of gradient and its square (default: (0.95, 0.98))
  19. eps (float, optional): term added to the denominator to improve
  20. numerical stability (default: 1e-8)
  21. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  22. grad_averaging: gradient averaging
  23. amsgrad (boolean, optional): whether to use the AMSGrad variant of this
  24. algorithm from the paper `On the Convergence of Adam and Beyond`_
  25. (default: False)
  26. """
  27. def __init__(
  28. self,
  29. params,
  30. lr=1e-3,
  31. betas=(0.95, 0.98),
  32. eps=1e-8,
  33. weight_decay=0,
  34. grad_averaging=False,
  35. amsgrad=False,
  36. ):
  37. if not 0.0 <= lr:
  38. raise ValueError("Invalid learning rate: {}".format(lr))
  39. if not 0.0 <= eps:
  40. raise ValueError("Invalid epsilon value: {}".format(eps))
  41. if not 0.0 <= betas[0] < 1.0:
  42. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  43. if not 0.0 <= betas[1] < 1.0:
  44. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  45. defaults = dict(
  46. lr=lr,
  47. betas=betas,
  48. eps=eps,
  49. weight_decay=weight_decay,
  50. grad_averaging=grad_averaging,
  51. amsgrad=amsgrad,
  52. )
  53. super(NvNovoGrad, self).__init__(params, defaults)
  54. def __setstate__(self, state):
  55. super(NvNovoGrad, self).__setstate__(state)
  56. for group in self.param_groups:
  57. group.setdefault('amsgrad', False)
  58. @torch.no_grad()
  59. def step(self, closure=None):
  60. """Performs a single optimization step.
  61. Arguments:
  62. closure (callable, optional): A closure that reevaluates the model
  63. and returns the loss.
  64. """
  65. loss = None
  66. if closure is not None:
  67. with torch.enable_grad():
  68. loss = closure()
  69. for group in self.param_groups:
  70. for p in group['params']:
  71. if p.grad is None:
  72. continue
  73. grad = p.grad
  74. if grad.is_sparse:
  75. raise RuntimeError('Sparse gradients are not supported.')
  76. amsgrad = group['amsgrad']
  77. state = self.state[p]
  78. # State initialization
  79. if len(state) == 0:
  80. state['step'] = 0
  81. # Exponential moving average of gradient values
  82. state['exp_avg'] = torch.zeros_like(p)
  83. # Exponential moving average of squared gradient values
  84. state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
  85. if amsgrad:
  86. # Maintains max of all exp. moving avg. of sq. grad. values
  87. state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
  88. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  89. if amsgrad:
  90. max_exp_avg_sq = state['max_exp_avg_sq']
  91. beta1, beta2 = group['betas']
  92. state['step'] += 1
  93. norm = torch.sum(torch.pow(grad, 2))
  94. if exp_avg_sq == 0:
  95. exp_avg_sq.copy_(norm)
  96. else:
  97. exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2)
  98. if amsgrad:
  99. # Maintains the maximum of all 2nd moment running avg. till now
  100. torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
  101. # Use the max. for normalizing running avg. of gradient
  102. denom = max_exp_avg_sq.sqrt().add_(group['eps'])
  103. else:
  104. denom = exp_avg_sq.sqrt().add_(group['eps'])
  105. grad.div_(denom)
  106. if group['weight_decay'] != 0:
  107. grad.add_(p, alpha=group['weight_decay'])
  108. if group['grad_averaging']:
  109. grad.mul_(1 - beta1)
  110. exp_avg.mul_(beta1).add_(grad)
  111. p.add_(exp_avg, alpha=-group['lr'])
  112. return loss