| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- """ PyTorch MADGRAD optimizer
- MADGRAD: https://arxiv.org/abs/2101.11075
- Code from: https://github.com/facebookresearch/madgrad
- """
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import math
- from typing import TYPE_CHECKING, Any, Callable, Optional
- import torch
- import torch.optim
- if TYPE_CHECKING:
- from torch.optim.optimizer import _params_t
- else:
- _params_t = Any
- class MADGRAD(torch.optim.Optimizer):
- """
- MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
- Optimization.
- .. _MADGRAD: https://arxiv.org/abs/2101.11075
- MADGRAD is a general purpose optimizer that can be used in place of SGD or
- Adam may converge faster and generalize better. Currently GPU-only.
- Typically, the same learning rate schedule that is used for SGD or Adam may
- be used. The overall learning rate is not comparable to either method and
- should be determined by a hyper-parameter sweep.
- MADGRAD requires less weight decay than other methods, often as little as
- zero. Momentum values used for SGD or Adam's beta1 should work here also.
- On sparse problems both weight_decay and momentum should be set to 0.
- Arguments:
- params (iterable):
- Iterable of parameters to optimize or dicts defining parameter groups.
- lr (float):
- Learning rate (default: 1e-2).
- momentum (float):
- Momentum value in the range [0,1) (default: 0.9).
- weight_decay (float):
- Weight decay, i.e. a L2 penalty (default: 0).
- eps (float):
- Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6).
- """
- def __init__(
- self,
- params: _params_t,
- lr: float = 1e-2,
- momentum: float = 0.9,
- weight_decay: float = 0,
- eps: float = 1e-6,
- decoupled_decay: bool = False,
- ):
- if momentum < 0 or momentum >= 1:
- raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
- if lr <= 0:
- raise ValueError(f"Learning rate {lr} must be positive")
- if weight_decay < 0:
- raise ValueError(f"Weight decay {weight_decay} must be non-negative")
- if eps < 0:
- raise ValueError(f"Eps must be non-negative")
- defaults = dict(
- lr=lr,
- eps=eps,
- momentum=momentum,
- weight_decay=weight_decay,
- decoupled_decay=decoupled_decay,
- )
- super().__init__(params, defaults)
- @property
- def supports_memory_efficient_fp16(self) -> bool:
- return False
- @property
- def supports_flat_params(self) -> bool:
- return True
- @torch.no_grad()
- def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
- """Performs a single optimization step.
- Arguments:
- closure (callable, optional): A closure that reevaluates the model and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
- for group in self.param_groups:
- eps = group['eps']
- lr = group['lr'] + eps
- weight_decay = group['weight_decay']
- momentum = group['momentum']
- ck = 1 - momentum
- for p in group["params"]:
- if p.grad is None:
- continue
- grad = p.grad
- if momentum != 0.0 and grad.is_sparse:
- raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
- state = self.state[p]
- if len(state) == 0:
- state['step'] = 0
- state['grad_sum_sq'] = torch.zeros_like(p)
- state['s'] = torch.zeros_like(p)
- if momentum != 0:
- state['x0'] = torch.clone(p).detach()
- state['step'] += 1
- grad_sum_sq = state['grad_sum_sq']
- s = state['s']
- lamb = lr * math.sqrt(state['step'])
- # Apply weight decay
- if weight_decay != 0:
- if group['decoupled_decay']:
- p.mul_(1.0 - group['lr'] * weight_decay)
- else:
- if grad.is_sparse:
- raise RuntimeError("weight_decay option is not compatible with sparse gradients")
- grad.add_(p, alpha=weight_decay)
- if grad.is_sparse:
- grad = grad.coalesce()
- grad_val = grad._values()
- p_masked = p.sparse_mask(grad)
- grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
- s_masked = s.sparse_mask(grad)
- # Compute x_0 from other known quantities
- rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
- x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1)
- # Dense + sparse op
- grad_sq = grad * grad
- grad_sum_sq.add_(grad_sq, alpha=lamb)
- grad_sum_sq_masked.add_(grad_sq, alpha=lamb)
- rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
- s.add_(grad, alpha=lamb)
- s_masked._values().add_(grad_val, alpha=lamb)
- # update masked copy of p
- p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
- # Copy updated masked p to dense p using an add operation
- p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
- p.add_(p_masked, alpha=-1)
- else:
- if momentum == 0:
- # Compute x_0 from other known quantities
- rms = grad_sum_sq.pow(1 / 3).add_(eps)
- x0 = p.addcdiv(s, rms, value=1)
- else:
- x0 = state['x0']
- # Accumulate second moments
- grad_sum_sq.addcmul_(grad, grad, value=lamb)
- rms = grad_sum_sq.pow(1 / 3).add_(eps)
- # Update s
- s.add_(grad, alpha=lamb)
- # Step
- if momentum == 0:
- p.copy_(x0.addcdiv(s, rms, value=-1))
- else:
- z = x0.addcdiv(s, rms, value=-1)
- # p is a moving average of z
- p.mul_(1 - ck).add_(z, alpha=ck)
- return loss
|