| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- """ Scheduler Factory
- Hacked together by / Copyright 2021 Ross Wightman
- """
- from typing import List, Optional, Union
- from torch.optim import Optimizer
- from .cosine_lr import CosineLRScheduler
- from .multistep_lr import MultiStepLRScheduler
- from .plateau_lr import PlateauLRScheduler
- from .poly_lr import PolyLRScheduler
- from .step_lr import StepLRScheduler
- from .tanh_lr import TanhLRScheduler
- def scheduler_kwargs(cfg, decreasing_metric: Optional[bool] = None):
- """ cfg/argparse to kwargs helper
- Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
- """
- eval_metric = getattr(cfg, 'eval_metric', 'top1')
- if decreasing_metric is not None:
- plateau_mode = 'min' if decreasing_metric else 'max'
- else:
- plateau_mode = 'min' if 'loss' in eval_metric else 'max'
- kwargs = dict(
- sched=cfg.sched,
- num_epochs=getattr(cfg, 'epochs', 100),
- decay_epochs=getattr(cfg, 'decay_epochs', 30),
- decay_milestones=getattr(cfg, 'decay_milestones', [30, 60]),
- warmup_epochs=getattr(cfg, 'warmup_epochs', 5),
- cooldown_epochs=getattr(cfg, 'cooldown_epochs', 0),
- patience_epochs=getattr(cfg, 'patience_epochs', 10),
- decay_rate=getattr(cfg, 'decay_rate', 0.1),
- min_lr=getattr(cfg, 'min_lr', 0.),
- warmup_lr=getattr(cfg, 'warmup_lr', 1e-5),
- warmup_prefix=getattr(cfg, 'warmup_prefix', False),
- noise=getattr(cfg, 'lr_noise', None),
- noise_pct=getattr(cfg, 'lr_noise_pct', 0.67),
- noise_std=getattr(cfg, 'lr_noise_std', 1.),
- noise_seed=getattr(cfg, 'seed', 42),
- cycle_mul=getattr(cfg, 'lr_cycle_mul', 1.),
- cycle_decay=getattr(cfg, 'lr_cycle_decay', 0.1),
- cycle_limit=getattr(cfg, 'lr_cycle_limit', 1),
- k_decay=getattr(cfg, 'lr_k_decay', 1.0),
- plateau_mode=plateau_mode,
- step_on_epochs=not getattr(cfg, 'sched_on_updates', False),
- )
- return kwargs
- def create_scheduler(
- args,
- optimizer: Optimizer,
- updates_per_epoch: int = 0,
- ):
- return create_scheduler_v2(
- optimizer=optimizer,
- **scheduler_kwargs(args),
- updates_per_epoch=updates_per_epoch,
- )
- def create_scheduler_v2(
- optimizer: Optimizer,
- sched: str = 'cosine',
- num_epochs: int = 300,
- decay_epochs: int = 90,
- decay_milestones: List[int] = (90, 180, 270),
- cooldown_epochs: int = 0,
- patience_epochs: int = 10,
- decay_rate: float = 0.1,
- min_lr: float = 0.,
- warmup_lr: float = 1e-5,
- warmup_epochs: int = 0,
- warmup_prefix: bool = False,
- noise: Union[float, List[float]] = None,
- noise_pct: float = 0.67,
- noise_std: float = 1.,
- noise_seed: int = 42,
- cycle_mul: float = 1.,
- cycle_decay: float = 0.1,
- cycle_limit: int = 1,
- k_decay: float = 1.0,
- plateau_mode: str = 'max',
- step_on_epochs: bool = True,
- updates_per_epoch: int = 0,
- ):
- t_initial = num_epochs
- warmup_t = warmup_epochs
- decay_t = decay_epochs
- cooldown_t = cooldown_epochs
- if not step_on_epochs:
- assert updates_per_epoch > 0, 'updates_per_epoch must be set to number of dataloader batches'
- t_initial = t_initial * updates_per_epoch
- warmup_t = warmup_t * updates_per_epoch
- decay_t = decay_t * updates_per_epoch
- decay_milestones = [d * updates_per_epoch for d in decay_milestones]
- cooldown_t = cooldown_t * updates_per_epoch
- # warmup args
- warmup_args = dict(
- warmup_lr_init=warmup_lr,
- warmup_t=warmup_t,
- warmup_prefix=warmup_prefix,
- )
- # setup noise args for supporting schedulers
- if noise is not None:
- if isinstance(noise, (list, tuple)):
- noise_range = [n * t_initial for n in noise]
- if len(noise_range) == 1:
- noise_range = noise_range[0]
- else:
- noise_range = noise * t_initial
- else:
- noise_range = None
- noise_args = dict(
- noise_range_t=noise_range,
- noise_pct=noise_pct,
- noise_std=noise_std,
- noise_seed=noise_seed,
- )
- # setup cycle args for supporting schedulers
- cycle_args = dict(
- cycle_mul=cycle_mul,
- cycle_decay=cycle_decay,
- cycle_limit=cycle_limit,
- )
- lr_scheduler = None
- if sched == 'cosine':
- lr_scheduler = CosineLRScheduler(
- optimizer,
- t_initial=t_initial,
- lr_min=min_lr,
- t_in_epochs=step_on_epochs,
- **cycle_args,
- **warmup_args,
- **noise_args,
- k_decay=k_decay,
- )
- elif sched == 'tanh':
- lr_scheduler = TanhLRScheduler(
- optimizer,
- t_initial=t_initial,
- lr_min=min_lr,
- t_in_epochs=step_on_epochs,
- **cycle_args,
- **warmup_args,
- **noise_args,
- )
- elif sched == 'step':
- lr_scheduler = StepLRScheduler(
- optimizer,
- decay_t=decay_t,
- decay_rate=decay_rate,
- t_in_epochs=step_on_epochs,
- **warmup_args,
- **noise_args,
- )
- elif sched == 'multistep':
- lr_scheduler = MultiStepLRScheduler(
- optimizer,
- decay_t=decay_milestones,
- decay_rate=decay_rate,
- t_in_epochs=step_on_epochs,
- **warmup_args,
- **noise_args,
- )
- elif sched == 'plateau':
- assert step_on_epochs, 'Plateau LR only supports step per epoch.'
- warmup_args.pop('warmup_prefix', False)
- lr_scheduler = PlateauLRScheduler(
- optimizer,
- decay_rate=decay_rate,
- patience_t=patience_epochs,
- cooldown_t=0,
- **warmup_args,
- lr_min=min_lr,
- mode=plateau_mode,
- **noise_args,
- )
- elif sched == 'poly':
- lr_scheduler = PolyLRScheduler(
- optimizer,
- power=decay_rate, # overloading 'decay_rate' as polynomial power
- t_initial=t_initial,
- lr_min=min_lr,
- t_in_epochs=step_on_epochs,
- k_decay=k_decay,
- **cycle_args,
- **warmup_args,
- **noise_args,
- )
- if hasattr(lr_scheduler, 'get_cycle_length'):
- # For cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
- # NOTE: Warmup prefix added in get_cycle_lengths() if enabled
- t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
- if step_on_epochs:
- num_epochs = t_with_cycles_and_cooldown
- else:
- num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
- else:
- if warmup_prefix:
- num_epochs += warmup_epochs
- return lr_scheduler, num_epochs
|