multistep_lr.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. """ MultiStep LR Scheduler
  2. Basic multi step LR schedule with warmup, noise.
  3. """
  4. import torch
  5. import bisect
  6. from timm.scheduler.scheduler import Scheduler
  7. from typing import List, Tuple, Union
  8. class MultiStepLRScheduler(Scheduler):
  9. """
  10. """
  11. def __init__(
  12. self,
  13. optimizer: torch.optim.Optimizer,
  14. decay_t: List[int],
  15. decay_rate: float = 1.,
  16. warmup_t: int = 0,
  17. warmup_lr_init: float = 0.,
  18. warmup_prefix: bool = True,
  19. t_in_epochs: bool = True,
  20. noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
  21. noise_pct: float = 0.67,
  22. noise_std: float = 1.0,
  23. noise_seed: int = 42,
  24. initialize: bool = True,
  25. ) -> None:
  26. super().__init__(
  27. optimizer,
  28. param_group_field="lr",
  29. t_in_epochs=t_in_epochs,
  30. noise_range_t=noise_range_t,
  31. noise_pct=noise_pct,
  32. noise_std=noise_std,
  33. noise_seed=noise_seed,
  34. initialize=initialize,
  35. )
  36. self.decay_t = decay_t
  37. self.decay_rate = decay_rate
  38. self.warmup_t = warmup_t
  39. self.warmup_lr_init = warmup_lr_init
  40. self.warmup_prefix = warmup_prefix
  41. if self.warmup_t:
  42. self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
  43. super().update_groups(self.warmup_lr_init)
  44. else:
  45. self.warmup_steps = [1 for _ in self.base_values]
  46. def get_curr_decay_steps(self, t):
  47. # find where in the array t goes,
  48. # assumes self.decay_t is sorted
  49. return bisect.bisect_right(self.decay_t, t + 1)
  50. def _get_lr(self, t: int) -> List[float]:
  51. if t < self.warmup_t:
  52. lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
  53. else:
  54. if self.warmup_prefix:
  55. t = t - self.warmup_t
  56. lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values]
  57. return lrs