lookahead.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. """ Lookahead Optimizer Wrapper.
  2. Implementation modified from: https://github.com/alphadl/lookahead.pytorch
  3. Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
  4. Hacked together by / Copyright 2020 Ross Wightman
  5. """
  6. from collections import OrderedDict
  7. from typing import Callable, Dict
  8. import torch
  9. from torch.optim.optimizer import Optimizer
  10. from collections import defaultdict
  11. class Lookahead(Optimizer):
  12. def __init__(self, base_optimizer, alpha=0.5, k=6):
  13. # NOTE super().__init__() not called on purpose
  14. self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
  15. self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
  16. if not 0.0 <= alpha <= 1.0:
  17. raise ValueError(f'Invalid slow update rate: {alpha}')
  18. if not 1 <= k:
  19. raise ValueError(f'Invalid lookahead steps: {k}')
  20. defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
  21. self._base_optimizer = base_optimizer
  22. self.param_groups = base_optimizer.param_groups
  23. self.defaults = base_optimizer.defaults
  24. self.defaults.update(defaults)
  25. self.state = defaultdict(dict)
  26. # manually add our defaults to the param groups
  27. for name, default in defaults.items():
  28. for group in self._base_optimizer.param_groups:
  29. group.setdefault(name, default)
  30. @torch.no_grad()
  31. def update_slow(self, group):
  32. for fast_p in group["params"]:
  33. if fast_p.grad is None:
  34. continue
  35. param_state = self._base_optimizer.state[fast_p]
  36. if 'lookahead_slow_buff' not in param_state:
  37. param_state['lookahead_slow_buff'] = torch.empty_like(fast_p)
  38. param_state['lookahead_slow_buff'].copy_(fast_p)
  39. slow = param_state['lookahead_slow_buff']
  40. slow.add_(fast_p - slow, alpha=group['lookahead_alpha'])
  41. fast_p.copy_(slow)
  42. def sync_lookahead(self):
  43. for group in self._base_optimizer.param_groups:
  44. self.update_slow(group)
  45. @torch.no_grad()
  46. def step(self, closure=None):
  47. loss = self._base_optimizer.step(closure)
  48. for group in self._base_optimizer.param_groups:
  49. group['lookahead_step'] += 1
  50. if group['lookahead_step'] % group['lookahead_k'] == 0:
  51. self.update_slow(group)
  52. return loss
  53. def state_dict(self):
  54. return self._base_optimizer.state_dict()
  55. def load_state_dict(self, state_dict):
  56. self._base_optimizer.load_state_dict(state_dict)
  57. self.param_groups = self._base_optimizer.param_groups