sgdw.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. """ SGD with decoupled weight-decay.
  2. References for added functionality:
  3. Cautious Optimizers: https://arxiv.org/abs/2411.16085
  4. Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285
  5. Hacked together by Ross Wightman
  6. """
  7. from typing import List, Optional
  8. import torch
  9. from torch import Tensor
  10. from torch.optim.optimizer import Optimizer
  11. try:
  12. from torch.optim.optimizer import _use_grad_for_differentiable, _default_to_fused_or_foreach
  13. has_recent_pt = True
  14. except ImportError:
  15. has_recent_pt = False
  16. from ._types import ParamsT
  17. __all__ = ['SGDW', 'sgdw']
  18. class SGDW(Optimizer):
  19. def __init__(
  20. self,
  21. params: ParamsT,
  22. lr: float = 1e-3,
  23. momentum: float = 0.,
  24. dampening: float = 0.,
  25. weight_decay: float = 0.,
  26. nesterov: bool = False,
  27. *,
  28. caution: bool = False,
  29. corrected_weight_decay: bool = False,
  30. maximize: bool = False,
  31. foreach: Optional[bool] = None,
  32. differentiable: bool = False,
  33. ):
  34. if lr < 0.0:
  35. raise ValueError(f"Invalid learning rate: {lr}")
  36. if momentum < 0.0:
  37. raise ValueError(f"Invalid momentum value: {momentum}")
  38. if weight_decay < 0.0:
  39. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  40. defaults = dict(
  41. lr=lr,
  42. momentum=momentum,
  43. dampening=dampening,
  44. weight_decay=weight_decay,
  45. nesterov=nesterov,
  46. caution=caution,
  47. corrected_weight_decay=corrected_weight_decay,
  48. maximize=maximize,
  49. foreach=foreach,
  50. differentiable=differentiable,
  51. )
  52. if nesterov and (momentum <= 0 or dampening != 0):
  53. raise ValueError("Nesterov momentum requires a momentum and zero dampening")
  54. super().__init__(params, defaults)
  55. def __setstate__(self, state):
  56. super().__setstate__(state)
  57. for group in self.param_groups:
  58. group.setdefault('caution', False)
  59. group.setdefault('corrected_weight_decay', False)
  60. group.setdefault('nesterov', False)
  61. group.setdefault('maximize', False)
  62. group.setdefault('foreach', None)
  63. group.setdefault('differentiable', False)
  64. def _init_group(self, group, params_with_grad, grads, momentum_buffer_list):
  65. has_sparse_grad = False
  66. for p in group['params']:
  67. if p.grad is not None:
  68. params_with_grad.append(p)
  69. grads.append(p.grad)
  70. if p.grad.is_sparse:
  71. has_sparse_grad = True
  72. state = self.state[p]
  73. if 'momentum_buffer' not in state:
  74. momentum_buffer_list.append(None)
  75. else:
  76. momentum_buffer_list.append(state['momentum_buffer'])
  77. return has_sparse_grad
  78. # FIXME figure out how to make _use_grad_for_differentiable interchangeable with no_grad decorator
  79. # without args, for backwards compatibility with old pytorch
  80. @torch.no_grad()
  81. def step(self, closure=None):
  82. """Performs a single optimization step.
  83. Args:
  84. closure (Callable, optional): A closure that reevaluates the model
  85. and returns the loss.
  86. """
  87. loss = None
  88. if closure is not None:
  89. with torch.enable_grad():
  90. loss = closure()
  91. for group in self.param_groups:
  92. params_with_grad = []
  93. grads = []
  94. momentum_buffer_list = []
  95. has_sparse_grad = self._init_group(group, params_with_grad, grads, momentum_buffer_list)
  96. sgdw(
  97. params_with_grad,
  98. grads,
  99. momentum_buffer_list,
  100. weight_decay=group['weight_decay'],
  101. momentum=group['momentum'],
  102. lr=group['lr'],
  103. dampening=group['dampening'],
  104. nesterov=group['nesterov'],
  105. caution=group['caution'],
  106. maximize=group['maximize'],
  107. has_sparse_grad=has_sparse_grad,
  108. foreach=group['foreach'],
  109. max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None,
  110. )
  111. # update momentum_buffers in state
  112. for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
  113. state = self.state[p]
  114. state['momentum_buffer'] = momentum_buffer
  115. return loss
  116. def sgdw(
  117. params: List[Tensor],
  118. grads: List[Tensor],
  119. momentum_buffer_list: List[Optional[Tensor]],
  120. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  121. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  122. has_sparse_grad: bool = None,
  123. foreach: Optional[bool] = None,
  124. *,
  125. weight_decay: float,
  126. momentum: float,
  127. lr: float,
  128. dampening: float,
  129. nesterov: bool,
  130. caution: bool,
  131. maximize: bool,
  132. max_lr: Optional[float] = None
  133. ):
  134. r"""Functional API that performs SGD algorithm computation.
  135. See :class:`~torch.optim.SGD` for details.
  136. """
  137. if has_recent_pt and hasattr(Optimizer, '_group_tensors_by_device_and_dtype'):
  138. if foreach is None:
  139. # why must we be explicit about an if statement for torch.jit.is_scripting here?
  140. # because JIT can't handle Optionals nor fancy conditionals when scripting
  141. if not torch.jit.is_scripting():
  142. _, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False)
  143. else:
  144. foreach = False
  145. if foreach and torch.jit.is_scripting():
  146. raise RuntimeError('torch.jit.script not supported with foreach optimizers')
  147. else:
  148. foreach = False # disabling altogether for older pytorch, as using _group_tensors_by_device_and_dtype
  149. if foreach and not torch.jit.is_scripting():
  150. func = _multi_tensor_sgdw
  151. else:
  152. func = _single_tensor_sgdw
  153. func(
  154. params,
  155. grads,
  156. momentum_buffer_list,
  157. weight_decay=weight_decay,
  158. momentum=momentum,
  159. lr=lr,
  160. dampening=dampening,
  161. nesterov=nesterov,
  162. caution=caution,
  163. has_sparse_grad=has_sparse_grad,
  164. maximize=maximize,
  165. max_lr=max_lr,
  166. )
  167. def _single_tensor_sgdw(
  168. params: List[Tensor],
  169. grads: List[Tensor],
  170. momentum_buffer_list: List[Optional[Tensor]],
  171. *,
  172. weight_decay: float,
  173. momentum: float,
  174. lr: float,
  175. dampening: float,
  176. nesterov: bool,
  177. caution: bool,
  178. maximize: bool,
  179. has_sparse_grad: bool,
  180. max_lr: Optional[float]
  181. ):
  182. for i, param in enumerate(params):
  183. grad = grads[i] if not maximize else -grads[i]
  184. wd_scale = lr if max_lr is None else lr ** 2 / max_lr
  185. param.mul_(1. - wd_scale * weight_decay)
  186. if momentum != 0:
  187. buf = momentum_buffer_list[i]
  188. if buf is None:
  189. buf = torch.clone(grad).detach()
  190. momentum_buffer_list[i] = buf
  191. else:
  192. buf.mul_(momentum).add_(grad, alpha=1 - dampening)
  193. if caution:
  194. if nesterov:
  195. buf = grad.add(buf, alpha=momentum)
  196. # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
  197. mask = (buf * grad > 0).to(grad.dtype)
  198. mask.div_(mask.mean().clamp_(min=1e-3))
  199. grad = buf * mask
  200. else:
  201. if nesterov:
  202. grad = grad.add(buf, alpha=momentum)
  203. else:
  204. grad = buf
  205. param.add_(grad, alpha=-lr)
  206. def _multi_tensor_sgdw(
  207. params: List[Tensor],
  208. grads: List[Tensor],
  209. momentum_buffer_list: List[Optional[Tensor]],
  210. *,
  211. weight_decay: float,
  212. momentum: float,
  213. lr: float,
  214. dampening: float,
  215. nesterov: bool,
  216. caution: bool,
  217. maximize: bool,
  218. has_sparse_grad: bool,
  219. max_lr: Optional[float]
  220. ):
  221. if len(params) == 0:
  222. return
  223. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  224. [params, grads, momentum_buffer_list], with_indices=True)
  225. for ((device_params, device_grads, device_momentum_buffer_list), indices) in grouped_tensors.values():
  226. device_has_sparse_grad = has_sparse_grad and any(grad.is_sparse for grad in device_grads)
  227. if maximize:
  228. device_grads = torch._foreach_neg(device_grads)
  229. wd_scale = lr if max_lr is None else lr ** 2 / max_lr
  230. torch._foreach_mul_(params, 1. - wd_scale * weight_decay)
  231. if momentum != 0:
  232. bufs = []
  233. all_states_with_momentum_buffer = True
  234. for i in range(len(device_momentum_buffer_list)):
  235. if device_momentum_buffer_list[i] is None:
  236. all_states_with_momentum_buffer = False
  237. break
  238. else:
  239. bufs.append(device_momentum_buffer_list[i])
  240. if all_states_with_momentum_buffer:
  241. torch._foreach_mul_(bufs, momentum)
  242. torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
  243. else:
  244. bufs = []
  245. for i in range(len(device_momentum_buffer_list)):
  246. if device_momentum_buffer_list[i] is None:
  247. buf = device_momentum_buffer_list[i] = momentum_buffer_list[indices[i]] = \
  248. torch.clone(device_grads[i]).detach()
  249. else:
  250. buf = device_momentum_buffer_list[i]
  251. buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
  252. bufs.append(buf)
  253. if caution:
  254. if nesterov:
  255. # Can't do nesterov in-place if we want to compare against orig grad for caution
  256. bufs = torch._foreach_add(device_grads, bufs, alpha=momentum)
  257. # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
  258. masks = torch._foreach_mul(bufs, device_grads)
  259. masks = [(m > 0).to(g.dtype) for m, g in zip(masks, device_grads)]
  260. mask_scale = [m.mean() for m in masks]
  261. torch._foreach_maximum_(mask_scale, 1e-3)
  262. torch._foreach_div_(masks, mask_scale)
  263. device_grads = torch._foreach_mul(bufs, masks)
  264. else:
  265. if nesterov:
  266. torch._foreach_add_(device_grads, bufs, alpha=momentum)
  267. else:
  268. device_grads = bufs
  269. if not device_has_sparse_grad:
  270. torch._foreach_add_(device_params, device_grads, alpha=-lr)
  271. else:
  272. # foreach APIs don't support sparse
  273. for i in range(len(device_params)):
  274. device_params[i].add_(device_grads[i], alpha=-lr)