scheduler_factory.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. """ Scheduler Factory
  2. Hacked together by / Copyright 2021 Ross Wightman
  3. """
  4. from typing import List, Optional, Union
  5. from torch.optim import Optimizer
  6. from .cosine_lr import CosineLRScheduler
  7. from .multistep_lr import MultiStepLRScheduler
  8. from .plateau_lr import PlateauLRScheduler
  9. from .poly_lr import PolyLRScheduler
  10. from .step_lr import StepLRScheduler
  11. from .tanh_lr import TanhLRScheduler
  12. def scheduler_kwargs(cfg, decreasing_metric: Optional[bool] = None):
  13. """ cfg/argparse to kwargs helper
  14. Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
  15. """
  16. eval_metric = getattr(cfg, 'eval_metric', 'top1')
  17. if decreasing_metric is not None:
  18. plateau_mode = 'min' if decreasing_metric else 'max'
  19. else:
  20. plateau_mode = 'min' if 'loss' in eval_metric else 'max'
  21. kwargs = dict(
  22. sched=cfg.sched,
  23. num_epochs=getattr(cfg, 'epochs', 100),
  24. decay_epochs=getattr(cfg, 'decay_epochs', 30),
  25. decay_milestones=getattr(cfg, 'decay_milestones', [30, 60]),
  26. warmup_epochs=getattr(cfg, 'warmup_epochs', 5),
  27. cooldown_epochs=getattr(cfg, 'cooldown_epochs', 0),
  28. patience_epochs=getattr(cfg, 'patience_epochs', 10),
  29. decay_rate=getattr(cfg, 'decay_rate', 0.1),
  30. min_lr=getattr(cfg, 'min_lr', 0.),
  31. warmup_lr=getattr(cfg, 'warmup_lr', 1e-5),
  32. warmup_prefix=getattr(cfg, 'warmup_prefix', False),
  33. noise=getattr(cfg, 'lr_noise', None),
  34. noise_pct=getattr(cfg, 'lr_noise_pct', 0.67),
  35. noise_std=getattr(cfg, 'lr_noise_std', 1.),
  36. noise_seed=getattr(cfg, 'seed', 42),
  37. cycle_mul=getattr(cfg, 'lr_cycle_mul', 1.),
  38. cycle_decay=getattr(cfg, 'lr_cycle_decay', 0.1),
  39. cycle_limit=getattr(cfg, 'lr_cycle_limit', 1),
  40. k_decay=getattr(cfg, 'lr_k_decay', 1.0),
  41. plateau_mode=plateau_mode,
  42. step_on_epochs=not getattr(cfg, 'sched_on_updates', False),
  43. )
  44. return kwargs
  45. def create_scheduler(
  46. args,
  47. optimizer: Optimizer,
  48. updates_per_epoch: int = 0,
  49. ):
  50. return create_scheduler_v2(
  51. optimizer=optimizer,
  52. **scheduler_kwargs(args),
  53. updates_per_epoch=updates_per_epoch,
  54. )
  55. def create_scheduler_v2(
  56. optimizer: Optimizer,
  57. sched: str = 'cosine',
  58. num_epochs: int = 300,
  59. decay_epochs: int = 90,
  60. decay_milestones: List[int] = (90, 180, 270),
  61. cooldown_epochs: int = 0,
  62. patience_epochs: int = 10,
  63. decay_rate: float = 0.1,
  64. min_lr: float = 0.,
  65. warmup_lr: float = 1e-5,
  66. warmup_epochs: int = 0,
  67. warmup_prefix: bool = False,
  68. noise: Union[float, List[float]] = None,
  69. noise_pct: float = 0.67,
  70. noise_std: float = 1.,
  71. noise_seed: int = 42,
  72. cycle_mul: float = 1.,
  73. cycle_decay: float = 0.1,
  74. cycle_limit: int = 1,
  75. k_decay: float = 1.0,
  76. plateau_mode: str = 'max',
  77. step_on_epochs: bool = True,
  78. updates_per_epoch: int = 0,
  79. ):
  80. t_initial = num_epochs
  81. warmup_t = warmup_epochs
  82. decay_t = decay_epochs
  83. cooldown_t = cooldown_epochs
  84. if not step_on_epochs:
  85. assert updates_per_epoch > 0, 'updates_per_epoch must be set to number of dataloader batches'
  86. t_initial = t_initial * updates_per_epoch
  87. warmup_t = warmup_t * updates_per_epoch
  88. decay_t = decay_t * updates_per_epoch
  89. decay_milestones = [d * updates_per_epoch for d in decay_milestones]
  90. cooldown_t = cooldown_t * updates_per_epoch
  91. # warmup args
  92. warmup_args = dict(
  93. warmup_lr_init=warmup_lr,
  94. warmup_t=warmup_t,
  95. warmup_prefix=warmup_prefix,
  96. )
  97. # setup noise args for supporting schedulers
  98. if noise is not None:
  99. if isinstance(noise, (list, tuple)):
  100. noise_range = [n * t_initial for n in noise]
  101. if len(noise_range) == 1:
  102. noise_range = noise_range[0]
  103. else:
  104. noise_range = noise * t_initial
  105. else:
  106. noise_range = None
  107. noise_args = dict(
  108. noise_range_t=noise_range,
  109. noise_pct=noise_pct,
  110. noise_std=noise_std,
  111. noise_seed=noise_seed,
  112. )
  113. # setup cycle args for supporting schedulers
  114. cycle_args = dict(
  115. cycle_mul=cycle_mul,
  116. cycle_decay=cycle_decay,
  117. cycle_limit=cycle_limit,
  118. )
  119. lr_scheduler = None
  120. if sched == 'cosine':
  121. lr_scheduler = CosineLRScheduler(
  122. optimizer,
  123. t_initial=t_initial,
  124. lr_min=min_lr,
  125. t_in_epochs=step_on_epochs,
  126. **cycle_args,
  127. **warmup_args,
  128. **noise_args,
  129. k_decay=k_decay,
  130. )
  131. elif sched == 'tanh':
  132. lr_scheduler = TanhLRScheduler(
  133. optimizer,
  134. t_initial=t_initial,
  135. lr_min=min_lr,
  136. t_in_epochs=step_on_epochs,
  137. **cycle_args,
  138. **warmup_args,
  139. **noise_args,
  140. )
  141. elif sched == 'step':
  142. lr_scheduler = StepLRScheduler(
  143. optimizer,
  144. decay_t=decay_t,
  145. decay_rate=decay_rate,
  146. t_in_epochs=step_on_epochs,
  147. **warmup_args,
  148. **noise_args,
  149. )
  150. elif sched == 'multistep':
  151. lr_scheduler = MultiStepLRScheduler(
  152. optimizer,
  153. decay_t=decay_milestones,
  154. decay_rate=decay_rate,
  155. t_in_epochs=step_on_epochs,
  156. **warmup_args,
  157. **noise_args,
  158. )
  159. elif sched == 'plateau':
  160. assert step_on_epochs, 'Plateau LR only supports step per epoch.'
  161. warmup_args.pop('warmup_prefix', False)
  162. lr_scheduler = PlateauLRScheduler(
  163. optimizer,
  164. decay_rate=decay_rate,
  165. patience_t=patience_epochs,
  166. cooldown_t=0,
  167. **warmup_args,
  168. lr_min=min_lr,
  169. mode=plateau_mode,
  170. **noise_args,
  171. )
  172. elif sched == 'poly':
  173. lr_scheduler = PolyLRScheduler(
  174. optimizer,
  175. power=decay_rate, # overloading 'decay_rate' as polynomial power
  176. t_initial=t_initial,
  177. lr_min=min_lr,
  178. t_in_epochs=step_on_epochs,
  179. k_decay=k_decay,
  180. **cycle_args,
  181. **warmup_args,
  182. **noise_args,
  183. )
  184. if hasattr(lr_scheduler, 'get_cycle_length'):
  185. # For cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
  186. # NOTE: Warmup prefix added in get_cycle_lengths() if enabled
  187. t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
  188. if step_on_epochs:
  189. num_epochs = t_with_cycles_and_cooldown
  190. else:
  191. num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
  192. else:
  193. if warmup_prefix:
  194. num_epochs += warmup_epochs
  195. return lr_scheduler, num_epochs