poly_lr.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. """ Polynomial Scheduler
  2. Polynomial LR schedule with warmup, noise.
  3. Hacked together by / Copyright 2021 Ross Wightman
  4. """
  5. import math
  6. import logging
  7. from typing import List, Tuple, Union
  8. import torch
  9. from .scheduler import Scheduler
  10. _logger = logging.getLogger(__name__)
  11. class PolyLRScheduler(Scheduler):
  12. """ Polynomial LR Scheduler w/ warmup, noise, and k-decay
  13. k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
  14. """
  15. def __init__(
  16. self,
  17. optimizer: torch.optim.Optimizer,
  18. t_initial: int,
  19. power: float = 0.5,
  20. lr_min: float = 0.,
  21. cycle_mul: float = 1.,
  22. cycle_decay: float = 1.,
  23. cycle_limit: int = 1,
  24. warmup_t: int = 0,
  25. warmup_lr_init: float = 0.,
  26. warmup_prefix: bool = False,
  27. t_in_epochs: bool = True,
  28. noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
  29. noise_pct: float = 0.67,
  30. noise_std: float = 1.0,
  31. noise_seed: int = 42,
  32. k_decay: float = 1.0,
  33. initialize: bool = True,
  34. ) -> None:
  35. super().__init__(
  36. optimizer,
  37. param_group_field="lr",
  38. t_in_epochs=t_in_epochs,
  39. noise_range_t=noise_range_t,
  40. noise_pct=noise_pct,
  41. noise_std=noise_std,
  42. noise_seed=noise_seed,
  43. initialize=initialize
  44. )
  45. assert t_initial > 0
  46. assert lr_min >= 0
  47. if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
  48. _logger.warning("Cosine annealing scheduler will have no effect on the learning "
  49. "rate since t_initial = t_mul = eta_mul = 1.")
  50. self.t_initial = t_initial
  51. self.power = power
  52. self.lr_min = lr_min
  53. self.cycle_mul = cycle_mul
  54. self.cycle_decay = cycle_decay
  55. self.cycle_limit = cycle_limit
  56. self.warmup_t = warmup_t
  57. self.warmup_lr_init = warmup_lr_init
  58. self.warmup_prefix = warmup_prefix
  59. self.k_decay = k_decay
  60. if self.warmup_t:
  61. self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
  62. super().update_groups(self.warmup_lr_init)
  63. else:
  64. self.warmup_steps = [1 for _ in self.base_values]
  65. def _get_lr(self, t: int) -> List[float]:
  66. if t < self.warmup_t:
  67. lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
  68. else:
  69. if self.warmup_prefix:
  70. t = t - self.warmup_t
  71. if self.cycle_mul != 1:
  72. i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
  73. t_i = self.cycle_mul ** i * self.t_initial
  74. t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
  75. else:
  76. i = t // self.t_initial
  77. t_i = self.t_initial
  78. t_curr = t - (self.t_initial * i)
  79. gamma = self.cycle_decay ** i
  80. lr_max_values = [v * gamma for v in self.base_values]
  81. k = self.k_decay
  82. if i < self.cycle_limit:
  83. lrs = [
  84. self.lr_min + (lr_max - self.lr_min) * (1 - t_curr ** k / t_i ** k) ** self.power
  85. for lr_max in lr_max_values
  86. ]
  87. else:
  88. lrs = [self.lr_min for _ in self.base_values]
  89. return lrs
  90. def get_cycle_length(self, cycles=0):
  91. cycles = max(1, cycles or self.cycle_limit)
  92. if self.cycle_mul == 1.0:
  93. t = self.t_initial * cycles
  94. else:
  95. t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
  96. return t + self.warmup_t if self.warmup_prefix else t