plateau_lr.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. """ Plateau Scheduler
  2. Adapts PyTorch plateau scheduler and allows application of noise, warmup.
  3. Hacked together by / Copyright 2020 Ross Wightman
  4. """
  5. import torch
  6. from typing import Any, Dict, List, Optional, Tuple, Union
  7. from .scheduler import Scheduler
  8. class PlateauLRScheduler(Scheduler):
  9. """Decay the LR by a factor every time the validation loss plateaus."""
  10. def __init__(
  11. self,
  12. optimizer: torch.optim.Optimizer,
  13. decay_rate: float = 0.1,
  14. patience_t: int = 10,
  15. threshold: float = 1e-4,
  16. cooldown_t: int = 0,
  17. warmup_t: int = 0,
  18. warmup_lr_init: float = 0.,
  19. lr_min: float = 0.,
  20. mode: str = 'max',
  21. noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
  22. noise_type: str = 'normal',
  23. noise_pct: float = 0.67,
  24. noise_std: float = 1.0,
  25. noise_seed: Optional[int] = None,
  26. initialize: bool = True,
  27. ) -> None:
  28. super().__init__(
  29. optimizer,
  30. 'lr',
  31. noise_range_t=noise_range_t,
  32. noise_type=noise_type,
  33. noise_pct=noise_pct,
  34. noise_std=noise_std,
  35. noise_seed=noise_seed,
  36. initialize=initialize,
  37. )
  38. self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
  39. self.optimizer,
  40. patience=patience_t,
  41. factor=decay_rate,
  42. threshold=threshold,
  43. cooldown=cooldown_t,
  44. mode=mode,
  45. min_lr=lr_min,
  46. )
  47. self.warmup_t = warmup_t
  48. self.warmup_lr_init = warmup_lr_init
  49. if self.warmup_t:
  50. self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
  51. super().update_groups(self.warmup_lr_init)
  52. else:
  53. self.warmup_steps = [1 for _ in self.base_values]
  54. self.restore_lr = None
  55. def state_dict(self) -> Dict[str, Any]:
  56. return {
  57. 'best': self.lr_scheduler.best,
  58. 'last_epoch': self.lr_scheduler.last_epoch,
  59. }
  60. def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
  61. self.lr_scheduler.best = state_dict['best']
  62. if 'last_epoch' in state_dict:
  63. self.lr_scheduler.last_epoch = state_dict['last_epoch']
  64. # override the base class step fn completely
  65. def step(self, epoch: int, metric: Optional[float] = None) -> None:
  66. if epoch <= self.warmup_t:
  67. lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
  68. super().update_groups(lrs)
  69. else:
  70. if self.restore_lr is not None:
  71. # restore actual LR from before our last noise perturbation before stepping base
  72. for i, param_group in enumerate(self.optimizer.param_groups):
  73. param_group['lr'] = self.restore_lr[i]
  74. self.restore_lr = None
  75. # step the base scheduler if metric given
  76. if metric is not None:
  77. self.lr_scheduler.step(metric)
  78. if self._is_apply_noise(epoch):
  79. self._apply_noise(epoch)
  80. def step_update(self, num_updates: int, metric: Optional[float] = None):
  81. return None
  82. def _apply_noise(self, epoch: int) -> None:
  83. noise = self._calculate_noise(epoch)
  84. # apply the noise on top of previous LR, cache the old value so we can restore for normal
  85. # stepping of base scheduler
  86. restore_lr = []
  87. for i, param_group in enumerate(self.optimizer.param_groups):
  88. old_lr = float(param_group['lr'])
  89. restore_lr.append(old_lr)
  90. new_lr = old_lr + old_lr * noise
  91. param_group['lr'] = new_lr
  92. self.restore_lr = restore_lr
  93. def _get_lr(self, t: int) -> List[float]:
  94. assert False, 'should not be called as step is overridden'