cosine_lr.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. """ Cosine Scheduler
  2. Cosine LR schedule with warmup, cycle/restarts, noise, k-decay.
  3. Hacked together by / Copyright 2021 Ross Wightman
  4. """
  5. import logging
  6. import math
  7. import numpy as np
  8. import torch
  9. from typing import Tuple, List, Union
  10. from .scheduler import Scheduler
  11. _logger = logging.getLogger(__name__)
  12. class CosineLRScheduler(Scheduler):
  13. """
  14. Cosine decay with restarts.
  15. This is described in the paper https://arxiv.org/abs/1608.03983.
  16. Inspiration from
  17. https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
  18. k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
  19. """
  20. def __init__(
  21. self,
  22. optimizer: torch.optim.Optimizer,
  23. t_initial: int,
  24. lr_min: float = 0.,
  25. cycle_mul: float = 1.,
  26. cycle_decay: float = 1.,
  27. cycle_limit: int = 1,
  28. warmup_t: int = 0,
  29. warmup_lr_init: float = 0.,
  30. warmup_prefix: bool = False,
  31. t_in_epochs: bool = True,
  32. noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
  33. noise_pct: float = 0.67,
  34. noise_std: float = 1.0,
  35. noise_seed: int = 42,
  36. k_decay: float = 1.0,
  37. initialize: bool = True,
  38. ) -> None:
  39. super().__init__(
  40. optimizer,
  41. param_group_field="lr",
  42. t_in_epochs=t_in_epochs,
  43. noise_range_t=noise_range_t,
  44. noise_pct=noise_pct,
  45. noise_std=noise_std,
  46. noise_seed=noise_seed,
  47. initialize=initialize,
  48. )
  49. assert t_initial > 0
  50. assert lr_min >= 0
  51. if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
  52. _logger.warning(
  53. "Cosine annealing scheduler will have no effect on the learning "
  54. "rate since t_initial = t_mul = eta_mul = 1.")
  55. self.t_initial = t_initial
  56. self.lr_min = lr_min
  57. self.cycle_mul = cycle_mul
  58. self.cycle_decay = cycle_decay
  59. self.cycle_limit = cycle_limit
  60. self.warmup_t = warmup_t
  61. self.warmup_lr_init = warmup_lr_init
  62. self.warmup_prefix = warmup_prefix
  63. self.k_decay = k_decay
  64. if self.warmup_t:
  65. self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
  66. super().update_groups(self.warmup_lr_init)
  67. else:
  68. self.warmup_steps = [1 for _ in self.base_values]
  69. def _get_lr(self, t: int) -> List[float]:
  70. if t < self.warmup_t:
  71. lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
  72. else:
  73. if self.warmup_prefix:
  74. t = t - self.warmup_t
  75. if self.cycle_mul != 1:
  76. i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
  77. t_i = self.cycle_mul ** i * self.t_initial
  78. t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
  79. else:
  80. i = t // self.t_initial
  81. t_i = self.t_initial
  82. t_curr = t - (self.t_initial * i)
  83. gamma = self.cycle_decay ** i
  84. lr_max_values = [v * gamma for v in self.base_values]
  85. k = self.k_decay
  86. if i < self.cycle_limit:
  87. lrs = [
  88. self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k))
  89. for lr_max in lr_max_values
  90. ]
  91. else:
  92. lrs = [self.lr_min for _ in self.base_values]
  93. return lrs
  94. def get_cycle_length(self, cycles=0):
  95. cycles = max(1, cycles or self.cycle_limit)
  96. if self.cycle_mul == 1.0:
  97. t = self.t_initial * cycles
  98. else:
  99. t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
  100. return t + self.warmup_t if self.warmup_prefix else t