| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- """ Lookahead Optimizer Wrapper.
- Implementation modified from: https://github.com/alphadl/lookahead.pytorch
- Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
- Hacked together by / Copyright 2020 Ross Wightman
- """
- from collections import OrderedDict
- from typing import Callable, Dict
- import torch
- from torch.optim.optimizer import Optimizer
- from collections import defaultdict
- class Lookahead(Optimizer):
- def __init__(self, base_optimizer, alpha=0.5, k=6):
- # NOTE super().__init__() not called on purpose
- self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
- self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
- if not 0.0 <= alpha <= 1.0:
- raise ValueError(f'Invalid slow update rate: {alpha}')
- if not 1 <= k:
- raise ValueError(f'Invalid lookahead steps: {k}')
- defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
- self._base_optimizer = base_optimizer
- self.param_groups = base_optimizer.param_groups
- self.defaults = base_optimizer.defaults
- self.defaults.update(defaults)
- self.state = defaultdict(dict)
- # manually add our defaults to the param groups
- for name, default in defaults.items():
- for group in self._base_optimizer.param_groups:
- group.setdefault(name, default)
- @torch.no_grad()
- def update_slow(self, group):
- for fast_p in group["params"]:
- if fast_p.grad is None:
- continue
- param_state = self._base_optimizer.state[fast_p]
- if 'lookahead_slow_buff' not in param_state:
- param_state['lookahead_slow_buff'] = torch.empty_like(fast_p)
- param_state['lookahead_slow_buff'].copy_(fast_p)
- slow = param_state['lookahead_slow_buff']
- slow.add_(fast_p - slow, alpha=group['lookahead_alpha'])
- fast_p.copy_(slow)
- def sync_lookahead(self):
- for group in self._base_optimizer.param_groups:
- self.update_slow(group)
- @torch.no_grad()
- def step(self, closure=None):
- loss = self._base_optimizer.step(closure)
- for group in self._base_optimizer.param_groups:
- group['lookahead_step'] += 1
- if group['lookahead_step'] % group['lookahead_k'] == 0:
- self.update_slow(group)
- return loss
- def state_dict(self):
- return self._base_optimizer.state_dict()
- def load_state_dict(self, state_dict):
- self._base_optimizer.load_state_dict(state_dict)
- self.param_groups = self._base_optimizer.param_groups
|