step_lr.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. """ Step Scheduler
  2. Basic step LR schedule with warmup, noise.
  3. Hacked together by / Copyright 2020 Ross Wightman
  4. """
  5. import math
  6. import torch
  7. from typing import List, Tuple, Union
  8. from .scheduler import Scheduler
  9. class StepLRScheduler(Scheduler):
  10. """
  11. """
  12. def __init__(
  13. self,
  14. optimizer: torch.optim.Optimizer,
  15. decay_t: float,
  16. decay_rate: float = 1.,
  17. warmup_t: int = 0,
  18. warmup_lr_init: float = 0.,
  19. warmup_prefix: bool = True,
  20. t_in_epochs: bool = True,
  21. noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
  22. noise_pct: float = 0.67,
  23. noise_std: float = 1.0,
  24. noise_seed: int = 42,
  25. initialize: bool = True,
  26. ) -> None:
  27. super().__init__(
  28. optimizer,
  29. param_group_field="lr",
  30. t_in_epochs=t_in_epochs,
  31. noise_range_t=noise_range_t,
  32. noise_pct=noise_pct,
  33. noise_std=noise_std,
  34. noise_seed=noise_seed,
  35. initialize=initialize,
  36. )
  37. self.decay_t = decay_t
  38. self.decay_rate = decay_rate
  39. self.warmup_t = warmup_t
  40. self.warmup_lr_init = warmup_lr_init
  41. self.warmup_prefix = warmup_prefix
  42. if self.warmup_t:
  43. self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
  44. super().update_groups(self.warmup_lr_init)
  45. else:
  46. self.warmup_steps = [1 for _ in self.base_values]
  47. def _get_lr(self, t: int) -> List[float]:
  48. if t < self.warmup_t:
  49. lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
  50. else:
  51. if self.warmup_prefix:
  52. t = t - self.warmup_t
  53. lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
  54. return lrs