| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- """ Adafactor (Big Vision variant) for PyTorch
- Adapted from the implementation in big vision: https://github.com/google-research/big_vision
- Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560
- References for added functionality:
- Cautious Optimizers: https://arxiv.org/abs/2411.16085
- Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285
- Adaptation and PyTorch modifications by Ross Wightman
- """
- from typing import List, Optional, Tuple, Union
- import torch
- from torch import Tensor
- from torch.optim import Optimizer
- from ._types import ParamsT
- def _get_scalar_dtype():
- """Get the scalar dtype that the optimizer uses for state"""
- return torch.float64
- def _factored_dims(
- shape: Tuple[int, ...],
- factored: bool,
- min_dim_size_to_factor: int
- ) -> Optional[tuple[int, int]]:
- """Whether to use a factored second moment estimator.
- This function returns a tuple with the two largest axes to reduce over.
- If no two dimensions have size >= min_dim_size_to_factor, return None.
- Args:
- shape: an input shape
- factored: whether to use factored second-moment estimator for > 2d vars.
- min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
- Returns:
- None or a tuple of ints
- """
- if not factored or len(shape) < 2:
- return None
- sorted_dims = sorted(((x, i) for i, x in enumerate(shape)))
- if shape[sorted_dims[-2][1]] < min_dim_size_to_factor:
- return None
- return int(sorted_dims[-2][1]), int(sorted_dims[-1][1])
- class AdafactorBigVision(Optimizer):
- """
- PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
- Adapted from https://github.com/google-research/big_vision by Ross Wightman
- """
- def __init__(
- self,
- params: ParamsT,
- lr: float = 1.0,
- min_dim_size_to_factor: int = 16,
- decay_rate: float = 0.8,
- decay_offset: int = 0,
- beta2_cap: float = 0.999,
- momentum: Optional[float] = 0.9,
- momentum_dtype: Union[str, torch.dtype] = torch.bfloat16,
- eps: Optional[float] = None,
- weight_decay: float = 0.0,
- clipping_threshold: Optional[float] = None,
- unscaled_wd: bool = False,
- caution: bool = False,
- corrected_weight_decay: bool = False,
- *,
- foreach: Optional[bool] = False,
- ):
- if isinstance(momentum_dtype, str):
- if momentum_dtype == 'float16':
- momentum_dtype = torch.float16
- elif momentum_dtype == 'bfloat16':
- momentum_dtype = torch.bfloat16
- else:
- assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported'
- momentum_dtype = torch.float32
- # FIXME try to check if momentum dtype is appropriate for device? Torch API not great for this.
- defaults = dict(
- lr=lr,
- min_dim_size_to_factor=min_dim_size_to_factor,
- decay_rate=decay_rate,
- decay_offset=decay_offset,
- beta2_cap=beta2_cap,
- momentum=momentum,
- momentum_dtype=momentum_dtype,
- eps=eps,
- weight_decay=weight_decay,
- clipping_threshold=clipping_threshold,
- unscaled_wd=unscaled_wd,
- caution=caution,
- corrected_weight_decay=corrected_weight_decay,
- foreach=foreach,
- )
- super().__init__(params, defaults)
- def __setstate__(self, state):
- super().__setstate__(state)
- for group in self.param_groups:
- group.setdefault('caution', False)
- group.setdefault('corrected_weight_decay', False)
- group.setdefault('foreach', None)
- for p in group['params']:
- p_state = self.state.get(p, {})
- if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
- p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
- if 'exp_avg' in p_state and torch.is_tensor(p_state['exp_avg']):
- # FIXME this is a bit of a hack, optimizer.load_state_dict appears to upcast
- # the momentum to float32 (it's half precision in the state_dict), need to
- # look into this further. Better to override _process_value_according_to_param_policy?
- p_state['exp_avg'] = p_state['exp_avg'].to(dtype=self.defaults['momentum_dtype'])
- @torch.no_grad()
- def step(self, closure=None):
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
- for group in self.param_groups:
- params_with_grad = []
- grads = []
- exp_avg_sq_rs = []
- exp_avg_sq_cs = []
- exp_avg_sqs = []
- state_steps = []
- exp_avgs = [] # For momentum
- for p in group['params']:
- if p.grad is None:
- continue
- if p.grad.is_sparse:
- raise RuntimeError("Sparse gradients not supported")
- params_with_grad.append(p)
- grads.append(p.grad)
- state = self.state[p]
- if len(state) == 0:
- # NOTE step on CPU, probably need some more though to make capturable
- state['step'] = torch.tensor(0.0, dtype=_get_scalar_dtype())
- shape = p.grad.shape
- factored_dims = _factored_dims(
- shape,
- factored=True,
- min_dim_size_to_factor=self.defaults['min_dim_size_to_factor']
- )
- if factored_dims is not None:
- dc, dr = factored_dims
- row_shape = list(p.grad.shape)
- row_shape[dr] = 1
- col_shape = list(p.grad.shape)
- col_shape[dc] = 1
- state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape)
- state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape)
- else:
- state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
- if self.defaults['momentum'] is not None:
- state['exp_avg'] = torch.zeros_like(p.grad, dtype=self.defaults['momentum_dtype'])
- state_steps.append(state['step'])
- exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None))
- exp_avg_sq_cs.append(state.get('exp_avg_sq_c', None))
- exp_avg_sqs.append(state.get('exp_avg_sq', None))
- exp_avgs.append(state.get('exp_avg', None))
- if group['foreach']:
- func = _multi_tensor_adafactor
- else:
- func = _single_tensor_adafactor
- func(
- params=params_with_grad,
- grads=grads,
- exp_avg_sq_rs=exp_avg_sq_rs,
- exp_avg_sq_cs=exp_avg_sq_cs,
- exp_avg_sqs=exp_avg_sqs,
- exp_avgs=exp_avgs,
- state_steps=state_steps,
- beta2_decay=group['decay_rate'],
- beta2_cap=group['beta2_cap'],
- min_dim_size_to_factor=group['min_dim_size_to_factor'],
- eps=group['eps'],
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- momentum=group['momentum'],
- momentum_dtype=group['momentum_dtype'],
- clipping_threshold=group['clipping_threshold'],
- unscaled_wd=group['unscaled_wd'],
- caution=group['caution'],
- max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None,
- )
- return loss
- def _single_tensor_adafactor(
- params: List[Tensor],
- grads: List[Tensor],
- exp_avg_sq_rs: List[Optional[Tensor]],
- exp_avg_sq_cs: List[Optional[Tensor]],
- exp_avg_sqs: List[Optional[Tensor]],
- exp_avgs: List[Optional[Tensor]],
- state_steps: List[Tensor],
- *,
- beta2_decay: float,
- beta2_cap: float,
- min_dim_size_to_factor: int,
- eps: float,
- lr: float,
- weight_decay: float,
- momentum: Optional[float],
- momentum_dtype: Union[str, torch.dtype],
- clipping_threshold: Optional[float],
- unscaled_wd: bool,
- caution: bool,
- max_lr: Optional[float],
- ):
- for i, param in enumerate(params):
- grad = grads[i]
- exp_avg_sq_r = exp_avg_sq_rs[i]
- exp_avg_sq_c = exp_avg_sq_cs[i]
- exp_avg_sq = exp_avg_sqs[i]
- exp_avg = exp_avgs[i]
- step_t = state_steps[i]
- if eps is None:
- # default eps for avoiding div by zero, diff from float type eps
- eps = 1e-7 if grad.dtype == torch.float16 else 1e-30
- # Update step
- step_t += 1
- beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay))
- one_minus_beta2_t = 1 - beta2_t
- grad_sqr = torch.square(grad) + eps
- # NOTE application of eps (epsilon1) mirrors the optax/big vision/t5x approach
- if exp_avg_sq is None:
- # factorized second moment
- dc, dr = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
- exp_avg_sq_r.lerp_(grad_sqr.mean(dim=dr, keepdim=True), one_minus_beta2_t)
- exp_avg_sq_c.lerp_(grad_sqr.mean(dim=dc, keepdim=True), one_minus_beta2_t)
- reduce_dc = dc - 1 if dc > dr else dc
- row_col_mean = exp_avg_sq_r.mean(dim=reduce_dc, keepdim=True)
- row_factor = (exp_avg_sq_r / row_col_mean).rsqrt()
- col_factor = exp_avg_sq_c.rsqrt()
- update = grad * row_factor * col_factor
- else:
- # non-factorized second moment
- assert exp_avg_sq_r is None and exp_avg_sq_c is None
- exp_avg_sq.lerp_(grad_sqr, one_minus_beta2_t)
- update = grad * exp_avg_sq.rsqrt()
- # Clip by RMS value
- if clipping_threshold is not None:
- denom = (update.norm(2) / ((update.numel() ** 0.5) / clipping_threshold)).clamp_(max=1.0)
- update.div_(denom)
- # Apply momentum (in different dtype)
- if momentum is not None and exp_avg is not None:
- if momentum_dtype != grad.dtype:
- exp_avg.lerp_(update.to(momentum_dtype), 1 - momentum) # ema
- update = exp_avg.to(grad.dtype)
- else:
- exp_avg.lerp_(update, 1 - momentum) # ema
- update = exp_avg.clone()
- if caution:
- # apply caution as per 'Cautious Optimizers': https://arxiv.org/abs/2411.16085
- mask = (update * grad > 0).to(grad.dtype)
- mask.div_(mask.mean().clamp_(min=1e-3))
- update.mul_(mask)
- # Scale by learning rate
- update.mul_(lr)
- # Perform weight decay
- if weight_decay != 0:
- if unscaled_wd:
- # match big vision impl, 'fully decoupled' decay w/o LR scaling
- if max_lr is None:
- param.mul_(1. - weight_decay)
- else:
- # corrected weight decay: scale by lr / max_lr
- param.mul_(1. - (lr / max_lr) * weight_decay)
- else:
- # match typical pytorch behaviour for decoupled decay, eg adamw where wd is scaled by LR
- if max_lr is None:
- param.mul_(1. - lr * weight_decay)
- else:
- # corrected weight decay: scale by lr^2 / max_lr
- param.mul_(1. - (lr ** 2 / max_lr) * weight_decay)
- # Update parameters
- param.add_(update, alpha=-1.0)
- def _multi_tensor_adafactor(
- params: List[Tensor],
- grads: List[Tensor],
- exp_avg_sq_rs: List[Optional[Tensor]],
- exp_avg_sq_cs: List[Optional[Tensor]],
- exp_avg_sqs: List[Optional[Tensor]],
- exp_avgs: List[Optional[Tensor]],
- state_steps: List[Tensor],
- *,
- beta2_decay: float,
- beta2_cap: float,
- min_dim_size_to_factor: int,
- eps: float,
- lr: float,
- weight_decay: float,
- momentum: Optional[float],
- momentum_dtype: Union[str, torch.dtype],
- clipping_threshold: Optional[float],
- unscaled_wd: bool,
- caution: bool,
- max_lr: Optional[float],
- ):
- # FIXME TODO
- assert False, 'multi-tensor fn (foreach=True) not implemented yet'
|