| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- import abc
- from abc import ABC
- from typing import Any, Dict, List, Optional, Tuple, Union
- import torch
- class Scheduler(ABC):
- """ Parameter Scheduler Base Class
- A scheduler base class that can be used to schedule any optimizer parameter groups.
- Unlike the builtin PyTorch schedulers, this is intended to be consistently called
- * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
- * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
- The schedulers built on this should try to remain as stateless as possible (for simplicity).
- This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
- and -1 values for special behaviour. All epoch and update counts must be tracked in the training
- code and explicitly passed in to the schedulers on the corresponding step or step_update call.
- Based on ideas from:
- * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
- * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
- """
- def __init__(
- self,
- optimizer: torch.optim.Optimizer,
- param_group_field: str,
- t_in_epochs: bool = True,
- noise_range_t: Union[List[int], Tuple[int, int], int, None] = None,
- noise_type: str = 'normal',
- noise_pct: float = 0.67,
- noise_std: float = 1.0,
- noise_seed: Optional[int] = None,
- initialize: bool = True,
- ) -> None:
- self.optimizer = optimizer
- self.param_group_field = param_group_field
- self._initial_param_group_field = f"initial_{param_group_field}"
- if initialize:
- for i, group in enumerate(self.optimizer.param_groups):
- if param_group_field not in group:
- raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
- group.setdefault(self._initial_param_group_field, group[param_group_field])
- else:
- for i, group in enumerate(self.optimizer.param_groups):
- if self._initial_param_group_field not in group:
- raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
- self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
- self.metric = None # any point to having this for all?
- self.t_in_epochs = t_in_epochs
- self.noise_range_t = noise_range_t
- self.noise_pct = noise_pct
- self.noise_type = noise_type
- self.noise_std = noise_std
- self.noise_seed = noise_seed if noise_seed is not None else 42
- self.update_groups(self.base_values)
- def state_dict(self) -> Dict[str, Any]:
- return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
- def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
- self.__dict__.update(state_dict)
- @abc.abstractmethod
- def _get_lr(self, t: int) -> List[float]:
- pass
- def _get_values(self, t: int, on_epoch: bool = True) -> Optional[List[float]]:
- proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs)
- if not proceed:
- return None
- return self._get_lr(t)
- def step(self, epoch: int, metric: Optional[float] = None) -> None:
- self.metric = metric
- values = self._get_values(epoch, on_epoch=True)
- if values is not None:
- values = self._add_noise(values, epoch)
- self.update_groups(values)
- def step_update(self, num_updates: int, metric: Optional[float] = None) -> None:
- self.metric = metric
- values = self._get_values(num_updates, on_epoch=False)
- if values is not None:
- values = self._add_noise(values, num_updates)
- self.update_groups(values)
- def update_groups(self, values: Union[float, List[float]]) -> None:
- if not isinstance(values, (list, tuple)):
- values = [values] * len(self.optimizer.param_groups)
- for param_group, value in zip(self.optimizer.param_groups, values):
- if 'lr_scale' in param_group:
- param_group[self.param_group_field] = value * param_group['lr_scale']
- else:
- param_group[self.param_group_field] = value
- def _add_noise(self, lrs: List[float], t: int) -> List[float]:
- if self._is_apply_noise(t):
- noise = self._calculate_noise(t)
- lrs = [v + v * noise for v in lrs]
- return lrs
- def _is_apply_noise(self, t: int) -> bool:
- """Return True if scheduler in noise range."""
- apply_noise = False
- if self.noise_range_t is not None:
- if isinstance(self.noise_range_t, (list, tuple)):
- apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
- else:
- apply_noise = t >= self.noise_range_t
- return apply_noise
- def _calculate_noise(self, t) -> float:
- g = torch.Generator()
- g.manual_seed(self.noise_seed + t)
- if self.noise_type == 'normal':
- while True:
- # resample if noise out of percent limit, brute force but shouldn't spin much
- noise = torch.randn(1, generator=g).item()
- if abs(noise) < self.noise_pct:
- return noise
- else:
- noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
- return noise
|