| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- """ Polynomial Scheduler
- Polynomial LR schedule with warmup, noise.
- Hacked together by / Copyright 2021 Ross Wightman
- """
- import math
- import logging
- from typing import List, Tuple, Union
- import torch
- from .scheduler import Scheduler
- _logger = logging.getLogger(__name__)
- class PolyLRScheduler(Scheduler):
- """ Polynomial LR Scheduler w/ warmup, noise, and k-decay
- k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
- """
- def __init__(
- self,
- optimizer: torch.optim.Optimizer,
- t_initial: int,
- power: float = 0.5,
- lr_min: float = 0.,
- cycle_mul: float = 1.,
- cycle_decay: float = 1.,
- cycle_limit: int = 1,
- warmup_t: int = 0,
- warmup_lr_init: float = 0.,
- warmup_prefix: bool = False,
- t_in_epochs: bool = True,
- noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
- noise_pct: float = 0.67,
- noise_std: float = 1.0,
- noise_seed: int = 42,
- k_decay: float = 1.0,
- initialize: bool = True,
- ) -> None:
- super().__init__(
- optimizer,
- param_group_field="lr",
- t_in_epochs=t_in_epochs,
- noise_range_t=noise_range_t,
- noise_pct=noise_pct,
- noise_std=noise_std,
- noise_seed=noise_seed,
- initialize=initialize
- )
- assert t_initial > 0
- assert lr_min >= 0
- if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
- _logger.warning("Cosine annealing scheduler will have no effect on the learning "
- "rate since t_initial = t_mul = eta_mul = 1.")
- self.t_initial = t_initial
- self.power = power
- self.lr_min = lr_min
- self.cycle_mul = cycle_mul
- self.cycle_decay = cycle_decay
- self.cycle_limit = cycle_limit
- self.warmup_t = warmup_t
- self.warmup_lr_init = warmup_lr_init
- self.warmup_prefix = warmup_prefix
- self.k_decay = k_decay
- if self.warmup_t:
- self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
- super().update_groups(self.warmup_lr_init)
- else:
- self.warmup_steps = [1 for _ in self.base_values]
- def _get_lr(self, t: int) -> List[float]:
- if t < self.warmup_t:
- lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
- else:
- if self.warmup_prefix:
- t = t - self.warmup_t
- if self.cycle_mul != 1:
- i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
- t_i = self.cycle_mul ** i * self.t_initial
- t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
- else:
- i = t // self.t_initial
- t_i = self.t_initial
- t_curr = t - (self.t_initial * i)
- gamma = self.cycle_decay ** i
- lr_max_values = [v * gamma for v in self.base_values]
- k = self.k_decay
- if i < self.cycle_limit:
- lrs = [
- self.lr_min + (lr_max - self.lr_min) * (1 - t_curr ** k / t_i ** k) ** self.power
- for lr_max in lr_max_values
- ]
- else:
- lrs = [self.lr_min for _ in self.base_values]
- return lrs
- def get_cycle_length(self, cycles=0):
- cycles = max(1, cycles or self.cycle_limit)
- if self.cycle_mul == 1.0:
- t = self.t_initial * cycles
- else:
- t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
- return t + self.warmup_t if self.warmup_prefix else t
|