scheduler.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import abc
  2. from abc import ABC
  3. from typing import Any, Dict, List, Optional, Tuple, Union
  4. import torch
  5. class Scheduler(ABC):
  6. """ Parameter Scheduler Base Class
  7. A scheduler base class that can be used to schedule any optimizer parameter groups.
  8. Unlike the builtin PyTorch schedulers, this is intended to be consistently called
  9. * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
  10. * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
  11. The schedulers built on this should try to remain as stateless as possible (for simplicity).
  12. This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
  13. and -1 values for special behaviour. All epoch and update counts must be tracked in the training
  14. code and explicitly passed in to the schedulers on the corresponding step or step_update call.
  15. Based on ideas from:
  16. * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
  17. * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
  18. """
  19. def __init__(
  20. self,
  21. optimizer: torch.optim.Optimizer,
  22. param_group_field: str,
  23. t_in_epochs: bool = True,
  24. noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
  25. noise_type: str = 'normal',
  26. noise_pct: float = 0.67,
  27. noise_std: float = 1.0,
  28. noise_seed: Optional[int] = None,
  29. initialize: bool = True,
  30. ) -> None:
  31. self.optimizer = optimizer
  32. self.param_group_field = param_group_field
  33. self._initial_param_group_field = f"initial_{param_group_field}"
  34. if initialize:
  35. for i, group in enumerate(self.optimizer.param_groups):
  36. if param_group_field not in group:
  37. raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
  38. group.setdefault(self._initial_param_group_field, group[param_group_field])
  39. else:
  40. for i, group in enumerate(self.optimizer.param_groups):
  41. if self._initial_param_group_field not in group:
  42. raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
  43. self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
  44. self.metric = None # any point to having this for all?
  45. self.t_in_epochs = t_in_epochs
  46. self.noise_range_t = noise_range_t
  47. self.noise_pct = noise_pct
  48. self.noise_type = noise_type
  49. self.noise_std = noise_std
  50. self.noise_seed = noise_seed if noise_seed is not None else 42
  51. self.update_groups(self.base_values)
  52. def state_dict(self) -> Dict[str, Any]:
  53. return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
  54. def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
  55. self.__dict__.update(state_dict)
  56. @abc.abstractmethod
  57. def _get_lr(self, t: int) -> List[float]:
  58. pass
  59. def _get_values(self, t: int, on_epoch: bool = True) -> Optional[List[float]]:
  60. proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs)
  61. if not proceed:
  62. return None
  63. return self._get_lr(t)
  64. def step(self, epoch: int, metric: Optional[float] = None) -> None:
  65. self.metric = metric
  66. values = self._get_values(epoch, on_epoch=True)
  67. if values is not None:
  68. values = self._add_noise(values, epoch)
  69. self.update_groups(values)
  70. def step_update(self, num_updates: int, metric: Optional[float] = None) -> None:
  71. self.metric = metric
  72. values = self._get_values(num_updates, on_epoch=False)
  73. if values is not None:
  74. values = self._add_noise(values, num_updates)
  75. self.update_groups(values)
  76. def update_groups(self, values: Union[float, List[float]]) -> None:
  77. if not isinstance(values, (list, tuple)):
  78. values = [values] * len(self.optimizer.param_groups)
  79. for param_group, value in zip(self.optimizer.param_groups, values):
  80. if 'lr_scale' in param_group:
  81. param_group[self.param_group_field] = value * param_group['lr_scale']
  82. else:
  83. param_group[self.param_group_field] = value
  84. def _add_noise(self, lrs: List[float], t: int) -> List[float]:
  85. if self._is_apply_noise(t):
  86. noise = self._calculate_noise(t)
  87. lrs = [v + v * noise for v in lrs]
  88. return lrs
  89. def _is_apply_noise(self, t: int) -> bool:
  90. """Return True if scheduler in noise range."""
  91. apply_noise = False
  92. if self.noise_range_t is not None:
  93. if isinstance(self.noise_range_t, (list, tuple)):
  94. apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
  95. else:
  96. apply_noise = t >= self.noise_range_t
  97. return apply_noise
  98. def _calculate_noise(self, t) -> float:
  99. g = torch.Generator()
  100. g.manual_seed(self.noise_seed + t)
  101. if self.noise_type == 'normal':
  102. while True:
  103. # resample if noise out of percent limit, brute force but shouldn't spin much
  104. noise = torch.randn(1, generator=g).item()
  105. if abs(noise) < self.noise_pct:
  106. return noise
  107. else:
  108. noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
  109. return noise