adabelief.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import math
  2. import torch
  3. from torch.optim.optimizer import Optimizer
  4. class AdaBelief(Optimizer):
  5. r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch
  6. Arguments:
  7. params (iterable): iterable of parameters to optimize or dicts defining
  8. parameter groups
  9. lr (float, optional): learning rate (default: 1e-3)
  10. betas (Tuple[float, float], optional): coefficients used for computing
  11. running averages of gradient and its square (default: (0.9, 0.999))
  12. eps (float, optional): term added to the denominator to improve
  13. numerical stability (default: 1e-16)
  14. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  15. amsgrad (boolean, optional): whether to use the AMSGrad variant of this
  16. algorithm from the paper `On the Convergence of Adam and Beyond`_
  17. (default: False)
  18. decoupled_decay (boolean, optional): (default: True) If set as True, then
  19. the optimizer uses decoupled weight decay as in AdamW
  20. fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
  21. is set as True.
  22. When fixed_decay == True, the weight decay is performed as
  23. $W_{new} = W_{old} - W_{old} \times decay$.
  24. When fixed_decay == False, the weight decay is performed as
  25. $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the
  26. weight decay ratio decreases with learning rate (lr).
  27. rectify (boolean, optional): (default: True) If set as True, then perform the rectified
  28. update similar to RAdam
  29. degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update
  30. when variance of gradient is high
  31. reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020
  32. For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer'
  33. For example train/args for EfficientNet see these gists
  34. - link to train_script: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037
  35. - link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3
  36. """
  37. def __init__(
  38. self,
  39. params,
  40. lr=1e-3,
  41. betas=(0.9, 0.999),
  42. eps=1e-16,
  43. weight_decay=0,
  44. amsgrad=False,
  45. decoupled_decay=True,
  46. fixed_decay=False,
  47. rectify=True,
  48. degenerated_to_sgd=True,
  49. ):
  50. if not 0.0 <= lr:
  51. raise ValueError("Invalid learning rate: {}".format(lr))
  52. if not 0.0 <= eps:
  53. raise ValueError("Invalid epsilon value: {}".format(eps))
  54. if not 0.0 <= betas[0] < 1.0:
  55. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  56. if not 0.0 <= betas[1] < 1.0:
  57. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  58. if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
  59. for param in params:
  60. if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
  61. param['buffer'] = [[None, None, None] for _ in range(10)]
  62. defaults = dict(
  63. lr=lr,
  64. betas=betas,
  65. eps=eps,
  66. weight_decay=weight_decay,
  67. amsgrad=amsgrad,
  68. degenerated_to_sgd=degenerated_to_sgd,
  69. decoupled_decay=decoupled_decay,
  70. rectify=rectify,
  71. fixed_decay=fixed_decay,
  72. buffer=[[None, None, None] for _ in range(10)]
  73. )
  74. super(AdaBelief, self).__init__(params, defaults)
  75. def __setstate__(self, state):
  76. super(AdaBelief, self).__setstate__(state)
  77. for group in self.param_groups:
  78. group.setdefault('amsgrad', False)
  79. @torch.no_grad()
  80. def reset(self):
  81. for group in self.param_groups:
  82. for p in group['params']:
  83. state = self.state[p]
  84. amsgrad = group['amsgrad']
  85. # State initialization
  86. state['step'] = 0
  87. # Exponential moving average of gradient values
  88. state['exp_avg'] = torch.zeros_like(p)
  89. # Exponential moving average of squared gradient values
  90. state['exp_avg_var'] = torch.zeros_like(p)
  91. if amsgrad:
  92. # Maintains max of all exp. moving avg. of sq. grad. values
  93. state['max_exp_avg_var'] = torch.zeros_like(p)
  94. @torch.no_grad()
  95. def step(self, closure=None):
  96. """Performs a single optimization step.
  97. Arguments:
  98. closure (callable, optional): A closure that reevaluates the model
  99. and returns the loss.
  100. """
  101. loss = None
  102. if closure is not None:
  103. with torch.enable_grad():
  104. loss = closure()
  105. for group in self.param_groups:
  106. for p in group['params']:
  107. if p.grad is None:
  108. continue
  109. grad = p.grad
  110. if grad.dtype in {torch.float16, torch.bfloat16}:
  111. grad = grad.float()
  112. if grad.is_sparse:
  113. raise RuntimeError(
  114. 'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
  115. p_fp32 = p
  116. if p.dtype in {torch.float16, torch.bfloat16}:
  117. p_fp32 = p_fp32.float()
  118. amsgrad = group['amsgrad']
  119. beta1, beta2 = group['betas']
  120. state = self.state[p]
  121. # State initialization
  122. if len(state) == 0:
  123. state['step'] = 0
  124. # Exponential moving average of gradient values
  125. state['exp_avg'] = torch.zeros_like(p_fp32)
  126. # Exponential moving average of squared gradient values
  127. state['exp_avg_var'] = torch.zeros_like(p_fp32)
  128. if amsgrad:
  129. # Maintains max of all exp. moving avg. of sq. grad. values
  130. state['max_exp_avg_var'] = torch.zeros_like(p_fp32)
  131. # perform weight decay, check if decoupled weight decay
  132. if group['decoupled_decay']:
  133. if not group['fixed_decay']:
  134. p_fp32.mul_(1.0 - group['lr'] * group['weight_decay'])
  135. else:
  136. p_fp32.mul_(1.0 - group['weight_decay'])
  137. else:
  138. if group['weight_decay'] != 0:
  139. grad.add_(p_fp32, alpha=group['weight_decay'])
  140. # get current state variable
  141. exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
  142. state['step'] += 1
  143. bias_correction1 = 1 - beta1 ** state['step']
  144. bias_correction2 = 1 - beta2 ** state['step']
  145. # Update first and second moment running average
  146. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
  147. grad_residual = grad - exp_avg
  148. exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
  149. if amsgrad:
  150. max_exp_avg_var = state['max_exp_avg_var']
  151. # Maintains the maximum of all 2nd moment running avg. till now
  152. torch.max(max_exp_avg_var, exp_avg_var.add_(group['eps']), out=max_exp_avg_var)
  153. # Use the max. for normalizing running avg. of gradient
  154. denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
  155. else:
  156. denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
  157. # update
  158. if not group['rectify']:
  159. # Default update
  160. step_size = group['lr'] / bias_correction1
  161. p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
  162. else:
  163. # Rectified update, forked from RAdam
  164. buffered = group['buffer'][int(state['step'] % 10)]
  165. if state['step'] == buffered[0]:
  166. num_sma, step_size = buffered[1], buffered[2]
  167. else:
  168. buffered[0] = state['step']
  169. beta2_t = beta2 ** state['step']
  170. num_sma_max = 2 / (1 - beta2) - 1
  171. num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
  172. buffered[1] = num_sma
  173. # more conservative since it's an approximated value
  174. if num_sma >= 5:
  175. step_size = math.sqrt(
  176. (1 - beta2_t) *
  177. (num_sma - 4) / (num_sma_max - 4) *
  178. (num_sma - 2) / num_sma *
  179. num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step'])
  180. elif group['degenerated_to_sgd']:
  181. step_size = 1.0 / (1 - beta1 ** state['step'])
  182. else:
  183. step_size = -1
  184. buffered[2] = step_size
  185. if num_sma >= 5:
  186. denom = exp_avg_var.sqrt().add_(group['eps'])
  187. p_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
  188. elif step_size > 0:
  189. p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
  190. if p.dtype in {torch.float16, torch.bfloat16}:
  191. p.copy_(p_fp32)
  192. return loss