tanh_lr.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. """ TanH Scheduler
  2. TanH schedule with warmup, cycle/restarts, noise.
  3. Hacked together by / Copyright 2021 Ross Wightman
  4. """
  5. import logging
  6. import math
  7. import numpy as np
  8. import torch
  9. from typing import List, Tuple, Union
  10. from .scheduler import Scheduler
  11. _logger = logging.getLogger(__name__)
  12. class TanhLRScheduler(Scheduler):
  13. """
  14. Hyberbolic-Tangent decay with restarts.
  15. This is described in the paper https://arxiv.org/abs/1806.01593
  16. """
  17. def __init__(
  18. self,
  19. optimizer: torch.optim.Optimizer,
  20. t_initial: int,
  21. lb: float = -7.,
  22. ub: float = 3.,
  23. lr_min: float = 0.,
  24. cycle_mul: float = 1.,
  25. cycle_decay: float = 1.,
  26. cycle_limit: int = 1,
  27. warmup_t: int = 0,
  28. warmup_lr_init: float = 0.,
  29. warmup_prefix: bool = False,
  30. t_in_epochs: bool = True,
  31. noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
  32. noise_pct: float = 0.67,
  33. noise_std: float = 1.0,
  34. noise_seed: int = 42,
  35. initialize: bool = True,
  36. ) -> None:
  37. super().__init__(
  38. optimizer,
  39. param_group_field="lr",
  40. t_in_epochs=t_in_epochs,
  41. noise_range_t=noise_range_t,
  42. noise_pct=noise_pct,
  43. noise_std=noise_std,
  44. noise_seed=noise_seed,
  45. initialize=initialize,
  46. )
  47. assert t_initial > 0
  48. assert lr_min >= 0
  49. assert lb < ub
  50. assert cycle_limit >= 0
  51. assert warmup_t >= 0
  52. assert warmup_lr_init >= 0
  53. self.lb = lb
  54. self.ub = ub
  55. self.t_initial = t_initial
  56. self.lr_min = lr_min
  57. self.cycle_mul = cycle_mul
  58. self.cycle_decay = cycle_decay
  59. self.cycle_limit = cycle_limit
  60. self.warmup_t = warmup_t
  61. self.warmup_lr_init = warmup_lr_init
  62. self.warmup_prefix = warmup_prefix
  63. if self.warmup_t:
  64. t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
  65. self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
  66. super().update_groups(self.warmup_lr_init)
  67. else:
  68. self.warmup_steps = [1 for _ in self.base_values]
  69. def _get_lr(self, t: int) -> List[float]:
  70. if t < self.warmup_t:
  71. lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
  72. else:
  73. if self.warmup_prefix:
  74. t = t - self.warmup_t
  75. if self.cycle_mul != 1:
  76. i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
  77. t_i = self.cycle_mul ** i * self.t_initial
  78. t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
  79. else:
  80. i = t // self.t_initial
  81. t_i = self.t_initial
  82. t_curr = t - (self.t_initial * i)
  83. if i < self.cycle_limit:
  84. gamma = self.cycle_decay ** i
  85. lr_max_values = [v * gamma for v in self.base_values]
  86. tr = t_curr / t_i
  87. lrs = [
  88. self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr))
  89. for lr_max in lr_max_values
  90. ]
  91. else:
  92. lrs = [self.lr_min for _ in self.base_values]
  93. return lrs
  94. def get_cycle_length(self, cycles=0):
  95. cycles = max(1, cycles or self.cycle_limit)
  96. if self.cycle_mul == 1.0:
  97. t = self.t_initial * cycles
  98. else:
  99. t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
  100. return t + self.warmup_t if self.warmup_prefix else t