| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- """ Plateau Scheduler
- Adapts PyTorch plateau scheduler and allows application of noise, warmup.
- Hacked together by / Copyright 2020 Ross Wightman
- """
- import torch
- from typing import Any, Dict, List, Optional, Tuple, Union
- from .scheduler import Scheduler
- class PlateauLRScheduler(Scheduler):
- """Decay the LR by a factor every time the validation loss plateaus."""
- def __init__(
- self,
- optimizer: torch.optim.Optimizer,
- decay_rate: float = 0.1,
- patience_t: int = 10,
- threshold: float = 1e-4,
- cooldown_t: int = 0,
- warmup_t: int = 0,
- warmup_lr_init: float = 0.,
- lr_min: float = 0.,
- mode: str = 'max',
- noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
- noise_type: str = 'normal',
- noise_pct: float = 0.67,
- noise_std: float = 1.0,
- noise_seed: Optional[int] = None,
- initialize: bool = True,
- ) -> None:
- super().__init__(
- optimizer,
- 'lr',
- noise_range_t=noise_range_t,
- noise_type=noise_type,
- noise_pct=noise_pct,
- noise_std=noise_std,
- noise_seed=noise_seed,
- initialize=initialize,
- )
- self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
- self.optimizer,
- patience=patience_t,
- factor=decay_rate,
- threshold=threshold,
- cooldown=cooldown_t,
- mode=mode,
- min_lr=lr_min,
- )
- self.warmup_t = warmup_t
- self.warmup_lr_init = warmup_lr_init
- if self.warmup_t:
- self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
- super().update_groups(self.warmup_lr_init)
- else:
- self.warmup_steps = [1 for _ in self.base_values]
- self.restore_lr = None
- def state_dict(self) -> Dict[str, Any]:
- return {
- 'best': self.lr_scheduler.best,
- 'last_epoch': self.lr_scheduler.last_epoch,
- }
- def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
- self.lr_scheduler.best = state_dict['best']
- if 'last_epoch' in state_dict:
- self.lr_scheduler.last_epoch = state_dict['last_epoch']
- # override the base class step fn completely
- def step(self, epoch: int, metric: Optional[float] = None) -> None:
- if epoch <= self.warmup_t:
- lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
- super().update_groups(lrs)
- else:
- if self.restore_lr is not None:
- # restore actual LR from before our last noise perturbation before stepping base
- for i, param_group in enumerate(self.optimizer.param_groups):
- param_group['lr'] = self.restore_lr[i]
- self.restore_lr = None
- # step the base scheduler if metric given
- if metric is not None:
- self.lr_scheduler.step(metric)
- if self._is_apply_noise(epoch):
- self._apply_noise(epoch)
- def step_update(self, num_updates: int, metric: Optional[float] = None):
- return None
- def _apply_noise(self, epoch: int) -> None:
- noise = self._calculate_noise(epoch)
- # apply the noise on top of previous LR, cache the old value so we can restore for normal
- # stepping of base scheduler
- restore_lr = []
- for i, param_group in enumerate(self.optimizer.param_groups):
- old_lr = float(param_group['lr'])
- restore_lr.append(old_lr)
- new_lr = old_lr + old_lr * noise
- param_group['lr'] = new_lr
- self.restore_lr = restore_lr
- def _get_lr(self, t: int) -> List[float]:
- assert False, 'should not be called as step is overridden'
|