madgrad.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. """ PyTorch MADGRAD optimizer
  2. MADGRAD: https://arxiv.org/abs/2101.11075
  3. Code from: https://github.com/facebookresearch/madgrad
  4. """
  5. # Copyright (c) Facebook, Inc. and its affiliates.
  6. #
  7. # This source code is licensed under the MIT license found in the
  8. # LICENSE file in the root directory of this source tree.
  9. import math
  10. from typing import TYPE_CHECKING, Any, Callable, Optional
  11. import torch
  12. import torch.optim
  13. if TYPE_CHECKING:
  14. from torch.optim.optimizer import _params_t
  15. else:
  16. _params_t = Any
  17. class MADGRAD(torch.optim.Optimizer):
  18. """
  19. MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
  20. Optimization.
  21. .. _MADGRAD: https://arxiv.org/abs/2101.11075
  22. MADGRAD is a general purpose optimizer that can be used in place of SGD or
  23. Adam may converge faster and generalize better. Currently GPU-only.
  24. Typically, the same learning rate schedule that is used for SGD or Adam may
  25. be used. The overall learning rate is not comparable to either method and
  26. should be determined by a hyper-parameter sweep.
  27. MADGRAD requires less weight decay than other methods, often as little as
  28. zero. Momentum values used for SGD or Adam's beta1 should work here also.
  29. On sparse problems both weight_decay and momentum should be set to 0.
  30. Arguments:
  31. params (iterable):
  32. Iterable of parameters to optimize or dicts defining parameter groups.
  33. lr (float):
  34. Learning rate (default: 1e-2).
  35. momentum (float):
  36. Momentum value in the range [0,1) (default: 0.9).
  37. weight_decay (float):
  38. Weight decay, i.e. a L2 penalty (default: 0).
  39. eps (float):
  40. Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6).
  41. """
  42. def __init__(
  43. self,
  44. params: _params_t,
  45. lr: float = 1e-2,
  46. momentum: float = 0.9,
  47. weight_decay: float = 0,
  48. eps: float = 1e-6,
  49. decoupled_decay: bool = False,
  50. ):
  51. if momentum < 0 or momentum >= 1:
  52. raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
  53. if lr <= 0:
  54. raise ValueError(f"Learning rate {lr} must be positive")
  55. if weight_decay < 0:
  56. raise ValueError(f"Weight decay {weight_decay} must be non-negative")
  57. if eps < 0:
  58. raise ValueError(f"Eps must be non-negative")
  59. defaults = dict(
  60. lr=lr,
  61. eps=eps,
  62. momentum=momentum,
  63. weight_decay=weight_decay,
  64. decoupled_decay=decoupled_decay,
  65. )
  66. super().__init__(params, defaults)
  67. @property
  68. def supports_memory_efficient_fp16(self) -> bool:
  69. return False
  70. @property
  71. def supports_flat_params(self) -> bool:
  72. return True
  73. @torch.no_grad()
  74. def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
  75. """Performs a single optimization step.
  76. Arguments:
  77. closure (callable, optional): A closure that reevaluates the model and returns the loss.
  78. """
  79. loss = None
  80. if closure is not None:
  81. with torch.enable_grad():
  82. loss = closure()
  83. for group in self.param_groups:
  84. eps = group['eps']
  85. lr = group['lr'] + eps
  86. weight_decay = group['weight_decay']
  87. momentum = group['momentum']
  88. ck = 1 - momentum
  89. for p in group["params"]:
  90. if p.grad is None:
  91. continue
  92. grad = p.grad
  93. if momentum != 0.0 and grad.is_sparse:
  94. raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
  95. state = self.state[p]
  96. if len(state) == 0:
  97. state['step'] = 0
  98. state['grad_sum_sq'] = torch.zeros_like(p)
  99. state['s'] = torch.zeros_like(p)
  100. if momentum != 0:
  101. state['x0'] = torch.clone(p).detach()
  102. state['step'] += 1
  103. grad_sum_sq = state['grad_sum_sq']
  104. s = state['s']
  105. lamb = lr * math.sqrt(state['step'])
  106. # Apply weight decay
  107. if weight_decay != 0:
  108. if group['decoupled_decay']:
  109. p.mul_(1.0 - group['lr'] * weight_decay)
  110. else:
  111. if grad.is_sparse:
  112. raise RuntimeError("weight_decay option is not compatible with sparse gradients")
  113. grad.add_(p, alpha=weight_decay)
  114. if grad.is_sparse:
  115. grad = grad.coalesce()
  116. grad_val = grad._values()
  117. p_masked = p.sparse_mask(grad)
  118. grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
  119. s_masked = s.sparse_mask(grad)
  120. # Compute x_0 from other known quantities
  121. rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
  122. x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1)
  123. # Dense + sparse op
  124. grad_sq = grad * grad
  125. grad_sum_sq.add_(grad_sq, alpha=lamb)
  126. grad_sum_sq_masked.add_(grad_sq, alpha=lamb)
  127. rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
  128. s.add_(grad, alpha=lamb)
  129. s_masked._values().add_(grad_val, alpha=lamb)
  130. # update masked copy of p
  131. p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
  132. # Copy updated masked p to dense p using an add operation
  133. p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
  134. p.add_(p_masked, alpha=-1)
  135. else:
  136. if momentum == 0:
  137. # Compute x_0 from other known quantities
  138. rms = grad_sum_sq.pow(1 / 3).add_(eps)
  139. x0 = p.addcdiv(s, rms, value=1)
  140. else:
  141. x0 = state['x0']
  142. # Accumulate second moments
  143. grad_sum_sq.addcmul_(grad, grad, value=lamb)
  144. rms = grad_sum_sq.pow(1 / 3).add_(eps)
  145. # Update s
  146. s.add_(grad, alpha=lamb)
  147. # Step
  148. if momentum == 0:
  149. p.copy_(x0.addcdiv(s, rms, value=-1))
  150. else:
  151. z = x0.addcdiv(s, rms, value=-1)
  152. # p is a moving average of z
  153. p.mul_(1 - ck).add_(z, alpha=ck)
  154. return loss